<a href="https://colab.research.google.com/github/amantimalsina/Mamba-SAE/blob/main/training_a_sparse_autoencoder_for_mamba.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# A very basic SAE Training Tutorial

Please note that it is very easy for tutorial code to go stale so please have a low bar for raising an issue in the

## Setup

In [1]:
try:
    #import google.colab # type: ignore
    #from google.colab import output
    %pip install sae-lens transformer-lens circuitsvis
except:
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

Collecting sae-lens
  Downloading sae_lens-3.18.2-py3-none-any.whl.metadata (5.1 kB)
Collecting transformer-lens
  Downloading transformer_lens-2.4.0-py3-none-any.whl.metadata (12 kB)
Collecting circuitsvis
  Downloading circuitsvis-1.43.2-py3-none-any.whl.metadata (2.3 kB)
Collecting automated-interpretability<1.0.0,>=0.0.5 (from sae-lens)
  Downloading automated_interpretability-0.0.6-py3-none-any.whl.metadata (778 bytes)
Collecting babe<0.0.8,>=0.0.7 (from sae-lens)
  Downloading babe-0.0.7-py3-none-any.whl.metadata (10 kB)
Collecting datasets<3.0.0,>=2.17.1 (from sae-lens)
  Downloading datasets-2.21.0-py3-none-any.whl.metadata (21 kB)
Collecting matplotlib<4.0.0,>=3.8.3 (from sae-lens)
  Downloading matplotlib-3.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting plotly<6.0.0,>=5.19.0 (from sae-lens)
  Downloading plotly-5.24.0-py3-none-any.whl.metadata (7.3 kB)
Collecting plotly-express<0.5.0,>=0.4.1 (from sae-lens)
  Downloading plotly_expr

In [7]:
import torch
import os

from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print("Using device:", device)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

Using device: cuda


# Model Selection and Evaluation (Feel Free to Skip)

We'll use the runner to train an SAE on a TinyStories Model. This is a very small model so we can train an SAE on it quite quickly. Before we get started, let's load in the model with `transformer_lens` and see what it can do.

TransformerLens gives us 2 functions that are useful here (and circuits viz provides a third):
1. `transformer_lens.utils.test_prompt` will help us see when the model can infer one token.
2. `HookedTransformer.generate` will help us see what happens when we sample from the model.
3. `circuitsvis.logits.token_log_probs` will help us visualize the log probs of tokens at several positions in a prompt.

In [8]:
pip install git+https://github.com/Phylliida/MambaLens.git

