## SIEs on MNIST MLPs

What happens when we train SAEs on MLPs? With information?

### Setup

In [13]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import tqdm
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import mutual_info_score
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import clear_output
from collections import defaultdict
from itertools import islice
import random
import time
from pathlib import Path
import math

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # Normalize the data
])

train_dataset = MNIST(root='.', train=True, download=True, transform=transform)
test_dataset = MNIST(root='.', train=False, download=True, transform=transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

model = MLP()
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3)
train_losses = []

for epoch in range(5):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())
        if batch_idx % 400 == 0:
            print(f'Epoch {epoch+1}/{5}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}')

Epoch 1/5, Batch 0/938, Loss: 2.3570
Epoch 1/5, Batch 400/938, Loss: 2.2115
Epoch 1/5, Batch 800/938, Loss: 2.0289
Epoch 2/5, Batch 0/938, Loss: 1.9463
Epoch 2/5, Batch 400/938, Loss: 1.6465
Epoch 2/5, Batch 800/938, Loss: 1.2891
Epoch 3/5, Batch 0/938, Loss: 1.1406
Epoch 3/5, Batch 400/938, Loss: 0.7989
Epoch 3/5, Batch 800/938, Loss: 0.6858
Epoch 4/5, Batch 0/938, Loss: 0.6213
Epoch 4/5, Batch 400/938, Loss: 0.7032
Epoch 4/5, Batch 800/938, Loss: 0.5854
Epoch 5/5, Batch 0/938, Loss: 0.6899
Epoch 5/5, Batch 400/938, Loss: 0.6359
Epoch 5/5, Batch 800/938, Loss: 0.3879


In [14]:
fig = go.Figure()
fig.add_trace(go.Scatter(y=train_losses, mode='lines', name='', line=dict(color='darkred', width=2)))
fig.update_layout({'plot_bgcolor': 'rgba(255, 255, 255, 1)',})
fig.update_layout(
    xaxis=dict(showgrid=True, gridwidth=1, gridcolor='LightGray'),
    yaxis=dict(showgrid=True, gridwidth=1, gridcolor='LightGray'),
)
fig.update_xaxes(title_text='Optimization Step')
fig.update_yaxes(title_text='CrossEntropy Loss')
fig.update_layout(width=600, height=400, autosize=False)
fig.show()

### Lessgo!

In [15]:
model

MLP(
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=128, bias=True)
  (fc3): Linear(in_features=128, out_features=10, bias=True)
)

In [16]:
class AutoEncoder(nn.Module):
    
    def __init__(self, hidden_size=1024, device='cuda', l1_coeff=3e-4):
        super().__init__()
        self.l1_coeff = 3e-4 if not l1_coeff else l1_coeff
        self.d_model = 128
        self.W_enc = nn.Linear(self.d_model, hidden_size)
        self.W_dec = nn.Linear(hidden_size, self.d_model)
        self.b_enc = nn.Parameter(torch.zeros(hidden_size))
        self.b_dec = nn.Parameter(torch.zeros(self.d_model))
        self.W_dec.weight.data = self.W_dec.weight.data / self.W_dec.weight.data.norm(dim=-1, keepdim=True)
        self.to(device)

    def forward(self, x):
        x_cent = x - self.b_dec
        acts = F.relu(self.W_enc(x_cent) + self.b_enc)
        x_hat = self.W_dec(acts) + self.b_dec
        l2_loss = (x - x_hat).pow(2).sum(-1).mean(0)
        l1_loss = self.l1_coeff * (acts.float().abs().sum())
        l0_loss = (acts > 0).sum()
        loss = l2_loss + l1_loss
        return loss, x_hat, acts, l2_loss, l1_loss, l0_loss
    
    @torch.no_grad()
    def make_decoder_weights_and_grad_unit_norm(self):
        W_dec_normed = self.W_dec.weight.data / self.W_dec.weight.data.norm(dim=-1, keepdim=True)
        W_dec_grad_proj = (self.W_dec.weight.grad * W_dec_normed).sum(-1, keepdim=True) * W_dec_normed
        self.W_dec.weight.grad -= W_dec_grad_proj
        self.W_dec.data = W_dec_normed

In [17]:
activation_store = []

def hook_fn(module, input, output):
    global activation_store
    output_ = output.detach().cpu().numpy()
    activation_store.append(output_)

# clear hooks
model.fc2._forward_hooks.clear()
model.fc2.register_forward_hook(hook_fn)

for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device)
    model(data)

activation_store = np.concatenate(activation_store, axis=0)
activation_store = torch.tensor(activation_store).to(device)

model.fc2._forward_hooks.clear()

In [18]:
activation_store.shape

torch.Size([60000, 128])

In [19]:
sae = AutoEncoder(l1_coeff=3e-3)
sae_reconstruction_loss = []
sae_l1_loss = []
epochs = 200
lr = 1e-3
optimizer = optim.Adam(sae.parameters(), lr=lr)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.99)

