In [None]:
import torch, torch.nn as nn, torch.nn.functional as torchFun
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_NAME = "google/gemma-3-270m-it"  # start small; 
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None: # REALLY IMPORTANT - otherwise GPT-2 needs inputs of the same length
    tokenizer.pad_token = tokenizer.eos_token
                                            # Output hidden states lets us see the last hidden layer
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, output_hidden_states=True).to(device).eval()


In [None]:
!curl "https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt" > "shakespeare.txt"
!curl "https://gist.githubusercontent.com/MattIPv4/045239bc27b16b2bcf7a3a9a4648c08a/raw/2411e31293a35f3e565f61e7490a806d4720ea7e/bee%2520movie%2520script" > "beeMovie.txt"

In [None]:
# import kagglehub

# # Download latest version
# path = kagglehub.dataset_download("wikimedia-foundation/wikipedia-structured-contents")

# print("Path to dataset files:", path)

In [None]:
import os, glob, random, torch
from torch.utils.data import IterableDataset, DataLoader

class LineByLineTextIterable(IterableDataset):
    def __init__(self, folder="data", tokenizer=None, maxLen=256, shuffleBuffer=8192, seed=1337):
        super().__init__()
        self.folder = folder
        self.tokenizer = tokenizer
        self.maxLen = maxLen
        self.shuffleBuffer = shuffleBuffer
        self.seed = seed

        self.files = sorted(glob.glob(os.path.join(folder, "**/*.txt"), recursive=True))

    def line_iterator(self):
        rng = random.Random(self.seed)
        files = self.files[:]
        rng.shuffle(files)

        for fp in files:
            with open(fp, "r", encoding="utf-8", errors="ignore") as f:
                for line in f:
                    s = line.strip()
                    if s:
                        yield s

    def __iter__(self):
        worker = torch.utils.data.get_worker_info()
        seed = self.seed if worker is None else self.seed + worker.id * 10_000

        rng = random.Random(seed)
        buf = []

        for s in self.line_iterator():
            if len(buf) < self.shuffleBuffer:
                buf.append(s)
                continue
            j = rng.randrange(self.shuffleBuffer)
            out = buf[j]
            buf[j] = s
            yield out

        rng.shuffle(buf)
        for s in buf:
            yield s

def collate_tokenize(batch, tokenizer, maxLen=256):
    enc = tokenizer(
        batch,
        return_tensors="pt",
        truncation=True,
        max_length=maxLen,
        padding="longest",
    )
    return {k: v for k, v in enc.items()}



In [None]:
class SimpleTextDataset(Dataset):
    def __init__(self, strings, tokenizer, maxLen=256):
        self.strings = strings
        self.tokenizer = tokenizer
        self.maxLen = maxLen

    def __len__(self):
        return len(self.strings)

    def __getitem__(self, i):
        enc = self.tokenizer(
            self.strings[i],
            return_tensors="pt",
            truncation=True,
            max_length=self.maxLen,
            padding="max_length",
        )
        return {k: v.squeeze(0) for k, v in enc.items()}

with open("shakespeare.txt", "r") as f:
    text = f.read()

marker="""1609

THE SONNETS

by William Shakespeare"""

# This nasty little one liner gets rid of the header and gives us just the text
# Might want to get rid of the passage numbers later
shakespeare = text[text.find(marker)+len(marker):].strip().split('\n\n\n')

# dataset = SimpleTextDataset(shakespeare, tokenizer)

streamingDS = LineByLineTextIterable(
    folder="data",
    tokenizer=tokenizer,
    maxLen=256,
    shuffleBuffer=8192,  # tune: larger = better shuffle, more RAM
    seed=1337,
)

import os, glob

def read_all_lines_from_folder(folder="data"):
    files = sorted(glob.glob(os.path.join(folder, "**/*.txt"), recursive=True))
    lines = []
    for fp in files:
        with open(fp, "r", encoding="utf-8", errors="ignore") as f:
            for line in f:
                line = line.strip()
                if line:
                    lines.append(line)
    return lines

allLines = read_all_lines_from_folder("data")
dataset = SimpleTextDataset(allLines, tokenizer, maxLen=256)  # reuses your class
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, drop_last=True)


# dataloader = DataLoader(
#     streamingDS,
#     batch_size=8,                  # tune for GPU throughput
#     shuffle=False,                 # must be False for IterableDataset
#     num_workers=2,                 # >0 to read files in parallel
#     pin_memory=True,               # good if using CUDA
#     persistent_workers=True,       # keeps workers alive across iterations
#     collate_fn=lambda batch: collate_tokenize(batch, tokenizer, maxLen=256),
#     drop_last=True,
# )

