In [1]:
import transformer_lens
import torch
import pandas as pd

## Loading the Annotated Dataset

In [2]:
gendered_data = pd.read_csv('gender_data.csv')
gendered_data.head()

Unnamed: 0,Word,Gender,Category
0,Abigail,Female,Name
1,Adeline,Female,Name
2,Admiral,Male,Common Noun
3,Adventuress,Female,Common Noun
4,Alethea,Female,Name


In [3]:
gendered_data.shape

(312, 3)

## Loading Pre-trained Model

In [4]:
# Load the pre-trained "tiny-stories-1L-21M" model
model_name = "tiny-stories-1L-21M"  # Replace this with the actual model name
model = transformer_lens.HookedTransformer.from_pretrained(model_name)



Loaded pretrained model tiny-stories-1L-21M into HookedTransformer


In [5]:
# here we use generate to get 10 completeions with temperature 1. Feel free to play with the prompt to make it more interesting.
for i in range(5):
    display(
        model.generate(
            "Once upon a time",
            stop_at_eos=False,  # avoids a bug on MPS
            temperature=0.5,
            verbose=False,
            max_new_tokens=50,
        )
    )

'Once upon a time, there was a girl named Lucy. She was very pretty and loved to play with her toys. One day, she went to the park with her mommy.\n\nAt the park, Lucy saw a big, tall tree. She wanted to'

'Once upon a time, there was a little girl named Lucy. She was three years old and loved to explore. One day, she went for a walk in the woods.\n\nAs she was walking, she saw something sparkly in the grass. It was a'

"Once upon a time, there was a boy named Sam. Sam was three years old and loved to play with his toys. Every day he would go to his room and do his homework.\n\nOne day, Sam's mom asked him to clean his room. Sam"

'Once upon a time, there was a girl named Lucy. She was three years old and loved to explore. One day, she decided to go on a journey.\n\nAs she walked through the forest, she saw something that made her stop and look. It was'

'Once upon a time, there was a big bird. He was a very happy bird and he loved to fly. One day, he saw a big, juicy apple on the ground. He wanted to eat it, but it was too high up. So, he flew'

## Tokenizing the prompts and Setting Up Hooks

In [6]:
prompts = list(gendered_data['Word'])
batch_tokens = model.tokenizer(prompts, return_tensors="pt", padding=True)['input_ids']
batch_tokens.shape # batch dimension, time dimension

torch.Size([312, 4])

In [7]:
mlp_activations = []
def capture_mlp_activations(module, ip, output):
    #print(ip)
    mlp_activations.append(output.detach().cpu())

# Register hooks to capture MLP activations at each layer
# Assuming the MLP layers are named "mlp" in the model (adjust if necessary)
for layer_index in range(model.cfg.n_layers):
    model.blocks[layer_index].mlp.register_forward_hook(capture_mlp_activations)

## Model Token Context Testing

In [8]:
eg_token = batch_tokens[0]

In [9]:
token_vectors = []
mlp_activations = []

In [10]:
no_context_embedding = eg_token.reshape(len(eg_token),-1)
no_context_embedding

tensor([[ 4826],
        [  328],
        [  603],
        [50256]])

In [11]:
with torch.no_grad():
    model(no_context_embedding)

In [16]:
print(mlp_activations[0].shape)
exp1 = mlp_activations[0].squeeze(1)
print(exp1.shape) # batch * sample * embedding dimensions

torch.Size([4, 1, 1024])
torch.Size([4, 1024])


In [21]:
context_embedding = eg_token.clone()
mlp_activations = []
print(mlp_activations)
with torch.no_grad():
    model(context_embedding)
print(mlp_activations)

[]
[tensor([[[-0.4515,  0.4032, -0.1441,  ...,  0.1154,  0.5834,  0.2643],
         [ 0.3742,  0.0929, -0.2444,  ...,  0.3709,  1.2639,  0.4193],
         [ 0.0167,  0.6710, -0.4912,  ...,  0.0497,  1.1304,  0.1132],
         [-0.8192,  0.3767, -0.4370,  ...,  0.0478,  0.3139, -0.1156]]])]


In [26]:
print(mlp_activations[0].shape)
exp2 = mlp_activations[0].squeeze(0)
print(exp2.shape) # batch * sample * embedding dimensions

torch.Size([1, 4, 1024])
torch.Size([4, 1024])


In [29]:
print('Without Previous Token Context:',exp1[-1])
print('With Previous Token Context:',exp2[-1])

Without Previous Token Context: tensor([-1.8401, -0.4987,  0.7532,  ..., -0.8744,  0.6383, -0.2679])
With Previous Token Context: tensor([-0.8192,  0.3767, -0.4370,  ...,  0.0478,  0.3139, -0.1156])


## Computing Activations for Gendered Words

In [38]:
#Forward pass over the model to capture MLP activations
mlp_activations = []
with torch.no_grad():
    model(batch_tokens)

# Print out the collected MLP activations
for i, activation in enumerate(mlp_activations):
    print(f"Layer {i} MLP activations: {activation.shape}")

activations = torch.mean(mlp_activations[0],dim=1)
print('Final Activations:', activations.shape)

Layer 0 MLP activations: torch.Size([312, 4, 1024])
Final Activations: torch.Size([312, 1024])
