## Logit lens

Trying logit lens

In [1]:
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import torch

import circuitsvis as cv

import transformer_lens.utils as utils
from transformer_lens.hook_points import HookPoint
from transformer_lens import HookedTransformer, FactoredMatrix
import einops

import tqdm.auto as tqdm
from functools import partial

In [2]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [3]:
def vis_attn_patterns(model, text, layers, compact=True):
    str_tokens = model.to_str_tokens(text)
    logits, cache = model.run_with_cache(text, remove_batch_dim=True)

    if compact:
        for layer in layers:
            attention_pattern = cache["pattern", layer]
            display(cv.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern))
    
    else:
        for layer in layers:
            attention_pattern = cache["pattern", layer]
            display(cv.attention.attention_heads(tokens=str_tokens, attention=attention_pattern))

In [4]:
model = HookedTransformer.from_pretrained("gpt2-small")

Loaded pretrained model gpt2-small into HookedTransformer


In [5]:
text = "The capital of France is Paris and has a population of 12"
answer = " million"

logits, cache = model.run_with_cache(text)
print(logits.shape)

torch.Size([1, 13, 50257])


In [6]:
model.generate(text, max_new_tokens=10,do_sample=False)

  0%|          | 0/10 [00:00<?, ?it/s]

'The capital of France is Paris and has a population of 12 million.\n\nThe capital of France is Paris'

In [7]:
acc_resid_stack = cache.accumulated_resid()
scaled_resid_stack = cache.apply_ln_to_stack(acc_resid_stack)
print(scaled_resid_stack.shape)
Unembedding_matrix = model.W_U
#print(Unembedding_matrix.shape)
#n_layers, _, n_dim = scaled_resid_stack.shape

logit_lens_final_logits = einops.einsum(scaled_resid_stack, Unembedding_matrix,
                                        "n_layer ... pos n_dim, n_dim vocab -> n_layer pos vocab")
print(logit_lens_final_logits.shape)

torch.Size([13, 1, 13, 768])
torch.Size([13, 13, 50257])


In [8]:
print(len(logit_lens_final_logits))
logits_final = []
inter_tokens = []
indices_all = []
for layer in range(model.cfg.n_layers):
    print(f'Layer {layer}:')
    probs = logit_lens_final_logits[layer].softmax(dim=-1)
    #print(probs.shape)
    top_prob, indices = probs.topk(1)
    #print(top_prob)
    indices_all.append(indices)
    logits_final.append([logit_lens_final_logits[layer, i, index].item() for i, index in enumerate(indices)])
    #print(indices)
    top_tokens = [model.to_string(index.item()) for index in indices]
    inter_tokens.append(top_tokens)
    for token, prob in zip(top_tokens, top_prob):
        print(f'{token}: {prob.item():.4f}  ', end='')
    print()

13
Layer 0:
���: 0.0000  theless: 0.0000   capital: 0.0001   livest: 0.0000   France: 0.0000   destro: 0.0000   Paris: 0.0001   mathemat: 0.0000   destro: 0.0000   destro: 0.0000   population: 0.0000   destro: 0.0000   destro: 0.0000  
Layer 1:
Loading: 0.0006  ories: 0.0001   capital: 0.0007   course: 0.0001   Marse: 0.0002   unlikely: 0.0001   Paris: 0.0006   then: 0.0001   been: 0.0001   few: 0.0001   population: 0.0002   course: 0.0001  th: 0.0002  
Layer 2:

: 0.0038  ories: 0.0001   capital: 0.0011   course: 0.0001   Connection: 0.0002   unlikely: 0.0001   Hilton: 0.0010   then: 0.0001   been: 0.0002   few: 0.0001   population: 0.0002   course: 0.0001  th: 0.0003  
Layer 3:

: 0.0768  ories: 0.0001   capital: 0.0020   course: 0.0001   Connection: 0.0003   now: 0.0001   Hilton: 0.0019   then: 0.0001   been: 0.0002   huge: 0.0002   explosion: 0.0002   course: 0.0001  th: 0.0003  
Layer 4:

: 0.1057  resa: 0.0001   capital: 0.0023   Hope: 0.0002   Alps: 0.0004   now: 0.0002   Hilton

In [None]:
logit_diff_results = torch.zeros(model.cfg.n_layers, len(model.to_str_tokens(text)))
#print(logit_diff_results.shape)
print((model.to_tokens(text)[1:]))
answer_tokens = model.to_tokens(text).tolist()[0][1:] + model.to_tokens(answer).tolist()[0][1:]
answers_str_tokens = [f'{model.to_string(answer)}_{ids}' for ids, answer in enumerate(answer_tokens)]
print(answers_str_tokens)
print(answer_tokens)
for layer in range(model.cfg.n_layers):
    #print(layer)
    for position in range(len(logit_lens_final_logits[0])):
        logit_diff_results[layer][position] = logits[0, position, answer_tokens[position]] - logit_lens_final_logits[layer, position, indices_all[layer][position][0]]
logits_final = np.array(logits_final)
answers_str_tokens = np.array(answers_str_tokens)
print(len(answers_str_tokens))
print(answers_str_tokens)
print(logits_final.shape)
fig = px.imshow(logits_final[:,1:], labels=dict(x="Tokens", y="Layers", color="Logits"),x=answers_str_tokens[1:], aspect="auto")
fig.update_traces(text=inter_tokens, texttemplate="%{text}")

tensor([], size=(0, 13), dtype=torch.int64)
['The_0', ' capital_1', ' of_2', ' France_3', ' is_4', ' Paris_5', ' and_6', ' has_7', ' a_8', ' population_9', ' of_10', ' 12_11', ' million_12']
[464, 3139, 286, 4881, 318, 6342, 290, 468, 257, 3265, 286, 1105, 1510]
13
['The_0' ' capital_1' ' of_2' ' France_3' ' is_4' ' Paris_5' ' and_6'
 ' has_7' ' a_8' ' population_9' ' of_10' ' 12_11' ' million_12']
(12, 13)
----
Tokens
----


: 