Collecting git+https://github.com/Phylliida/MambaLens.git
  Cloning https://github.com/Phylliida/MambaLens.git to /tmp/pip-req-build-93jvm446
  Running command git clone --filter=blob:none --quiet https://github.com/Phylliida/MambaLens.git /tmp/pip-req-build-93jvm446
  Resolved https://github.com/Phylliida/MambaLens.git to commit 89faa6863b05642401f6c403f7a0149f1dd6ae1a
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: mamba_lens
  Building wheel for mamba_lens (pyproject.toml) ... [?25l[?25hdone
  Created wheel for mamba_lens: filename=mamba_lens-0.0.4-py3-none-any.whl size=27707 sha256=25545a14b629fa066628f48d86192361071b45e2231339a42818b4c574a61de0
  Stored in directory: /tmp/pip-ephem-wheel-cache-vwnmgs57/wheels/65/42/45/740d4c9f216e098f81553897c14ead8421b0d6e0909b2ae333
Successfully built mamba_lens
Installing collected 

In [9]:
import mamba_lens
model = mamba_lens.HookedMamba.from_pretrained(
                              "state-spaces/mamba-370m",
                              device='cuda'
                              )

  def forward(ctx, x, delta, A, B, C, D=None, skip=None, delta_bias=None, delta_softplus=False):
  def backward(ctx, dout, *args):
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/200 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.49G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.08M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/457k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/90.0 [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Moving model to device:  cuda


In [10]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


### Getting a vibe for a model using `model.generate`

Let's start by generating some stories using the model.

In [107]:
torch.cuda.empty_cache()

In [18]:
pip install causal_conv1d mamba-ssm

Collecting mamba-ssm
  Downloading mamba_ssm-2.2.2.tar.gz (85 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/85.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.4/85.4 kB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: mamba-ssm
  Building wheel for mamba-ssm (setup.py) ... [?25l[?25hdone
  Created wheel for mamba-ssm: filename=mamba_ssm-2.2.2-cp310-cp310-linux_x86_64.whl size=323988104 sha256=6b082468a6abb6f6bc50c99263f17c6c7f5a2e8f6b275ed7998b81fb25279229
  Stored in directory: /root/.cache/pip/wheels/57/7c/90/9f963468ecc3791e36e388f9e7b4a4e1e3f90fbb340055aa4d
Successfully built mamba-ssm
Installing collected packages: mamba-ssm
Successfully installed mamba-ssm-2.2.2


In [39]:
# Feed it back to the model and keep predicting tokens:
prompt = "Once upon a time, there was a little girl named Lily. She lived in a big, happy little girl. On her big adventure,"
for i in range(2):
    tokens = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    logits, activations = model.run_with_cache(tokens,
                                               fast_ssm=True,
                                               fast_conv=True,
                                               warn_disabled_hooks=False
                                               )
    generated_text = tokenizer.batch_decode(logits.argmax(dim=-1)[0])
    prompt += ' '.join(generated_text)
print(prompt)

Once upon a time, there was a little girl named Lily. She lived in a big, happy little girl. On her big adventure, the  a  time ,  there  was  a  man  girl  named  Alice .  She  was  in  a  small  house  old  family  town 's  She  the  birthday , ,  she the  a  time ,  there  was  a  man  girl  named  Alice .  She  was  in  a  small  house  old  family  town 's  She  the  birthday , ,  she  little big unt a ,  she a  was a  a a    big    who    and    a ice .    Al    was    a    a    big    town    with    and    house    s    home    was    big day          and   


### Spot checking model abilities with `transformer_lens.utils.test_prompt`

In [36]:
from transformer_lens.utils import test_prompt
# Test the model with a prompt
test_prompt(
    "Once upon a time, there was a little girl named Lily. She lived in a big, happy little girl. On her big adventure,",
    " Lily",
    model,
    prepend_space_to_answer=False,
)

Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a', ' time', ',', ' there', ' was', ' a', ' little', ' girl', ' named', ' Lily', '.', ' She', ' lived', ' in', ' a', ' big', ',', ' happy', ' little', ' girl', '.', ' On', ' her', ' big', ' adventure', ',']
Tokenized answer: [' Lily']


Top 0th token. Logit: 32.43 Prob: 60.78% Token: | she|
Top 1th token. Logit: 30.98 Prob: 14.16% Token: | Lily|
Top 2th token. Logit: 29.63 Prob:  3.67% Token: | the|
Top 3th token. Logit: 29.57 Prob:  3.47% Token: | her|
Top 4th token. Logit: 29.07 Prob:  2.11% Token: | there|
Top 5th token. Logit: 28.70 Prob:  1.45% Token: | a|
Top 6th token. Logit: 28.50 Prob:  1.18% Token: | one|
Top 7th token. Logit: 27.87 Prob:  0.63% Token: | you|
Top 8th token. Logit: 27.39 Prob:  0.39% Token: | it|
Top 9th token. Logit: 27.31 Prob:  0.36% Token: | everyone|


In the output above, we see that the model assigns ~ 70% probability to "she" being the next token, and a 13% chance to " Lily" being the next token. Other names like Lucy or Anna are not highly ranked.

### Exploring Model Capabilities with Log Probs

Looking at token ranking for a single prompt is interesting, but a much higher through way to understand models is to look at token log probs for all tokens in text. We can use the `circuits_vis` package to get a nice visualization where we can see tokenization, and hover to get the top5 tokens by log probability. Darker tokens are tokens where the model assigned a higher probability to the actual next token.

In [40]:
import circuitsvis as cv  # optional dep, install with pip install circuitsvis

# Let's make a longer prompt and see the log probabilities of the tokens
example_prompt = """Hi, how are you doing this? I'm really enjoying your posts"""
logits, cache = model.run_with_cache(example_prompt)
cv.logits.token_log_probs(
    model.to_tokens(example_prompt),
    model(example_prompt)[0].log_softmax(dim=-1),
    model.to_string,
)
# hover on the output to see the result.

In [57]:
n_layers = model.cfg.n_layers
d_model = model.cfg.d_model
print(model.cfg)

MambaCfg(d_model=1024, n_layers=48, vocab_size=50280, d_state=16, expand=2, dt_rank=64, d_conv=4, pad_vocab_size_multiple=8, conv_bias=True, bias=False, default_prepend_bos=True, tokenizer_prepends_bos=False, n_ctx=2048, device='cuda', initializer_cfg=MambaInitCfg(initializer_range=(0.02,), rescale_prenorm_residual=(True,), n_residuals_per_layer=(1,), dt_init=('random',), dt_scale=(1.0,), dt_min=(0.001,), dt_max=(0.1,), dt_init_floor=0.0001), d_inner=2048)


In [104]:
import os
import json
import random
# %%
import wandb
import pprint
# %%
import argparse
def arg_parse_update_cfg(default_cfg):
    """
    Helper function to take in a dictionary of arguments, convert these to command line arguments, look at what was passed in, and return an updated dictionary.

    If in Ipython, just returns with no changes
    """
    if get_ipython() is not None:
        # Is in IPython
        print("In IPython - skipped argparse")
        return default_cfg
    cfg = dict(default_cfg)
    parser = argparse.ArgumentParser()
    for key, value in default_cfg.items():
        if type(value) == bool:
            # argparse for Booleans is broken rip. Now you put in a flag to change the default --{flag} to set True, --{flag} to set False
            if value:
                parser.add_argument(f"--{key}", action="store_false")
            else:
                parser.add_argument(f"--{key}", action="store_true")

        else:
            parser.add_argument(f"--{key}", type=type(value), default=value)
    args = parser.parse_args()
    parsed_args = vars(args)
    cfg.update(parsed_args)
    print("Updated config")
    print(json.dumps(cfg, indent=2))
    return cfg
default_cfg = {
    "seed": 49,
    "batch_size": 4096,
    "buffer_mult": 384,
    "lr": 1e-4,
    "num_tokens": int(2e9),
    "l1_coeff": 3e-4,
    "beta1": 0.9,
    "beta2": 0.99,
    "dict_mult": 32,
    "seq_len": 128,
    "enc_dtype":"fp32",
    "remove_rare_dir": False,
    "model_name": "gelu-2l",
    "site": "mlp_out",
    "layer": 0,
    "device": "cuda:0"
}
site_to_size = {
    "mlp_out": 512,
    "post": 2048,
    "resid_pre": 512,
    "resid_mid": 512,
    "resid_post": 512,
}

cfg = {
    "name": "Mamba-SAE",
    "model_name": "gelu-2l",
    "site": "mlp_out",
    "layer": 0,
    "dict_size": 32,
    "seed": 49,
    "batch_size": 4096,
    "buffer_mult": 384,
    "buffer_size": 512,
    "buffer_batches": 10,
    "model_batch_size": 1024,
    "lr": 1e-4,
    "num_tokens": int(2e9),
    "l1_coeff": 3e-4,
    "enc_dtype": "fp32",
    "act_size": 512,
    "device": "cuda:0",
    "beta1": 0.9,
    "beta2": 0.99,
    "remove_rare_dir": False,
    "act_name": "mlp_out"
}
def post_init_cfg(cfg):
    cfg["name"] = f"{cfg['model_name']}_{cfg['layer']}_{cfg['dict_size']}_{cfg['site']}"
post_init_cfg(cfg)
pprint.pprint(cfg)
# %%

{'act_name': 'mlp_out',
 'act_size': 512,
 'batch_size': 4096,
 'beta1': 0.9,
 'beta2': 0.99,
 'buffer_batches': 10,
 'buffer_mult': 384,
 'buffer_size': 512,
 'device': 'cuda:0',
 'dict_size': 32,
 'enc_dtype': 'fp32',
 'l1_coeff': 0.0003,
 'layer': 0,
 'lr': 0.0001,
 'model_batch_size': 1024,
 'model_name': 'gelu-2l',
 'name': 'gelu-2l_0_32_mlp_out',
 'num_tokens': 2000000000,
 'remove_rare_dir': False,
 'seed': 49,
 'site': 'mlp_out'}


In [71]:
SEED = cfg["seed"]
GENERATOR = torch.manual_seed(SEED)
DTYPES = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
np.random.seed(SEED)
random.seed(SEED)
torch.set_grad_enabled(True)

# TODO: Replace this with HookedMamba
#model = HookedTransformer.from_pretrained(cfg["model_name"]).to(DTYPES[cfg["enc_dtype"]]).to(cfg["device"])

n_layers = model.cfg.n_layers
d_model = model.cfg.d_model
# %%
@torch.no_grad()
def get_acts(tokens, batch_size=1024):
    _, cache = model.run_with_cache(tokens, stop_at_layer=cfg["layer"]+1, names_filter=cfg["act_name"])
    acts = cache[cfg["act_name"]]
    acts = acts.reshape(-1, acts.shape[-1])
    subsample = torch.randperm(acts.shape[0], generator=GENERATOR)[:batch_size]
    subsampled_acts = acts[subsample, :]
    return subsampled_acts, acts
# sub, acts = get_acts(torch.arange(20).reshape(2, 10), batch_size=3)
# sub.shape, acts.shape
# %%

In [81]:
!cd workspace

In [89]:
SAVE_DIR = Path("/workspace/checkpoints")
class AutoEncoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        d_hidden = cfg["dict_size"]
        l1_coeff = cfg["l1_coeff"]
        dtype = DTYPES[cfg["enc_dtype"]]
        torch.manual_seed(cfg["seed"])
        self.W_enc = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(cfg["act_size"], d_hidden, dtype=dtype)))
        self.W_dec = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(d_hidden, cfg["act_size"], dtype=dtype)))
        self.b_enc = nn.Parameter(torch.zeros(d_hidden, dtype=dtype))
        self.b_dec = nn.Parameter(torch.zeros(cfg["act_size"], dtype=dtype))

        self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)

        self.d_hidden = d_hidden
        self.l1_coeff = l1_coeff

        self.to(cfg["device"])

    def forward(self, x):
        x_cent = x - self.b_dec
        acts = F.relu(x_cent @ self.W_enc + self.b_enc)
        x_reconstruct = acts @ self.W_dec + self.b_dec
        l2_loss = (x_reconstruct.float() - x.float()).pow(2).sum(-1).mean(0)
        l1_loss = self.l1_coeff * (acts.float().abs().sum())
        loss = l2_loss + l1_loss
        return loss, x_reconstruct, acts, l2_loss, l1_loss

    @torch.no_grad()
    def make_decoder_weights_and_grad_unit_norm(self):
        W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
        W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum(-1, keepdim=True) * W_dec_normed
        self.W_dec.grad -= W_dec_grad_proj
        # Bugfix(?) for ensuring W_dec retains unit norm, this was not there when I trained my original autoencoders.
        self.W_dec.data = W_dec_normed

    def get_version(self):
        version_list = [int(file.name.split(".")[0]) for file in list(SAVE_DIR.iterdir()) if "pt" in str(file)]
        if len(version_list):
            return 1+max(version_list)
        else:
            return 0

    def save(self):
        version = self.get_version()
        torch.save(self.state_dict(), SAVE_DIR/(str(version)+".pt"))
        with open(SAVE_DIR/(str(version)+"_cfg.json"), "w") as f:
            json.dump(cfg, f)
        print("Saved as version", version)

    @classmethod
    def load(cls, version):
        cfg = (json.load(open(SAVE_DIR/(str(version)+"_cfg.json"), "r")))
        pprint.pprint(cfg)
        self = cls(cfg=cfg)
        self.load_state_dict(torch.load(SAVE_DIR/(str(version)+".pt")))
        return self

    @classmethod
    def load_from_hf(cls, version):
        """
        Loads the saved autoencoder from HuggingFace.

        Version is expected to be an int, or "run1" or "run2"

        version 25 is the final checkpoint of the first autoencoder run,
        version 47 is the final checkpoint of the second autoencoder run.
        """
        if version=="run1":
            version = 25
        elif version=="run2":
            version = 47

        cfg = utils.download_file_from_hf("NeelNanda/sparse_autoencoder", f"{version}_cfg.json")
        pprint.pprint(cfg)
        self = cls(cfg=cfg)
        self.load_state_dict(utils.download_file_from_hf("NeelNanda/sparse_autoencoder", f"{version}.pt", force_is_torch=True))
        return self

In [88]:
# %%
def shuffle_data(all_tokens):
    print("Shuffled data")
    return all_tokens[torch.randperm(all_tokens.shape[0])]

loading_data_first_time = False
if loading_data_first_time:
    data = load_dataset("NeelNanda/c4-code-tokenized-2b", split="train", cache_dir="/workspace/cache/")
    data.save_to_disk("/workspace/data/c4_code_tokenized_2b.hf")
    data.set_format(type="torch", columns=["tokens"])
    all_tokens = data["tokens"]
    all_tokens.shape


    all_tokens_reshaped = einops.rearrange(all_tokens, "batch (x seq_len) -> (batch x) seq_len", x=8, seq_len=128)
    all_tokens_reshaped[:, 0] = model.tokenizer.bos_token_id
    all_tokens_reshaped = all_tokens_reshaped[torch.randperm(all_tokens_reshaped.shape[0])]
    torch.save(all_tokens_reshaped, "/workspace/data/c4_code_2b_tokens_reshaped.pt")
else:
    # data = datasets.load_from_disk("/workspace/data/c4_code_tokenized_2b.hf")
    all_tokens = torch.load("/workspace/data/c4_code_2b_tokens_reshaped.pt")
    all_tokens = shuffle_data(all_tokens)

  all_tokens = torch.load("/workspace/data/c4_code_2b_tokens_reshaped.pt")


Shuffled data


In [90]:
# %%
class Buffer():
    """
    This defines a data buffer, to store a bunch of MLP acts that can be used to train the autoencoder. It'll automatically run the model to generate more when it gets halfway empty.
    """
    def __init__(self, cfg):
        self.buffer = torch.zeros((cfg["buffer_size"], cfg["act_size"]), dtype=torch.bfloat16, requires_grad=False).to(cfg["device"])
        self.cfg = cfg
        self.token_pointer = 0
        self.first = True
        self.refresh()

    @torch.no_grad()
    def refresh(self):
        self.pointer = 0
        with torch.autocast("cuda", torch.bfloat16):
            if self.first:
                num_batches = self.cfg["buffer_batches"]
            else:
                num_batches = self.cfg["buffer_batches"]//2
            self.first = False
            for _ in range(0, num_batches, self.cfg["model_batch_size"]):
                tokens = all_tokens[self.token_pointer:self.token_pointer+self.cfg["model_batch_size"]]
                _, cache = model.run_with_cache(tokens, stop_at_layer=cfg["layer"]+1, names_filter=cfg["act_name"])
                acts = cache[cfg["act_name"]].reshape(-1, self.cfg["act_size"])

                # print(tokens.shape, acts.shape, self.pointer, self.token_pointer)
                self.buffer[self.pointer: self.pointer+acts.shape[0]] = acts
                self.pointer += acts.shape[0]
                self.token_pointer += self.cfg["model_batch_size"]
                # if self.token_pointer > all_tokens.shape[0] - self.cfg["model_batch_size"]:
                #     self.token_pointer = 0

        self.pointer = 0
        self.buffer = self.buffer[torch.randperm(self.buffer.shape[0]).to(cfg["device"])]

    @torch.no_grad()
    def next(self):
        out = self.buffer[self.pointer:self.pointer+self.cfg["batch_size"]]
        self.pointer += self.cfg["batch_size"]
        if self.pointer > self.buffer.shape[0]//2 - self.cfg["batch_size"]:
            # print("Refreshing the buffer!")
            self.refresh()
        return out

# buffer.refresh()
 # %%

# %%

In [91]:
def replacement_hook(mlp_post, hook, encoder):
    mlp_post_reconstr = encoder(mlp_post)[1]
    return mlp_post_reconstr

def mean_ablate_hook(mlp_post, hook):
    mlp_post[:] = mlp_post.mean([0, 1])
    return mlp_post

def zero_ablate_hook(mlp_post, hook):
    mlp_post[:] = 0.
    return mlp_post

@torch.no_grad()
def get_recons_loss(num_batches=5, local_encoder=None):
    if local_encoder is None:
        local_encoder = encoder
    loss_list = []
    for i in range(num_batches):
        tokens = all_tokens[torch.randperm(len(all_tokens))[:cfg["model_batch_size"]]]
        loss = model(tokens, return_type="loss")
        recons_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(cfg["act_name"], partial(replacement_hook, encoder=local_encoder))])
        # mean_abl_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(cfg["act_name"], mean_ablate_hook)])
        zero_abl_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(cfg["act_name"], zero_ablate_hook)])
        loss_list.append((loss, recons_loss, zero_abl_loss))
    losses = torch.tensor(loss_list)
    loss, recons_loss, zero_abl_loss = losses.mean(0).tolist()

    print(loss, recons_loss, zero_abl_loss)
    score = ((zero_abl_loss - recons_loss)/(zero_abl_loss - loss))
    print(f"{score:.2%}")
    # print(f"{((zero_abl_loss - mean_abl_loss)/(zero_abl_loss - loss)).item():.2%}")
    return score, loss, recons_loss, zero_abl_loss
