# Comparing prediction trajectories

A prediction trajectory is the sequence of latent predictions produced by the tuned 
lens. You can think of this as the best guesses that can be made about
the final token distribution from by the lenses' affine translator for that layer.

This gives us a 3 dimensional tensor we call the prediction trajectory with shape
`(num_layers x sequence_length x vocab_size)`. These distributions are typically stored
in log space for numerical precision reasons.


But first we will need to load our lenses and models from hugging face hub.

In [19]:
import torch
from tuned_lens.nn.lenses import TunedLens, LogitLens
from transformers import AutoModelForCausalLM, AutoTokenizer

device = torch.device('cpu')
# To try a diffrent modle / lens check if the lens is avalible then modify this code
model = AutoModelForCausalLM.from_pretrained('EleutherAI/pythia-160m-deduped-v0')
model = model.to(device)
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/pythia-160m-deduped-v0')
tuned_lens = TunedLens.load("pythia-160m-deduped-v0", map_location=device)
tuned_lens = tuned_lens.to(device)



Now lets prepare some interesting text to examine. Here we will use a quote from tolken 
that has some nice repetition. It's also common enough that it was likely in the training
data so modifying it will hopefully let us see some conflicts between the model's parametric
knowledge and it's in context learning.

In [34]:

input_ids_ring = tokenizer.encode(
    "One Ring to rule them all,\n"
    "One Ring to find them,\n"
    "One Ring to bring them all\n"
    "and in the darkness bind them"
)

input_ids_model = tokenizer.encode(
    "One Model to rule them all,\n"
    "One Model to find them,\n"
    "One Model to bring them all\n"
    "and in the darkness bind them"
)


targets_ring = input_ids_ring[1:] + [tokenizer.eos_token_id]
targets_model = input_ids_model[1:] + [tokenizer.eos_token_id]

Let's validate that the tokenizations line up and this is indeed going to be a one
token substitution.

In [35]:
print(tokenizer.convert_ids_to_tokens(input_ids_ring))
print(tokenizer.convert_ids_to_tokens(input_ids_model))

['One', 'ĠRing', 'Ġto', 'Ġrule', 'Ġthem', 'Ġall', ',', 'Ċ', 'One', 'ĠRing', 'Ġto', 'Ġfind', 'Ġthem', ',', 'Ċ', 'One', 'ĠRing', 'Ġto', 'Ġbring', 'Ġthem', 'Ġall', 'Ċ', 'and', 'Ġin', 'Ġthe', 'Ġdarkness', 'Ġbind', 'Ġthem']
['One', 'ĠModel', 'Ġto', 'Ġrule', 'Ġthem', 'Ġall', ',', 'Ċ', 'One', 'ĠModel', 'Ġto', 'Ġfind', 'Ġthem', ',', 'Ċ', 'One', 'ĠModel', 'Ġto', 'Ġbring', 'Ġthem', 'Ġall', 'Ċ', 'and', 'Ġin', 'Ġthe', 'Ġdarkness', 'Ġbind', 'Ġthem']


Now lets examine the third line in tolken's epigrame thats going to be tokens [14, 21].

In [36]:
predictition_traj_ring = PredictionTrajectory.from_lens_and_model(
    tuned_lens,
    model,
    start_pos=14,
    end_pos=21,
    tokenizer=tokenizer,
    input_ids=input_ids_ring,
    targets=targets_ring,
)

Now let's visualize the prediction trajectory for this slice of the tranformer.
Note that the entire sequence is still being fed to the model we are just visualizing a 
prediction trajectory for this particular slice.

In [75]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

fig = make_subplots(
    rows=3,
    cols=1,
    shared_xaxes=True,
    vertical_spacing=0.03,
    subplot_titles=("Entropy", "Forward KL", "Cross Entropy"),
)

fig.add_trace(
    predictition_traj_ring.entropy().heatmap(
        colorbar_y=0.86, colorbar_len=0.3, textfont={'size':10}
    ),
    row=1, col=1
)

fig.add_trace(
    predictition_traj_ring.forward_kl().heatmap(
        colorbar_y=0.5, colorbar_len=0.3, textfont={'size':10}
    ),
    row=2, col=1
)

fig.add_trace(
    predictition_traj_ring.cross_entropy().heatmap(
        colorbar_y=0.15, colorbar_len=0.3, textfont={'size':10}
    ),
    row=3, col=1
)

fig.update_layout(height=800, width=500, title_text="Tolkien's Tokens on visualized with the Tuned Lens")
fig

Okay now let's look at the prediction trajectory for our modified sequence.

In [76]:
predictition_traj_model = PredictionTrajectory.from_lens_and_model(
    tuned_lens,
    model,
    start_pos=14,
    end_pos=21,
    tokenizer=tokenizer,
    input_ids=input_ids_model,
    targets=targets_model,
)

In [77]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

fig = make_subplots(
    rows=3,
    cols=1,
    shared_xaxes=True,
    vertical_spacing=0.03,
    subplot_titles=("Entropy", "Forward KL", "Cross Entropy"),
)

fig.add_trace(
    predictition_traj_model.entropy().heatmap(
        colorbar_y=0.86, colorbar_len=0.3, textfont={'size':10}
    ),
    row=1, col=1
)

fig.add_trace(
    predictition_traj_model.forward_kl().heatmap(
        colorbar_y=0.5, colorbar_len=0.3, textfont={'size':10}
    ),
    row=2, col=1
)

fig.add_trace(
    predictition_traj_model.cross_entropy().heatmap(
        colorbar_y=0.15, colorbar_len=0.3, textfont={'size':10}
    ),
    row=3, col=1
)

fig.update_layout(height=800, width=500, title_text="Intervened on tokens visualized with the tuned The Tuned Lens")
fig