In [1]:
from multiprocessing import cpu_count

from sae.anthropic import SAEConfig, SAEPLDataset, SAEPLModel
from sae.hooks import RecordingHookPoint

import torch
from torch import nn
from torch.nn import functional as F
from torchvision import datasets, transforms
import lightning as L

In [2]:
%load_ext jaxtyping
%jaxtyping.typechecker typeguard.typechecked

In [3]:
# following taken from https://github.com/pytorch/examples/blob/main/mnist/main.py

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

In [4]:
def train(model, device, train_loader, optimizer, epoch, log_interval):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print(
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                    epoch,
                    batch_idx * len(data),
                    len(train_loader.dataset),
                    100.0 * batch_idx / len(train_loader),
                    loss.item(),
                ),
                end="\r",
            )

In [5]:
torch.manual_seed(42)

<torch._C.Generator at 0x7efb8da9d390>

In [6]:
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)

In [7]:
dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform)
dataset2 = datasets.MNIST("../data", train=False, transform=transform)

In [8]:
train_kwargs = {
    "batch_size": 64,
    "num_workers": cpu_count(),
    "shuffle": True,
}

In [9]:
train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)

In [10]:
device = torch.device("cuda")
model = Net().to(device)

In [11]:
epochs = 100

for epoch in range(1, epochs + 1):
    train(model, device, train_loader, torch.optim.Adam(model.parameters()), epoch, 100)



In [12]:
hook_fc1 = RecordingHookPoint(model, "fc1")
hook_fc2 = RecordingHookPoint(model, "fc2")

train_loader = torch.utils.data.DataLoader(
    dataset1, batch_size=128, num_workers=cpu_count(), shuffle=False
)
test_loader = torch.utils.data.DataLoader(
    dataset2, batch_size=128, num_workers=cpu_count(), shuffle=False
)

with torch.no_grad():
    for img, _ in train_loader:
        img = img.to(device)
        model(img)
        del img
        torch.cuda.empty_cache()

train_fc1_activations = torch.cat(hook_fc1.activation_store)
train_fc2_activations = torch.cat(hook_fc2.activation_store)

hook_fc1.reset_activation_store()
hook_fc2.reset_activation_store()

In [13]:
with torch.no_grad():
    for img, _ in test_loader:
        img = img.to(device)
        model(img)
        del img
        torch.cuda.empty_cache()

test_fc1_activations = torch.cat(hook_fc1.activation_store)
test_fc2_activations = torch.cat(hook_fc2.activation_store)

hook_fc1.close()
hook_fc2.close()

del hook_fc1
del hook_fc2

In [14]:
sae_config = SAEConfig(
    input_dim=train_fc1_activations.size(1), latent_dim=2**15, batch_size=512, sparsity_coefficient=0.2
)

In [15]:
sae_fc1 = SAEPLModel(sae_config)

In [16]:
trainer = L.Trainer(
    max_steps=100000,
    accelerator="gpu",
    logger=False,
    enable_checkpointing=False,
    gradient_clip_val=1.0,
    gradient_clip_algorithm="norm",
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [17]:
ds = SAEPLDataset(
    torch.cat([train_fc1_activations, test_fc1_activations]), sae_config
)

In [18]:
trainer.fit(sae_fc1, datamodule=ds)

You are using a CUDA device ('NVIDIA GeForce RTX 3090 Ti') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type | Params | Mode 
--------------------------------------
0 | sae  | SAE  | 8.4 M  | train
--------------------------------------
8.4 M     Trainable params
0         Non-trainable params
8.4 M     Total params
33.686    Total estimated model params size (MB)
1         Modules in train mode
0         Modules in eval mode


Epoch 0:   2%|▏         | 2/124 [00:00<00:12,  9.50it/s, train/loss=0.957, train/loss_mse=0.957, train/loss_sparsity=0.000, lr=5e-5, sparsity_coefficient=4e-5, dead_neurons=0.000, train/firing_rate=1.91e+4] 



Epoch 806:  45%|████▌     | 56/124 [00:00<00:00, 144.47it/s, train/loss=0.657, train/loss_mse=0.657, train/loss_sparsity=0.000942, lr=2.5e-9, sparsity_coefficient=0.200, dead_neurons=30426.0, train/firing_rate=0.102, val/loss=0.649, val/loss_mse=0.646, val/loss_sparsity=0.00284]    

`Trainer.fit` stopped: `max_steps=100000` reached.


Epoch 806:  45%|████▌     | 56/124 [00:00<00:00, 144.24it/s, train/loss=0.657, train/loss_mse=0.657, train/loss_sparsity=0.000942, lr=2.5e-9, sparsity_coefficient=0.200, dead_neurons=30426.0, train/firing_rate=0.102, val/loss=0.649, val/loss_mse=0.646, val/loss_sparsity=0.00284]