In [None]:
# This is where we get the interesting bits
@torch.no_grad()
def collect_activations(dataloader, takeLastToken=True, maxBatches=50):
    outputActivations = [] # The eventual feature activations
    for i, batch in enumerate(dataloader):
        if i >= maxBatches: break # Make sure we don't get lost in the sauce
        """
        Q: What does this next line mean?
        A: Move all the tensors from the dataloader batches to {device}
        """
        batch = {k: v.to(device) for k, v in batch.items()}
        # Where the magic happens
        # Pass the batch through the model
        out = model(**batch)
        hiddenStates = out.hidden_states[-1]
        # TODO: Uncomment this next line and pass the hidden states through a normalization function
        # this might help with scaling artifacts (more accurate representation of what the model "wants to say")
        # model.transformer.ln_f(hiddenStates)
        if takeLastToken:
            # TODO: Change the layer that we're looking at and see if there's any interesting activations there
            # TODO: Change the token we're grabbing, as there's a high chance the last token is punctuation
            lastHiddenState = hiddenStates[:, -1, :] # This is the last hidden state (final res stream)
                            # High in semantic data ^
            # TODO: Randomly break up text. The last token may be punctuation heavy
        else:
            lastHiddenState = hiddenStates.reshape(-1, hiddenStates.size(-1))

        """
        Q: What does detach() do?
        A: It pulls the tensor away from the computation graph
        Reason: That's all we need. If we don't, PyTorch will run backprop (don't need it)
        """
        outputActivations.append(lastHiddenState.detach().cpu())

    return torch.cat(outputActivations, dim=0)

activations = collect_activations(dataloader, takeLastToken=True, maxBatches=50)
modelDims = activations.shape[-1]
activations.shape

In [None]:
# TODO: Data whitening
# Write only if needed later
# print(activations[0].sort())


In [None]:
# import matplotlib.pyplot as plt

# for act in activations[:10]:
#     actNormed = model.transformer.ln_f(act).detach().numpy()
#     # layerNorm = torch.nn.LayerNorm(act.size(-1), elementwise_affine=False).to(device).eval()
#     # actNormed = layerNorm(act).cpu().detach().numpy()
#     # print(torch.allclose(act, actNormed)) #<- Outputs False, so layerNorm *should* be doing something


#     (neuronActivations, indices) = zip(*enumerate(actNormed))
#     plt.bar(range(len(act)), act, width=10)
# plt.xlabel("Indices")
# plt.ylabel("Neuron Activations")
# plt.show()
# # What's happening around neurons 400-500 and ~300???

In [None]:
# Building the magic to make sense of the interesting bits
class SAE(nn.Module):
    def __init__(self, inDims, codeDims, tied=False, topk=None):
        super().__init__()
        self.encoder = nn.Linear(inDims, codeDims, bias=True) # Bias is true because it's good at learning activation offsets
        self.decoder = nn.Linear(codeDims, inDims, bias=False) # Bias is false because we want representations to be just combinations of feature directions
        # Note: If we had bias true, some feature vectors could be offset by an arbitrary amount,
        # making it harder to compare veature vectors in some situations
        self.tied = tied # True can help improve improve how identifiable features are, constrains the solution
        self.topk = topk # Hard sparsity - keeps only largest k activations - used for fixed amount of active features
        if tied:
            self.decoder.weight = self.encoder.weight

    def encode(self, x):
        # x is the hidden layer we're passing in
        s = self.encoder(x)
        if self.topk is not None:
            # k = self.topk
            # Grab the top k values and their dimensions
            topkVals, topkIndex = torch.topk(s, self.topk, dim=-1)
            # Make a mask out of them
            mask = torch.zeros_like(s).scatter_(-1, topkIndex, 1.0)
            # Multiply the mask in and kill any features that aren't within the {self.topk} dimensions
            # Works as a "hard sparsity"
            s = s * mask
            s = torchFun.relu(s)
        else:
            # Relies on L1 penalty in loss - "soft sparsity"
            s = torchFun.relu(s)
        return s

    def forward(self, x):
        s = self.encode(x)
        xHat = self.decoder(s)
        return xHat, s

In [None]:
# SAE training function
from tqdm import trange

