# The tuned lens 🔎
A tuned lens allows us to peak at the iterative computations that a transformer is using the compute the next token.

A lens into a transformer with n layers allows you to replace the last $m$ layers of the model with an [affine transformation](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) (we call these affine translators).

This essentially skips over these last few layers and lets you see the best prediction that can be made from the model's representations, i.e. the residual stream, at layer $n - m$. Since the representations may be rotated, shifted, or stretched from layer to layer it's useful to train the len's affine translators specifically on each layer. This training is what differentiates this method from simpler approaches that decode the residual stream of the network directly using the unembedding layer i.e. the logit lens. We explain this process along with more applications of the method in [the paper](ttps://arxiv.org/abs/2303.08112).

You can find the complete set of pretrained lenses on [the hugging face space](https://huggingface.co/spaces/AlignmentResearch/tuned-lens/tree/main/lens).

## Usage
Since the tuned lens produces a distribution of predictions to visualize it's output we need to we need to provide a summary statistic to plot.  The default is simply [entropy](https://en.wikipedia.org/wiki/Entropy_(information_theory)), but you can also choose the [cross entropy](https://en.wikipedia.org/wiki/Cross_entropy) with the target token, or the [KL divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) between the model's predictions and the tuned lens' predictions. You can also hover over a token to see more of the distribution i.e. the top 10 most probable tokens and their probabilities.

## Examples
Some interesting examples you can try.

### Copy paste:
```
Copy: A!2j!#u&NGApS&MkkHe8Gm!#
Paste: A!2j!#u&NGApS&MkkHe8Gm!#
```

### Trivial in-context learning
```
inc 1 2
inc 4 5
inc 13 
```

#### Addition
```
add 1 1 2
add 3 4 7
add 13 2 
```

In [None]:
#!pip install tuned-lens

In [None]:
#!pip uninstall -y tuned-lens

In [10]:

!pip install plotly
!pip install ipywidgets
!pip install nbformat

Collecting nbformat
  Downloading nbformat-5.9.2-py3-none-any.whl.metadata (3.4 kB)
Collecting fastjsonschema (from nbformat)
  Downloading fastjsonschema-2.19.0-py3-none-any.whl.metadata (2.0 kB)
Collecting jsonschema>=2.6 (from nbformat)
  Downloading jsonschema-4.20.0-py3-none-any.whl.metadata (8.1 kB)
Collecting jsonschema-specifications>=2023.03.6 (from jsonschema>=2.6->nbformat)
  Downloading jsonschema_specifications-2023.11.2-py3-none-any.whl.metadata (3.0 kB)
Collecting referencing>=0.28.4 (from jsonschema>=2.6->nbformat)
  Downloading referencing-0.32.0-py3-none-any.whl.metadata (2.7 kB)
Collecting rpds-py>=0.7.1 (from jsonschema>=2.6->nbformat)
  Downloading rpds_py-0.15.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.1 kB)
