In [2]:
import numpy as np
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Model

# Load pre-trained model and tokenizer
model_name = "gpt2-medium"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)

# Set the model to evaluation mode
model.eval()
t: GPT2Model = model.transformer

In [2]:
# 1. Value Vector Projection

def get_top_k_tokens(rep, k=5):
    prob = torch.softmax(model.transformer.wte.weight @ rep, dim=-1)
    prob, indices = torch.topk(prob, 5)
    return [tokenizer.decode(i) for i in indices]

def get_value_vector_tokens(layer, dim):
    return get_top_k_tokens(t.h[layer].mlp.c_proj.weight[dim])

get_value_vector_tokens(layer=17, dim=2940)

['cold', ' colder', ' precipitation', ' frost', 'clone']

In [12]:
def get_top_k_token_indices(rep, k=5):
    prob = torch.softmax(model.transformer.wte.weight @ rep, dim=-1)
    prob, indices = torch.topk(prob, 5)
    return indices

In [49]:
# 2. Before and After Layer Representation
import numpy as np

max_tokens = 20
rep = np.zeros((max_tokens, len(t.h)))


def hook(_, args, output, idx):
    token_idx = output[0].shape[1]
    input_vec = args[0][0,-1,:]
    output_vec = output[0][0,-1,:]
    rep[token_idx][idx] = get_top_k_token_indices(t.ln_f(output_vec))[0]
          

hooks = []
for i, layer in enumerate(t.h[:]):
    h1 = layer.register_forward_hook(
        lambda module, args, output, idx=i: hook(module, args, output, idx)
    )
    
    hooks.append(h1)

try:
    input = tokenizer.encode("My wife is working as a", return_tensors="pt")
    output_ids = np.zeros(max_tokens)
    for i in range(input.shape[1]):
        output_ids[i] = input[0][i]
    for i in range(input.shape[1], max_tokens):
        # Run the model to get outputs and capture intermediate representations
        with torch.no_grad():
            outputs = model(input)
        logits = outputs.logits
        generated_ids = torch.argmax(logits, dim=-1)
        generated_text = tokenizer.decode(generated_ids[0][-1])
        input = torch.cat([input, generated_ids[0][-1].unsqueeze(0).unsqueeze(0)], dim=1)
        output_ids[i] = generated_ids[0][-1]
    
    print(tokenizer.decode(input[0]))
except Exception as e:
    print(e)

# Remove the hooks
for h in hooks:
    h.remove()


import pandas as pd
i = 7
finished = np.argmin((rep - np.reshape(output_ids, (max_tokens, 1))) ** 2, axis=1)
text = [tokenizer.decode(int(i)) for i in output_ids]
pd.DataFrame({"token": text, "layer": finished})

My wife is working as a nurse and I'm a teacher. We have a small house and we