def train_sae(X, codeDims=4*modelDims, l1Strength=1e-3, epochs=5, batchSize=256, topk=None, learningRate=1e-3, tied=False):
    sae = SAE(inDims=modelDims, codeDims=codeDims, tied=tied, topk=topk).to(device)
    optimizer = torch.optim.Adam(sae.parameters(), lr=learningRate)

    dataset = torch.utils.data.TensorDataset(X) # X is CPU tensor
    dataLoader = DataLoader(dataset, batch_size=batchSize, shuffle=True, drop_last=True)

    pbar = trange(epochs, desc='Bar desc', leave=True)
    for ep in pbar:
        losses, reconstructionLosses, l1Losses = [], [], []
        for (batchActivations,) in dataLoader:
            batchActivations = batchActivations.to(device)
            xhat, s = sae(batchActivations)
            reconstructionLoss = torchFun.mse_loss(xhat, batchActivations)
            l1Penalty = s.abs().mean()
            loss = reconstructionLoss + l1Strength * l1Penalty

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            losses.append(loss.item())
            reconstructionLosses.append(reconstructionLoss.item())
            l1Losses.append(l1Penalty.item())
        pbar.set_description(f"ep {ep}: loss {sum(losses)/len(losses):.4f} | recon {sum(reconstructionLosses)/len(reconstructionLosses):.4f} | L1 {sum(l1Losses)/len(l1Losses):.4f}")
        pbar.refresh()

    return sae.eval()


In [None]:
# A helpful function to get the topk features
# This is what maximally activates features

@torch.no_grad()
def get_feature_topk(sae, X, k=20, batch=1024):
    indexes = []
    allScores = []
    for i in range(0, X.size(0), batch):
        batchActivations = X[i:i+batch].to(device) # xb is a batch of activations
        _, s = sae(batchActivations) # s is sparse code
        allScores.append(s.detach().cpu())
    scores = torch.cat(allScores, dim=0)

    for j in range(scores.size(1)):
        vals, topIndex = torch.topk(scores[:, j], k)
        indexes.append(topIndex)
    return indexes, scores

In [None]:
# Feature steering
# Lets us push features in directions (small + 1 = medium + 1 = large)
# Shows that a direction has meaning

@torch.no_grad()
def steer_with_feature(model, tokenizer, prompt, sae, featureId, epsilon=0.5):
    # Run a pass through the LM
    encoder = tokenizer(prompt, return_tensors="pt").to(device)
    out = model(**encoder, output_hidden_states=True)
    H = out.hidden_states[-1][:, -1, :]

    # Run a pass through the SAE with the LM's hidden states
    _, s = sae(H)
    d = sae.decoder.weight[featureId]

    # Steer the SAE's outputs in a direction
    H_steered = H + epsilon * d

    # Pass the logits back out for analysis
    logits = model.lm_head(H)
    logits_steered = model.lm_head(H_steered)
    return logits.squeeze(0), logits_steered.squeeze(0)



In [None]:
import random

# Some helper functions for analysis
@torch.no_grad()
def eval_reconstruction_stats(sae, Xeval, batch=2048):
    reconstructionSum, elements = 0.0, 0
    allScores = []
    print(len(Xeval), batch)
    for i in range(0, len(Xeval), batch):
        xb = Xeval[i:i+batch].to(device)
        xhat, score = sae(xb)
        # Calculate the mean squared error of the model over the evaluation ds
        reconstructionSum += torchFun.mse_loss(xhat, xb, reduction='sum').item()
        elements += xb.numel()
        allScores.append(score.detach().cpu())
    mse = reconstructionSum / elements
    variance = Xeval.pow(2).mean().item()
    r2 = max(0.0, 1.0 - mse / (variance + 1e-12))
    score = torch.cat(allScores, dim=0)
    return mse, r2, score

@torch.no_grad()
def calc_sparsity_metrics(encoderActivations):
    # Feature activation density
    density = (encoderActivations > 0).float().mean().item()
    # Part of dictionary that's unused
    deadRate = (encoderActivations.sum(dim=0) == 0).float().mean().item()
    # How strongly/how often a feature is used
    usage = encoderActivations.abs().sum(dim=0)
    # gini coefficient formula -> degree of variation in dataset
    if usage.sum() == 0:
        gini = 0.0
    else:
        u = torch.sort(usage)[0]
        n = u.numel()
        # normalized Gini (0 = equal usage, 1 = unequal)
        gini = ((2 * torch.arange(1, n + 1) - n - 1).float().to(u) * u).sum() / (n * u.sum() + 1e-12)
        gini = gini.item()
    # Note: Look at idea.md, working session 5
    # for tips on interpreting/fixing these v
    return density, deadRate, gini

