<a href="https://colab.research.google.com/github/TarasKutsyk/TMS_Workshop/blob/main/TMS_Workshop_Part_2_SAEs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Sparse Autoencoders
Part of ML in PL 2025 Workshop: From Superposition to Sparse Autoencoders: Understanding Neural Feature Representations by P. Wielopolski & T. Kutsyk.

Based on [ARENA notebook](https://github.com/callummcdougall/ARENA_3.0/blob/main/chapter1_transformer_interp/exercises/part31_superposition_and_saes/1.3.1_Toy_Models_of_Superposition_%26_SAEs_exercises.ipynb).

<img src="https://raw.githubusercontent.com/info-arena/ARENA_img/main/misc/headers/header-13-1.png" width="350">

# Introduction

In this section, you'll learn about sparse autoencoders (SAEs) and how they can help resolve superposition problems. You'll train a sparse autoencoder on the toy model setup from the earlier tutorial.

> ##### Learning Objectives
>
> - Understand sparse autoencoders and how they can disentangle features represented in superposition
> - Train your own SAE on toy models from earlier sections and visualize the feature reconstruction process


Sparse autoencoders represent a promising recent development in mechanistic interpretability, notably explored by Anthropic in their [recent paper](https://transformer-circuits.pub/2023/monosemantic-features/index.html). This approach is currently one of the most active areas of research in the field.

In the following exercises, you will:

- Build your own sparse autoencoder by implementing its architecture and loss function
- Train your SAE on the hidden activations of the `Model` class you defined earlier (note: this differs from Anthropic's approach, which trained SAEs on MLP layer outputs)
- Extract features from your trained SAE and verify they match your model's learned features


*Short Q&A*

<details>
<summary>What is an autoencoder, and what is it trained to do?</summary>

Autoencoders are a type of neural network which learns efficient encodings / representations of unlabelled data. It is trained to compress the input in some way to a **latent representation**, then map it back into the original input space. It is trained by minimizing the reconstruction loss between the input and the reconstructed input.

The "encoding" part usually refers to the latent space being lower-dimensional than the input. However, that's not always the case, as we'll see with sparse autoencoders.

<img src="https://raw.githubusercontent.com/chloeli-15/ARENA_img/main/img/sae-diagram-2.png" width="900">

</details>

<details>
<summary>Why is the hidden dimension of our autoencoder larger than the number of activations, when we train an SAE on an MLP layer?</summary>

Usually the latent vector is a compressed representation of the input because it's lower-dimensional. However, it can still be a compressed representation even if it's higher dimensional, if we enforce a sparsity constraint on the latent vector (which in some sense reduces its effective dimensionality).

As for why we do this specifically for our autoencoder use case, it's because we're trying to recover features from superposition, in cases where there are **more features than neurons**. We're hoping our autoencoder learns an **overcomplete feature basis**.

</details>

<details>
<summary>Why does the L1 penalty encourage sparsity? (This isn't specifically mentioned in this paper, but it's an important thing to understand.)</summary>

Unlike $L_2$ penalties, the $L_1$ penalty actually pushes values towards zero. This is a well-known result in statistics, best illustrated below:

<img src="https://raw.githubusercontent.com/info-arena/ARENA_img/main/misc/l1-viz.png" width="450">

See [this Google ML page](https://developers.google.com/machine-learning/crash-course/regularization-for-sparsity/l1-regularization) for more of an explanation (it also has a nice out-of-context animation!).

</details>

## Setup (don't read, just run)

In [None]:
import os
import sys
from pathlib import Path

IN_COLAB = "google.colab" in sys.modules

chapter = "chapter1_transformer_interp"
repo = "ARENA_3.0"
branch = "main"

# Install dependencies
try:
    import transformer_lens
except:
    %pip install einops datasets jaxtyping "sae-lens>=4.0.0,<5.0.0" tabulate eindex-callum transformer_lens==2.11.0
    %pip install --force-reinstall numpy pandas

# Get root directory, handling 3 different cases: (1) Colab, (2) notebook not in ARENA repo, (3) notebook in ARENA repo
root = (
    "/content"
    if IN_COLAB
    else "/root"
    if repo not in os.getcwd()
    else str(next(p for p in Path.cwd().parents if p.name == repo))
)

if Path(root).exists() and not Path(f"{root}/{chapter}").exists():
    if not IN_COLAB:
        !sudo apt-get install unzip
        %pip install jupyter ipython --upgrade

    if not os.path.exists(f"{root}/{chapter}"):
        !wget -P {root} https://github.com/callummcdougall/ARENA_3.0/archive/refs/heads/{branch}.zip
        !unzip {root}/{branch}.zip '{repo}-{branch}/{chapter}/exercises/*' -d {root}
        !mv {root}/{repo}-{branch}/{chapter} {root}/{chapter}
        !rm {root}/{branch}.zip
        !rmdir {root}/{repo}-{branch}


if f"{root}/{chapter}/exercises" not in sys.path:
    sys.path.append(f"{root}/{chapter}/exercises")

os.chdir(f"{root}/{chapter}/exercises")

In [None]:
import sys
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Literal

import einops
import numpy as np
import pandas as pd
import plotly.express as px
import torch as t
from IPython.display import HTML, display
from jaxtyping import Float
from torch import Tensor, nn
from torch.distributions.categorical import Categorical
from torch.nn import functional as F
from tqdm.auto import tqdm

device = t.device(
    "mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu"
)

# Make sure exercises are in the path
chapter = "chapter1_transformer_interp"
section = "part31_superposition_and_saes"
root_dir = next(p for p in Path.cwd().parents if (p / chapter).exists())
exercises_dir = root_dir / chapter / "exercises"
section_dir = exercises_dir / section

import part31_superposition_and_saes.tests as tests
import part31_superposition_and_saes.utils as utils
from plotly_utils import imshow, line

MAIN = __name__ == "__main__"

### Problem setup

Recall the setup of our previous model:

$$
\begin{aligned}
h &= W x \\
x' &= \operatorname{ReLU}(W^T h + b)
\end{aligned}
$$

We're going to train our autoencoder to just take in the hidden state activations $h$, map them to a larger (overcomplete) hidden state $z$, then reconstruct the original hidden state $h$ from $z$.

$$
\begin{aligned}
z &= \operatorname{ReLU}(W_{enc}(h - b_{dec}) + b_{enc}) \\
h' &= W_{dec}z + b_{dec}
\end{aligned}
$$

Note the choice to have a different encoder and decoder weight matrix, rather than having them tied - we'll discuss this more later.

It's important not to get confused between the autoencoder and model's notation. Remember - the model takes in features $x$, maps them to **lower-dimensional** vectors $h$, and then reconstructs them as $x'$. The autoencoder takes in these hidden states $h$, maps them to a **higher-dimensional but sparse** vector $z$, and then reconstructs them as $h'$. Our hope is that the elements of $z$ correspond to the features of $x$.

Another note - the use of $b_{dec}$ here might seem weird, since we're subtracting it at the start then adding it back at the end. The way we're treating this term is as a **centralizing term for the hidden states**. It subtracts some learned mean vector from them so that $W_{enc}$ can act on centralized vectors, and then this term gets added back to the reconstructed hidden states at the end of the model.

In [None]:
def constant_lr(*_):
    return 1.0

@dataclass
class ToyModelConfig:
    n_inst: int
    n_features: int = 5
    d_hidden: int = 2
    n_correlated_pairs: int = 0
    n_anticorrelated_pairs: int = 0
    feat_mag_distn: Literal["unif", "normal"] = "unif"


class ToyModel(nn.Module):
    W: Float[Tensor, "inst d_hidden feats"]
    b_final: Float[Tensor, "inst feats"]

    # Our linear map (for a single instance) is x -> ReLU(W.T @ W @ x + b_final)

    def __init__(
        self,
        cfg: ToyModelConfig,
        feature_probability: float | Tensor = 0.01,
        importance: float | Tensor = 1.0,
        device=device,
    ):
        super(ToyModel, self).__init__()
        self.cfg = cfg

        if isinstance(feature_probability, float):
            feature_probability = t.tensor(feature_probability)
        self.feature_probability = feature_probability.to(device).broadcast_to(
            (cfg.n_inst, cfg.n_features)
        )
        if isinstance(importance, float):
            importance = t.tensor(importance)
        self.importance = importance.to(device).broadcast_to((cfg.n_inst, cfg.n_features))

        self.W = nn.Parameter(
            nn.init.xavier_normal_(t.empty((cfg.n_inst, cfg.d_hidden, cfg.n_features)))
        )
        self.b_final = nn.Parameter(t.zeros((cfg.n_inst, cfg.n_features)))
        self.to(device)

    def forward(
        self,
        features: Float[Tensor, "... inst feats"],
    ) -> Float[Tensor, "... inst feats"]:
        """
        Performs a single forward pass. For a single instance, this is given by:
            x -> ReLU(W.T @ W @ x + b_final)
        """
        h = einops.einsum(
            features, self.W, "... inst feats, inst hidden feats -> ... inst hidden"
        )
        out = einops.einsum(
            h, self.W, "... inst hidden, inst hidden feats -> ... inst feats"
        )
        return F.relu(out + self.b_final)

    def generate_batch(self, batch_size: int) -> Float[Tensor, "batch inst feats"]:
        """
        Generates a batch of data of shape (batch_size, n_instances, n_features).
        """
        batch_shape = (batch_size, self.cfg.n_inst, self.cfg.n_features)
        feat_mag = t.rand(batch_shape, device=self.W.device)
        feat_seeds = t.rand(batch_shape, device=self.W.device)
        return t.where(feat_seeds <= self.feature_probability, feat_mag, 0.0)

    def calculate_loss(
        self,
        out: Float[Tensor, "batch inst feats"],
        batch: Float[Tensor, "batch inst feats"],
    ) -> Float[Tensor, ""]:
        """
        Calculates the loss for a given batch, using this loss described in the Toy Models paper:

            https://transformer-circuits.pub/2022/toy_model/index.html#demonstrating-setup-loss

        Remember, `self.importance` will always have shape (n_inst, n_features).
        """
        error = self.importance * ((batch - out) ** 2)
        loss = einops.reduce(error, "batch inst feats -> inst", "mean").sum()
        return loss

    def optimize(
        self,
        batch_size: int = 1024,
        steps: int = 5_000,
        log_freq: int = 50,
        lr: float = 1e-3,
        lr_scale: Callable[[int, int], float] = constant_lr,
    ):
        """
        Optimizes the model using the given hyperparameters.
        """
        optimizer = t.optim.Adam(self.parameters(), lr=lr)

        progress_bar = tqdm(range(steps))

        for step in progress_bar:
            # Update learning rate
            step_lr = lr * lr_scale(step, steps)
            for group in optimizer.param_groups:
                group["lr"] = step_lr

            # Optimize
            optimizer.zero_grad()
            batch = self.generate_batch(batch_size)
            out = self(batch)
            loss = self.calculate_loss(out, batch)
            loss.backward()
            optimizer.step()

            # Display progress bar
            if step % log_freq == 0 or (step + 1 == steps):
                progress_bar.set_postfix(loss=loss.item() / self.cfg.n_inst, lr=step_lr)


### Notation

The autoencoder's hidden activations go by many names. Sometimes they're called **neurons** (since they do have an activation function applied to them which makes them a privileged basis, like the neurons in an MLP layer). Sometimes they're called **features**, since the idea with SAEs is that these hidden activations are meant to refer to specific features in the data. However, the word feature is a bit [overloaded](https://www.lesswrong.com/posts/9Nkb389gidsozY9Tf/lewis-smith-s-shortform#fd64ALuWK8rXdLKz6) - ideally we want to use "feature" to refer to the attributes of the data itself - if our SAE's weights are randomly initialized, is it fair to call this a feature?!

For this reason, we'll be referring to the autoencoder's hidden activations as **SAE latents**. However, it's worth noting that people sometimes use "SAE features" or "neurons" instead, so try not to get confused (e.g. often people use "neuron resampling" to refer to the resampling of the weights in the SAE).

The new notation we'll adopt in this section is:

- `d_sae`, which is the number of activations in the SAE's hidden layer (i.e. the latent dimension). Note that we want the SAE latents to correspond to the original data features, which is why we'll need `d_sae >= n_features` (usually we'll have equality in this section).
- `d_in`, which is the SAE input dimension. This is the same as `d_hidden` from the previous sections because the SAE is reconstructing the model's hidden activations, however calling it `d_hidden` in the context of an SAE would be confusing. Usually in this section, we'll have `d_in = d_hidden = 2`, so we can visualize the results.

<details>
<summary>Question - in the formulas above (in the "Problem setup" section), what are the shapes of x, x', z, h, and h' ?</summary>

Ignoring batch and instance dimensions:

- `x` and `x'` are vectors of shape `(n_features,)`
- `z` is a vector of shape `(d_sae,)`
- `h` and `h'` are vectors of shape `(d_in,)`, which is equal to `d_hidden` from previous sections

Including batch and instance dimensions, all shapes have extra leading dimensions `(batch_size, n_inst, d)`.

</details>

### SAE class

We've provided the `ToySAEConfig` class below. Its arguments are as follows:

- `n_inst`, which means the same as it does in your `ToyModel` class
- `d_in`, the input size to your SAE (equal to `d_hidden` of your `ToyModel` class)
- `d_sae`, the SAE's latent dimension size
- `sparsity_coeff`, which is used in your loss function
- `weight_normalize_eps`, which is added to the denominator whenever you normalize weights
- `tied_weights`, which is a boolean determining whether your encoder and decoder weights are tied

We've also given you the `ToySAE` class. Your job over the next exercise will be to fill in the `forward` method.

In [None]:
@dataclass
class ToySAEConfig:
    n_inst: int
    d_in: int
    d_sae: int
    sparsity_coeff: float = 0.2
    weight_normalize_eps: float = 1e-8
    tied_weights: bool = False


class ToySAE(nn.Module):
    W_enc: Float[Tensor, "inst d_in d_sae"]
    _W_dec: Float[Tensor, "inst d_sae d_in"] | None
    b_enc: Float[Tensor, "inst d_sae"]
    b_dec: Float[Tensor, "inst d_in"]

    def __init__(self, cfg: ToySAEConfig, model: ToyModel) -> None:
        super(ToySAE, self).__init__()

        assert cfg.d_in == model.cfg.d_hidden, "Model's hidden dim doesn't match SAE input dim"
        self.cfg = cfg
        self.model = model.requires_grad_(False)
        # Setting up all toy models as exactly the same.
        self.model.W.data[1:] = self.model.W.data[0]
        self.model.b_final.data[1:] = self.model.b_final.data[0]

        self.W_enc = nn.Parameter(nn.init.kaiming_uniform_(t.empty((cfg.n_inst, cfg.d_in, cfg.d_sae))))
        self._W_dec = (
            None
            if self.cfg.tied_weights
            else nn.Parameter(nn.init.kaiming_uniform_(t.empty((cfg.n_inst, cfg.d_sae, cfg.d_in))))
        )
        self.b_enc = nn.Parameter(t.zeros(cfg.n_inst, cfg.d_sae))
        self.b_dec = nn.Parameter(t.zeros(cfg.n_inst, cfg.d_in))

        self.to(device)

    @property
    def W_dec(self) -> Float[Tensor, "inst d_sae d_in"]:
        return self._W_dec if self._W_dec is not None else self.W_enc.transpose(-1, -2)

    @property
    def W_dec_normalized(self) -> Float[Tensor, "inst d_sae d_in"]:
        """Returns decoder weights, normalized over the autoencoder input dimension."""
        return self.W_dec / (self.W_dec.norm(dim=-1, keepdim=True) + self.cfg.weight_normalize_eps)

    def generate_batch(self, batch_size: int) -> Float[Tensor, "batch inst d_in"]:
        """
        Generates a batch of hidden activations from our model.
        """
        return einops.einsum(
            self.model.generate_batch(batch_size),
            self.model.W,
            "batch inst feats, inst d_in feats -> batch inst d_in",
        )

    def forward(
        self, h: Float[Tensor, "batch inst d_in"]
    ) -> tuple[
        dict[str, Float[Tensor, "batch inst"]],
        Float[Tensor, "batch inst"],
        Float[Tensor, "batch inst d_sae"],
        Float[Tensor, "batch inst d_in"],
    ]:
        """
        Forward pass on the autoencoder.

        Args:
            h: hidden layer activations of model

        Returns:
            loss_dict:       dict of different loss terms, each having shape (batch_size, n_inst)
            loss:            total loss (i.e. sum over terms of loss dict), same shape as loss terms
            acts_post:       autoencoder latent activations, after applying ReLU
            h_reconstructed: reconstructed autoencoder input
        """
        # You'll fill this in later
        raise NotImplementedError()

    def optimize(
        self,
        batch_size: int = 1024,
        steps: int = 10_000,
        log_freq: int = 100,
        lr: float = 1e-3,
        lr_scale: Callable[[int, int], float] = constant_lr,
        resample_method: Literal["simple", None] = None,
        resample_freq: int = 2500,
        resample_window: int = 500,
        resample_scale: float = 0.5,
        hidden_sample_size: int = 256,
    ) -> list[dict[str, Any]]:
        """
        Optimizes the autoencoder using the given hyperparameters.

        Args:
            model:              we reconstruct features from model's hidden activations
            batch_size:         size of batches we pass through model & train autoencoder on
            steps:              number of optimization steps
            log_freq:           number of optimization steps between logging
            lr:                 learning rate
            lr_scale:           learning rate scaling function
            resample_method:    method for resampling dead latents
            resample_freq:      number of optimization steps between resampling dead latents
            resample_window:    number of steps needed for us to classify a neuron as dead
            resample_scale:     scale factor for resampled neurons
            hidden_sample_size: size of hidden value sample we add to the logs (for visualization)

        Returns:
            data_log:           dictionary containing data we'll use for visualization
        """
        assert resample_window <= resample_freq

        optimizer = t.optim.Adam(self.parameters(), lr=lr)  # betas=(0.0, 0.999)
        frac_active_list = []
        progress_bar = tqdm(range(steps))

        # Create lists of dicts to store data we'll eventually be plotting
        data_log = []

        for step in progress_bar:
            # Resample dead latents
            if (resample_method is not None) and ((step + 1) % resample_freq == 0):
                frac_active_in_window = t.stack(frac_active_list[-resample_window:], dim=0)
                if resample_method == "simple":
                    self.resample_simple(frac_active_in_window, resample_scale)

            # Update learning rate
            step_lr = lr * lr_scale(step, steps)
            for group in optimizer.param_groups:
                group["lr"] = step_lr

            # Get a batch of hidden activations from the model
            with t.inference_mode():
                h = self.generate_batch(batch_size)

            # Optimize
            loss_dict, loss, acts, _ = self.forward(h)
            loss.mean(0).sum().backward()
            optimizer.step()
            optimizer.zero_grad()

            # Normalize decoder weights by modifying them directly (if not using tied weights)
            if not self.cfg.tied_weights:
                self.W_dec.data = self.W_dec_normalized.data

            # Calculate the mean sparsities over batch dim for each feature
            frac_active = (acts.abs() > 1e-8).float().mean(0)
            frac_active_list.append(frac_active)

            # Display progress bar, and log a bunch of values for creating plots / animations
            if step % log_freq == 0 or (step + 1 == steps):
                progress_bar.set_postfix(
                    lr=step_lr,
                    loss=loss.mean(0).sum().item(),
                    frac_active=frac_active.mean().item(),
                    **{k: v.mean(0).sum().item() for k, v in loss_dict.items()},  # type: ignore
                )
                with t.inference_mode():
                    loss_dict, loss, acts, h_r = self.forward(
                        h := self.generate_batch(hidden_sample_size)
                    )
                data_log.append(
                    {
                        "steps": step,
                        "frac_active": (acts.abs() > 1e-8).float().mean(0).detach().cpu(),
                        "loss": loss.detach().cpu(),
                        "h": h.detach().cpu(),
                        "h_r": h_r.detach().cpu(),
                        **{name: param.detach().cpu() for name, param in self.named_parameters()},
                        **{name: loss_term.detach().cpu() for name, loss_term in loss_dict.items()},
                    }
                )

        return data_log

    @t.no_grad()
    def resample_simple(
        self,
        frac_active_in_window: Float[Tensor, "window inst d_sae"],
        resample_scale: float,
    ) -> None:
        """
        Resamples dead latents, by modifying the model's weights and biases inplace.

        Resampling method is:
            - For each dead neuron, generate a random vector of size (d_in,), and normalize these vecs
            - Set new values of W_dec and W_enc to be these normalized vectors, at each dead neuron
            - Set b_enc to be zero, at each dead neuron

        This function performs resampling over all instances at once, using batched operations.
        """
        # Get a tensor of dead latents
        dead_latents_mask = (frac_active_in_window < 1e-8).all(dim=0)  # [instances d_sae]
        n_dead = int(dead_latents_mask.int().sum().item())

        # Get our random replacement values of shape [n_dead d_in], and scale them
        replacement_values = t.randn((n_dead, self.cfg.d_in), device=self.W_enc.device)
        replacement_values_normed = replacement_values / (
            replacement_values.norm(dim=-1, keepdim=True) + self.cfg.weight_normalize_eps
        )

        # Change the corresponding values in W_enc, W_dec, and b_enc
        self.W_enc.data.transpose(-1, -2)[dead_latents_mask] = (
            resample_scale * replacement_values_normed
        )
        self.W_dec.data[dead_latents_mask] = replacement_values_normed
        self.b_enc.data[dead_latents_mask] = 0.0


<details>
<summary>Why might we want / not want to tie our weights?</summary>

In our `Model` implementations, we used a weight and its transpose. You might think it also makes sense to have the encoder and decoder weights be transposed copies of each other, since intuitively both the encoder and decoder's latent vectors meant to represent some feature's "direction in the original model's hidden dimension".

The reason we might not want to tie weights is pretty subtle. The job of the encoder is in some sense to recover features from superposition, whereas the job of the decoder is just to represent that feature faithfully if present (since the goal of our SAE is to write the input as a linear combination of `W_dec` vectors) - this is why we generally see the decoder weights as the "true direction" for a feature, when weights are untied.

The diagram below might help illustrate this concept (if you want, you can replicate the results in this diagram using our toy model setup!).

<img src="https://raw.githubusercontent.com/info-arena/ARENA_img/main/misc/w-dec-explained.png" width="700">

In simple settings like this toy model we might not benefit much from untying weights, and tying weights can actually help us avoid finding annoying local minima in our optimization. However, for most of these exercises we'll use untied weights in order to illustrate SAE concepts more clearly.

</details>

Also, note that we've defined `self.cfg` and `self.model` for you in the init function - in the latter case, we've frozen the model's weights (because when you train your SAE you don't want to track gradients in your base model), and we've also modified the model's weights so they all match the first instance (this is so we can more easily interpret our SAE plots we'll create when we finish training).

### Exercise - implement `forward`

> ```yaml
> Difficulty: 🔴🔴🔴⚪⚪
> Importance: 🔵🔵🔵🔵🔵
>
> You should spend up to 25-40 minutes on this exercise.
> ```

You should calculate the autoencoder's hidden state activations as $z = \operatorname{ReLU}(W_{enc}(h - b_{dec}) + b_{enc})$, and then reconstruct the output as $h' = W_{dec}z + b_{dec}$. A few notes:

- The **first variable** we return is a `loss_dict`, which contains the loss tensors of shape `(batch_size, n_inst)` for both terms in our loss function (before multiplying by the L1 coefficient). This is used for logging, and it'll also be used later in our neuron resampling methods. For this architecture, your keys should be `"L_reconstruction"` and `"L_sparsity"`.
- The **second variable** we return is the `loss` term, which also has shape `(batch_size, n_inst)`, and is created by summing the losses in `loss_dict` (with sparsity loss multiplied by `cfg.sparsity_coeff`). When doing gradient descent, we'll average over the batch dimension & sum over the instance dimension (since we're training our instances independently & in parallel).
- The **third variable** we return is the hidden state activations `acts`, which are also used later for neuron resampling (as well as logging how many latents are active).
- The **fourth variable** we return is the reconstructed hidden states `h_reconstructed`, i.e. the autoencoder's actual output.

An important note regarding our loss term - the reconstruction loss is the squared difference between input & output **averaged** over the `d_in` dimension, but the sparsity penalty is the L1 norm of the hidden activations **summed** over the `d_sae` dimension. Can you see why we average one but sum the other?

<details>
<summary>Hint</summary>

Suppose we averaged L1 loss too. Consider the gradients a single latent receives from the reconstruction loss and sparsity penalty - what do they look like in the limit of very large `d_sae`?

</details>

<details>
<summary>Answer - why we average L2 loss over <code>d_in</code> but sum L1 loss over <code>d_sae</code></summary>

Suppose for sake of argument we averaged L1 loss too. Imagine if we doubled the latent dimension, but kept all other SAE hyperparameters the same. The per-hidden-unit gradient from the reconstruction loss would still be the same (because changing a single hidden unit's encoder or decoder vector would have the same effect on the output as before), but the per-hidden-unit gradient from the sparsity penalty would have halved (because we're averaging the sparsity penalty over `d_sae`). This means that in the limit, the sparsity penalty wouldn't matter at all, and the only important thing would be getting zero reconstruction loss.

</details>

Note - make sure you're using `self.W_dec_normalized` rather than `self.W_dec` in your forward function. This is because if we're using tied weights then we won't be able to manually normalize `W_dec` inplace, but we still want to use the normalized version.

In [None]:
# Go back up and edit your `ToySAE.forward` method, then run the test below

tests.test_sae_forward(ToySAE)

<details><summary>Solution</summary>

```python
def forward(
    self: ToySAE, h: Float[Tensor, "batch inst d_in"]
) -> tuple[
    dict[str, Float[Tensor, "batch inst"]],
    Float[Tensor, "batch inst"],
    Float[Tensor, "batch inst d_sae"],
    Float[Tensor, "batch inst d_in"],
]:
    """
    Forward pass on the autoencoder.

    Args:
        h: hidden layer activations of model

    Returns:
        loss_dict:       dict of different loss terms, each dict value having shape (batch_size, n_inst)
        loss:            total loss (i.e. sum over terms of loss dict), same shape as loss_dict values
        acts_post:       autoencoder latent activations, after applying ReLU
        h_reconstructed: reconstructed autoencoder input
    """
    h_cent = h - self.b_dec

    # Compute latent (hidden layer) activations
    acts_pre = (
        einops.einsum(h_cent, self.W_enc, "batch inst d_in, inst d_in d_sae -> batch inst d_sae")
        + self.b_enc
    )
    acts_post = F.relu(acts_pre)

    # Compute reconstructed input
    h_reconstructed = (
        einops.einsum(
            acts_post, self.W_dec_normalized, "batch inst d_sae, inst d_sae d_in -> batch inst d_in"
        )
        + self.b_dec
    )

    # Compute loss terms
    L_reconstruction = (h_reconstructed - h).pow(2).mean(-1)
    L_sparsity = acts_post.abs().sum(-1)
    loss_dict = {"L_reconstruction": L_reconstruction, "L_sparsity": L_sparsity}
    loss = L_reconstruction + self.cfg.sparsity_coeff * L_sparsity

    return loss_dict, loss, acts_post, h_reconstructed


ToySAE.forward = forward
```
</details>

## Training your SAE

The `optimize` method has been given to you. A few notes on how it differs from your previous model:

- Before each optimization step, we implement **neuron resampling** - we'll get to this later.
- We have more logging, via the `data_log` dictionary - we'll use this for visualization.
- We've used `betas=(0.0, 0.999)`, to match the description in [Anthropic's Feb 2024 update](https://transformer-circuits.pub/2024/feb-update/index.html#dict-learning-loss) - although they document it to work better specifically for large models, we may as well match it here.

First, let's define and train our model, and visualize model weights and the data returned from `sae.generate_batch` (which are the hidden state representations of our trained model, and will be used for training our SAE).

Note that we'll use a feature probability of 2.5% for all subsequent exercises.

In [None]:
d_hidden = d_in = 2
n_features = d_sae = 5
n_inst = 16

# Create a toy model, and train it to convergence
cfg = ToyModelConfig(n_inst=n_inst, n_features=n_features, d_hidden=d_hidden)
model = ToyModel(cfg=cfg, device=device, feature_probability=0.025)
model.optimize()

sae = ToySAE(cfg=ToySAEConfig(n_inst=n_inst, d_in=d_in, d_sae=d_sae), model=model)

h = sae.generate_batch(512)

utils.plot_features_in_2d(model.W[:8], title="Base model")
utils.plot_features_in_2d(
    einops.rearrange(h[:, :8], "batch inst d_in -> inst d_in batch"),
    title="Hidden state representation of a random batch of data",
)

Now, let's train our SAE, and visualize the instances with lowest loss! We've also created a function `animate_features_in_2d` which creates an animation of the training over time. If the inline displaying doesn't work, you might have to open the saved HTML file in your browser to see it.

In [None]:
data_log = sae.optimize(steps=20_000, resample_method=None)

utils.animate_features_in_2d(
    data_log,
    instances=list(range(8)),  # only plot the first 8 instances
    rows=["W_enc", "_W_dec"],
    filename=str(section_dir / "animation-training.html"),
    title="SAE on toy model",
)

# If this display code doesn't work, try saving & opening animation in your browser
with open(section_dir / "animation-training.html") as f:
    display(HTML(f.read()))

In other words, the autoencoder is generally successful at reconstructing the model's hidden states, and maybe sometimes it learns the fully monosemantic solution (one latent per feature), but more often it learns a combination of **polysemantic latents** and **dead latents** (which never activate). These are a big problem because they don't receive any gradients during training, so they're not a problem which fixes itself over time. You can check the presence of dead latents by graphing the feature probabilities over training, in the code below. You should find that:

1. Some latents are dead for most or all of training (with "fraction of datapoints active" being zero),
2. Some latents fire more frequently than the target feature prob of 2.5% (these are usually polysemantic, i.e. they fire on more than one different feature),
3. Some latents fire approximately at or slightly below the target probability (these are usually monosemantic). If any of your instances above learned the full monosemantic solution (i.e. latents uniformly spaced around the 2D hidden dimension) then you should find that all 5 latents in that instance fall into this third category.

In [None]:
utils.frac_active_line_plot(
    frac_active=t.stack([data["frac_active"] for data in data_log]),
    title="Probability of sae features being active during training",
    avg_window=20,
)

## Resampling

From Anthropic's paper (replacing terminology "dead neurons" with "dead latents" in accordance with how we're using the term):

> Second, we found that over the course of training some latents cease to activate, even across a large number of datapoints. We found that “resampling” these dead latents during training gave better results by allowing the model to represent more features for a given autoencoder hidden layer dimension. Our resampling procedure is detailed in [Autoencoder Resampling](https://transformer-circuits.pub/2023/monosemantic-features/index.html#appendix-autoencoder-resampling), but in brief we periodically check for latents which have not fired in a significant number of steps and reset the encoder weights on the dead latents to match data points that the autoencoder does not currently represent well.

We've implemented this procedure for you. Let's see the results!

Let's train your model again, and watch the animation to see how the neuron resampling has helped the training process. You should be able to see the resampled neurons in red.

In [None]:
resampling_sae = ToySAE(cfg=ToySAEConfig(n_inst=n_inst, d_in=d_in, d_sae=d_sae), model=model)

resampling_data_log = resampling_sae.optimize(steps=20_000, resample_method="simple")

utils.animate_features_in_2d(
    resampling_data_log,
    rows=["W_enc", "_W_dec"],
    instances=list(range(8)),  # only plot the first 8 instances
    filename=str(section_dir / "animation-training-resampling.html"),
    color_resampled_latents=True,
    title="SAE on toy model (with resampling)",
)

# If this display code doesn't work, try saving & opening animation in your browser
with open(section_dir / "animation-training-resampling.html") as f:
    display(HTML(f.read()))

In [None]:
utils.frac_active_line_plot(
    frac_active=t.stack([data["frac_active"] for data in resampling_data_log]),
    title="Probability of sae features being active during training",
    avg_window=20,
)

Much better!

Now that we have pretty much full reconstruction on our features, let's visualize that reconstruction! The `animate_features_in_2d` function also offers features to plot hidden state reconstructions and how they evolve over time. Examining how the hidden state reconstructions evolve over time can help you understand what's going on, for example:

- The SAE often learns a non-sparse solution (e.g. 4 uniformly spaced polysemantic latents & 1 dead latent) before converging to the ideal solution.
    - Note, we also see something similar when training SAEs on LLMs: they first find a non-sparse solution with small reconstruction loss, before learning a more sparse solution (L0 goes down).
- Hovering over hidden states, you should observe some things:
    - Low-magnitude hidden states are often reconstructed as zero, this is because the SAE can't separate them from interference from other features.
    - Even for correctly reconstructed features, the hidden state magnitude is generally smaller than the true hidden states - this is called **shrinkage**.

In [None]:
utils.animate_features_in_2d(
    resampling_data_log,
    rows=["W_enc", "h", "h_r"],
    instances=list(range(4)),  # plotting fewer instances for a smaller animation file size
    color_resampled_latents=True,
    filename=str(section_dir / "animation-training-reconstructions.html"),
    title="SAE on toy model (showing hidden states & reconstructions)",
)

# If this display code doesn't work, try saving & opening animation in your browser
with open(section_dir / "animation-training-reconstructions.html") as f:
    display(HTML(f.read()))

**That's all for the tutorial! Thank you for your attention!**