In [29]:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Model

# Load pre-trained model and tokenizer
model_name = "gpt2-medium"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)

# Set the model to evaluation mode
model.eval()
t: GPT2Model = model.transformer

In [30]:
# Value Vector Projection

def get_top_k_tokens(rep, k=5):
    prob = torch.softmax(model.transformer.wte.weight @ rep, dim=-1)
    prob, indices = torch.topk(prob, 5)
    return [tokenizer.decode(i) for i in indices]

def get_value_vector_tokens(layer, dim):
    return get_top_k_tokens(t.h[layer].mlp.c_proj.weight[dim])

get_value_vector_tokens(layer=17, dim=2940)

['cold', ' colder', ' precipitation', ' frost', 'clone']

In [10]:
t.h[17].mlp.c_fc

Conv1D(nf=4096, nx=1024)

In [60]:
t.h[0].mlp.act

NewGELUActivation()

In [76]:
# Before and After Layer Representation
# Dominant Sub Updates
# Intervention

# interventions = {
#     10: [3141],
#     17: [115]
# }

interventions = {}

def hook(_, args, output, idx):
    input_vec = args[0][0,-1,:]
    output_vec = output[0][0,-1,:]
    print(f"Input: {get_top_k_tokens(t.ln_f(input_vec))}")
    print(f"Output: {get_top_k_tokens(t.ln_f(output_vec))}")
    
def proj_hook(module, args, output, idx):
    coeff_vec = args[0][0,-1,:]
    value_norms = torch.linalg.norm(module.weight.data, dim=1)
    scaled_coefs = torch.absolute(coeff_vec) * value_norms
    if idx not in interventions:
        print(f"\n------Layer {idx}------")
    subupdates = list(enumerate(scaled_coefs))
    subupdates = sorted(subupdates, key=lambda x: x[1], reverse=True)
    subupdates = [f"L{idx}D{dim}: {val:.2f}" for dim, val in subupdates[:10]]
    print(f"Dominant sub updates: {subupdates[:10]}")
    
def intervene_hook(module, args, output, idx):
    if idx not in interventions:
        return
    print(f"\n------Layer {idx}------")
    print(f"Intervention(s) at layer {idx}: {[f"L{idx}D{dim}" for dim in interventions[idx]]}")
    coeff_vec = output[0,-1,:]
    # coeff_vec = t.h[idx].mlp.act(coeff_vec)
    c_proj = t.h[idx].mlp.c_proj
    value_norms = torch.linalg.norm(c_proj.weight.data, dim=1)
    scaled_coefs = torch.absolute(coeff_vec) * value_norms
    max_coeff = torch.max(scaled_coefs)
    for dim in interventions[idx]:
        output[0,-1,dim] = max_coeff
    return output
        
        

hooks = []
for i, layer in enumerate(t.h[:]):
    h1 = layer.register_forward_hook(
        lambda module, args, output, idx=i: hook(module, args, output, idx)
    )
    h2 = layer.mlp.c_proj.register_forward_hook(
        lambda module, args, output, idx=i: proj_hook(module, args, output, idx)
    )
    h3 = layer.mlp.c_fc.register_forward_hook(
        lambda module, args, output, idx=i: intervene_hook(module, args, output, idx)
    )
    
    hooks.append(h1)
    hooks.append(h2)
    hooks.append(h3)

try:
    # Run the model to get outputs and capture intermediate representations
    input = tokenizer.encode("My wife is working as a", return_tensors="pt")
    with torch.no_grad():
        outputs = model(input)
    logits = outputs.logits
    generated_ids = torch.argmax(logits, dim=-1)
    generated_text = tokenizer.decode(generated_ids[0][-1])
    print(f"\nGenerated next token: {generated_text}")
except Exception as e:
    print(e)

# Remove the hooks
for h in hooks:
    h.remove()


------Layer 0------
Dominant sub updates: ['L0D366: 21.29', 'L0D1198: 19.19', 'L0D4055: 16.59', 'L0D798: 13.89', 'L0D1254: 12.90', 'L0D284: 10.89', 'L0D2121: 9.05', 'L0D3969: 7.66', 'L0D1619: 7.31', 'L0D2938: 6.82']
Input: [' unden', ' helicop', ' streng', ' enthusi', ' notor']
Output: [' completely', ' "', ' fully', ' particularly', ' certain']

------Layer 1------
Dominant sub updates: ['L1D3460: 9.68', 'L1D736: 7.85', 'L1D51: 4.03', 'L1D676: 3.57', 'L1D1922: 3.43', 'L1D1091: 3.00', 'L1D2945: 2.57', 'L1D2023: 2.36', 'L1D3026: 2.22', 'L1D3205: 1.98']
Input: [' completely', ' "', ' fully', ' particularly', ' certain']
Output: [' particularly', ' "', ' single', ' completely', ' fully']

------Layer 2------
Dominant sub updates: ['L2D609: 2.94', 'L2D2718: 2.57', 'L2D2520: 2.20', 'L2D3131: 2.19', 'L2D3524: 2.00', 'L2D3857: 1.71', 'L2D2102: 1.18', 'L2D2107: 1.12', 'L2D1651: 1.07', 'L2D1789: 1.01']
Input: [' particularly', ' "', ' single', ' completely', ' fully']
Output: [' particularly',