In [6]:
#----- imports --------

import tqdm
import torch
from torch import nn
from torch.nn import functional as F
import wandb
import os
import tokenizers
from matplotlib import pyplot as plt
import numpy as np
import json
import random
import tqdm


device= 'cuda' if torch.cuda.is_available() else 'cpu'
torch.set_default_device(device)
assert device == 'cuda', "This notebook is not optimized for CPU"

config = {
    "learning_rate": 1e-3,
    # 'sae_size': 2**18,
     'sae_size': 2**14,
    "sae_learning_rate": 5e-5,
    "sae_sparsity_penalty": 250,
    "model_embedding_layer": 6,
    "eval_interval": 500,
    "max_iters": 60000, 
    "H": 32, # hidden dimension size
    "B": 64,
    "T": 256,
    "C": 256,
    "feedforward_factor": 3,
    "n_heads": 8,
    "n_layers": 12,
    "tokenizer_vocab_size": 2**13,
    "git_hash": os.popen("git rev-parse HEAD").read().strip()
}

# initial
for k,v in config.items():
    locals ()[k] = v


# wandb.init(
#    project = "scaling-monosemanticity",
#    config = config,
# )

VBox(children=(Label(value='0.004 MB of 0.012 MB uploaded\r'), FloatProgress(value=0.36583816405939806, max=1.…

In [7]:
class SparseAutoEncoder(nn.Module):
    def __init__(self, activations_dim, sparse_dim):
        super().__init__()
        self.activations_dim = activations_dim
        encoder_weight = torch.randn(activations_dim, sparse_dim)
        decoder_weight = torch.randn(sparse_dim, activations_dim)
        self.encoder_bias = nn.Parameter(torch.zeros(sparse_dim))
        self.decoder_bias = nn.Parameter(torch.zeros(activations_dim))
        self.sparse_dim = sparse_dim
        self.sparsity_penalty = sae_sparsity_penalty

        # set the encoder_weight to have the activations dim to be normalized to have l2 norm randomly between 0.05 and 1
        direction_lengths = torch.rand(sparse_dim) * 0.95 + 0.05
        # normalize the encoder_weight along columns (dim -2) to have l2 norm of 1
        encoder_weight = F.normalize(encoder_weight, p=2, dim=0)
        # multiply the column norms by the direction_lengths
        encoder_weight = encoder_weight * direction_lengths
        # make the decoder weight be the transpose of the encoder_weight
        decoder_weight = torch.transpose(encoder_weight, 0, 1)

        self.encoder_weight = nn.Parameter(encoder_weight)
        self.decoder_weight = nn.Parameter(decoder_weight)



    def forward(self, x):
        # preprocessing normalization
        # now on average any embedding has euclidian length 1

        encoded = F.relu(x @ self.encoder_weight + self.encoder_bias) # all act. are positive
        decoded = encoded @ self.decoder_weight + self.decoder_bias

        reconstruction_l2_loss = F.mse_loss(x, decoded)

        # every row in the tall decoder matrix
        # is the "sum" of the total influence of a feature on the output
        # the l2 norm of that row is the "influence" of that feature on that output
        # calculate that, store as row
        decoder_l2 = torch.linalg.norm(self.decoder_weight, dim=-1)
        # the feature activation is the sparse activation * it's influence on output
        feature_activations = (encoded) * decoder_l2
        # sum of feature activations
        # divide by the batch size * sequence length
        # should work if there is no batch dimension
        if x.ndim == 3:
            batch_dim, sequence_dim, _ = x.shape
        elif x.ndim == 2:
            batch_dim = 1
            sequence_dim, _ = x.shape
        elif x.ndim == 1:
            batch_dim = 1
            sequence_dim = 1
        else:
            raise ValueError(f"x has {x.ndim} dimensions, but it should have 1, 2, or 3")
        
        sparsity_loss = torch.sum(feature_activations) * self.sparsity_penalty / (batch_dim * sequence_dim * self.sparse_dim)

        total_loss = reconstruction_l2_loss + sparsity_loss

        return {"encoded": encoded, "decoded": decoded, 'feature_activations': feature_activations, "reconstruction_loss": reconstruction_l2_loss, "sparsity_loss": sparsity_loss, "total_loss": total_loss}





sae = SparseAutoEncoder(C, sae_size)
optimizer = torch.optim.Adam(sae.parameters(), lr=sae_learning_rate)




In [8]:
# Calculate the total number of parameters
total_params = sum(p.numel() for p in sae.parameters() if p.requires_grad)
print(f"Total trainable parameters: {total_params}")

Total trainable parameters: 8405248


In [9]:
def load_tensor(filepath):
    # load the .pt tensor
    tensor = torch.load(filepath)
    tensor = torch.cat(tensor, dim=0)
    tensor = tensor.to(device)
    return tensor
    

In [10]:
@torch.no_grad()
def estimate_sae_loss(eval_iters, tensor):
    sae_loss = 0
    sae_sparsity_loss = 0
    sae_reconstruction_loss = 0
    count = 0
    for i in range(0, eval_iters, B):
        count += 1
        start = i
        end = i+B
        assert tensor.shape[0] >= end, f"too many eval_iters"
        sample = tensor[start:end]
        sae_output = sae(sample)
        sae_loss += sae_output['total_loss'].item()
        sae_sparsity_loss += sae_output['sparsity_loss'].item()
        sae_reconstruction_loss += sae_output['reconstruction_loss'].item()
    avg_loss = sae_loss/count
    avg_sparsity_loss = sae_sparsity_loss/count
    avg_reconstruction_loss = sae_reconstruction_loss/count
    return {"reconstruction_loss": avg_reconstruction_loss, "sparsity_loss": avg_sparsity_loss, "total_loss": avg_loss}
    


estimate_sae_loss(100, load_tensor("residuals/residuals_train_1.pt"))

{'reconstruction_loss': 2926.012451171875,
 'sparsity_loss': 179.74524688720703,
 'total_loss': 3105.7576904296875}

In [11]:
train_filepaths = []
val_filepaths = []
for file in os.listdir(f'residuals'):
    if file.startswith(f"residuals_train"):
        train_filepaths.append(f"residuals/{file}")
    elif file.startswith(f"residuals_val"):
        val_filepaths.append(f"residuals/{file}")



In [12]:
optimizer = torch.optim.Adam(sae.parameters(), lr=sae_learning_rate)
num_epochs = 1
logging_interval = 50000

for epoch in range(num_epochs):
    for filepath in train_filepaths:
        val_residuals_tensor = load_tensor(random.choice(val_filepaths))
        print(f"val loss on next datafile")
        val_data = estimate_sae_loss(1000, val_residuals_tensor)# keys: reconstruction_loss, sparsity_loss, total_loss
        # wandb.log({"val_reconstruction_loss": val_data['reconstruction_loss'], "val_sparsity_loss": val_data['sparsity_loss'], "val_total_loss": val_data['total_loss']})
        del val_residuals_tensor
        residuals_tensor = load_tensor(filepath)
        print(f"train loss on next datafile")
        train_data = estimate_sae_loss(1000, residuals_tensor)# keys: reconstruction_loss, sparsity_loss, total_loss
        # wandb.log({"train_reconstruction_loss": train_data['reconstruction_loss'], "train_sparsity_loss": train_data['sparsity_loss'], "train_total_loss": train_data['total_loss']})
        print(f"training on {filepath}")

        for i in tqdm.tqdm(range(0, residuals_tensor.shape[0]-B, B)):
            start = i
            end = i+B
            assert residuals_tensor.shape[0] >= end, f"too many train samples"
            sample = residuals_tensor[start:end]
            optimizer.zero_grad()
            sae_output = sae(sample)
            sae_reconstruction_loss = sae_output['reconstruction_loss']
            sae_sparsity_loss = sae_output['sparsity_loss']
            total_loss = sae_reconstruction_loss + sae_sparsity_loss
            total_loss.backward()
            optimizer.step()
            if i % logging_interval == 0:
                pass
                # wandb.log({"frequent_reconstruction_loss": sae_reconstruction_loss, "frequent_sparsity_loss": sae_sparsity_loss, "frequent_total_loss": total_loss})
            
# wandb.finish()
torch.save(sae.state_dict(), 'sae_large_model_weights.pth')
            


val loss on next datafile
train loss on next datafile
training on residuals/residuals_train_5.pt


100%|██████████| 51187/51187 [01:49<00:00, 469.56it/s] 


val loss on next datafile
train loss on next datafile
training on residuals/residuals_train_11.pt


100%|██████████| 51187/51187 [01:53<00:00, 452.39it/s] 


val loss on next datafile
train loss on next datafile
training on residuals/residuals_train_6.pt


100%|██████████| 51187/51187 [01:33<00:00, 547.27it/s] 


val loss on next datafile
train loss on next datafile
training on residuals/residuals_train_7.pt


100%|██████████| 51187/51187 [01:32<00:00, 553.85it/s] 


val loss on next datafile
train loss on next datafile
training on residuals/residuals_train_4.pt


100%|██████████| 51187/51187 [01:36<00:00, 528.18it/s] 


val loss on next datafile
train loss on next datafile
training on residuals/residuals_train_10.pt


100%|██████████| 51187/51187 [01:48<00:00, 471.65it/s] 


val loss on next datafile
train loss on next datafile
training on residuals/residuals_train_12.pt


100%|██████████| 44865/44865 [01:23<00:00, 537.28it/s] 


val loss on next datafile
train loss on next datafile
training on residuals/residuals_train_3.pt


100%|██████████| 51187/51187 [01:34<00:00, 539.97it/s] 


val loss on next datafile
train loss on next datafile
training on residuals/residuals_train_2.pt


100%|██████████| 51187/51187 [02:02<00:00, 416.55it/s] 


val loss on next datafile
train loss on next datafile
training on residuals/residuals_train_1.pt


100%|██████████| 51187/51187 [01:40<00:00, 509.24it/s] 


val loss on next datafile
train loss on next datafile
training on residuals/residuals_train_8.pt


 10%|▉         | 5034/51187 [00:12<01:50, 418.13it/s] 


KeyboardInterrupt: 

In [14]:
# Save the model weights (2**14 size)
torch.save(sae.state_dict(), 'sae_14_model_weights.pth')


: 