In [1]:
import numpy as np
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Model

# Load pre-trained model and tokenizer
model_name = "gpt2-medium"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)

# Set the model to evaluation mode
model.eval()

device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'
elif torch.mps.is_available():
    device = 'mps'

print("Device: " + device)

model.to(device)

Device: mps


GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 1024)
    (wpe): Embedding(1024, 1024)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-23): 24 x GPT2Block(
        (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=3072, nx=1024)
          (c_proj): Conv1D(nf=1024, nx=1024)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=4096, nx=1024)
          (c_proj): Conv1D(nf=1024, nx=4096)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=1024, out_features=50257, bias=False)
)

In [14]:
import pandas as pd
import matplotlib.pyplot as plt

input = open("women_in_tech.txt", "r").read()
input_ids = tokenizer.encode(input, return_tensors="pt").to(device)
target_ids = input_ids.clone()[0, 1:]
input_ids = input_ids[:, :-1]

def get_top_token_ids(rep):
    prob = torch.softmax(rep @ model.transformer.wte.weight.T, dim=-1)
    indices = torch.argmax(prob, dim=-1)
    return indices

num_tokens = input_ids.shape[1]
num_layers = len(model.transformer.h)
rep = torch.zeros((num_layers, num_tokens), device=device, dtype=torch.int64)
def hook(_, args, output, idx):
    token_idx = output[0].shape[1]
    output_vec = output[0][0,:,:]
    rep[idx] = get_top_token_ids(output_vec)
    
hooks = []
for i, h in enumerate(model.transformer.h):
    hk = h.register_forward_hook(lambda module, args, output, idx=i: hook(module, args, output, idx))
    hooks.append(hk)
    
try:
    logits = model(input_ids).logits[0]
except Exception as e:
    print("Error in model call: ",e)
    
for h in hooks:
    h.remove()
    
output_ids = torch.argmax(logits, dim=-1)
stablised_layer = torch.argmin((rep - output_ids) ** 2, dim=0).to('cpu')
output = [tokenizer.decode(int(i)) for i in output_ids]
ip = [tokenizer.decode(int(i)) for i in input_ids[0]]
target = [tokenizer.decode(int(i)) for i in target_ids]
pd.DataFrame({"input": ip,"output": output, "layer": stablised_layer, "target": target})
# pd.DataFrame({'layer':stablised_layer}).hist()

Unnamed: 0,input,output,layer,target
0,My,friend,0,wife
1,wife,and,7,is
2,is,a,15,working
3,working,on,19,at
4,at,a,13,a
...,...,...,...,...
281,learning,.,21,and
282,and,machine,14,natural
283,natural,language,5,language
284,language,processing,12,processing
