## Notebook to explore the performance of different toy models of superposition across various sparsity levels

Toy models: 
1. Simple - ReLU(x)
2. Simple - Relu(x) (handcoded, Lucius)
3. Apollo: Residual + Embedding
4. Anthropic: abs(x)

Sparsity levels: 


Other info: 

In [1]:
"""Notebook settings and imports."""

%load_ext autoreload
%autoreload 2
# %flow mode reactive

import os

from dataclasses import dataclass, field
from typing import Callable, List, Optional, Tuple, Union

import numpy as np
import torch as t

from einops import asnumpy, einsum, rearrange, reduce, repeat, pack, parse_shape, unpack
from einops.layers.torch import Rearrange, Reduce
from jaxtyping import Float, Int
from matplotlib import pyplot as plt
from plotly import express as px
from plotly import graph_objects as go
from plotly import io as pio
from rich import print as rprint
from torch import nn, optim
from torch.nn import functional as F
from tqdm.notebook import tqdm

from toy_cis import plot
from toy_cis.models import CisConfig, Cis
from toy_cis.util import threshold_matrix

In [2]:
"""Set KMP_DUPLICATE_LIB_OK=TRUE to avoid MKL errors when plotting with mpl"""

os.environ["KMP_DUPLICATE_LIB_OK"] = "True"

In [3]:
"""Create functions for generating batches, loss, and training."""

# Create function for generating batch of `x` and `y_true` data
def gen_batch(model: Cis, batch_sz: int) -> (
    tuple[Float[t.Tensor, "batch inst feat"], Float[t.Tensor, "batch inst feat"]]
):
    """Generates a batch of x=(sparse feature vals on [-1, 1]) and y=abs(x)."""

    # Randomly generate features vals, and for each, randomly set which samples are non-zero
    x = t.rand(batch_sz, model.cfg.n_instances, model.cfg.n_feat) * 2 - 1  # [-1, 1]
    is_active = t.rand(batch_sz, model.cfg.n_instances, model.cfg.n_feat) < (1 - model.s)
    x *= is_active
    return x, F.relu(x)

def loss_fn(y, y_true, i):
    return reduce((y - y_true) ** 2 * i, "batch inst feat -> ", "mean")

def train(
    model: Cis,
    batch_sz: int,
    loss_fn: Callable,
    optimizer: optim.Optimizer,
    n_steps: int,
    logging_freq: int,
) -> List[Float]:
    """Trains the model for `n_steps` steps, logging loss every `logging_freq` steps."""    
    losses = []

    pbar = tqdm(range(n_steps), desc="Training")
    for step in pbar:
        x, y_true = gen_batch(model, batch_sz)
        y = model.forward(x)
        loss = loss_fn(y, y_true, model.i)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # Log progress
        if step % logging_freq == 0 or (step + 1 == n_steps):
            losses.append(loss.item())
            pbar.set_postfix({"loss": f"{loss.item():.4f}"})
    
    return losses

In [None]:
# 1. Load model 1 = Simple ReLU(x)