# print(get_recons_loss())

# %%
# Frequency
@torch.no_grad()
def get_freqs(num_batches=25, local_encoder=None):
    if local_encoder is None:
        local_encoder = encoder
    act_freq_scores = torch.zeros(local_encoder.d_hidden, dtype=torch.float32).to(cfg["device"])
    total = 0
    for i in tqdm.trange(num_batches):
        tokens = all_tokens[torch.randperm(len(all_tokens))[:cfg["model_batch_size"]]]

        _, cache = model.run_with_cache(tokens, stop_at_layer=cfg["layer"]+1, names_filter=cfg["act_name"])
        acts = cache[cfg["act_name"]]
        acts = acts.reshape(-1, cfg["act_size"])

        hidden = local_encoder(acts)[2]

        act_freq_scores += (hidden > 0).sum(0)
        total+=hidden.shape[0]
    act_freq_scores /= total
    num_dead = (act_freq_scores==0).float().mean()
    print("Num dead", num_dead)
    return act_freq_scores
# %%
@torch.no_grad()
def re_init(indices, encoder):
    new_W_enc = (torch.nn.init.kaiming_uniform_(torch.zeros_like(encoder.W_enc)))
    new_W_dec = (torch.nn.init.kaiming_uniform_(torch.zeros_like(encoder.W_dec)))
    new_b_enc = (torch.zeros_like(encoder.b_enc))
    print(new_W_dec.shape, new_W_enc.shape, new_b_enc.shape)
    encoder.W_enc.data[:, indices] = new_W_enc[:, indices]
    encoder.W_dec.data[indices, :] = new_W_dec[indices, :]
    encoder.b_enc.data[indices] = new_b_enc[indices]

