# The tuned lens 🔎
A tuned lens allows us to peak at the iterative computations that a transformer is using the compute the next token.

A lens into a transformer with n layers allows you to replace the last $m$ layers of the model with an [affine transformation](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) (we call these affine adapters).

This essentially skips over these last few layers and lets you see the best prediction that can be made from the model's representations, i.e. the residual stream, at layer $n - m$. Since the representations may be rotated, shifted, or stretched from layer to layer it's useful to train the len's affine adapters specifically on each layer. This training is what differentiates this method from simpler approaches that decode the residual stream of the network directly using the unembeding layer i.e. the logit lens. We explain this process in more detail in a forthcoming paper.

In [1]:
from tuned_lens.nn.lenses import TunedLens, LogitLens
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained('gpt2')
tokenizer = AutoTokenizer.from_pretrained('gpt2')
tuned_lens = TunedLens.load("gpt2")
logit_lens = LogitLens(model)

TunedLens.load: ignoring config key 'sublayers'
TunedLens.load: ignoring config key 'orthogonal'
TunedLens.load: ignoring config key 'rank'
TunedLens.load: ignoring config key 'include_final'
TunedLens.load: ignoring config key 'identity_init'


In [3]:
from tuned_lens.plotting import plot_lens
import ipywidgets as widgets
from plotly import graph_objects as go



def make_plot(lens, text, statistic, token_range):
    input_ids = tokenizer.encode(text, return_tensors="pt")

    if len(input_ids[0]) == 0:
        return widgets.Text("Please enter some text.")

    return go.FigureWidget(
        plot_lens(
            model,
            tokenizer,
            lens,
            input_ids=input_ids,
            start_pos=token_range[0],
            end_pos=token_range[1] if token_range[1] > 0 else None,
            statistic=statistic,
        )
    )

style = {'description_width': 'initial'}
statistic_wdg = widgets.Dropdown(
    options=[
        ('Entropy', 'entropy'),
        ('Cross Entropy', 'ce'), 
        ('Forward KL', 'forward_kl'),
    ],
    description='Select Statistic:',
    style=style,
)
text_wdg = widgets.Textarea(
    description="Input Text",
    value="We propose a new simple network architecture, the Transformer",
)
lens_wdg = widgets.Dropdown(
    options=[('Tuned Lens', tuned_lens), ('Logit Lens', logit_lens)],
    description='Select Lens:',
    style=style,
)

token_range_wdg = widgets.IntRangeSlider(
    description='Token Range',
    min=0, 
    max=1, 
    step=1, 
    style=style,
)


def update_token_range(*args):
    token_range_wdg.max = len(tokenizer.encode(text_wdg.value))

update_token_range()

token_range_wdg.value = [0, token_range_wdg.max]
text_wdg.observe(update_token_range, 'value')

interact = widgets.interact.options(manual_name='Run Lens', manual=True)

plot = interact(
    make_plot,
    text=text_wdg,
    statistic=statistic_wdg,
    lens=lens_wdg,
    token_range=token_range_wdg,
)

interactive(children=(Dropdown(description='Select Lens:', options=(('Tuned Lens', TunedLens(
  (extra_layers)…