In [None]:
#!pip install tuned-lens

In [None]:
#!pip uninstall -y tuned-lens

In [None]:

!pip install plotly
!pip install ipywidgets
!pip install nbformat

In [None]:
import torch
import sys
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
import mamba_ssm
#sys.path.append('tuned-lens/') <- local tuned-lens with the required changes
from tuned_lens.nn.lenses import TunedLens, LogitLens
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoTokenizer, AutoModelForCausalLM
from tuned_lens.scripts import train_loop
from tuned_lens.scripts import ingredients
from pathlib import Path




In [None]:
from transformers.configuration_utils import PretrainedConfig
from transformers import PreTrainedModel
from collections import namedtuple
from dataclasses import dataclass, field


#this is so ugly i want to cry
@dataclass
class mambaConfig:

    d_model: int = 768
    n_layer: int = 24
    vocab_size: int = 50280
    ssm_cfg: dict = None
    rms_norm: bool = True
    residual_in_fp32: bool = True
    fused_add_norm: bool = True
    pad_vocab_size_multiple: int = 8


class MambaConfig(PretrainedConfig):
    model_type = "mamba"
    attribute_map = {"max_position_embeddings": "context_length"}

    def __init__(
        self,
        dmodel=768,
        vocab_size=50280,
        n_layer=24,
        **kwargs,
    ):
        self.vocab_size = vocab_size
        self.hidden_size = dmodel
        self.num_hidden_layers = n_layer
        
        super().__init__(**kwargs)

activations = {}
class MambaModel(PreTrainedModel):
    config_class = MambaConfig
    base_model_prefix = "model"
    name_or_path = "mamba"
    name = "mamba-130m"
    def activation_hook(self, module,input, output):
        if len(output)>1:
            output = output[0]
        activations[module] = output
    
    def __init__(self, config: MambaConfig):
        super().__init__(config)
        self.tokenizer=tokenizer
        self.dmodel = config.hidden_size
        self.vocab_size = config.vocab_size
        self.num_hidden_layers = config.num_hidden_layers   
        self.model = MambaLMHeadModel(mambaConfig)
        self.model.name_or_path = "mamba"
    def load_state_dict(self, state_dict, strict=False):
        self.model.load_state_dict(state_dict, strict=strict)

    def hook_intermediate(self):
        activation_hook = self.activation_hook
        self.model.backbone.embedding.register_forward_hook(activation_hook)
        for layer in self.model.backbone.layers:
            layer.register_forward_hook(activation_hook)
    def forward(self, input_ids, output_hidden_states=True):
        activations.clear()
        if output_hidden_states==True:
            self.hook_intermediate()
        outputs = self.model(input_ids).logits
        hidden_states=[]
        for layer in activations.keys():
            hidden_states.append(activations[layer])
        hidden_states=hidden_states
        CausalLMOutput = namedtuple("CausalLMOutput", ["logits", "hidden_states"])
        return CausalLMOutput(logits=outputs, hidden_states=hidden_states)

    def get_output_embeddings(self):
        return self.model.lm_head

    def load(self,device):
        return self.to(device), self.tokenizer


In [None]:
device = torch.device('cuda')

model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-130m", device="cuda", dtype=torch.float16)

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
config = MambaConfig(dmodel=768, vocab_size=50280, n_layer=24)
mamba= MambaModel(config)
mamba.load_state_dict(model.state_dict())


In [None]:
train_data = ingredients.Data(["roneneldan/TinyStories"])
optimizer = ingredients.Optimizer()
distributer = ingredients.Distributed(per_gpu_batch_size=1)
p = Path("next")
loss=train_loop.LossChoice.KL
train_data.split = "train"
train_data.text_column="text"
train = train_loop.Train(mamba,train_data,optimizer,distributer,p,wandb="Lens",loss=loss)


In [None]:
train.execute()

In [None]:

tuned_lens = TunedLens.from_model(mamba).cuda()

state = torch.load("future/params.pt")
tuned_lens.layer_translators.load_state_dict(state)
logit_lens = LogitLens.from_model(mamba).cuda()

In [None]:
from tuned_lens.plotting import PredictionTrajectory
import ipywidgets as widgets
from plotly import graph_objects as go


def make_plot(lens, text, layer_stride, statistic, token_range):
    input_ids = tokenizer.encode(text)
    targets = input_ids[1:] + [tokenizer.eos_token_id]

    if len(input_ids) == 0:
        return widgets.Text("Please enter some text.")
    
    if (token_range[0] == token_range[1]):
        return widgets.Text("Please provide valid token range.")
    pred_traj = PredictionTrajectory.from_lens_and_model(
        lens=lens,
        model=mamba,
        input_ids=input_ids,
        tokenizer=tokenizer,
        targets=targets,
    ).slice_sequence(slice(*token_range))

    return getattr(pred_traj, statistic)().stride(layer_stride).figure(
        title=f"{lens.__class__.__name__} ({mamba.name_or_path}) {statistic}",
    )

style = {'description_width': 'initial'}
statistic_wdg = widgets.Dropdown(
    options=[
        ('Entropy', 'entropy'),
        ('Cross Entropy', 'cross_entropy'),
        ('Forward KL', 'forward_kl'),
    ],
    description='Select Statistic:',
    style=style,
)
text_wdg = widgets.Textarea(
    description="Input Text",
    value="it was the best of times, it was the worst of times",
)
lens_wdg = widgets.Dropdown(
    options=[('Tuned Lens', tuned_lens), ('Logit Lens', logit_lens)],
    description='Select Lens:',
    style=style,
)

layer_stride_wdg = widgets.BoundedIntText(
    value=2,
    min=1,
    max=10,
    step=1,
    description='Layer Stride:',
    disabled=False
)

token_range_wdg = widgets.IntRangeSlider(
    description='Token Range',
    min=0,
    max=1,
    step=1,
    style=style,
)


def update_token_range(*args):
    token_range_wdg.max = len(tokenizer.encode(text_wdg.value))

update_token_range()

token_range_wdg.value = [0, token_range_wdg.max]
text_wdg.observe(update_token_range, 'value')

interact = widgets.interact.options(manual_name='Run Lens', manual=True)

plot = interact(
    make_plot,
    text=text_wdg,
    statistic=statistic_wdg,
    lens=lens_wdg,
    layer_stride=layer_stride_wdg,
    token_range=token_range_wdg,
)