# The tuned lens 🔎
A tuned lens allows us to peak at the iterative computations that a transformer is using the compute the next token.

A lens into a transformer with n layers allows you to replace the last $m$ layers of the model with an [affine transformation](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) (we call these affine translators).

This essentially skips over these last few layers and lets you see the best prediction that can be made from the model's representations, i.e. the residual stream, at layer $n - m$. Since the representations may be rotated, shifted, or stretched from layer to layer it's useful to train the len's affine translators specifically on each layer. This training is what differentiates this method from simpler approaches that decode the residual stream of the network directly using the unembeding layer i.e. the logit lens. We explain this process along with more applications of the method in [the paper](ttps://arxiv.org/abs/2303.08112).

You can find the complete set of pretrained lenses on [the hugging face space](https://huggingface.co/spaces/AlignmentResearch/tuned-lens/tree/main/lens).

## Usage
Since the tuned lens produces a distribution of predictions to visualize it's output we need to we need to provide a summary statistic to plot.  The default is simply [entropy](https://en.wikipedia.org/wiki/Entropy_(information_theory)), but you can also choose the [cross entropy](https://en.wikipedia.org/wiki/Cross_entropy) with the target token, or the [KL divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) between the model's predictions and the tuned lens' predictions. You can also hover over a token to see more of the distribution i.e. the top 10 most probable tokens and their probabilities.

## Examples
Some interesting examples you can try.

### Copy paste:
```
Copy: A!2j!#u&NGApS&MkkHe8Gm!#
Paste: A!2j!#u&NGApS&MkkHe8Gm!#
```

### Trivial in-context learning
```
inc 1 2
inc 4 5
inc 13 
```

#### Addition
```
add 1 1 2
add 3 4 7
add 13 2 
```

In [1]:
# !pip install tuned-lens==0.0.3


In [2]:
import os

import torch
from tuned_lens.nn.lenses import TunedLens, LogitLens
from transformers import AutoModelForCausalLM, AutoTokenizer

windows = False
try:
  from google.colab import output
  output.enable_custom_widget_manager()
except:
  if os.name == 'nt':
    windows = True

bfloat16_available = torch.cuda.get_device_capability()[0] >= 8
floatDtype = torch.bfloat16 if bfloat16_available else torch.float16

In [3]:
device = torch.device('cuda')
loadLens = True
modelName = 'EleutherAI/pythia-1b'
LensName = "outputs\\test-1b\\"
# modelName = 'sshleifer/tiny-gpt2'
# To try a diffrent modle / lens check if the lens is avalible then modify this code
model = AutoModelForCausalLM.from_pretrained(modelName, low_cpu_mem_usage=True, torch_dtype=floatDtype)
model.eval()
model.requires_grad_(False)
model = model.to(device)

tokenizer = AutoTokenizer.from_pretrained(modelName)
if loadLens:
    tuned_lens = TunedLens.load(LensName, map_location=device)
else:
    tuned_lens = TunedLens(model)
tuned_lens = tuned_lens.to(device)
# logit_lens = LogitLens(model)
logit_lens = None

In [None]:
from torchinfo import summary
print(summary(model))
print(summary(tuned_lens))

Layer (type:depth-idx)                             Param #
GPTNeoXForCausalLM                                 --
├─GPTNeoXModel: 1-1                                --
│    └─Embedding: 2-1                              (103,022,592)
│    └─ModuleList: 2-2                             --
│    │    └─GPTNeoXLayer: 3-1                      (50,358,272)
│    │    └─GPTNeoXLayer: 3-2                      (50,358,272)
│    │    └─GPTNeoXLayer: 3-3                      (50,358,272)
│    │    └─GPTNeoXLayer: 3-4                      (50,358,272)
│    │    └─GPTNeoXLayer: 3-5                      (50,358,272)
│    │    └─GPTNeoXLayer: 3-6                      (50,358,272)
│    │    └─GPTNeoXLayer: 3-7                      (50,358,272)
│    │    └─GPTNeoXLayer: 3-8                      (50,358,272)
│    │    └─GPTNeoXLayer: 3-9                      (50,358,272)
│    │    └─GPTNeoXLayer: 3-10                     (50,358,272)
│    │    └─GPTNeoXLayer: 3-11                     (50,358,272)
│    │    

In [None]:
from tuned_lens.plotting import plot_lens, get_lens_stream
import ipywidgets as widgets
from plotly import graph_objects as go


def make_plot(lens, text, layer_stride, statistic, token_range):
    input_ids = tokenizer.encode(text, return_tensors="pt")

    if len(input_ids[0]) == 0:
        return widgets.Text("Please enter some text.")
    
    if (token_range[0] == token_range[1]):
        return widgets.Text("Please provide valid token range.")

    return go.FigureWidget(
        plot_lens(
            model,
            tokenizer,
            lens,
            layer_stride=layer_stride,
            input_ids=input_ids,
            start_pos=token_range[0],
            end_pos=token_range[1] if token_range[1] > 0 else None,
            statistic=statistic,
        )
    )

style = {'description_width': 'initial'}
statistic_wdg = widgets.Dropdown(
    options=[
        ('Entropy', 'entropy'),
        ('Cross Entropy', 'ce'),
        ('Forward KL', 'forward_kl'),
    ],
    description='Select Statistic:',
    style=style,
)
text_wdg = widgets.Textarea(
    description="Input Text",
    value="it was the best of times, it was the worst of times",
)
lens_wdg = widgets.Dropdown(
    options=[('Tuned Lens', tuned_lens), ('Logit Lens', logit_lens)],
    description='Select Lens:',
    style=style,
)

layer_stride_wdg = widgets.BoundedIntText(
    value=2,
    min=1,
    max=model.config.num_hidden_layers//2,
    step=1,
    description='Layer Stride:',
    disabled=False
)

token_range_wdg = widgets.IntRangeSlider(
    description='Token Range',
    min=0,
    max=1,
    step=1,
    style=style,
)


def update_token_range(*args):
    token_range_wdg.max = len(tokenizer.encode(text_wdg.value))

update_token_range()

token_range_wdg.value = [0, token_range_wdg.max]
text_wdg.observe(update_token_range, 'value')

interact = widgets.interact.options(manual_name='Run Lens', manual=True)

with torch.autocast("cuda", dtype=floatDtype):
    plot = interact(
        make_plot,
        text=text_wdg,
        statistic=statistic_wdg,
        lens=lens_wdg,
        layer_stride=layer_stride_wdg,
        token_range=token_range_wdg,
    )

interactive(children=(Dropdown(description='Select Lens:', options=(('Tuned Lens', TunedLens(
  (extra_layers)…

In [None]:
import json
from torch.utils.data import DataLoader
from tuned_lens.scripts.train_loop_latent import RelatedCollator
from tuned_lens.scripts.lens_latent import RelatedDataset
# open related json to gain access to dataset
dataDir = "datasets/"
with open(dataDir +"related.json", "r") as f: 
    relatedJson = json.load(f)
        
dataset = RelatedDataset(relatedJson, tokenizer)
collator = RelatedCollator(tokenizer, dataset.pad_to_multiple_of)
dl = DataLoader(dataset, batch_size=1, collate_fn=collator, shuffle=True)

In [None]:
for prompt, response, related in dl:

    hidden_lps, responseOutput = get_lens_stream(
                model,
                tokenizer,
                tuned_lens,
                layer_stride=1,
                input_ids=prompt["input_ids"],
                input_att_mask=prompt["attention_mask"],
                response_ids=response["input_ids"],
                response_att_mask=response["attention_mask"],
                related_ids=related["input_ids"],
                related_att_mask=related["attention_mask"],
                start_pos=0,
                end_pos=None,
                statistic="ce",
            )
    print(hidden_lps)
    print(responseOutput)
    break

torch.Size([7])
torch.Size([1, 7])


TypeError: where() received an invalid combination of arguments - got (bool, int, int), but expected one of:
 * (Tensor condition)
 * (Tensor condition, Tensor input, Tensor other, *, Tensor out)
 * (Tensor condition, Number self, Tensor other)
      didn't match because some of the arguments have invalid types: (!bool!, !int!, !int!)
 * (Tensor condition, Tensor input, Number other)
      didn't match because some of the arguments have invalid types: (!bool!, !int!, !int!)
 * (Tensor condition, Number self, Number other)
      didn't match because some of the arguments have invalid types: (!bool!, !int!, !int!)


In [None]:
# tuned-lens train latent gpt2 related --loss ce --token-shift 0 --output outputs/test
# tuned-lens train latent EleutherAI/pythia-1b related --loss ce --token-shift 0 --output outputs/test-1b-long --tokens-per-step 8192 --wandb latent-lens-long
# tuned-lens train latent EleutherAI/pythia-6.9b related --loss ce --token-shift 0 --output outputs/test --tokens-per-step 1024 --wandb latent-lens-test