<a href="https://colab.research.google.com/github/amantimalsina/Mamba-SAE/blob/main/training_a_sparse_autoencoder_for_mamba.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# A very basic SAE Training Tutorial

Please note that it is very easy for tutorial code to go stale so please have a low bar for raising an issue in the

## Setup

In [1]:
%pip install sae-lens transformer-lens circuitsvis

# MambaLens:
!pip install git+https://github.com/Phylliida/MambaLens.git
# For faster inference in SSMs:
!pip install causal_conv1d mamba-ssm

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/pip/_vendor/pkg_resources/__init__.py", line 3108, in _dep_map
    return self.__dep_map
  File "/usr/local/lib/python3.10/dist-packages/pip/_vendor/pkg_resources/__init__.py", line 2901, in __getattr__
    raise AttributeError(attr)
AttributeError: _DistInfoDistribution__dep_map

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/pip/_internal/cli/base_command.py", line 169, in exc_logging_wrapper
    status = run_func(*args)
  File "/usr/local/lib/python3.10/dist-packages/pip/_internal/cli/req_command.py", line 242, in wrapper
    return func(self, options, args)
  File "/usr/local/lib/python3.10/dist-packages/pip/_internal/commands/install.py", line 377, in run
    requirement_set = resolver.resolve(
  File "/usr/local/lib/python3.10/dist-packages/pip/_internal/resolution/resolvelib/resolver.py", line 

In [1]:
# %%
import os
import json
import random
from pathlib import Path
import gc
# %%
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
# %%
import wandb
from tqdm import tqdm
import pprint
# %%
import argparse
# %%
from transformers import AutoTokenizer
import datasets
from datasets import load_dataset
# %%
from transformer_lens.utils import test_prompt
import circuitsvis as cv  # optional dep, install with pip install circuitsvis
from functools import partial
# %%
import mamba_lens

In [2]:
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print("Using device:", device)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

Using device: cuda


# Model Selection and Evaluation (Feel Free to Skip)

We'll use the runner to train an SAE on a TinyStories Model. This is a very small model so we can train an SAE on it quite quickly. Before we get started, let's load in the model with `transformer_lens` and see what it can do.

TransformerLens gives us 2 functions that are useful here (and circuits viz provides a third):
1. `transformer_lens.utils.test_prompt` will help us see when the model can infer one token.
2. `HookedTransformer.generate` will help us see what happens when we sample from the model.
3. `circuitsvis.logits.token_log_probs` will help us visualize the log probs of tokens at several positions in a prompt.

In [3]:
model = mamba_lens.HookedMamba.from_pretrained(
                              "state-spaces/mamba-130m",
                              device='cuda'
                              )
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
  return self.fget.__get__(instance, owner)()
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Moving model to device:  cuda


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:

n_layers = model.cfg.n_layers
d_model = model.cfg.d_model
d_state = model.cfg.d_state
d_conv = model.cfg.d_conv
n_ctx = model.cfg.n_ctx
d_inner = model.cfg.d_inner
print(f"n_layers: {n_layers}")
print(f"d_model: {d_model}")
print(f"d_state: {d_state}")
print(f"d_conv: {d_conv}")
print(f"n_ctx: {n_ctx}")
print(f"d_inner: {d_inner}")
print(model.cfg)

n_layers: 24
d_model: 768
d_state: 16
d_conv: 4
n_ctx: 2048
d_inner: 1536
MambaCfg(d_model=768, n_layers=24, vocab_size=50280, d_state=16, expand=2, dt_rank=48, d_conv=4, pad_vocab_size_multiple=8, conv_bias=True, bias=False, default_prepend_bos=True, tokenizer_prepends_bos=False, n_ctx=2048, device='cuda', initializer_cfg=MambaInitCfg(initializer_range=(0.02,), rescale_prenorm_residual=(True,), n_residuals_per_layer=(1,), dt_init=('random',), dt_scale=(1.0,), dt_min=(0.001,), dt_max=(0.1,), dt_init_floor=0.0001), d_inner=1536)


In [5]:
print(model)

HookedMamba(
  (embedding): Embedding(50280, 768)
  (hook_embed): HookPoint()
  (blocks): ModuleList(
    (0-23): 24 x HookedMambaBlock(
      (hook_resid_pre): HookPoint()
      (hook_layer_input): HookPoint()
      (norm): RMSNorm()
      (hook_normalized_input): HookPoint()
      (skip_proj): Linear(in_features=768, out_features=1536, bias=False)
      (hook_skip): HookPoint()
      (in_proj): Linear(in_features=768, out_features=1536, bias=False)
      (hook_in_proj): HookPoint()
      (conv1d): Conv1d(1536, 1536, kernel_size=(4,), stride=(1,), padding=(3,), groups=1536)
      (hook_conv): HookPoint()
      (hook_ssm_input): HookPoint()
      (W_delta_1): Linear(in_features=1536, out_features=48, bias=False)
      (W_delta_2): Linear(in_features=48, out_features=1536, bias=True)
      (W_B): Linear(in_features=1536, out_features=16, bias=False)
      (W_C): Linear(in_features=1536, out_features=16, bias=False)
      (hook_h_start): HookPoint()
      (hook_delta_1): HookPoint()
    

Let's start by generating some stories using the model.

In [6]:

torch.cuda.empty_cache()

In [7]:
# Feed it back to the model and keep predicting tokens:
prompt = "Once upon a time, there was a little girl named Lily. She lived in a big, happy little girl. On her big adventure,"
for i in range(2):
    tokens = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    logits, activations = model.run_with_cache(tokens,
                                               fast_ssm=True,
                                               fast_conv=True,
                                               warn_disabled_hooks=False
                                               )
    generated_text = tokenizer.batch_decode(logits.argmax(dim=-1)[0])
    prompt += ' '.join(generated_text)
print(prompt)

Once upon a time, there was a little girl named Lily. She lived in a big, happy little girl. On her big adventure, again  a  time ,  there  was  a  man  girl  named  Alice  who  She  was  in  a  small  house  big  house  house 's  She  the  little , ,  she again  a  time ,  there  was  a  man  girl  named  Alice  who  She  was  in  a  small  house  big  house  house 's  She  the  little , ,  she , 
  little little ,  she she  was was  a a  little little  named who    named    L ice . who    was    was    a    a    big    girl    in    girl    in    s    big    was    little          little   


### Spot checking model abilities with `transformer_lens.utils.test_prompt`

In [8]:
# Test the model with a prompt
test_prompt(
    "Once upon a time, there was a little girl named Lily. She lived in a big, happy little girl. On her big adventure,",
    " Lily",
    model,
    prepend_space_to_answer=False,
)

Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a', ' time', ',', ' there', ' was', ' a', ' little', ' girl', ' named', ' Lily', '.', ' She', ' lived', ' in', ' a', ' big', ',', ' happy', ' little', ' girl', '.', ' On', ' her', ' big', ' adventure', ',']
Tokenized answer: [' Lily']


Top 0th token. Logit: 25.97 Prob: 63.23% Token: | she|
Top 1th token. Logit: 24.37 Prob: 12.73% Token: | Lily|
Top 2th token. Logit: 23.73 Prob:  6.74% Token: | the|
Top 3th token. Logit: 23.00 Prob:  3.25% Token: | her|
Top 4th token. Logit: 21.95 Prob:  1.14% Token: | a|
Top 5th token. Logit: 21.90 Prob:  1.08% Token: | there|
Top 6th token. Logit: 21.52 Prob:  0.74% Token: | they|
Top 7th token. Logit: 21.24 Prob:  0.56% Token: | it|
Top 8th token. Logit: 20.74 Prob:  0.34% Token: | he|
Top 9th token. Logit: 20.61 Prob:  0.30% Token: | all|


In the output above, we see that the model assigns ~ 70% probability to "she" being the next token, and a 13% chance to " Lily" being the next token. Other names like Lucy or Anna are not highly ranked.

### Exploring Model Capabilities with Log Probs

Looking at token ranking for a single prompt is interesting, but a much higher through way to understand models is to look at token log probs for all tokens in text. We can use the `circuits_vis` package to get a nice visualization where we can see tokenization, and hover to get the top5 tokens by log probability. Darker tokens are tokens where the model assigned a higher probability to the actual next token.

In [9]:
# Let's make a longer prompt and see the log probabilities of the tokens
example_prompt = """Hi, how are you doing this? I'm really enjoying your posts"""
logits, cache = model.run_with_cache(example_prompt)
cv.logits.token_log_probs(
    model.to_tokens(example_prompt),
    model(example_prompt)[0].log_softmax(dim=-1),
    model.to_string,
)
# hover on the output to see the result.

In [25]:
"""
# Code used to remove the "rare freq direction", the shared direction among the ultra low frequency features.
# I experimented with removing it and retraining the autoencoder.
if cfg["remove_rare_dir"]:
    rare_freq_dir = torch.load("rare_freq_dir.pt")
    rare_freq_dir.requires_grad = False

# %%
"""
# Training cfg:
cfg = {
    # Data parameters
    "num_tokens": int(4e4),  # Total number of tokens to use
    "batch_size": 32,  # Batch size for training

    # Model parameters
    "act_name": "hook_norm",  # Name of the activation to extract
    "dict_size": 1536,
    "l1_coeff": 3e-4,
    "beta1": 0.9,
    "beta2": 0.99,
    "dict_mult": 32,
    "seq_len": 128,
    "remove_rare_dir": False,
    "device": "cuda:0",
    "enc_dtype": "fp32",
    "seed": 16,
    "act_size": 768,
    "model_batch_size": 32,

    # Training parameters
    "num_epochs": 10,  # Number of epochs to train
    "lr": 1e-3,  # Learning rate
    "beta1": 0.9,  # Adam optimizer beta1
    "beta2": 0.999,  # Adam optimizer beta2

    # Regularization
    "l1_weight": 1e-5,  # L1 regularization weight
    "l2_weight": 1e-5,  # L2 regularization weight

    # Logging and checkpointing
    "log_every": 100,  # Log every N batches
    "eval_every": 100,  # Evaluate reconstruction every N batches
    "recons_every": 500,  # Reconstruct every N batches
    "save_every": 500,  # Save model every N batches
    "reset_freq_threshold": 10**(-5.5),  # Frequency threshold for resetting neurons

    # Wandb configuration
    "wandb_project": "mamba_autoencoder",
    "wandb_name": "experiment_001",

    # Model specific (you might need to adjust these)
    "encoder_hidden_sizes": [512, 256],  # Hidden layer sizes for encoder
    "decoder_hidden_sizes": [256, 512],  # Hidden layer sizes for decoder
    "latent_dim": 64,  # Dimension of the latent space
}

# Dataset Preparation:

In [26]:
# %%
def shuffle_data(all_tokens):
    print("Shuffled data")
    return all_tokens[torch.randperm(all_tokens.shape[0])]

num_tokens=cfg['num_tokens']
loading_data_first_time = False
if loading_data_first_time:
    data = load_dataset("NeelNanda/c4-code-tokenized-2b", split="train", cache_dir="/workspace/cache/")
    data.save_to_disk("/workspace/data/c4_code_tokenized_2b.hf")
    data.set_format(type="torch", columns=["tokens"])
    limited_tokens = data["tokens"][:num_tokens]
    print(limited_tokens.shape)


    limited_tokens_reshaped = einops.rearrange(limited_tokens, "batch (x seq_len) -> (batch x) seq_len", x=8, seq_len=128)
    limited_tokens_reshaped[:, 0] = model.tokenizer.bos_token_id
    limited_tokens_reshaped = limited_tokens_reshaped[torch.randperm(limited_tokens_reshaped.shape[0])]
    torch.save(limited_tokens_reshaped, "/workspace/data/c4_code_2e5_tokens_reshaped.pt")

    print(f"Saved {limited_tokens_reshaped.shape[0]}")
else:
    #data = datasets.load_from_disk("/workspace/data/c4_code_tokenized_2b.hf")
    all_tokens = torch.load("/workspace/data/c4_code_2e5_limited_tokens_reshaped.pt")
    all_tokens = shuffle_data(all_tokens)

Shuffled data


In [27]:
all_tokens = shuffle_data(all_tokens[:num_tokens,])
print(all_tokens.shape)

Shuffled data
torch.Size([40000, 128])


In [28]:
from torch.utils.data import Dataset, DataLoader

In [29]:
class TokenDataset(Dataset):
    def __init__(self, tokens, max_tokens=int(2e5)):
        self.tokens = tokens[:max_tokens]

    def __len__(self):
        return len(self.tokens)

    def __getitem__(self, idx):
        return self.tokens[idx]

# Load your data
all_tokens = all_tokens.to(device)  # Move the data to the appropriate device

# Create dataset and dataloader
dataset = TokenDataset(all_tokens)
dataloader = DataLoader(dataset, batch_size=cfg["batch_size"], shuffle=True)

# Autoencoder Class

In [30]:
SAVE_DIR = Path("/workspace/checkpoints")
class AutoEncoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        d_hidden = cfg["dict_size"]
        l1_coeff = cfg["l1_coeff"]
        dtype = DTYPES[cfg["enc_dtype"]]
        torch.manual_seed(cfg["seed"])
        self.W_enc = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(cfg["act_size"], d_hidden, dtype=dtype)))
        self.W_dec = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(d_hidden, cfg["act_size"], dtype=dtype)))
        self.b_enc = nn.Parameter(torch.zeros(d_hidden, dtype=dtype))
        self.b_dec = nn.Parameter(torch.zeros(cfg["act_size"], dtype=dtype))

        self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)

        self.d_hidden = d_hidden
        self.l1_coeff = l1_coeff

        self.to(cfg["device"])

    def forward(self, x):
        #print(f"x.shape: {x.shape}")
        x_cent = x - self.b_dec
        #print(f"x_cent.shape: {x_cent.shape}")
        acts = F.relu(x_cent @ self.W_enc + self.b_enc)
        #print(f"acts.shape: {acts.shape}")
        x_reconstruct = acts @ self.W_dec + self.b_dec
        #print(f"x_reconstruct.shape: {x_reconstruct.shape}")
        l2_loss = (x_reconstruct.float() - x.float()).pow(2).sum(-1).mean()
        #print(f"l2_loss: {l2_loss}")
        l1_loss = self.l1_coeff * (acts.float().abs().sum())
        #print(f"l1_loss: {l1_loss}")
        loss = l2_loss + l1_loss
        #print(f"loss: {loss}")
        return loss, x_reconstruct, acts, l2_loss, l1_loss

    @torch.no_grad()
    def make_decoder_weights_and_grad_unit_norm(self):
        W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
        W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum(-1, keepdim=True) * W_dec_normed
        self.W_dec.grad -= W_dec_grad_proj
        # Bugfix(?) for ensuring W_dec retains unit norm, this was not there when I trained my original autoencoders.
        self.W_dec.data = W_dec_normed

    def get_version(self):
        version_list = [int(file.name.split(".")[0]) for file in list(SAVE_DIR.iterdir()) if "pt" in str(file)]
        if len(version_list):
            return 1+max(version_list)
        else:
            return 0

    def save(self):
        version = self.get_version()
        torch.save(self.state_dict(), SAVE_DIR/(str(version)+".pt"))
        with open(SAVE_DIR/(str(version)+"_cfg.json"), "w") as f:
            json.dump(cfg, f)
        print("Saved as version", version)

    @classmethod
    def load(cls, version):
        cfg = (json.load(open(SAVE_DIR/(str(version)+"_cfg.json"), "r")))
        pprint.pprint(cfg)
        self = cls(cfg=cfg)
        self.load_state_dict(torch.load(SAVE_DIR/(str(version)+".pt")))
        return self

    @classmethod
    def load_from_hf(cls, version):
        """
        Loads the saved autoencoder from HuggingFace.

        Version is expected to be an int, or "run1" or "run2"

        version 25 is the final checkpoint of the first autoencoder run,
        version 47 is the final checkpoint of the second autoencoder run.
        """
        if version=="run1":
            version = 25
        elif version=="run2":
            version = 47

        cfg = utils.download_file_from_hf("NeelNanda/sparse_autoencoder", f"{version}_cfg.json")
        pprint.pprint(cfg)
        self = cls(cfg=cfg)
        self.load_state_dict(utils.download_file_from_hf("NeelNanda/sparse_autoencoder", f"{version}.pt", force_is_torch=True))
        return self