@torch.no_grad()
def decoder_cosine_summary(sae, sampleLimit=2000):
    # weight = indims, codeDims
    decoder = sae.decoder.weight.detach().cpu().T
    dcode = decoder.size(0)
    m = min(sampleLimit, max(0, (dcode * (dcode - 1)) // 2))
    if m == 0:
        return {"mean": float('nan'), "p95": float('nan'), "max": float('nan')}
    pairs = set()
    while len(pairs) < m:
        i, j = random.randrange(dcode), random.randrange(dcode)
        if i < j: pairs.add((i, j))
    pairs = list(pairs)
    A = torchFun.normalize(decoder[[i for i,_ in pairs]], dim=1)
    B = torchFun.normalize(decoder[[j for _,j in pairs]], dim=1)
    cos = (A * B).sum(dim=1)
    cosSorted = torch.sort(cos).values
    return {
        "mean": cos.mean().item(),
        "p95": cosSorted[int(0.95*len(cosSorted))-1].item() if len(cosSorted) > 1 else cosSorted.item(),
        "max": cos.max().item()
    }


In [None]:
# Train and evaluate SAE
import math, random
import torch
import torch.nn as nn
import torch.nn.functional as torchFun
from torch.utils.data import DataLoader, TensorDataset


codeDimsMultiplier = 8      # A multiplier for the "number of concepts" a model can learn
l1Strength         = 3e-2   # The strength of the l1 "lens"
epochs             = 100    # Iterations to train for
batchSize          = 256    # Samples to process at the same time
topkFeatures       = 10    # Number of top features to keep
learningRate       = 1e-3   # Learning rate to iterate at
seed               = 4738   # Random number seed
printTopN          = 5      # How many examples to print
showFeatures       = 5      # Number of features to show


torch.manual_seed(seed)
X_raw = activations.clone()
with torch.no_grad():
    X_mean = X_raw.mean(dim=0, keepdim=True)
X = (X_raw - X_mean)
N, d_in = X.shape
assert d_in == modelDims, f"Expected d_in == modelDims, got {d_in} vs {modelDims}"
codeDims = codeDimsMultiplier * modelDims

perm = torch.randperm(N)
valFrac = 0.1
nVal = max(1, int(N * valFrac))
valIdx, trainIdx = perm[:nVal], perm[nVal:]
X_train, X_val = X[trainIdx], X[valIdx]

# for l1Strength in [0, 1e-4, 3e-4, 1e-3]:
if True:
  sae = train_sae(
      X_train,
      codeDims=codeDims,
      l1Strength=l1Strength,
      epochs=epochs,
      batchSize=batchSize,
      topk=topkFeatures,
      learningRate=learningRate,
      tied=False,
  )


  mseVal, r2Val, S_val = eval_reconstruction_stats(sae, X_val)
  density, deadRate, gini = calc_sparsity_metrics(S_val)
  cosStats = decoder_cosine_summary(sae)

  print("\n=== SAE Evaluation (Validation Set) ===")
  print(f"N_val: {X_val.size(0)}  |  modelDims: {modelDims}  |  codeDims: {codeDims}")
  print(f"Reconstruction MSE: {mseVal:.6f}")
  print(f"Reconstruction R^2: {r2Val:.4f}   (vs centered baseline)")
  print(f"Activation density (mean L0 fraction): {density:.4f}")
  print(f"Dead feature rate: {deadRate:.4f}")
  print(f"Feature-usage inequality (gini-like): {gini:.4f}")
  print(f"Decoder cosine summary: mean={cosStats['mean']:.3f}, p95={cosStats['p95']:.3f}, max={cosStats['max']:.3f}")


  if 'get_feature_topk' in globals():
      idxs, S_all = get_feature_topk(sae, X)
      featActivity = S_val.abs().sum(dim=0)
      kShow = min(showFeatures, featActivity.numel())
      topFeatIds = torch.topk(featActivity, k=kShow).indices.tolist()

      print("\n=== Feature cards (exemplar indices) ===")
      for j in topFeatIds:
          topIdx = idxs[j][:printTopN].tolist()
          print(f"\nFeature {j} — top {printTopN} exemplar rows: {topIdx}")
          try:
              if 'dataset' in globals() and hasattr(dataset, 'strings'):
                  for r in topIdx:
                      if 0 <= r < len(dataset.strings):
                          print("  •", dataset.strings[r][:200].replace("\n", " "))
          except Exception:
              pass

  with torch.no_grad():
      meanActive = (S_val > 0).float().sum(dim=1).float().mean().item()
  print(f"\nApprox mean active features per sample (val): {meanActive:.1f}")

  varsToWrite = [
      codeDimsMultiplier,
      l1Strength,
      epochs,
      batchSize,
      topkFeatures,
      learningRate,
      modelDims,
      codeDims,
      mseVal,
      r2Val,
      density,
      deadRate,
      gini,
      cosStats['mean'],
      cosStats['p95'],
      cosStats['max'],
      meanActive,
  ]

  # Convert everything to string before joining
  stringToWrite = ",".join(str(v) for v in varsToWrite) + "\n"

  # Use 'w' to overwrite or 'a' to append
  with open('out.txt', 'a') as f:
      f.write(stringToWrite)

