In [64]:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Model

# Load pre-trained model and tokenizer
model_name = "gpt2-medium"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)

# Set the model to evaluation mode
model.eval()
t: GPT2Model = model.transformer

In [65]:
# Value Vector Projection

def get_top_k_tokens(rep, k=5):
    prob = torch.softmax(model.transformer.wte.weight @ rep, dim=-1)
    prob, indices = torch.topk(prob, 5)
    return [tokenizer.decode(i) for i in indices]

def get_value_vector_tokens(layer, dim):
    return get_top_k_tokens(t.h[layer].mlp.c_proj.weight[dim])

get_value_vector_tokens(layer=17, dim=2940)

['cold', ' colder', ' precipitation', ' frost', 'clone']

In [75]:
# Before and After Layer Representation

def hook(_, args, output, idx):
    input_vec = args[0][0,-1,:]
    output_vec = output[0][0,-1,:]
    print(f"Input: {get_top_k_tokens(t.ln_f(input_vec))}")
    print(f"Output: {get_top_k_tokens(t.ln_f(output_vec))}")
    
def proj_hook(module, args, output, idx):
    coeff_vec = args[0][0,-1,:]
    value_norms = torch.linalg.norm(module.weight.data, dim=1)
    scaled_coefs = torch.absolute(coeff_vec) * value_norms
    print(f"\n------Layer {idx}------")
    subupdates = list(enumerate(scaled_coefs))
    subupdates = sorted(subupdates, key=lambda x: x[1], reverse=True)
    print(f"Dominant sub updates: {subupdates[:10]}")

hooks = []
for i, layer in enumerate(t.h[:]):
    h1 = layer.register_forward_hook(
        lambda module, args, output, idx=i: hook(module, args, output, idx)
    )
    h2 = layer.mlp.c_proj.register_forward_hook(
        lambda module, args, output, idx=i: proj_hook(module, args, output, idx)
    )
    hooks.append(h1)
    hooks.append(h2)
    
try:
    # Run the model to get outputs and capture intermediate representations
    input = tokenizer.encode("My wife is working as a", return_tensors="pt")
    with torch.no_grad():
        outputs = model(input)
    logits = outputs.logits
    generated_ids = torch.argmax(logits, dim=-1)
    generated_text = tokenizer.decode(generated_ids[0][-1])
    print(f"\nGenerated next token: {generated_text}")
except Exception as e:
    print(e)

# Remove the hooks
for h in hooks:
    h.remove()


------Layer 0------
Dominant sub updates: [(366, tensor(21.2909)), (1198, tensor(19.1944)), (4055, tensor(16.5859)), (798, tensor(13.8921)), (1254, tensor(12.9030)), (284, tensor(10.8873)), (2121, tensor(9.0504)), (3969, tensor(7.6568)), (1619, tensor(7.3107)), (2938, tensor(6.8245))]
Input: [' unden', ' helicop', ' streng', ' enthusi', ' notor']
Output: [' completely', ' "', ' fully', ' particularly', ' certain']

------Layer 1------
Dominant sub updates: [(3460, tensor(9.6827)), (736, tensor(7.8470)), (51, tensor(4.0347)), (676, tensor(3.5699)), (1922, tensor(3.4282)), (1091, tensor(2.9955)), (2945, tensor(2.5733)), (2023, tensor(2.3635)), (3026, tensor(2.2193)), (3205, tensor(1.9815))]
Input: [' completely', ' "', ' fully', ' particularly', ' certain']
Output: [' particularly', ' "', ' single', ' completely', ' fully']

------Layer 2------
Dominant sub updates: [(609, tensor(2.9426)), (2718, tensor(2.5653)), (2520, tensor(2.2043)), (3131, tensor(2.1944)), (3524, tensor(2.0022)), (3