## Utilities for Training

In [40]:
SEED = cfg["seed"]
GENERATOR = torch.manual_seed(SEED)
DTYPES = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
np.random.seed(SEED)
random.seed(SEED)
torch.set_grad_enabled(True)

n_layers = model.cfg.n_layers
d_model = model.cfg.d_model
# %%
@torch.no_grad()
def get_acts(tokens, batch_size=1024):
    outputs, cache = model.run_with_cache(tokens,
                                          fast_ssm=True,
                                          fast_conv=True,
                                          warn_disabled_hooks=False,
                                          names_filter=cfg["act_name"])
    acts = cache[cfg["act_name"]]
    del outputs, cache
    torch.cuda.empty_cache()
    gc.collect()
    acts = acts.reshape(-1, acts.shape[-1])
    subsample = torch.randperm(acts.shape[0], generator=GENERATOR)[:batch_size]
    subsampled_acts = acts[subsample, :]
    return subsampled_acts, acts
# sub, acts = get_acts(torch.arange(20).reshape(2, 10), batch_size=3)
# sub.shape, acts.shape
# %%

In [32]:
def replacement_hook(pre_linear, hook, encoder):
    pre_linear_reconstr = encoder(pre_linear)[1]
    return pre_linear_reconstr

def mean_ablate_hook(pre_linear, hook):
    pre_linear[:] = pre_linear.mean([0, 1])
    return pre_linear

