## Terminology prerequisites

*computation-in-superposition (CiS)*: <br>
A model expressing CiS performs more computations than it has neurons, and takes advantage of superposition to perform better with sparser inputs.

*compressed-computation (CC)*: <br>
A model expressing CC performs more computations than it has neurons.

*naive loss*: <br>
The loss of a model that performs `n_neuron` computations perfectly. i.e. a baseline loss for the case where each neuron is performing one computation. A model performing CiS or CC must have loss lower than the naive loss.

## Summary

In further investigations of the Toy Model of Compressed Computation (TMCC) in the APD paper, we find that the model may be doing something different to what the authors originally describe.

Here are some points that support this:

1. We find that both the embedding matrix and the residual stream are required in order for the model to perform below naive loss.

2. Importantly, if we train a TMCC model called `EmbResMlp`, create a new model without the embed and unembed matrices called `ResMlp`, and transplant: <br>
`einsum(EmbResMlp.We, EmbResMlp.W1, "emb feat, neur emb -> neur feat")` -> `ResMlp.W1` <br>
`einsum(EmbResMlp.W2, EmbResMlp.Wu, "emb neur, feat emb -> feat neur")` -> `ResMlp.W2` <br>
`ResMlp` never does better than naive loss.

3. If we train `EmbResMlp` on L1-loss instead of L2-loss, it computes perfectly `n_neuron` features and outputs 0 for the remaining features.

This suggests that `EmbResMlp` performs better than the naive loss not by performing meaningful computations on features, but instead by using the combination of the embedding matrix and residual stream to effectively cancel out noise or interference in the input space.

In [1]:
"""Notebook settings and imports."""

%load_ext autoreload
%autoreload 2
# %flow mode reactive

import os

from collections import defaultdict
from dataclasses import dataclass, field
from typing import Callable, List, Optional, Tuple, Union

import matplotlib as mpl
import numpy as np
import pandas as pd
import seaborn as sns
import torch as t

from einops import asnumpy, einsum, rearrange, reduce, repeat, pack, parse_shape, unpack
from einops.layers.torch import Rearrange, Reduce
from jaxtyping import Float, Int
from matplotlib import pyplot as plt
from matplotlib import ticker as mticker
from plotly import express as px
from plotly import graph_objects as go
from plotly import io as pio
from rich import print as rprint
from scipy.stats import describe, kstest, mannwhitneyu, sem
from torch import nn, optim, Tensor
from torch.nn import functional as F
from tqdm.notebook import tqdm

from toy_cis.models import CisConfig, Cis
from toy_cis.plot import plot_weight_bars, plot_input_output_response
from toy_cis.util import threshold_matrix, in_out_response

In [2]:
"""Set KMP_DUPLICATE_LIB_OK=TRUE to avoid MKL errors when plotting with mpl"""

os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
plt.rcParams.update({"font.size": 14})

In [3]:
"""Set torch device."""

# device = t.device("cpu")  # small toy models may be faster via cpu
device = t.device("cuda" if t.cuda.is_available() else "cpu")
device_name = t.cuda.get_device_name(0) if t.cuda.is_available() else "cpu"
print(f"{device_name=}")

device_name='NVIDIA GeForce RTX 3090'


## Transplant from res-embed-mlp to mlp

In [5]:
"""Create functions for generating batches, loss, and training."""

def gen_batch(
    model: Cis,
    batch_sz: int,
    sparsity: float | Float[Tensor, "inst feat"],
    res_factor: float,
    device: t.device
) -> (
    tuple[Float[Tensor, "batch inst feat"], Float[Tensor, "batch inst feat"]]
):
    """Generates a batch of x, y data."""
    # Randomly generate features vals, and for each, randomly set which samples are non-zero
    x = t.rand(batch_sz, model.cfg.n_instances, model.cfg.n_feat, device=device) * 2 - 1  # [-1, 1]
    is_active = (
        t.rand(batch_sz, model.cfg.n_instances, model.cfg.n_feat, device=device) < (1 - sparsity)
    )
    x *= is_active
    return x, t.relu(x) + (res_factor * x)

def loss_fn(y, y_true, i):
    return reduce((y - y_true) ** 2 * i, "batch inst feat -> ", "mean")

