In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pprint import pprint
import torch
import sys
import yaml
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
from Utils import TransformerSampler, ModelConfig, GenerationConfig
from model import Model
from statedict_mapping import convert_model_weights
torch.manual_seed(1337)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

Using device: cuda


In [3]:
sampler = TransformerSampler()
model = AutoModelForCausalLM.from_pretrained("/home/ubuntu/MechInter/OneLayerModel/Models/TinyStories-1L-21M").to(device)
tokenizer = AutoTokenizer.from_pretrained("/home/ubuntu/MechInter/OneLayerModel/Models/TinyStories-1L-21M")
tokenizer.pad_token = tokenizer.eos_token

with open('OneLM.yaml', 'r') as f:
    config_dict = yaml.safe_load(f)
    model_config = ModelConfig(config_dict)

modified_state_dict = convert_model_weights(model_config, model)
model = Model(model_config).to(device)
model.load_state_dict(modified_state_dict)

<All keys matched successfully>

In [4]:
prompts = [
    'Once upon a time',
    # 'In a galaxy far, far away',
    'The quick brown fox',
]

tokens = tokenizer(prompts, return_tensors='pt', padding=True, padding_side='left', truncation=True).to(device='cuda')
input_ids = tokens['input_ids']
attention_mask = tokens['attention_mask']

for _ in range(15):
    logits = model(input_ids = input_ids, attention_mask = attention_mask)

    # if running inference using transformers
    # logits = logits.logits
    
    next_token = sampler.decode(logits, temperature=0).to(device='cuda')
    print(tokenizer.batch_decode(next_token))
    input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=1)
    attention_mask = torch.cat([attention_mask, torch.ones((attention_mask.shape[0], 1), dtype=attention_mask.dtype, device='cuda')], dim=-1)
# print(tokenizer.batch_decode(input_ids))


[' there', ' was']
[' was', ' filled']
[' an', ' with']
[' enthusiastic', ' wonder']
[' elephant', ' when']
['.', ' he']
[' He', ' saw']
[' loved', ' the']
[' to', ' colorful']
[' play', ' thing']
[' in', '.']
[' the', ' He']
[' water', ' wondered']
[' all', ' what']
[' day', ' it']


In [5]:
# Define hook function outside the class to avoid method binding issues
def create_activation_hook(collector):
    """Factory function to create a hook that captures activations"""
    def hook(module, input, output):
        activations = output.detach().cpu()
        activations = activations.float()
        flat_activations = activations.view(-1, activations.size(-1))
        collector.activations.append(flat_activations)
    return hook

class ActivationCollector:
    """Collects activations from a specific layer during forward pass"""

    def __init__(self, model, layer_idx=0):
        self.model = model
        self.layer_idx = layer_idx
        self.activations = []
        self.hook = None
        self.register_hook()

    def register_hook(self):
        """Register forward hook on the activations in the MLP layer"""
        mlp_layer = self.model.transformer_block[self.layer_idx].mlp
        target_layer = mlp_layer.gelu

        # Use the factory function to create the hook
        hook_fn = create_activation_hook(self)
        self.hook = target_layer.register_forward_hook(hook_fn)

    def clear_activations(self):
        """Clear stored activations"""
        self.activations = []

    def get_activations(self):
        """Return concatenated activations"""
        if not self.activations:
            return None
        concatenated = torch.cat(self.activations, dim=0)
        return concatenated

    def remove_hook(self):
        """Remove the forward hook"""
        if self.hook:
            self.hook.remove()

# Initialize activation collector for layer 1
collector = ActivationCollector(model, layer_idx=0)
print("Activation collector initialized for layer 1")

Activation collector initialized for layer 1


In [6]:
with torch.no_grad():
    tokens = tokenizer(prompts, return_tensors='pt', padding=True, padding_side='left', truncation=True).to(device='cuda')
    input_ids = tokens['input_ids']
    attention_mask = tokens['attention_mask']
    _ = model(input_ids=input_ids, attention_mask=attention_mask)


In [11]:
tokens['input_ids']

tensor([[ 7454,  2402,   257,   640],
        [  464,  2068,  7586, 21831]], device='cuda:0')

In [11]:
collector.activations

[tensor([[-0.1529, -0.1152, -0.1108,  ..., -0.0799, -0.0500, -0.1450],
         [ 0.1718, -0.0645, -0.1697,  ...,  0.0225, -0.0198, -0.1569],
         [-0.0180, -0.0658, -0.1339,  ..., -0.1189, -0.0227, -0.1316],
         ...,
         [ 0.1220, -0.0259, -0.1370,  ...,  0.0919,  0.0065, -0.1514],
         [ 0.3229, -0.0226, -0.0977,  ...,  0.0128, -0.0649,  0.0561],
         [ 1.8612, -0.0320, -0.1239,  ..., -0.1228, -0.1221, -0.1291]])]

In [7]:
len(collector.activations[0])

8

In [9]:
collector.activations[0].shape

torch.Size([8, 4096])