In [2]:
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM


  from .autonotebook import tqdm as notebook_tqdm


In [8]:
# Core
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

# Hugging Face
from transformers import AutoTokenizer, AutoModelForCausalLM

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_NAME = "gpt2"  # start small; later swap to your target model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, output_hidden_states=True).to(device).eval()


In [4]:
!curl "https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt" > "shakespeare.txt"

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed

  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100 5330k  100 5330k    0     0  17.0M      0 --:--:-- --:--:-- --:--:-- 17.1M


In [16]:
class SimpleTextDataset(Dataset):
    def __init__(self, strings, tokenizer, maxLen=256):
        self.strings = strings
        self.tokenizer = tokenizer
        self.maxLen = maxLen
    
    def __len__(self): 
        return len(self.strings)
    
    def __getitem__(self, i):
        enc = self.tokenizer(
            self.strings[i], 
            return_tensors="pt", 
            truncation=True, 
            max_length=self.maxLen,
            padding="max_length",
        )
        return {k: v.squeeze(0) for k, v in enc.items()}
    
with open("shakespeare.txt", "r") as f:
    text = f.read()

marker="""1609

THE SONNETS

by William Shakespeare"""

# This nasty little one liner gets rid of the header and gives us just the text
# Might want to get rid of the passage numbers later
shakespeare = text[text.find(marker)+len(marker):].strip().split('\n\n\n')

dataset = SimpleTextDataset(shakespeare, tokenizer)
dataloader = DataLoader(dataset, batch_size = 8, shuffle = True, drop_last = True)

In [None]:
# This is where we get the interesting bits
@torch.no_grad()
def collect_activations(dataloader, takeLastToken=True, maxBatches=50):
    outputActivations = [] # The eventual feature activations
    for i, batch in enumerate(dataloader):
        if i >= maxBatches: break # Make sure we don't get lost in the sauce
        """
        Q: What does this next line mean?
        A: Move all the tensors from the dataloader batches to {device}
        """
        batch = {k: v.to(device) for k, v in batch.items()}
        # Where the magic happens
        # Pass the batch through the model
        out = model(**batch)
        hiddenStates = out.hidden_states[-1]
        if takeLastToken:
            lastHiddenState = hiddenStates[:, -1, :] # This is the last hiidden state (final res stream)
                            # High in semantic data ^
            # TODO: Randomly break up text. The last token may be punctuation heavy
        else:
            lastHiddenState = hiddenStates.reshape(-1, hiddenStates.size(-1))
        
        """
        Q: What does detach() do?
        A: It pulls the tensor away from the computation graph
        Reason: That's all we need. If we don't, PyTorch will run backprop (don't need it)
        """
        outputActivations.append(lastHiddenState.detach().cpu())
        
    return torch.cat(outputActivations, dim=0)

activations = collect_activations(dataloader, takeLastToken=True, maxBatches=200)
activations.shape        

torch.Size([1240, 768])

In [None]:
# Format activations for SAE


In [None]:
# Define SAE


In [None]:
# Run activations through SAE