Downloading nbformat-5.9.2-py3-none-any.whl (77 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m77.6/77.6 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jsonschema-4.20.0-py3-none-any.whl (84 kB)

In [3]:
import torch
import sys
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
import mamba_ssm
sys.path.append('/scratch/gpaulo/workspace/tuned-lens/')
from tuned_lens.nn.lenses import TunedLens, LogitLens
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoTokenizer, AutoModelForCausalLM
from tuned_lens.scripts import train_loop
from tuned_lens.scripts import ingredients
from pathlib import Path




Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [7]:
from transformers.configuration_utils import PretrainedConfig
from transformers import PreTrainedModel
from collections import namedtuple
from dataclasses import dataclass, field



@dataclass
class mambaConfig:

    d_model: int = 768
    n_layer: int = 24
    vocab_size: int = 50280
    ssm_cfg: dict = None
    rms_norm: bool = True
    residual_in_fp32: bool = True
    fused_add_norm: bool = True
    pad_vocab_size_multiple: int = 8



class MambaConfig(PretrainedConfig):
    model_type = "mamba"
    attribute_map = {"max_position_embeddings": "context_length"}

    def __init__(
        self,
        dmodel=768,
        vocab_size=50280,
        n_layer=24,
        **kwargs,
    ):
        self.vocab_size = vocab_size
        self.hidden_size = dmodel
        self.num_hidden_layers = n_layer
        
        super().__init__(**kwargs)

activations = {}
class MambaModel(PreTrainedModel):
    config_class = MambaConfig
    base_model_prefix = "model"
    name_or_path = "mamba"
    name = "mamba-130m"
    def activation_hook(self, module,input, output):
        if len(output)>1:
            output = output[0]
        activations[module] = output
    
    def __init__(self, config: MambaConfig):
        super().__init__(config)
        self.tokenizer=tokenizer
        self.dmodel = config.hidden_size
        self.vocab_size = config.vocab_size
        self.num_hidden_layers = config.num_hidden_layers   
        self.model = MambaLMHeadModel(mambaConfig)
        self.model.name_or_path = "mamba"
    def load_state_dict(self, state_dict, strict=False):
        self.model.load_state_dict(state_dict, strict=strict)

    def hook_intermediate(self):
        activation_hook = self.activation_hook
        self.model.backbone.embedding.register_forward_hook(activation_hook)
        for layer in self.model.backbone.layers:
            layer.register_forward_hook(activation_hook)
    def forward(self, input_ids, output_hidden_states=True):
        activations.clear()
        if output_hidden_states==True:
            self.hook_intermediate()
        outputs = self.model(input_ids).logits
        hidden_states=[]
        for layer in activations.keys():
            hidden_states.append(activations[layer])
        hidden_states=hidden_states
        CausalLMOutput = namedtuple("CausalLMOutput", ["logits", "hidden_states"])
        return CausalLMOutput(logits=outputs, hidden_states=hidden_states)

    def get_output_embeddings(self):
        return self.model.lm_head

    def load(self,device):
        return self.to(device), self.tokenizer


In [8]:
device = torch.device('cuda')

model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-130m", device="cuda", dtype=torch.float16)

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
config = MambaConfig(dmodel=768, vocab_size=50280, n_layer=24)
mamba= MambaModel(config)
mamba.load_state_dict(model.state_dict())


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [37]:
train_data = ingredients.Data(["roneneldan/TinyStories"])
optimizer = ingredients.Optimizer()
distributer = ingredients.Distributed(per_gpu_batch_size=1)
p = Path("next")
loss=train_loop.LossChoice.CE
train_data.split = "train"
train_data.text_column="text"
train = train_loop.Train(mamba,train_data,optimizer,distributer,p,wandb="Lens",token_shift=1,loss=loss)


In [38]:
train.execute()


Repo card metadata block was not found. Setting CardData to empty.

No checkpoint directory found. Snapshotting is disabled.




VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
bias_norm/0.ffn,▁▂▂▂▃▃▃▄▄▄▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇███████████████
bias_norm/1.ffn,▁▁▂▂▂▃▃▃▄▄▄▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇█████████████
bias_norm/10.ffn,▁▁▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇████████
bias_norm/11.ffn,▁▁▁▂▂▂▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇██████████
bias_norm/12.ffn,▁▁▁▂▂▂▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇█████████
bias_norm/13.ffn,▁▁▁▂▂▂▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇███████████
bias_norm/14.ffn,▁▁▂▂▂▃▃▃▄▄▄▄▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇████████████
bias_norm/15.ffn,▁▁▁▂▂▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇██████████
bias_norm/16.ffn,▁▁▁▂▂▂▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇█████████
bias_norm/17.ffn,▁▁▂▂▂▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇█████████

0,1
bias_norm/0.ffn,1.28184
bias_norm/1.ffn,3.26128
bias_norm/10.ffn,2.11979
bias_norm/11.ffn,1.37605
bias_norm/12.ffn,1.56379
bias_norm/13.ffn,0.91041
bias_norm/14.ffn,0.69378
bias_norm/15.ffn,0.85864
bias_norm/16.ffn,0.97149
bias_norm/17.ffn,1.01607


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112457343066732, max=1.0…

Training:   0%|          | 0/32000 [00:00<?, ?it/s]

In [35]:

tuned_lens = TunedLens.from_model(mamba).cuda()

state = torch.load("future/params.pt")
tuned_lens.layer_translators.load_state_dict(state)
logit_lens = LogitLens.from_model(mamba).cuda()

In [36]:
from tuned_lens.plotting import PredictionTrajectory
import ipywidgets as widgets
from plotly import graph_objects as go


def make_plot(lens, text, layer_stride, statistic, token_range):
    input_ids = tokenizer.encode(text)
    targets = input_ids[1:] + [tokenizer.eos_token_id]

    if len(input_ids) == 0:
        return widgets.Text("Please enter some text.")
    
    if (token_range[0] == token_range[1]):
        return widgets.Text("Please provide valid token range.")
    pred_traj = PredictionTrajectory.from_lens_and_model(
        lens=lens,
        model=mamba,
        input_ids=input_ids,
        tokenizer=tokenizer,
        targets=targets,
    ).slice_sequence(slice(*token_range))

    return getattr(pred_traj, statistic)().stride(layer_stride).figure(
        title=f"{lens.__class__.__name__} ({mamba.name_or_path}) {statistic}",
    )

style = {'description_width': 'initial'}
statistic_wdg = widgets.Dropdown(
    options=[
        ('Entropy', 'entropy'),
        ('Cross Entropy', 'cross_entropy'),
        ('Forward KL', 'forward_kl'),
    ],
    description='Select Statistic:',
    style=style,
)
text_wdg = widgets.Textarea(
    description="Input Text",
    value="it was the best of times, it was the worst of times",
)
lens_wdg = widgets.Dropdown(
    options=[('Tuned Lens', tuned_lens), ('Logit Lens', logit_lens)],
    description='Select Lens:',
    style=style,
)

layer_stride_wdg = widgets.BoundedIntText(
    value=2,
    min=1,
    max=10,
    step=1,
    description='Layer Stride:',
    disabled=False
)

token_range_wdg = widgets.IntRangeSlider(
    description='Token Range',
    min=0,
    max=1,
    step=1,
    style=style,
)


def update_token_range(*args):
    token_range_wdg.max = len(tokenizer.encode(text_wdg.value))

update_token_range()

token_range_wdg.value = [0, token_range_wdg.max]
text_wdg.observe(update_token_range, 'value')

interact = widgets.interact.options(manual_name='Run Lens', manual=True)

plot = interact(
    make_plot,
    text=text_wdg,
    statistic=statistic_wdg,
    lens=lens_wdg,
    layer_stride=layer_stride_wdg,
    token_range=token_range_wdg,
)

interactive(children=(Dropdown(description='Select Lens:', options=(('Tuned Lens', TunedLens(
  (unembed): Une…