In [1]:
from copy import deepcopy
from typing import List
import sys
sys.path.append('../../tuned-lens')

import accelerate
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader, Subset
from mamba_ssm.ops.triton.layernorm import rms_norm_fn
from tqdm import tqdm
from tuned_lens.nn.lenses import Lens, LogitLens, TunedLens
from tuned_lens.scripts.ingredients import Model
import wandb
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    PretrainedConfig,
    PreTrainedModel,
    PreTrainedTokenizerBase,
    get_linear_schedule_with_warmup,
)
from tuned_lens.scripts.ingredients import Model
from tuned_lens.model_patches.mamba_model import PreMambaConfig, MambaModel, MambaTokenizer


In [2]:
AutoConfig.register("mamba", PreMambaConfig)
AutoModelForCausalLM.register(PreMambaConfig, MambaModel)
AutoTokenizer.register(PreMambaConfig, MambaTokenizer)
name = "state-spaces/mamba-130m"
model = AutoModelForCausalLM.from_pretrained(  # type: ignore
                name,
            )
model = model.to("cuda")
tokenizer =  AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
tuned_lens=TunedLens.from_model_and_pretrained(model,"../../tuned-lens/my_lens/mamba/130m").to("cuda")
logit_lens=LogitLens.from_model(model)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
from tuned_lens.plotting import PredictionTrajectory
import ipywidgets as widgets
from plotly import graph_objects as go


def make_plot(lens, text, layer_stride, statistic, token_range):
    input_ids = tokenizer.encode(text)
    targets = input_ids[1:] + [tokenizer.eos_token_id]

    if len(input_ids) == 0:
        return widgets.Text("Please enter some text.")
    
    if (token_range[0] == token_range[1]):
        return widgets.Text("Please provide valid token range.")
    pred_traj = PredictionTrajectory.from_lens_and_model(
        lens=lens,
        model=model,
        input_ids=input_ids,
        tokenizer=tokenizer,
        targets=targets,
    ).slice_sequence(slice(*token_range))

    return getattr(pred_traj, statistic)().stride(layer_stride).figure(
        title=f"{lens.__class__.__name__} ({model.name_or_path}) {statistic}",
    )

style = {'description_width': 'initial'}
statistic_wdg = widgets.Dropdown(
    options=[
        ('Entropy', 'entropy'),
        ('Cross Entropy', 'cross_entropy'),
        ('Forward KL', 'forward_kl'),
    ],
    description='Select Statistic:',
    style=style,
)
text_wdg = widgets.Textarea(
    description="Input Text",
    value="I start like this and end like that",
)
lens_wdg = widgets.Dropdown(
    options=[('Tuned Lens', tuned_lens), ('Logit Lens', logit_lens)],
    description='Select Lens:',
    style=style,
)

layer_stride_wdg = widgets.BoundedIntText(
    value=2,
    min=1,
    max=10,
    step=1,
    description='Layer Stride:',
    disabled=False
)

token_range_wdg = widgets.IntRangeSlider(
    description='Token Range',
    min=0,
    max=1,
    step=1,
    style=style,
)


def update_token_range(*args):
    token_range_wdg.max = len(tokenizer.encode(text_wdg.value))



In [7]:
update_token_range()

token_range_wdg.value = [0, token_range_wdg.max]
text_wdg.observe(update_token_range, 'value')

interact = widgets.interact.options(manual_name='Run Lens', manual=True)

plot = interact(
    make_plot,
    text=text_wdg,
    statistic=statistic_wdg,
    lens=lens_wdg,
    layer_stride=layer_stride_wdg,
    token_range=token_range_wdg,
)

interactive(children=(Dropdown(description='Select Lens:', options=(('Tuned Lens', TunedLens(
  (unembed): Une…

In [7]:
fig = make_plot(tuned_lens, "I start like this and end like that", 2, "entropy", [0,8])

In [8]:
fig.write_image("tuned_lens.svg")