In [86]:
import torch, torch.nn as nn, torch.nn.functional as torchFun
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm


In [87]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# MODEL_NAME = "google/gemma-3-270m-it"  # start small;
MODEL_NAME = "facebook/opt-125m" # Use a different small model
# MODEL_NAME = "gpt2" # Try another common small model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None: # REALLY IMPORTANT - otherwise GPT-2 needs inputs of the same length
    tokenizer.pad_token = tokenizer.eos_token
                                            # Output hidden states lets us see the last hidden layer
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, output_hidden_states=True).to(device).eval()

In [88]:
# If I'm not using my school's google drive account, that means it's free storage!
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [89]:
!ls /content/drive/MyDrive/enwiki_formatted_ds/enwiki_raw_ds/ | wc -l

3373


In [90]:
import os, glob
from concurrent.futures import ThreadPoolExecutor
from itertools import chain
from tqdm import tqdm

import torch
from torch.utils.data import Dataset, DataLoader

# ---------- Faster file reading (parallel, IO-bound so threads work well) ----------
def _read_one_file(fp):
    lines = []
    with open(fp, "r", encoding="utf-8", errors="ignore") as f:
        for line in f:
            s = line.strip()
            if s:
                lines.append(s)
    return lines

def read_all_lines_from_folder(folder="data", max_workers=8):
    files = sorted(glob.glob(os.path.join(folder, "**/*.txt"), recursive=True))
    lines_per_file = []
    with ThreadPoolExecutor(max_workers=max_workers) as ex:
        for chunk in tqdm(ex.map(_read_one_file, files), total=len(files)):
            lines_per_file.append(chunk)
    # Flatten
    return list(chain.from_iterable(lines_per_file))

# ---------- Dataset returns raw strings; tokenize later in collate ----------
class SimpleTextDataset(Dataset):
    def __init__(self, strings):
        self.strings = strings

    def __len__(self):
        return len(self.strings)

    def __getitem__(self, i):
        return self.strings[i]

# ---------- Collate function: vectorized, batched tokenization ----------
def make_collate_fn(tokenizer, max_len=256):
    def collate_fn(batch_strings):
        # batch_strings: List[str]
        enc = tokenizer(
            batch_strings,
            return_tensors="pt",
            truncation=True,
            max_length=max_len,
            padding="max_length",  # keep same behavior as your original code
        )
        return enc
    return collate_fn