In [108]:
# %%
encoder = AutoEncoder(cfg)
buffer = Buffer(cfg)
# Code used to remove the "rare freq direction", the shared direction among the ultra low frequency features.
# I experimented with removing it and retraining the autoencoder.
if cfg["remove_rare_dir"]:
    rare_freq_dir = torch.load("rare_freq_dir.pt")
    rare_freq_dir.requires_grad = False

# %%
try:
    wandb.init(project="autoencoder", entity="neelnanda-io")
    num_batches = cfg["num_tokens"] // cfg["batch_size"]
    # model_num_batches = cfg["model_batch_size"] * num_batches
    encoder_optim = torch.optim.Adam(encoder.parameters(), lr=cfg["lr"], betas=(cfg["beta1"], cfg["beta2"]))
    recons_scores = []
    act_freq_scores_list = []
    for i in tqdm.trange(num_batches):
        i = i % all_tokens.shape[0]
        acts = buffer.next()
        loss, x_reconstruct, mid_acts, l2_loss, l1_loss = encoder(acts)
        loss.backward()
        encoder.make_decoder_weights_and_grad_unit_norm()
        encoder_optim.step()
        encoder_optim.zero_grad()
        loss_dict = {"loss": loss.item(), "l2_loss": l2_loss.item(), "l1_loss": l1_loss.item()}
        del loss, x_reconstruct, mid_acts, l2_loss, l1_loss, acts
        if (i) % 100 == 0:
            wandb.log(loss_dict)
            print(loss_dict)
        if (i) % 1000 == 0:
            x = (get_recons_loss(local_encoder=encoder))
            print("Reconstruction:", x)
            recons_scores.append(x[0])
            freqs = get_freqs(5, local_encoder=encoder)
            act_freq_scores_list.append(freqs)
            # histogram(freqs.log10(), marginal="box", histnorm="percent", title="Frequencies")
            wandb.log({
                "recons_score": x[0],
                "dead": (freqs==0).float().mean().item(),
                "below_1e-6": (freqs<1e-6).float().mean().item(),
                "below_1e-5": (freqs<1e-5).float().mean().item(),
            })
        if (i+1) % 30000 == 0:
            encoder.save()
            wandb.log({"reset_neurons": 0.0})
            freqs = get_freqs(50, local_encoder=encoder)
            to_be_reset = (freqs<10**(-5.5))
            print("Resetting neurons!", to_be_reset.sum())
            re_init(to_be_reset, encoder)
finally:
    encoder.save()

OutOfMemoryError: CUDA out of memory. Tried to allocate 16.00 GiB. GPU 0 has a total capacity of 39.56 GiB of which 3.24 GiB is free. Process 54264 has 36.32 GiB memory in use. Of the allocated memory 35.30 GiB is allocated by PyTorch, and 532.81 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)