def zero_ablate_hook(pre_linear, hook):
    pre_linear[:] = 0.
    return pre_linear

@torch.no_grad()
def get_recons_loss(num_batches=5, local_encoder=None):
    if local_encoder is None:
        local_encoder = encoder
    loss_list = []
    for i in range(num_batches):
        tokens = all_tokens[torch.randperm(len(all_tokens))[:cfg["model_batch_size"]]]
        loss = model(tokens, return_type="loss")
        recons_loss = model.run_with_hooks(tokens, return_type="loss",
                                           fast_ssm=True,
                                           fast_conv=True,
                                           fwd_hooks=[(cfg["act_name"], partial(replacement_hook, encoder=local_encoder))])
        # mean_abl_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(cfg["act_name"], mean_ablate_hook)])
        zero_abl_loss = model.run_with_hooks(tokens, return_type="loss",
                                             fast_ssm=True,
                                             fast_conv=True,
                                             fwd_hooks=[(cfg["act_name"], zero_ablate_hook)])
        loss_list.append((loss, recons_loss, zero_abl_loss))
    losses = torch.tensor(loss_list)
    loss, recons_loss, zero_abl_loss = losses.mean(0).tolist()

    print(loss, recons_loss, zero_abl_loss)
    score = ((zero_abl_loss - recons_loss)/(zero_abl_loss - loss))
    print(f"{score:.2%}")
    # print(f"{((zero_abl_loss - mean_abl_loss)/(zero_abl_loss - loss)).item():.2%}")
    return score, loss, recons_loss, zero_abl_loss