# ---------- Usage ----------
# If you're on GPU, pinned memory + more workers helps a lot.
pin = torch.cuda.is_available()
num_workers = max(1, os.cpu_count() // 2)  # adjust as you like

allLines = read_all_lines_from_folder(
    "/content/drive/MyDrive/enwiki_formatted_ds/enwiki_raw_ds",
    max_workers=8,  # tune based on your disk/VM
)

dataset = SimpleTextDataset(allLines)
collate_fn = make_collate_fn(tokenizer, max_len=256)

dataloader = DataLoader(
    dataset,
    batch_size=16,
    shuffle=False, # WHATEVER YOU DO, DO NOT SHUFFLE
    drop_last=True,
    num_workers=num_workers,
    pin_memory=pin,
    persistent_workers=(num_workers > 0),
    prefetch_factor=2,  # bump to 4 if you have RAM
    collate_fn=collate_fn,
)


100%|██████████| 3373/3373 [02:25<00:00, 23.17it/s]


In [91]:
print(len(dataset))

33516409


In [92]:
# This is where we get the interesting bits

hiddenLayerToGrab = -1

@torch.no_grad()
def collect_activations(dataloader, takeLastToken=True, maxBatches=50):
    outputActivations = [] # The eventual feature activations
    for i, batch in tqdm(enumerate(dataloader)):
        if i >= maxBatches: break # Make sure we don't get lost in the sauce
        """
        Q: What does this next line mean?
        A: Move all the tensors from the dataloader batches to {device}
        """
        batch = {k: v.to(device) for k, v in batch.items()}
        # Where the magic happens
        # Pass the batch through the model
        out = model(**batch)
        hiddenStates = out.hidden_states[hiddenLayerToGrab]
        # TODO: Uncomment this next line and pass the hidden states through a normalization function
        # this might help with scaling artifacts (more accurate representation of what the model "wants to say")
        # Ensure LayerNorm is applied correctly to the last dimension
        hiddenStates = torch.nn.LayerNorm(hiddenStates.size(-1), elementwise_affine=False).to(hiddenStates.device)(hiddenStates)
        if takeLastToken:
            # TODO: Change the layer that we're looking at and see if there's any interesting activations there
            # TODO: Change the token we're grabbing, as there's a high chance the last token is punctuation
            lastHiddenState = hiddenStates[:, -1, :] # This is the last hidden state (final res stream)
                            # High in semantic data ^
            # TODO: Randomly break up text. The last token may be punctuation heavy
        else:
            lastHiddenState = hiddenStates.reshape(-1, hiddenStates.size(-1))

        """
        Q: What does detach() do?
        A: It pulls the tensor away from the computation graph
        Reason: That's all we need. If we don't, PyTorch will run backprop (don't need it)
        """

        outputActivations.append(lastHiddenState.detach().cpu())

    return torch.cat(outputActivations, dim=0)

activations = collect_activations(dataloader, takeLastToken=True, maxBatches=2500)
modelDims = activations.shape[-1]
activations.shape

2500it [02:52, 14.46it/s]


torch.Size([40000, 768])

In [93]:
# Building the magic to make sense of the interesting bits
class SAE(nn.Module):
    def __init__(self, inDims, codeDims, tied=False, topk=None):
        super().__init__()
        self.encoder = nn.Linear(inDims, codeDims, bias=True) # Bias is true because it's good at learning activation offsets
        self.decoder = nn.Linear(codeDims, inDims, bias=False) # Bias is false because we want representations to be just combinations of feature directions
        # Note: If we had bias true, some feature vectors could be offset by an arbitrary amount,
        # making it harder to compare veature vectors in some situations
        self.tied = tied # True can help improve improve how identifiable features are, constrains the solution
        self.topk = topk # Hard sparsity - keeps only largest k activations - used for fixed amount of active features
        if tied:
            self.decoder.weight = self.encoder.weight

    def encode(self, x):
        # x is the hidden layer we're passing in
        s = self.encoder(x)
        if self.topk is not None:
            # k = self.topk
            # Grab the top k values and their dimensions
            topkVals, topkIndex = torch.topk(s, self.topk, dim=-1)
            # Make a mask out of them
            mask = torch.zeros_like(s).scatter_(-1, topkIndex, 1.0)
            # Multiply the mask in and kill any features that aren't within the {self.topk} dimensions
            # Works as a "hard sparsity"
            s = s * mask
            s = torchFun.relu(s)
        else:
            # Relies on L1 penalty in loss - "soft sparsity"
            s = torchFun.relu(s)
        return s

    def forward(self, x):
        s = self.encode(x)
        xHat = self.decoder(s)
        return xHat, s

In [94]:
# SAE training function
from tqdm import trange

def train_sae(X, codeDims=4*modelDims, l1Strength=1e-3, epochs=5, batchSize=256, topk=None, learningRate=1e-3, tied=False):
    sae = SAE(inDims=modelDims, codeDims=codeDims, tied=tied, topk=topk).to(device)
    optimizer = torch.optim.Adam(sae.parameters(), lr=learningRate)

    dataset = torch.utils.data.TensorDataset(X) # X is CPU tensor
    dataLoader = DataLoader(dataset, batch_size=batchSize, shuffle=True, drop_last=True)

    pbar = trange(epochs, desc='Bar desc', leave=True)
    for ep in pbar:
        losses, reconstructionLosses, l1Losses = [], [], []
        for (batchActivations,) in dataLoader:
            batchActivations = batchActivations.to(device)
            xhat, s = sae(batchActivations)
            reconstructionLoss = torchFun.mse_loss(xhat, batchActivations)
            l1Penalty = s.abs().mean()
            loss = reconstructionLoss + l1Strength * l1Penalty

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            losses.append(loss.item())
            reconstructionLosses.append(reconstructionLoss.item())
            l1Losses.append(l1Penalty.item())
        pbar.set_description(f"ep {ep}: loss {sum(losses)/len(losses):.4f} | recon {sum(reconstructionLosses)/len(reconstructionLosses):.4f} | L1 {sum(l1Losses)/len(l1Losses):.4f}")
        pbar.refresh()

    return sae.eval()


In [95]:
# A helpful function to get the topk features
# This is what maximally activates features

@torch.no_grad()
def get_feature_topk(sae, X, k=20, batch=1024):
    indexes = []
    allScores = []
    for i in range(0, X.size(0), batch):
        batchActivations = X[i:i+batch].to(device) # xb is a batch of activations
        _, s = sae(batchActivations) # s is sparse code
        allScores.append(s.detach().cpu())
    scores = torch.cat(allScores, dim=0)

    for j in range(scores.size(1)):
        vals, topIndex = torch.topk(scores[:, j], k)
        indexes.append(topIndex)
    return indexes, scores

In [96]:
import random

# Some helper functions for analysis
@torch.no_grad()
def eval_reconstruction_stats(sae, Xeval, batch=2048):
    reconstructionSum, elements = 0.0, 0
    allScores = []
    print(len(Xeval), batch)
    for i in range(0, len(Xeval), batch):
        xb = Xeval[i:i+batch].to(device)
        xhat, score = sae(xb)
        # Calculate the mean squared error of the model over the evaluation ds
        reconstructionSum += torchFun.mse_loss(xhat, xb, reduction='sum').item()
        elements += xb.numel()
        allScores.append(score.detach().cpu())
    mse = reconstructionSum / elements
    variance = Xeval.pow(2).mean().item()
    r2 = max(0.0, 1.0 - mse / (variance + 1e-12))
    score = torch.cat(allScores, dim=0)
    return mse, r2, score

@torch.no_grad()
def calc_sparsity_metrics(encoderActivations):
    # Feature activation density
    density = (encoderActivations > 0).float().mean().item()
    # Part of dictionary that's unused
    deadRate = (encoderActivations.sum(dim=0) == 0).float().mean().item()
    # How strongly/how often a feature is used
    usage = encoderActivations.abs().sum(dim=0)
    # gini coefficient formula -> degree of variation in dataset
    if usage.sum() == 0:
        gini = 0.0
    else:
        u = torch.sort(usage)[0]
        n = u.numel()
        # normalized Gini (0 = equal usage, 1 = unequal)
        gini = ((2 * torch.arange(1, n + 1) - n - 1).float().to(u) * u).sum() / (n * u.sum() + 1e-12)
        gini = gini.item()
    # Note: Look at idea.md, working session 5
    # for tips on interpreting/fixing these v
    return density, deadRate, gini

@torch.no_grad()
def decoder_cosine_summary(sae, sampleLimit=2000):
    # weight = indims, codeDims
    decoder = sae.decoder.weight.detach().cpu().T
    dcode = decoder.size(0)
    m = min(sampleLimit, max(0, (dcode * (dcode - 1)) // 2))
    if m == 0:
        return {"mean": float('nan'), "p95": float('nan'), "max": float('nan')}
    pairs = set()
    while len(pairs) < m:
        i, j = random.randrange(dcode), random.randrange(dcode)
        if i < j: pairs.add((i, j))
    pairs = list(pairs)
    A = torchFun.normalize(decoder[[i for i,_ in pairs]], dim=1)
    B = torchFun.normalize(decoder[[j for _,j in pairs]], dim=1)
    cos = (A * B).sum(dim=1)
    cosSorted = torch.sort(cos).values
    return {
        "mean": cos.mean().item(),
        "p95": cosSorted[int(0.95*len(cosSorted))-1].item() if len(cosSorted) > 1 else cosSorted.item(),
        "max": cos.max().item()
    }


In [97]:
# Train and evaluate SAE
import math, random
import torch
import torch.nn as nn
import torch.nn.functional as torchFun
from torch.utils.data import DataLoader, TensorDataset


codeDimsMultiplier = 4      # A multiplier for the "number of concepts" a model can learn
l1Strength         = 1e-3   # The strength of the l1 "lens"
epochs             = 100    # Iterations to train for
batchSize          = 512    # Samples to process at the same time
topkFeatures       = 64    # Number of top features to keep
learningRate       = 1e-3   # Learning rate to iterate at
seed               = 4738   # Random number seed
printTopN          = 10      # How many examples to print
showFeatures       = 10      # Number of features to show
saveTopFeatures    = True   # Whether or not to save the top feature texts


torch.manual_seed(seed)
X_raw = activations.clone()
with torch.no_grad():
    X_mean = X_raw.mean(dim=0, keepdim=True)
X = (X_raw - X_mean)
N, d_in = X.shape
assert d_in == modelDims, f"Expected d_in == modelDims, got {d_in} vs {modelDims}"
codeDims = codeDimsMultiplier * modelDims

perm = torch.randperm(N)
valFrac = 0.1
nVal = max(1, int(N * valFrac))
valIdx, trainIdx = perm[:nVal], perm[nVal:]
X_train, X_val = X[trainIdx], X[valIdx]

# for topkFeatures in [32, 64, 96]:
if True: #This is dummy so I don't have to indent everything below
  sae = train_sae(
      X_train,
      codeDims=codeDims,
      l1Strength=l1Strength,
      epochs=epochs,
      batchSize=batchSize,
      topk=topkFeatures,
      learningRate=learningRate,
      tied=False,
  )


  mseVal, r2Val, S_val = eval_reconstruction_stats(sae, X_val)
  density, deadRate, gini = calc_sparsity_metrics(S_val)
  cosStats = decoder_cosine_summary(sae)

  print("\n=== SAE Evaluation (Validation Set) ===")
  print(f"Data samples: {len(dataset)}")
  print(f"Grabbed hidden layer: {hiddenLayerToGrab}")
  print(f"N_val: {X_val.size(0)}  |  modelDims: {modelDims}  |  codeDims: {codeDims}")
  print(f"Reconstruction MSE: {mseVal:.6f}")
  print(f"Reconstruction R^2: {r2Val:.4f}   (vs centered baseline)")
  print(f"Activation density (mean L0 fraction): {density:.4f}")
  print(f"Dead feature rate: {deadRate:.4f}")
  print(f"Feature-usage inequality (gini-like): {gini:.4f}")
  print(f"Decoder cosine summary: mean={cosStats['mean']:.3f}, p95={cosStats['p95']:.3f}, max={cosStats['max']:.3f}")


  if 'get_feature_topk' in globals():
      idxs, S_all = get_feature_topk(sae, X)
      featActivity = S_val.abs().sum(dim=0)
      kShow = min(showFeatures, featActivity.numel())
      topFeatIds = torch.topk(featActivity, k=kShow).indices.tolist()
      saveNumber = len(os.listdir('feature_text_output'))

      print("\n=== Feature cards (exemplar indices) ===")
      for j in topFeatIds:
          topIdx = idxs[j][:printTopN].tolist()
          print(f"\nFeature {j} — top {printTopN} exemplar rows: {topIdx}")
          try:
              if 'dataset' in globals() and hasattr(dataset, 'strings'):
                  for r in topIdx:
                      if 0 <= r < len(dataset.strings):
                          print("  •", dataset.strings[r][:200].replace("\n", " "))
                          if saveTopFeatures:
                              with open(f'feature_text_output/{MODEL_NAME.replace("/", "_")}_{saveNumber}_feature_{j}.txt', 'a') as f:
                                  f.write(dataset.strings[r].replace("\n", " ") + "\n")
          except Exception as e:
              print(e)
              pass

  with torch.no_grad():
      meanActive = (S_val > 0).float().sum(dim=1).float().mean().item()
  print(f"\nApprox mean active features per sample (val): {meanActive:.1f}")

  # Save the model for later loading

  import time
  if not os.path.exists('savedModels'):
      os.mkdir('savedModels')

  modelPath = os.path.join('savedModels', f'model_{len(os.listdir("savedModels"))}_{time.time()}.pth')
  torch.save(sae.state_dict(), modelPath)
  print(f"Model saved to {modelPath}")

  varsToWrite = [
      modelPath,
      codeDimsMultiplier,
      l1Strength,
      epochs,
      batchSize,
      topkFeatures,
      learningRate,
      len(dataset),
      hiddenLayerToGrab,
      modelDims,
      codeDims,
      mseVal,
      r2Val,
      density,
      deadRate,
      gini,
      cosStats['mean'],
      cosStats['p95'],
      cosStats['max'],
      meanActive,
  ]

  # Convert everything to string before joining
  stringToWrite = ",".join(str(v) for v in varsToWrite) + "\n"

  # Use 'w' to overwrite or 'a' to append
  with open('out.txt', 'a') as f:
      f.write(stringToWrite)

ep 99: loss 0.0148 | recon 0.0148 | L1 0.0228: 100%|██████████| 100/100 [00:46<00:00,  2.17it/s]


4000 2048

=== SAE Evaluation (Validation Set) ===
Data samples: 33516409
Grabbed hidden layer: -1
N_val: 4000  |  modelDims: 768  |  codeDims: 3072
Reconstruction MSE: 0.026124
Reconstruction R^2: 0.9416   (vs centered baseline)
Activation density (mean L0 fraction): 0.0208
Dead feature rate: 0.1514
Feature-usage inequality (gini-like): 0.9283
Decoder cosine summary: mean=0.005, p95=0.081, max=0.209

=== Feature cards (exemplar indices) ===

Feature 251 — top 10 exemplar rows: [21711, 33230, 28077, 39849, 1427, 33205, 1426, 30316, 18531, 7401]
  • The leaves, known as tējapattā or tejpatta (तेजपत्ता) in Hindi, tejpat (तेजपात/তেজপাত) in Nepali, Maithili and Assamese, tejpata (তেজপাতা) in Bengali, vazhanayila/edanayila (വഴനയില/എടനഇല) in Malayalam
  • Since 2003 Carpenter has been performing with leading musicians and orchestras in the United States and Europe. As the First Prize winner of the 2005 Philadelphia Orchestra Young Artists Competition, 
  • The neighborhood is mainly a reside