<a href="https://colab.research.google.com/github/amantimalsina/Mamba-SAE/blob/main/training_a_sparse_autoencoder_for_mamba.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# A very basic SAE Training Tutorial

Please note that it is very easy for tutorial code to go stale so please have a low bar for raising an issue in the

## Setup

In [1]:
%pip install sae-lens transformer-lens circuitsvis

# MambaLens:
!pip install git+https://github.com/Phylliida/MambaLens.git
# For faster inference in SSMs:
!pip install causal_conv1d mamba-ssm

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/pip/_vendor/pkg_resources/__init__.py", line 3108, in _dep_map
    return self.__dep_map
  File "/usr/local/lib/python3.10/dist-packages/pip/_vendor/pkg_resources/__init__.py", line 2901, in __getattr__
    raise AttributeError(attr)
AttributeError: _DistInfoDistribution__dep_map

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/pip/_internal/cli/base_command.py", line 169, in exc_logging_wrapper
    status = run_func(*args)
  File "/usr/local/lib/python3.10/dist-packages/pip/_internal/cli/req_command.py", line 242, in wrapper
    return func(self, options, args)
  File "/usr/local/lib/python3.10/dist-packages/pip/_internal/commands/install.py", line 377, in run
    requirement_set = resolver.resolve(
  File "/usr/local/lib/python3.10/dist-packages/pip/_internal/resolution/resolvelib/resolver.py", line 

In [24]:
# %%
import os
import json
import random
from pathlib import Path
import gc
# %%
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
# %%
import wandb
from tqdm import tqdm
import pprint
# %%
import argparse
# %%
from transformers import AutoTokenizer
import datasets
from datasets import load_dataset
# %%
from transformer_lens.utils import test_prompt
import circuitsvis as cv  # optional dep, install with pip install circuitsvis
from functools import partial
# %%
import mamba_lens

In [25]:
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print("Using device:", device)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

Using device: cuda


# Model Selection and Evaluation (Feel Free to Skip)

We'll use the runner to train an SAE on a TinyStories Model. This is a very small model so we can train an SAE on it quite quickly. Before we get started, let's load in the model with `transformer_lens` and see what it can do.

TransformerLens gives us 2 functions that are useful here (and circuits viz provides a third):
1. `transformer_lens.utils.test_prompt` will help us see when the model can infer one token.
2. `HookedTransformer.generate` will help us see what happens when we sample from the model.
3. `circuitsvis.logits.token_log_probs` will help us visualize the log probs of tokens at several positions in a prompt.

In [26]:
model = mamba_lens.HookedMamba.from_pretrained(
                              "state-spaces/mamba-130m",
                              device='cuda'
                              )
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')

  return self.fget.__get__(instance, owner)()
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Moving model to device:  cuda


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [27]:

n_layers = model.cfg.n_layers
d_model = model.cfg.d_model
d_state = model.cfg.d_state
d_conv = model.cfg.d_conv
n_ctx = model.cfg.n_ctx
d_inner = model.cfg.d_inner
print(f"n_layers: {n_layers}")
print(f"d_model: {d_model}")
print(f"d_state: {d_state}")
print(f"d_conv: {d_conv}")
print(f"n_ctx: {n_ctx}")
print(f"d_inner: {d_inner}")
print(model.cfg)

n_layers: 24
d_model: 768
d_state: 16
d_conv: 4
n_ctx: 2048
d_inner: 1536
MambaCfg(d_model=768, n_layers=24, vocab_size=50280, d_state=16, expand=2, dt_rank=48, d_conv=4, pad_vocab_size_multiple=8, conv_bias=True, bias=False, default_prepend_bos=True, tokenizer_prepends_bos=False, n_ctx=2048, device='cuda', initializer_cfg=MambaInitCfg(initializer_range=(0.02,), rescale_prenorm_residual=(True,), n_residuals_per_layer=(1,), dt_init=('random',), dt_scale=(1.0,), dt_min=(0.001,), dt_max=(0.1,), dt_init_floor=0.0001), d_inner=1536)


In [28]:

print(model)

HookedMamba(
  (embedding): Embedding(50280, 768)
  (hook_embed): HookPoint()
  (blocks): ModuleList(
    (0-23): 24 x HookedMambaBlock(
      (hook_resid_pre): HookPoint()
      (hook_layer_input): HookPoint()
      (norm): RMSNorm()
      (hook_normalized_input): HookPoint()
      (skip_proj): Linear(in_features=768, out_features=1536, bias=False)
      (hook_skip): HookPoint()
      (in_proj): Linear(in_features=768, out_features=1536, bias=False)
      (hook_in_proj): HookPoint()
      (conv1d): Conv1d(1536, 1536, kernel_size=(4,), stride=(1,), padding=(3,), groups=1536)
      (hook_conv): HookPoint()
      (hook_ssm_input): HookPoint()
      (W_delta_1): Linear(in_features=1536, out_features=48, bias=False)
      (W_delta_2): Linear(in_features=48, out_features=1536, bias=True)
      (W_B): Linear(in_features=1536, out_features=16, bias=False)
      (W_C): Linear(in_features=1536, out_features=16, bias=False)
      (hook_h_start): HookPoint()
      (hook_delta_1): HookPoint()
    

Let's start by generating some stories using the model.

In [29]:

torch.cuda.empty_cache()

In [30]:
# Feed it back to the model and keep predicting tokens:
prompt = "Once upon a time, there was a little girl named Lily. She lived in a big, happy little girl. On her big adventure,"
for i in range(2):
    tokens = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    logits, activations = model.run_with_cache(tokens,
                                               fast_ssm=True,
                                               fast_conv=True,
                                               warn_disabled_hooks=False
                                               )
    generated_text = tokenizer.batch_decode(logits.argmax(dim=-1)[0])
    prompt += ' '.join(generated_text)
print(prompt)

Once upon a time, there was a little girl named Lily. She lived in a big, happy little girl. On her big adventure, again  a  time ,  there  was  a  man  girl  named  Alice  who  She  was  in  a  small  house  big  house  house 's  She  the  little , ,  she again  a  time ,  there  was  a  man  girl  named  Alice  who  She  was  in  a  small  house  big  house  house 's  She  the  little , ,  she , 
  little little ,  she she  was was  a a  little little  named who    named    L ice . who    was    was    a    a    big    girl    in    girl    in    s    big    was    little          little   


### Spot checking model abilities with `transformer_lens.utils.test_prompt`

In [31]:
# Test the model with a prompt
test_prompt(
    "Once upon a time, there was a little girl named Lily. She lived in a big, happy little girl. On her big adventure,",
    " Lily",
    model,
    prepend_space_to_answer=False,
)

Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a', ' time', ',', ' there', ' was', ' a', ' little', ' girl', ' named', ' Lily', '.', ' She', ' lived', ' in', ' a', ' big', ',', ' happy', ' little', ' girl', '.', ' On', ' her', ' big', ' adventure', ',']
Tokenized answer: [' Lily']


Top 0th token. Logit: 25.97 Prob: 63.23% Token: | she|
Top 1th token. Logit: 24.37 Prob: 12.73% Token: | Lily|
Top 2th token. Logit: 23.73 Prob:  6.74% Token: | the|
Top 3th token. Logit: 23.00 Prob:  3.25% Token: | her|
Top 4th token. Logit: 21.95 Prob:  1.14% Token: | a|
Top 5th token. Logit: 21.90 Prob:  1.08% Token: | there|
Top 6th token. Logit: 21.52 Prob:  0.74% Token: | they|
Top 7th token. Logit: 21.24 Prob:  0.56% Token: | it|
Top 8th token. Logit: 20.74 Prob:  0.34% Token: | he|
Top 9th token. Logit: 20.61 Prob:  0.30% Token: | all|


In the output above, we see that the model assigns ~ 70% probability to "she" being the next token, and a 13% chance to " Lily" being the next token. Other names like Lucy or Anna are not highly ranked.

### Exploring Model Capabilities with Log Probs

Looking at token ranking for a single prompt is interesting, but a much higher through way to understand models is to look at token log probs for all tokens in text. We can use the `circuits_vis` package to get a nice visualization where we can see tokenization, and hover to get the top5 tokens by log probability. Darker tokens are tokens where the model assigned a higher probability to the actual next token.

In [32]:
# Let's make a longer prompt and see the log probabilities of the tokens
example_prompt = """Hi, how are you doing this? I'm really enjoying your posts"""
logits, cache = model.run_with_cache(example_prompt)
cv.logits.token_log_probs(
    model.to_tokens(example_prompt),
    model(example_prompt)[0].log_softmax(dim=-1),
    model.to_string,
)
# hover on the output to see the result.

In [33]:
"""
# Code used to remove the "rare freq direction", the shared direction among the ultra low frequency features.
# I experimented with removing it and retraining the autoencoder.
if cfg["remove_rare_dir"]:
    rare_freq_dir = torch.load("rare_freq_dir.pt")
    rare_freq_dir.requires_grad = False

# %%
"""
# Training cfg:
cfg = {
    # Data parameters
    "num_tokens": int(1e5),  # Total number of tokens to use
    "batch_size": 32,  # Batch size for training

    # Model parameters
    "act_name": "hook_norm",  # Name of the activation to extract
    "dict_size": 1536,#3072,
    "l1_coeff": 3e-4,
    "beta1": 0.9,
    "beta2": 0.99,
    "dict_mult": 32,
    "seq_len": 128,
    "remove_rare_dir": False,
    "device": "cuda:0",
    "enc_dtype": "fp32",
    "seed": 16,
    "act_size": 768,
    "model_batch_size": 32,

    # Training parameters
    "num_epochs": 5,  # Number of epochs to train
    "lr": 1e-3,  # Learning rate
    "beta1": 0.9,  # Adam optimizer beta1
    "beta2": 0.999,  # Adam optimizer beta2

    # Regularization
    "l1_weight": 1e-5,  # L1 regularization weight
    "l2_weight": 1e-5,  # L2 regularization weight

    # Logging and checkpointing
    "log_every": 50,  # Log every N batches
    "eval_every": 100,  # Evaluate reconstruction every N batches
    "recons_every": 500,  # Reconstruct every N batches
    "save_every": 500,  # Save model every N batches
    "reset_freq_threshold": 10**(-5.5),  # Frequency threshold for resetting neurons

    # Wandb configuration
    "wandb_project": "mamba_autoencoder",
    "wandb_name": None,

    # Model specific (you might need to adjust these)
    "encoder_hidden_sizes": [512, 256],  # Hidden layer sizes for encoder
    "decoder_hidden_sizes": [256, 512],  # Hidden layer sizes for decoder
    "latent_dim": 64,  # Dimension of the latent space
}

# Dataset Preparation:

In [34]:
# %%
def shuffle_data(all_tokens):
    print("Shuffled data")
    return all_tokens[torch.randperm(all_tokens.shape[0])]

num_tokens=cfg['num_tokens']
loading_data_first_time = False
if loading_data_first_time:
    data = load_dataset("NeelNanda/c4-code-tokenized-2b", split="train", cache_dir="/workspace/cache/")
    data.save_to_disk("/workspace/data/c4_code_tokenized_2b.hf")
    data.set_format(type="torch", columns=["tokens"])
    limited_tokens = data["tokens"][:num_tokens]
    print(limited_tokens.shape)


    limited_tokens_reshaped = einops.rearrange(limited_tokens, "batch (x seq_len) -> (batch x) seq_len", x=8, seq_len=128)
    limited_tokens_reshaped[:, 0] = model.tokenizer.bos_token_id
    limited_tokens_reshaped = limited_tokens_reshaped[torch.randperm(limited_tokens_reshaped.shape[0])]
    torch.save(limited_tokens_reshaped, "/workspace/data/c4_code_2e5_tokens_reshaped.pt")

    print(f"Saved {limited_tokens_reshaped.shape[0]}")
else:
    #data = datasets.load_from_disk("/workspace/data/c4_code_tokenized_2b.hf")
    all_tokens = torch.load("/workspace/data/c4_code_2e5_limited_tokens_reshaped.pt")
    all_tokens = shuffle_data(all_tokens)

Shuffled data


In [35]:
all_tokens = shuffle_data(all_tokens[:num_tokens,])
print(all_tokens.shape)
print(all_tokens[0])

Shuffled data
torch.Size([100000, 128])
tensor([    0,    65, 33653,    65, 42524,    28, 21315,  4873,    65,  1636,
           75,    65,  9097,    16, 15983,    65,    88,  4753,    10,    88,
        13043,    65,   302,    14,  2176,    11, 23125,  1076, 11475, 16808,
        11916, 40056,  4078,  4021, 32048,  1525, 26132, 14769,  3451, 11916,
        16808,   804,  5166,    65,  6271,    65,  5238,    10,  1262,    14,
         8758,    31,    18,  2192,  5292,   603,  8758,  2224,   470,    28,
         7959,  1827,    16,  6271,    65,  5238,    65, 11628,   426,   338,
         7959,  1076, 23125,  6191,   426, 31681,    10,  3176,    31,  4874,
           10,  1262,    16,  6271,    65, 10649,    65,  3434,  1210, 23125,
          804,   279,    65, 14837,    65,  6271,    65,  5238,    10, 10649,
         2192,  7959,   278,  5015,    65,  9097,   426,  1827,    16,  6271,
           65,  1636,    75,    65,  3434,    61, 10649,    61,   298,  8092,
           65,  3361,   

In [36]:
from torch.utils.data import Dataset, DataLoader

In [37]:

class TokenDataset(Dataset):
    def __init__(self, tokens, max_tokens=int(2e5)):
        self.tokens = tokens[:max_tokens]

    def __len__(self):
        return len(self.tokens)

    def __getitem__(self, idx):
        return self.tokens[idx]

# Load your data
all_tokens = all_tokens.to(device)  # Move the data to the appropriate device

# Create dataset and dataloader
dataset = TokenDataset(all_tokens)
dataloader = DataLoader(dataset, batch_size=cfg["batch_size"], shuffle=True)

# Autoencoder Class

In [38]:
SAVE_DIR = Path("/workspace/checkpoints")
class AutoEncoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        d_hidden = cfg["dict_size"]
        l1_coeff = cfg["l1_coeff"]
        dtype = DTYPES[cfg["enc_dtype"]]
        torch.manual_seed(cfg["seed"])
        self.W_enc = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(cfg["act_size"], d_hidden, dtype=dtype)))
        self.W_dec = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(d_hidden, cfg["act_size"], dtype=dtype)))
        self.b_enc = nn.Parameter(torch.zeros(d_hidden, dtype=dtype))
        self.b_dec = nn.Parameter(torch.zeros(cfg["act_size"], dtype=dtype))

        self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)

        self.d_hidden = d_hidden
        self.l1_coeff = l1_coeff

        self.to(cfg["device"])

    def forward(self, x):
        #print(f"x.shape: {x.shape}")
        x_cent = x - self.b_dec
        #print(f"x_cent.shape: {x_cent.shape}")
        acts = F.relu(x_cent @ self.W_enc + self.b_enc)
        #print(f"acts.shape: {acts.shape}")
        x_reconstruct = acts @ self.W_dec + self.b_dec
        #print(f"x_reconstruct.shape: {x_reconstruct.shape}")
        l2_loss = (x_reconstruct.float() - x.float()).pow(2).sum(-1).mean()
        #print(f"l2_loss: {l2_loss}")
        l1_loss = self.l1_coeff * (acts.float().abs().sum())
        #print(f"l1_loss: {l1_loss}")
        loss = l2_loss + l1_loss
        #print(f"loss: {loss}")
        return loss, x_reconstruct, acts, l2_loss, l1_loss

    @torch.no_grad()
    def make_decoder_weights_and_grad_unit_norm(self):
        W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
        W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum(-1, keepdim=True) * W_dec_normed
        self.W_dec.grad -= W_dec_grad_proj
        # Bugfix(?) for ensuring W_dec retains unit norm, this was not there when I trained my original autoencoders.
        self.W_dec.data = W_dec_normed

    def get_version(self):
        version_list = [int(file.name.split(".")[0]) for file in list(SAVE_DIR.iterdir()) if "pt" in str(file)]
        if len(version_list):
            return 1+max(version_list)
        else:
            return 0

    def save(self):
        version = self.get_version()
        torch.save(self.state_dict(), SAVE_DIR/(str(version)+".pt"))
        with open(SAVE_DIR/(str(version)+"_cfg.json"), "w") as f:
            json.dump(cfg, f)
        print("Saved as version", version)

    @classmethod
    def load(cls, version):
        cfg = (json.load(open(SAVE_DIR/(str(version)+"_cfg.json"), "r")))
        pprint.pprint(cfg)
        self = cls(cfg=cfg)
        self.load_state_dict(torch.load(SAVE_DIR/(str(version)+".pt")))
        return self

    @classmethod
    def load_from_hf(cls, version):
        """
        Loads the saved autoencoder from HuggingFace.

        Version is expected to be an int, or "run1" or "run2"

        version 25 is the final checkpoint of the first autoencoder run,
        version 47 is the final checkpoint of the second autoencoder run.
        """
        if version=="run1":
            version = 25
        elif version=="run2":
            version = 47

        cfg = utils.download_file_from_hf("NeelNanda/sparse_autoencoder", f"{version}_cfg.json")
        pprint.pprint(cfg)
        self = cls(cfg=cfg)
        self.load_state_dict(utils.download_file_from_hf("NeelNanda/sparse_autoencoder", f"{version}.pt", force_is_torch=True))
        return self

## Utilities for Training

In [39]:
SEED = cfg["seed"]
GENERATOR = torch.manual_seed(SEED)
DTYPES = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
np.random.seed(SEED)
random.seed(SEED)
torch.set_grad_enabled(True)

n_layers = model.cfg.n_layers
d_model = model.cfg.d_model
# %%
@torch.no_grad()
def get_sub_acts(tokens, batch_size=1024):
    outputs, cache = model.run_with_cache(tokens,
                                          fast_ssm=True,
                                          fast_conv=True,
                                          warn_disabled_hooks=False,
                                          names_filter=cfg["act_name"])
    acts = cache[cfg["act_name"]]
    del outputs, cache
    torch.cuda.empty_cache()
    gc.collect()
    acts = acts.reshape(-1, acts.shape[-1])
    subsample = torch.randperm(acts.shape[0], generator=GENERATOR)[:batch_size]
    subsampled_acts = acts[subsample, :]
    return subsampled_acts, acts
# sub, acts = get_acts(torch.arange(20).reshape(2, 10), batch_size=3)
# sub.shape, acts.shape
# %%
@torch.no_grad()
def get_acts(tokens, batch_size=1024):
    outputs, cache = model.run_with_cache(tokens,
                                          fast_ssm=True,
                                          fast_conv=True,
                                          warn_disabled_hooks=False,
                                          names_filter=cfg["act_name"])
    acts = cache[cfg["act_name"]]
    del outputs, cache
    torch.cuda.empty_cache()
    gc.collect()
    acts = acts.reshape(-1, acts.shape[-1])
    return acts

In [40]:
def replacement_hook(pre_linear, hook, encoder):
    pre_linear_reconstr = encoder(pre_linear)[1]
    return pre_linear_reconstr

def mean_ablate_hook(pre_linear, hook):
    pre_linear[:] = pre_linear.mean([0, 1])
    return pre_linear

def zero_ablate_hook(pre_linear, hook):
    pre_linear[:] = 0.
    return pre_linear

@torch.no_grad()
def get_recons_loss(num_batches=5, local_encoder=None):
    if local_encoder is None:
        local_encoder = encoder
    loss_list = []
    for i in range(num_batches):
        tokens = all_tokens[torch.randperm(len(all_tokens))[:cfg["model_batch_size"]]]
        loss = model(tokens, return_type="loss")
        recons_loss = model.run_with_hooks(tokens, return_type="loss",
                                           fast_ssm=True,
                                           fast_conv=True,
                                           fwd_hooks=[(cfg["act_name"], partial(replacement_hook, encoder=local_encoder))])
        # mean_abl_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(cfg["act_name"], mean_ablate_hook)])
        zero_abl_loss = model.run_with_hooks(tokens, return_type="loss",
                                             fast_ssm=True,
                                             fast_conv=True,
                                             fwd_hooks=[(cfg["act_name"], zero_ablate_hook)])
        loss_list.append((loss, recons_loss, zero_abl_loss))
    losses = torch.tensor(loss_list)
    loss, recons_loss, zero_abl_loss = losses.mean(0).tolist()

    print(loss, recons_loss, zero_abl_loss)
    score = ((zero_abl_loss - recons_loss)/(zero_abl_loss - loss))
    print(f"{score:.2%}")
    # print(f"{((zero_abl_loss - mean_abl_loss)/(zero_abl_loss - loss)).item():.2%}")
    return score, loss, recons_loss, zero_abl_loss
# print(get_recons_loss())

# %%
# Frequency
@torch.no_grad()
def get_freqs(num_batches=25, local_encoder=None):
    if local_encoder is None:
        local_encoder = encoder
    act_freq_scores = torch.zeros(local_encoder.d_hidden, dtype=torch.float32).to(cfg["device"])
    total = 0
    for i in tqdm(range(num_batches)):
        tokens = all_tokens[torch.randperm(len(all_tokens))[:cfg["model_batch_size"]]]

        _, cache = model.run_with_cache(tokens, names_filter=cfg["act_name"], fast_ssm=True, fast_conv=True, warn_disabled_hooks=False)
        acts = cache[cfg["act_name"]]
        acts = acts.reshape(-1, cfg["act_size"])

        hidden = local_encoder(acts)[2]

        act_freq_scores += (hidden > 0).sum(0)
        total+=hidden.shape[0]
    act_freq_scores /= total
    num_dead = (act_freq_scores==0).float().mean()
    print("Num dead", num_dead)
    return act_freq_scores
# %%
@torch.no_grad()
def re_init(indices, encoder):
    new_W_enc = (torch.nn.init.kaiming_uniform_(torch.zeros_like(encoder.W_enc)))
    new_W_dec = (torch.nn.init.kaiming_uniform_(torch.zeros_like(encoder.W_dec)))
    new_b_enc = (torch.zeros_like(encoder.b_enc))
    print(new_W_dec.shape, new_W_enc.shape, new_b_enc.shape)
    encoder.W_enc.data[:, indices] = new_W_enc[:, indices]
    encoder.W_dec.data[indices, :] = new_W_dec[indices, :]
    encoder.b_enc.data[indices] = new_b_enc[indices]

In [41]:
lim_tokens = all_tokens[:32]
print(lim_tokens.shape)

outputs, MambaActs = model.run_with_cache(
                                        lim_tokens,
                                        fast_ssm=True,
                                        fast_conv=True,
                                        warn_disabled_hooks=False
                                        )

acts = MambaActs[cfg["act_name"]]
print(acts.shape)

torch.Size([32, 128])


OutOfMemoryError: CUDA out of memory. Tried to allocate 24.00 MiB. GPU 0 has a total capacty of 14.75 GiB of which 1.06 MiB is free. Process 81013 has 7.59 GiB memory in use. Process 80106 has 7.16 GiB memory in use. Of the allocated memory 7.37 GiB is allocated by PyTorch, and 90.65 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
del outputs, MambaActs
torch.cuda.empty_cache()
gc.collect()

## Training Run

In [None]:
# %%
DTYPES = {
    "fp32": torch.float32,
    "fp16": torch.float16,
    "bf16": torch.bfloat16
} # Add this line to define DTYPES

encoder = AutoEncoder(cfg)

In [None]:
# Wandb Args:
wandb_args = {
  "wandb_project": "mamba-sae",
  "wandb_name": None,
}

try:
    wandb.init(
        project=wandb_args["wandb_project"],
        name=wandb_args["wandb_name"],
    )

    num_batches = len(dataloader)
    print(f"Number of tokens: {cfg['num_tokens']}")
    print(f"Batch size: {cfg['batch_size']}")
    print(f"Number of batches: {num_batches}")

    encoder_optim = torch.optim.Adam(encoder.parameters(), lr=cfg["lr"], betas=(cfg["beta1"], cfg["beta2"]))
    recons_scores = []
    act_freq_scores_list = []

    for epoch in range(cfg["num_epochs"]):  # Add an outer loop for epochs if needed
        for i, tokens in enumerate(tqdm(dataloader)):
            tokens = tokens.to(device)

            #sub_acts = get_sub_acts(tokens, 16)[1]
            #acts = sub_acts
            outputs, cache = model.run_with_cache(tokens,
                                          fast_ssm=True,
                                          fast_conv=True,
                                          warn_disabled_hooks=False,
                                          names_filter=cfg["act_name"])
            acts = cache[cfg["act_name"]]
            del outputs, cache
            torch.cuda.empty_cache()
            gc.collect()

            acts = acts.reshape(-1, cfg["act_size"])

            loss, x_reconstruct, mid_acts, l2_loss, l1_loss = encoder(acts)
            loss.backward()

            encoder.make_decoder_weights_and_grad_unit_norm()
            encoder_optim.step()
            encoder_optim.zero_grad()

            loss_dict = {"loss": loss.item(), "l2_loss": l2_loss.item(), "l1_loss": l1_loss.item()}
            del loss, x_reconstruct, mid_acts, l2_loss, l1_loss, acts
            torch.cuda.empty_cache()
            gc.collect()

            """
            print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
            print("torch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024))
            print("torch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(0)/1024/1024/1024))
            """

            if (i + 1) % cfg["log_every"] == 0:
                wandb.log(loss_dict)
                print(loss_dict)

            if (i + 1) % cfg["recons_every"] == 0:
                x = get_recons_loss(local_encoder=encoder)
                print("Reconstruction:", x)
                recons_scores.append(x[0])
                freqs = get_freqs(5, local_encoder=encoder)
                act_freq_scores_list.append(freqs)
                wandb.log({
                    "recons_score": x[0],
                    "dead": (freqs==0).float().mean().item(),
                    "below_1e-6": (freqs<1e-6).float().mean().item(),
                    "below_1e-5": (freqs<1e-5).float().mean().item(),
                })

            if (i + 1) % cfg["save_every"] == 0:
                encoder.save()
                wandb.log({"reset_neurons": 0.0})
                freqs = get_freqs(50, local_encoder=encoder)
                to_be_reset = (freqs < cfg["reset_freq_threshold"])
                print("Resetting neurons!", to_be_reset.sum())
                re_init(to_be_reset, encoder)

finally:
    encoder.save()

In [None]:
print(encoder)

In [None]:
encoder.load(72)

In [None]:
from huggingface_hub import login
login()  # This will prompt for your auth token if you haven't saved it before

In [None]:
from huggingface_hub import HfApi, Repository

In [None]:
repo_name = "amantimalsina/mamba-sae"
api = HfApi()
repo_url = api.create_repo(repo_name, exist_ok=True)

# 4. Clone the empty repository
repo = Repository(local_dir="./my_model_repo", clone_from=repo_url)

In [None]:
version = 72
import shutil
shutil.copy(SAVE_DIR/(str(version)+".pt"), f"./my_model_repo/{version}.pt")
shutil.copy(SAVE_DIR/(str(version)+"_cfg.json"), f"./my_model_repo/{version}_cfg.json")

# 6. Create a simple README
with open("./my_model_repo/README.md", "w") as f:
    f.write(f"# AutoEncoder Model\n\nThis is a custom AutoEncoder model for Mamba-130m. Version: {version}")

In [None]:


# 6. Push the changes to the Hub
repo.git_add()
repo.git_commit("Added encoder for ssm_output")
repo.git_push()

print(f"Model saved to https://huggingface.co/{repo_name}")