In [248]:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Model

# Load pre-trained model and tokenizer
model_name = "gpt2-medium"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)

# Set the model to evaluation mode
model.eval()

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 1024)
    (wpe): Embedding(1024, 1024)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-23): 24 x GPT2Block(
        (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=3072, nx=1024)
          (c_proj): Conv1D(nf=1024, nx=1024)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=4096, nx=1024)
          (c_proj): Conv1D(nf=1024, nx=4096)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=1024, out_features=50257, bias=False)
)

In [249]:
# Value Vector Projection

def get_top_k_tokens(rep, k=5):
    prob = torch.softmax(model.transformer.wte.weight @ rep, dim=-1)
    prob, indices = torch.topk(prob, 5)
    return [tokenizer.decode(i) for i in indices]

def get_value_vector_tokens(layer, dim):
    return get_top_k_tokens(t.h[layer].mlp.c_proj.weight[dim])

get_value_vector_tokens(layer=0, dim=366)

[' behav', 'ngth', 'EStreamFrame', ' disg', 'ften']

In [251]:
# Before and After Layer Representation

t: GPT2Model = model.transformer

def hook(_, args, output, idx):
    print(f"\n------Layer {idx}------")
    input_vec = args[0][0,-1,:]
    output_vec = output[0][0,-1,:]
    print(f"Input: {get_top_k_tokens(t.ln_f(input_vec))}")
    print(f"Output: {get_top_k_tokens(t.ln_f(output_vec))}")

hooks = []
for i, layer in enumerate(t.h):
    h = layer.register_forward_hook(
        lambda module, args, output, idx=i: hook(module, args, output, idx)
    )
    hooks.append(h)
    
try:
    # Run the model to get outputs and capture intermediate representations
    input = tokenizer.encode("My wife is working as a", return_tensors="pt")
    with torch.no_grad():
        outputs = model(input)
    logits = outputs.logits
    generated_ids = torch.argmax(logits, dim=-1)
    generated_text = tokenizer.decode(generated_ids[0][-1])
    print(f"\nGenerated next token: {generated_text}")
except Exception as e:
    print(e)

# Remove the hooks
for h in hooks:
    h.remove()


------Layer 0------
Input: [' unden', ' helicop', ' streng', ' enthusi', ' notor']
Output: [' completely', ' "', ' fully', ' particularly', ' certain']

------Layer 1------
Input: [' completely', ' "', ' fully', ' particularly', ' certain']
Output: [' particularly', ' "', ' single', ' completely', ' fully']

------Layer 2------
Input: [' particularly', ' "', ' single', ' completely', ' fully']
Output: [' particularly', ' single', ' "', ' piece', ' very']

------Layer 3------
Input: [' particularly', ' single', ' "', ' piece', ' very']
Output: [' separate', ' very', ' single', ' well', ' particularly']

------Layer 4------
Input: [' separate', ' very', ' single', ' well', ' particularly']
Output: [' separate', ' member', ' part', ' single', ' well']

------Layer 5------
Input: [' separate', ' member', ' part', ' single', ' well']
Output: [' member', ' part', ' separate', ' host', ' single']

------Layer 6------
Input: [' member', ' part', ' separate', ' host', ' single']
Output: [' par