In [12]:
for epoch in range(epochs):
    curr_l2 = 0
    curr_l1 = 0
    curr_l0 = 0
    epoch_loss = 0
    batch_size = 64
    for i in range(0, activation_store.shape[0], batch_size):
        x = activation_store[i:i+64]
        loss, x_hat, acts, l2_loss, l1_loss, l0_loss = sae(x)
        loss.backward()
        sae.make_decoder_weights_and_grad_unit_norm()
        optimizer.step()
        optimizer.zero_grad()
        epoch_loss += loss.item()
        curr_l2 += l2_loss.item()
        curr_l1 += l1_loss.item()
        curr_l0 += l0_loss.item()
    lr_scheduler.step()
    sae_reconstruction_loss.append(curr_l2)
    sae_l1_loss.append(curr_l1)
    if epoch % (epochs // 40) == 0:
        print(f'Epoch {epoch+1}/{epochs}, Reconstruction Loss: {sae_reconstruction_loss[-1]:.4f}, L1 Loss: {sae_l1_loss[-1]:.4f}, L0 Loss: {curr_l0 / 60000:.4f}')

Epoch 1/200, Reconstruction Loss: 5506.6682, L1 Loss: 7636.7437, L0 Loss: 56.3189
Epoch 6/200, Reconstruction Loss: 1459.3415, L1 Loss: 2675.0650, L0 Loss: 24.3861
Epoch 11/200, Reconstruction Loss: 1279.9638, L1 Loss: 2456.0299, L0 Loss: 26.2061
Epoch 16/200, Reconstruction Loss: 1195.3213, L1 Loss: 2336.5762, L0 Loss: 27.6139
Epoch 21/200, Reconstruction Loss: 1154.1105, L1 Loss: 2236.8220, L0 Loss: 28.3518
Epoch 26/200, Reconstruction Loss: 1126.0550, L1 Loss: 2152.8710, L0 Loss: 28.8948
Epoch 31/200, Reconstruction Loss: 1103.3253, L1 Loss: 2080.2445, L0 Loss: 29.3405
Epoch 36/200, Reconstruction Loss: 1084.5896, L1 Loss: 2017.8386, L0 Loss: 29.7229
Epoch 41/200, Reconstruction Loss: 1067.8886, L1 Loss: 1962.8597, L0 Loss: 30.0466
Epoch 46/200, Reconstruction Loss: 1054.4712, L1 Loss: 1914.2676, L0 Loss: 30.3140
Epoch 51/200, Reconstruction Loss: 1041.9422, L1 Loss: 1869.8053, L0 Loss: 30.5562
Epoch 56/200, Reconstruction Loss: 1031.6115, L1 Loss: 1829.0843, L0 Loss: 30.7638
Epoch 

In [11]:
# show both losses side by side
fig = make_subplots(rows=1, cols=2, subplot_titles=('Reconstruction Loss', 'L1 Loss'))
fig.add_trace(go.Scatter(y=sae_reconstruction_loss, mode='lines', name='', line=dict(color='darkred', width=3)), row=1, col=1)
fig.add_trace(go.Scatter(y=sae_l1_loss, mode='lines', name='', line=dict(color='darkblue', width=3)), row=1, col=2)
fig.update_layout({'plot_bgcolor': 'rgba(255, 255, 255, 1)',})
fig.update_xaxes(title_text='Optimization Step')
fig.update_yaxes(title_text='Loss')
fig.update_layout(width=1200, height=400, autosize=False)
fig.show()

## Visualizing Features

In [12]:
sae.W_dec.weight.data.shape

torch.Size([128, 1024])

In [13]:
features = sae.W_dec.weight.data.T # (1024, 128)

In [14]:
def get_freqs(sae, mlp_acts):
    frequencies = [0 for _ in range(features.shape[0])]
    sae_acts = sae(mlp_acts)[2]
    active_points = (sae_acts > 0).float().sum(0)
    for i in range(features.shape[0]):
        frequencies[i] = active_points[i].item()
    return [int(x) for x in frequencies]

In [15]:
freqs = get_freqs(sae, activation_store)

In [16]:
# plot histogram of frequencies
fig = go.Figure()
fig.add_trace(go.Histogram(x=freqs, marker=dict(color='darkred')))
fig.update_layout({'plot_bgcolor': 'rgba(255, 255, 255, 1)',})
fig.update_xaxes(title_text='Number of Active Points')
fig.update_yaxes(title_text='Frequency')
fig.update_layout(width=600, height=400, autosize=False)
fig.show()

In [17]:
active_features = [i for i, freq in enumerate(freqs) if freq > 0]
dead_features = [i for i, freq in enumerate(freqs) if freq == 0]
print(f'Feature Activation: {len(active_features)} active, {len(dead_features)} dead')

Feature Activation: 147 active, 877 dead


In [22]:
len(active_features)

(147, 877)

In [28]:
sae_acts = sae(activation_store)[2]
sae_acts.shape
labels = [train_dataset[i][1] for i in range(len(train_dataset))]
len(labels)

60000

In [31]:
activating_datapoints = [[] for _ in range(len(active_features))]
activating_targets = [[] for _ in range(len(active_features))]

for i in tqdm.trange(len(labels)):
    for j, idx in enumerate(active_features):
        if sae_acts[i, idx] > 0:
            activating_datapoints[j].append(activation_store[i].cpu().numpy())
            activating_targets[j].append(labels[i])

100%|██████████| 60000/60000 [06:36<00:00, 151.14it/s]


In [55]:
feature_idx = active_features[2]

new_index = active_features.index(feature_idx)
activating_datapoints_ = activating_datapoints[new_index]
activating_targets_ = activating_targets[new_index]

# plot the distribution of the target labels
fig = go.Figure()
fig.add_trace(go.Histogram(x=activating_targets_, marker=dict(color='darkred'), xbins=dict(size=1), histnorm='probability'))
fig.update_layout(bargap=0.1)
fig.update_xaxes(tick0=0, dtick=1)
fig.update_layout({'plot_bgcolor': 'rgba(255, 255, 255, 1)',})
fig.update_xaxes(title_text='Target Label')
fig.update_yaxes(title_text='Frequency')
fig.update_layout(width=600, height=400, autosize=False)
fig.show()