# print(get_recons_loss())

# %%
# Frequency
@torch.no_grad()
def get_freqs(num_batches=25, local_encoder=None):
    if local_encoder is None:
        local_encoder = encoder
    act_freq_scores = torch.zeros(local_encoder.d_hidden, dtype=torch.float32).to(cfg["device"])
    total = 0
    for i in tqdm(range(num_batches)):
        tokens = all_tokens[torch.randperm(len(all_tokens))[:cfg["model_batch_size"]]]

        _, cache = model.run_with_cache(tokens, names_filter=cfg["act_name"], fast_ssm=True, fast_conv=True, warn_disabled_hooks=False)
        acts = cache[cfg["act_name"]]
        acts = acts.reshape(-1, cfg["act_size"])

        hidden = local_encoder(acts)[2]

        act_freq_scores += (hidden > 0).sum(0)
        total+=hidden.shape[0]
    act_freq_scores /= total
    num_dead = (act_freq_scores==0).float().mean()
    print("Num dead", num_dead)
    return act_freq_scores
# %%
@torch.no_grad()
def re_init(indices, encoder):
    new_W_enc = (torch.nn.init.kaiming_uniform_(torch.zeros_like(encoder.W_enc)))
    new_W_dec = (torch.nn.init.kaiming_uniform_(torch.zeros_like(encoder.W_dec)))
    new_b_enc = (torch.zeros_like(encoder.b_enc))
    print(new_W_dec.shape, new_W_enc.shape, new_b_enc.shape)
    encoder.W_enc.data[:, indices] = new_W_enc[:, indices]
    encoder.W_dec.data[indices, :] = new_W_dec[indices, :]
    encoder.b_enc.data[indices] = new_b_enc[indices]

In [33]:
lim_tokens = all_tokens[:32]
print(lim_tokens.shape)

outputs, MambaActs = model.run_with_cache(
                                        lim_tokens,
                                        fast_ssm=True,
                                        fast_conv=True,
                                        warn_disabled_hooks=False
                                        )

acts = MambaActs[cfg["act_name"]]
print(acts.shape)

torch.Size([32, 128])
torch.Size([32, 128, 768])


In [34]:
del outputs, MambaActs
torch.cuda.empty_cache()
gc.collect()

0

## Training Run

In [35]:
# %%
encoder = AutoEncoder(cfg)

In [None]:
# Wandb Args:
wandb_args = {
  "wandb_project": "mamba-sae",
  "wandb_name": None,
}

try:
    wandb.init(
        project=wandb_args["wandb_project"],
        name=wandb_args["wandb_name"],
    )

    num_batches = len(dataloader)
    print(f"Number of tokens: {cfg['num_tokens']}")
    print(f"Batch size: {cfg['batch_size']}")
    print(f"Number of batches: {num_batches}")

    encoder_optim = torch.optim.Adam(encoder.parameters(), lr=cfg["lr"], betas=(cfg["beta1"], cfg["beta2"]))
    recons_scores = []
    act_freq_scores_list = []

    for epoch in range(cfg["num_epochs"]):  # Add an outer loop for epochs if needed
        for i, tokens in enumerate(tqdm(dataloader)):
            tokens = tokens.to(device)

            sub_acts = get_acts(tokens, 16)[1]
            acts = sub_acts

            loss, x_reconstruct, mid_acts, l2_loss, l1_loss = encoder(acts)
            loss.backward()

            encoder.make_decoder_weights_and_grad_unit_norm()
            encoder_optim.step()
            encoder_optim.zero_grad()

            loss_dict = {"loss": loss.item(), "l2_loss": l2_loss.item(), "l1_loss": l1_loss.item()}
            del loss, x_reconstruct, mid_acts, l2_loss, l1_loss, acts
            torch.cuda.empty_cache()
            gc.collect()

            """
            print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
            print("torch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024))
            print("torch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(0)/1024/1024/1024))
            """

            if (i + 1) % cfg["log_every"] == 0:
                wandb.log(loss_dict)
                print(loss_dict)

            if (i + 1) % cfg["recons_every"] == 0:
                x = get_recons_loss(local_encoder=encoder)
                print("Reconstruction:", x)
                recons_scores.append(x[0])
                freqs = get_freqs(5, local_encoder=encoder)
                act_freq_scores_list.append(freqs)
                wandb.log({
                    "recons_score": x[0],
                    "dead": (freqs==0).float().mean().item(),
                    "below_1e-6": (freqs<1e-6).float().mean().item(),
                    "below_1e-5": (freqs<1e-5).float().mean().item(),
                })

            if (i + 1) % cfg["save_every"] == 0:
                encoder.save()
                wandb.log({"reset_neurons": 0.0})
                freqs = get_freqs(50, local_encoder=encoder)
                to_be_reset = (freqs < cfg["reset_freq_threshold"])
                print("Resetting neurons!", to_be_reset.sum())
                re_init(to_be_reset, encoder)