def train(
    model: Cis,
    batch_sz: int,
    feat_sparsity: float | Float[Tensor, "inst feat"],
    feat_importance: float | Float[Tensor, "inst feat"],
    res_factor: float,
    loss_fn: Callable,
    optimizer: optim.Optimizer,
    n_steps: int,
    logging_freq: int,
    device: t.device
) -> List[Float]:
    """Trains the model for `n_steps` steps, logging loss every `logging_freq` steps."""    
    losses = []

    pbar = tqdm(range(n_steps), desc="Training")
    for step in pbar:
        x, y_true = gen_batch(model, batch_sz, feat_sparsity, res_factor, device)
        y = model(x, res_factor)
        loss = loss_fn(y, y_true, feat_importance)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # Log progress
        if step % logging_freq == 0 or (step + 1 == n_steps):
            losses.append(loss.item())
            pbar.set_postfix({"loss": f"{loss.item():.6f}"})
    
    return losses

@t.no_grad()
def eval_model(
    model: Cis, 
    batch_sz: int,
    feat_sparsity: float | Float[Tensor, "inst feat"],
    feat_importance: float | Float[Tensor, "inst feat"],
    res_factor: float,
    n_batches: int,
    device: t.device
):
    losses = t.zeros(n_batches)
    
    for b in range(n_batches):
        x, y_true = gen_batch(model, batch_sz, feat_sparsity, res_factor, device)
        y = model(x, res_factor)
        losses[b] = loss_fn(y, y_true, feat_importance)
    
    return losses

In [6]:
"""Create and train model."""

n_runs = 1
min_loss = []
layer_act_fns = [t.relu, lambda x: x]
batch_sz = 1024
n_feat = 100
n_hidden = 50
We_dim = 1000
feat_sparsity = 0.99
feat_importance = 1
res_factor = 0
n_steps = 20000
logging_freq = n_steps // 10

for run in range(n_runs):
    # Create model
    reluPlusX_res_embed_cfg = CisConfig(
        n_instances=1,
        n_feat=n_feat,
        n_hidden=n_hidden,
        act_fn=layer_act_fns,
        b1=None,
        b2=None,
        skip_cnx=False,
        We_and_Wu=True,
        We_dim=We_dim,
    )
    reluPlusX_res_embed_cis = Cis(reluPlusX_res_embed_cfg, device=device)

    # Train model
    optimizer = t.optim.Adam(reluPlusX_res_embed_cis.parameters(), lr=5e-4)

    losses = train(
        reluPlusX_res_embed_cis,
        batch_sz,
        feat_sparsity,
        feat_importance,
        res_factor,
        loss_fn,
        optimizer,
        n_steps,
        logging_freq,
        device
    )

    # Eval model
    eval_loss_adj = eval_model(
        reluPlusX_res_embed_cis, 
        batch_sz=10000, 
        feat_sparsity=feat_sparsity,
        feat_importance=feat_importance,
        res_factor=res_factor,
        n_batches=100,
        device=device
    ).mean().item() / (1 - feat_sparsity)

Training:   0%|          | 0/20000 [00:00<?, ?it/s]

In [7]:
print(eval_loss_adj)

0.08319638436660164


In [17]:
"""Transplant weights into mlp and eval."""

relu_cis_cfg = CisConfig(
    n_instances=1,
    n_feat=100,
    n_hidden=50,
    act_fn=layer_act_fns,
    b1=None,
    b2=None,
    skip_cnx=True,
    We_and_Wu=False,
)

relu_cis = Cis(relu_cis_cfg, device=device)

transplant_w1 = einsum(
    reluPlusX_res_embed_cis.We.squeeze(), 
    reluPlusX_res_embed_cis.W1.squeeze(), 
    "emb feat, neur emb -> neur feat"
)
transplant_w2 = einsum(
    reluPlusX_res_embed_cis.W2.squeeze(),
    reluPlusX_res_embed_cis.Wu.squeeze(), 
    "emb neur, feat emb -> feat neur"
)

relu_cis.W1.data[0] = transplant_w1
relu_cis.W2.data[0] = transplant_w2

eval_loss_adj = eval_model(
    relu_cis, 
    batch_sz=10000, 
    feat_sparsity=feat_sparsity,
    feat_importance=feat_importance,
    res_factor=res_factor,
    n_batches=100,
    device=device
).mean().item() / (1 - feat_sparsity)

In [18]:
print(eval_loss_adj)

0.08913635392673305


In [19]:
naive_loss = (n_feat - n_hidden) * (1 - feat_sparsity) / 6
print(naive_loss)

0.08333333333333341


---
---