finally:
    encoder.save()

VBox(children=(Label(value='0.013 MB of 0.013 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

Number of tokens: 40000
Batch size: 32
Number of batches: 1250


  8%|▊         | 100/1250 [01:15<14:28,  1.32it/s]

{'loss': 204.88198852539062, 'l2_loss': 50.08789825439453, 'l1_loss': 154.79409790039062}


 16%|█▌        | 200/1250 [02:31<13:06,  1.33it/s]

{'loss': 178.07774353027344, 'l2_loss': 38.30793762207031, 'l1_loss': 139.76980590820312}


 24%|██▍       | 300/1250 [03:46<11:55,  1.33it/s]

{'loss': 170.08145141601562, 'l2_loss': 34.49451446533203, 'l1_loss': 135.58694458007812}


 32%|███▏      | 400/1250 [05:01<10:32,  1.34it/s]

{'loss': 173.08322143554688, 'l2_loss': 35.372093200683594, 'l1_loss': 137.71112060546875}


 40%|███▉      | 499/1250 [06:15<09:24,  1.33it/s]

{'loss': 164.3024139404297, 'l2_loss': 32.7568473815918, 'l1_loss': 131.54556274414062}
9.475272178649902 9.502836227416992 10.825362205505371
97.96%
Reconstruction: (0.9795835476014217, 9.475272178649902, 9.502836227416992, 10.825362205505371)



  0%|          | 0/5 [00:00<?, ?it/s][A
 20%|██        | 1/5 [00:00<00:01,  2.45it/s][A
 60%|██████    | 3/5 [00:00<00:00,  3.78it/s][A
 80%|████████  | 4/5 [00:01<00:00,  3.18it/s][A
100%|██████████| 5/5 [00:01<00:00,  3.05it/s]


Num dead tensor(0.3737, device='cuda:0')
Saved as version 44



  0%|          | 0/50 [00:00<?, ?it/s][A
  4%|▍         | 2/50 [00:00<00:10,  4.79it/s][A
  6%|▌         | 3/50 [00:00<00:13,  3.54it/s][A
  8%|▊         | 4/50 [00:01<00:15,  3.01it/s][A
 10%|█         | 5/50 [00:01<00:15,  2.84it/s][A
 12%|█▏        | 6/50 [00:02<00:16,  2.71it/s][A
 14%|█▍        | 7/50 [00:02<00:16,  2.65it/s][A
 16%|█▌        | 8/50 [00:02<00:15,  2.63it/s][A
 18%|█▊        | 9/50 [00:03<00:15,  2.57it/s][A
 20%|██        | 10/50 [00:03<00:15,  2.57it/s][A
 22%|██▏       | 11/50 [00:03<00:15,  2.56it/s][A
 24%|██▍       | 12/50 [00:04<00:14,  2.55it/s][A
 26%|██▌       | 13/50 [00:04<00:14,  2.55it/s][A
 28%|██▊       | 14/50 [00:05<00:14,  2.53it/s][A
 30%|███       | 15/50 [00:05<00:13,  2.53it/s][A
 32%|███▏      | 16/50 [00:05<00:13,  2.51it/s][A
 34%|███▍      | 17/50 [00:06<00:13,  2.53it/s][A
 36%|███▌      | 18/50 [00:06<00:12,  2.53it/s][A
 38%|███▊      | 19/50 [00:07<00:12,  2.52it/s][A
 40%|████      | 20/50 [00:07<00:11,  2.53it/s]

Num dead 

 40%|████      | 500/1250 [06:47<2:05:50, 10.07s/it]

tensor(0.1094, device='cuda:0')
Resetting neurons! tensor(168, device='cuda:0')
torch.Size([1536, 768]) torch.Size([768, 1536]) torch.Size([1536])


 48%|████▊     | 600/1250 [08:02<08:05,  1.34it/s]

{'loss': 166.66751098632812, 'l2_loss': 34.73277282714844, 'l1_loss': 131.9347381591797}


 56%|█████▌    | 700/1250 [09:17<06:53,  1.33it/s]

{'loss': 163.65097045898438, 'l2_loss': 34.84036636352539, 'l1_loss': 128.81060791015625}


 64%|██████▍   | 800/1250 [10:32<05:38,  1.33it/s]

{'loss': 163.880126953125, 'l2_loss': 32.230712890625, 'l1_loss': 131.6494140625}


 72%|███████▏  | 900/1250 [11:48<04:22,  1.34it/s]

{'loss': 163.4452362060547, 'l2_loss': 31.602737426757812, 'l1_loss': 131.84249877929688}


 80%|███████▉  | 999/1250 [13:02<03:08,  1.33it/s]

{'loss': 166.22894287109375, 'l2_loss': 34.34717559814453, 'l1_loss': 131.88177490234375}
9.441526412963867 9.511152267456055 10.825362205505371
94.97%
Reconstruction: (0.9496863321013579, 9.441526412963867, 9.511152267456055, 10.825362205505371)



  0%|          | 0/5 [00:00<?, ?it/s][A
 40%|████      | 2/5 [00:00<00:00,  4.51it/s][A
 60%|██████    | 3/5 [00:00<00:00,  3.47it/s][A
 80%|████████  | 4/5 [00:01<00:00,  3.00it/s][A
100%|██████████| 5/5 [00:01<00:00,  3.08it/s]


Num dead tensor(0.6139, device='cuda:0')
Saved as version 45



  0%|          | 0/50 [00:00<?, ?it/s][A
  4%|▍         | 2/50 [00:00<00:10,  4.76it/s][A
  6%|▌         | 3/50 [00:00<00:13,  3.52it/s][A
  8%|▊         | 4/50 [00:01<00:15,  3.04it/s][A
 10%|█         | 5/50 [00:01<00:15,  2.86it/s][A
 12%|█▏        | 6/50 [00:02<00:16,  2.73it/s][A
 14%|█▍        | 7/50 [00:02<00:16,  2.67it/s][A
 16%|█▌        | 8/50 [00:02<00:15,  2.64it/s][A
 18%|█▊        | 9/50 [00:03<00:15,  2.59it/s][A
 20%|██        | 10/50 [00:03<00:15,  2.57it/s][A
 22%|██▏       | 11/50 [00:03<00:15,  2.56it/s][A
 24%|██▍       | 12/50 [00:04<00:14,  2.54it/s][A
 26%|██▌       | 13/50 [00:04<00:14,  2.55it/s][A
 28%|██▊       | 14/50 [00:05<00:14,  2.55it/s][A
 30%|███       | 15/50 [00:05<00:13,  2.55it/s][A
 32%|███▏      | 16/50 [00:05<00:13,  2.54it/s][A
 34%|███▍      | 17/50 [00:06<00:13,  2.53it/s][A
 36%|███▌      | 18/50 [00:06<00:12,  2.52it/s][A
 38%|███▊      | 19/50 [00:07<00:12,  2.53it/s][A
 40%|████      | 20/50 [00:07<00:11,  2.52it/s]

Num dead 

 80%|████████  | 1000/1250 [13:33<41:45, 10.02s/it]

tensor(0.4264, device='cuda:0')
Resetting neurons! tensor(655, device='cuda:0')
torch.Size([1536, 768]) torch.Size([768, 1536]) torch.Size([1536])


 88%|████████▊ | 1100/1250 [14:48<01:52,  1.34it/s]

{'loss': 167.40994262695312, 'l2_loss': 36.55908966064453, 'l1_loss': 130.85086059570312}


 96%|█████████▌| 1200/1250 [16:03<00:37,  1.32it/s]

{'loss': 163.40879821777344, 'l2_loss': 35.63167953491211, 'l1_loss': 127.77711486816406}


100%|██████████| 1250/1250 [16:41<00:00,  1.25it/s]
  8%|▊         | 100/1250 [01:14<14:14,  1.35it/s]

{'loss': 158.4322509765625, 'l2_loss': 30.363689422607422, 'l1_loss': 128.0685577392578}


 16%|█▌        | 200/1250 [02:29<13:09,  1.33it/s]

{'loss': 158.50927734375, 'l2_loss': 32.23187255859375, 'l1_loss': 126.27740478515625}


 24%|██▍       | 300/1250 [03:44<11:50,  1.34it/s]

{'loss': 157.81356811523438, 'l2_loss': 31.87213134765625, 'l1_loss': 125.9414291381836}


 32%|███▏      | 400/1250 [04:59<10:37,  1.33it/s]

{'loss': 164.72000122070312, 'l2_loss': 35.16238784790039, 'l1_loss': 129.5576171875}


 40%|███▉      | 499/1250 [06:13<09:20,  1.34it/s]

{'loss': 157.7126007080078, 'l2_loss': 32.61067581176758, 'l1_loss': 125.1019287109375}
9.3403902053833 9.459783554077148 10.825362205505371
91.96%
Reconstruction: (0.9195989226167005, 9.3403902053833, 9.459783554077148, 10.825362205505371)



  0%|          | 0/5 [00:00<?, ?it/s][A
 40%|████      | 2/5 [00:00<00:00,  4.45it/s][A
 60%|██████    | 3/5 [00:00<00:00,  3.44it/s][A
 80%|████████  | 4/5 [00:01<00:00,  2.99it/s][A
100%|██████████| 5/5 [00:01<00:00,  3.05it/s]


Num dead tensor(0.4694, device='cuda:0')
Saved as version 46



  0%|          | 0/50 [00:00<?, ?it/s][A
  4%|▍         | 2/50 [00:00<00:10,  4.64it/s][A
  6%|▌         | 3/50 [00:00<00:13,  3.48it/s][A
  8%|▊         | 4/50 [00:01<00:15,  3.02it/s][A
 10%|█         | 5/50 [00:01<00:15,  2.84it/s][A
 12%|█▏        | 6/50 [00:02<00:16,  2.72it/s][A
 14%|█▍        | 7/50 [00:02<00:16,  2.65it/s][A
 16%|█▌        | 8/50 [00:02<00:16,  2.62it/s][A
 18%|█▊        | 9/50 [00:03<00:15,  2.59it/s][A
 20%|██        | 10/50 [00:03<00:15,  2.57it/s][A
 22%|██▏       | 11/50 [00:03<00:15,  2.56it/s][A
 24%|██▍       | 12/50 [00:04<00:14,  2.54it/s][A
 26%|██▌       | 13/50 [00:04<00:14,  2.54it/s][A
 28%|██▊       | 14/50 [00:05<00:14,  2.53it/s][A
 30%|███       | 15/50 [00:05<00:13,  2.53it/s][A
 32%|███▏      | 16/50 [00:05<00:13,  2.51it/s][A
 34%|███▍      | 17/50 [00:06<00:13,  2.53it/s][A
 36%|███▌      | 18/50 [00:06<00:12,  2.52it/s][A
 38%|███▊      | 19/50 [00:07<00:12,  2.52it/s][A
 40%|████      | 20/50 [00:07<00:11,  2.53it/s]

Num dead 

 40%|████      | 500/1250 [06:45<2:05:34, 10.05s/it]

tensor(0.2292, device='cuda:0')
Resetting neurons! tensor(352, device='cuda:0')
torch.Size([1536, 768]) torch.Size([768, 1536]) torch.Size([1536])


 48%|████▊     | 600/1250 [08:00<08:09,  1.33it/s]

{'loss': 159.95684814453125, 'l2_loss': 33.66575241088867, 'l1_loss': 126.29109954833984}


 56%|█████▌    | 700/1250 [09:14<06:50,  1.34it/s]

{'loss': 157.31640625, 'l2_loss': 32.18115997314453, 'l1_loss': 125.13524627685547}


 64%|██████▍   | 800/1250 [10:29<05:36,  1.34it/s]

{'loss': 156.83468627929688, 'l2_loss': 33.05988311767578, 'l1_loss': 123.77479553222656}


 72%|███████▏  | 900/1250 [11:44<04:22,  1.33it/s]

{'loss': 157.30479431152344, 'l2_loss': 31.66326332092285, 'l1_loss': 125.64152526855469}


 80%|███████▉  | 999/1250 [12:58<03:08,  1.33it/s]

{'loss': 157.253662109375, 'l2_loss': 32.628509521484375, 'l1_loss': 124.62516021728516}
9.686169624328613 9.614639282226562 10.825362205505371
106.28%
Reconstruction: (1.06279038617699, 9.686169624328613, 9.614639282226562, 10.825362205505371)



  0%|          | 0/5 [00:00<?, ?it/s][A
 40%|████      | 2/5 [00:00<00:00,  4.34it/s][A
 60%|██████    | 3/5 [00:00<00:00,  3.41it/s][A
 80%|████████  | 4/5 [00:01<00:00,  2.97it/s][A
100%|██████████| 5/5 [00:01<00:00,  3.04it/s]


Num dead tensor(0.4941, device='cuda:0')
Saved as version 47



  0%|          | 0/50 [00:00<?, ?it/s][A
  4%|▍         | 2/50 [00:00<00:10,  4.60it/s][A
  6%|▌         | 3/50 [00:00<00:13,  3.45it/s][A
  8%|▊         | 4/50 [00:01<00:15,  3.02it/s][A
 10%|█         | 5/50 [00:01<00:15,  2.83it/s][A
 12%|█▏        | 6/50 [00:02<00:16,  2.73it/s][A
 14%|█▍        | 7/50 [00:02<00:16,  2.66it/s][A
 16%|█▌        | 8/50 [00:02<00:16,  2.62it/s][A
 18%|█▊        | 9/50 [00:03<00:15,  2.58it/s][A
 20%|██        | 10/50 [00:03<00:15,  2.57it/s][A
 22%|██▏       | 11/50 [00:03<00:15,  2.55it/s][A
 24%|██▍       | 12/50 [00:04<00:14,  2.54it/s][A
 26%|██▌       | 13/50 [00:04<00:14,  2.54it/s][A
 28%|██▊       | 14/50 [00:05<00:14,  2.54it/s][A
 30%|███       | 15/50 [00:05<00:13,  2.54it/s][A
 32%|███▏      | 16/50 [00:05<00:13,  2.52it/s][A
 34%|███▍      | 17/50 [00:06<00:13,  2.52it/s][A
 36%|███▌      | 18/50 [00:06<00:12,  2.52it/s][A
 38%|███▊      | 19/50 [00:07<00:12,  2.53it/s][A
 40%|████      | 20/50 [00:07<00:11,  2.52it/s]

Num dead 

 80%|████████  | 1000/1250 [13:30<41:52, 10.05s/it]

tensor(0.1348, device='cuda:0')
Resetting neurons! tensor(207, device='cuda:0')
torch.Size([1536, 768]) torch.Size([768, 1536]) torch.Size([1536])


 88%|████████▊ | 1100/1250 [14:45<01:52,  1.33it/s]

{'loss': 157.2391357421875, 'l2_loss': 32.00498962402344, 'l1_loss': 125.23414611816406}


 96%|█████████▌| 1200/1250 [15:59<00:37,  1.34it/s]

{'loss': 156.4464111328125, 'l2_loss': 31.889694213867188, 'l1_loss': 124.55672454833984}


100%|██████████| 1250/1250 [16:37<00:00,  1.25it/s]
  8%|▊         | 100/1250 [01:14<14:11,  1.35it/s]

{'loss': 157.82537841796875, 'l2_loss': 32.42594528198242, 'l1_loss': 125.3994369506836}


 16%|█▌        | 200/1250 [02:28<12:59,  1.35it/s]

{'loss': 158.70159912109375, 'l2_loss': 34.894020080566406, 'l1_loss': 123.80758666992188}


 24%|██▍       | 300/1250 [03:42<11:44,  1.35it/s]

{'loss': 155.95082092285156, 'l2_loss': 36.01087188720703, 'l1_loss': 119.93994903564453}


 32%|███▏      | 400/1250 [04:56<10:28,  1.35it/s]

{'loss': 157.43365478515625, 'l2_loss': 32.501319885253906, 'l1_loss': 124.93234252929688}


 40%|███▉      | 499/1250 [06:10<09:21,  1.34it/s]

{'loss': 154.39874267578125, 'l2_loss': 30.551908493041992, 'l1_loss': 123.84683227539062}
9.519570350646973 9.540260314941406 10.825362205505371
98.42%
Reconstruction: (0.9841552356009471, 9.519570350646973, 9.540260314941406, 10.825362205505371)



  0%|          | 0/5 [00:00<?, ?it/s][A
 40%|████      | 2/5 [00:00<00:00,  4.39it/s][A
 60%|██████    | 3/5 [00:00<00:00,  3.45it/s][A
 80%|████████  | 4/5 [00:01<00:00,  2.99it/s][A
100%|██████████| 5/5 [00:01<00:00,  3.04it/s]


Num dead tensor(0.4141, device='cuda:0')
Saved as version 48



  0%|          | 0/50 [00:00<?, ?it/s][A
  4%|▍         | 2/50 [00:00<00:10,  4.68it/s][A
  6%|▌         | 3/50 [00:00<00:13,  3.46it/s][A
  8%|▊         | 4/50 [00:01<00:15,  3.01it/s][A
 10%|█         | 5/50 [00:01<00:15,  2.84it/s][A
 12%|█▏        | 6/50 [00:02<00:16,  2.73it/s][A
 14%|█▍        | 7/50 [00:02<00:16,  2.66it/s][A
 16%|█▌        | 8/50 [00:02<00:15,  2.63it/s][A
 18%|█▊        | 9/50 [00:03<00:15,  2.58it/s][A
 20%|██        | 10/50 [00:03<00:15,  2.57it/s][A
 22%|██▏       | 11/50 [00:03<00:15,  2.56it/s][A
 24%|██▍       | 12/50 [00:04<00:15,  2.53it/s][A
 26%|██▌       | 13/50 [00:04<00:14,  2.54it/s][A
 28%|██▊       | 14/50 [00:05<00:14,  2.54it/s][A
 30%|███       | 15/50 [00:05<00:13,  2.52it/s][A
 32%|███▏      | 16/50 [00:05<00:13,  2.52it/s][A
 34%|███▍      | 17/50 [00:06<00:13,  2.52it/s][A
 36%|███▌      | 18/50 [00:06<00:12,  2.53it/s][A
 38%|███▊      | 19/50 [00:07<00:12,  2.53it/s][A
 40%|████      | 20/50 [00:07<00:11,  2.53it/s]

Num dead 

 40%|████      | 500/1250 [06:41<2:05:39, 10.05s/it]

tensor(0.0905, device='cuda:0')
Resetting neurons! tensor(139, device='cuda:0')
torch.Size([1536, 768]) torch.Size([768, 1536]) torch.Size([1536])


 48%|████▊     | 600/1250 [07:56<08:05,  1.34it/s]

{'loss': 154.52767944335938, 'l2_loss': 30.578054428100586, 'l1_loss': 123.94963073730469}


 56%|█████▌    | 700/1250 [09:10<06:50,  1.34it/s]

{'loss': 151.96009826660156, 'l2_loss': 29.811382293701172, 'l1_loss': 122.14871978759766}


 64%|██████▍   | 800/1250 [10:25<05:35,  1.34it/s]

{'loss': 153.05035400390625, 'l2_loss': 31.76234245300293, 'l1_loss': 121.28801727294922}


 72%|███████▏  | 900/1250 [11:39<04:19,  1.35it/s]

{'loss': 152.17774963378906, 'l2_loss': 32.47003173828125, 'l1_loss': 119.70771789550781}


 80%|███████▉  | 999/1250 [12:53<03:06,  1.35it/s]

{'loss': 155.60337829589844, 'l2_loss': 32.079833984375, 'l1_loss': 123.52354431152344}
9.653242111206055 9.577951431274414 10.825362205505371
106.42%
Reconstruction: (1.064234612389824, 9.653242111206055, 9.577951431274414, 10.825362205505371)



  0%|          | 0/5 [00:00<?, ?it/s][A
 40%|████      | 2/5 [00:00<00:00,  4.44it/s][A
 60%|██████    | 3/5 [00:00<00:00,  3.47it/s][A
 80%|████████  | 4/5 [00:01<00:00,  3.00it/s][A
100%|██████████| 5/5 [00:01<00:00,  3.06it/s]


Num dead tensor(0.0846, device='cuda:0')
Saved as version 49



  0%|          | 0/50 [00:00<?, ?it/s][A
  4%|▍         | 2/50 [00:00<00:10,  4.68it/s][A
  6%|▌         | 3/50 [00:00<00:13,  3.48it/s][A
  8%|▊         | 4/50 [00:01<00:15,  3.04it/s][A
 10%|█         | 5/50 [00:01<00:15,  2.85it/s][A
 12%|█▏        | 6/50 [00:02<00:16,  2.72it/s][A
 14%|█▍        | 7/50 [00:02<00:16,  2.67it/s][A
 16%|█▌        | 8/50 [00:02<00:15,  2.63it/s][A
 18%|█▊        | 9/50 [00:03<00:15,  2.58it/s][A
 20%|██        | 10/50 [00:03<00:15,  2.57it/s][A
 22%|██▏       | 11/50 [00:03<00:15,  2.56it/s][A
 24%|██▍       | 12/50 [00:04<00:14,  2.55it/s][A
 26%|██▌       | 13/50 [00:04<00:14,  2.54it/s][A
 28%|██▊       | 14/50 [00:05<00:14,  2.53it/s][A
 30%|███       | 15/50 [00:05<00:13,  2.54it/s][A
 32%|███▏      | 16/50 [00:05<00:13,  2.53it/s][A
 34%|███▍      | 17/50 [00:06<00:12,  2.54it/s][A
 36%|███▌      | 18/50 [00:06<00:12,  2.54it/s][A
 38%|███▊      | 19/50 [00:07<00:12,  2.52it/s][A
 40%|████      | 20/50 [00:07<00:11,  2.53it/s]

Num dead 

 80%|████████  | 1000/1250 [13:25<41:52, 10.05s/it]

tensor(0.0918, device='cuda:0')
Resetting neurons! tensor(141, device='cuda:0')
torch.Size([1536, 768]) torch.Size([768, 1536]) torch.Size([1536])


 88%|████████▊ | 1100/1250 [14:39<01:51,  1.35it/s]

{'loss': 151.1935577392578, 'l2_loss': 29.00714111328125, 'l1_loss': 122.18641662597656}


 96%|█████████▌| 1200/1250 [15:54<00:37,  1.34it/s]

{'loss': 153.392333984375, 'l2_loss': 30.90787124633789, 'l1_loss': 122.48445892333984}


100%|██████████| 1250/1250 [16:31<00:00,  1.26it/s]
  8%|▊         | 100/1250 [01:14<14:12,  1.35it/s]

{'loss': 155.57455444335938, 'l2_loss': 30.14177703857422, 'l1_loss': 125.43277740478516}


 12%|█▏        | 153/1250 [01:53<13:33,  1.35it/s]