## Setup

In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import tqdm
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import mutual_info_score
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import clear_output
from collections import defaultdict
from itertools import islice
import random
import time
from pathlib import Path
import math

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [3]:
def randomseed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [4]:
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128, bias=False)
        self.fc2 = nn.Linear(128, 128, bias=False)
        self.fc3 = nn.Linear(128, 10, bias=False)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [5]:
def accuracy(model, data):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in data:
            outputs = model(images.to(device))
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels.to(device)).sum().item()
    return correct / total

In [6]:
randomseed(42)

In [7]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # normalize the data
])

train_dataset = MNIST(root='.', train=True, download=True, transform=transform)
test_dataset = MNIST(root='.', train=False, download=True, transform=transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

In [8]:
model = MLP()
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3)
train_losses = []

In [9]:
# struct to store gradients of fc2 during each update
grads = defaultdict(list)

In [75]:
randomseed(42)

for epoch in range(5):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        # store fc2 gradients
        for p in model.named_parameters():
            if p[0] == 'fc2.weight':
                grads[batch_idx + (epoch * len(train_loader))].append(p[1].grad.clone())
        optimizer.step()
        train_losses.append(loss.item())
        if batch_idx % 400 == 0:
            acc = accuracy(model, test_loader)
            print(f'Epoch {epoch+1}/{5}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}, Accuracy: {acc:.4f}')

Epoch 1/5, Batch 0/938, Loss: 2.3001, Accuracy: 0.1095, LR: 0.0004
Epoch 1/5, Batch 400/938, Loss: 2.1856, Accuracy: 0.4397, LR: 0.0002
Epoch 1/5, Batch 800/938, Loss: 1.9595, Accuracy: 0.6761, LR: 0.0002
Epoch 2/5, Batch 0/938, Loss: 1.9510, Accuracy: 0.6977, LR: 0.0001
Epoch 2/5, Batch 400/938, Loss: 1.5521, Accuracy: 0.7519, LR: 0.0001
Epoch 2/5, Batch 800/938, Loss: 1.1782, Accuracy: 0.7865, LR: 0.0001
Epoch 3/5, Batch 0/938, Loss: 1.1924, Accuracy: 0.7949, LR: 0.0001
Epoch 3/5, Batch 400/938, Loss: 0.8333, Accuracy: 0.8205, LR: 0.0
Epoch 3/5, Batch 800/938, Loss: 0.8115, Accuracy: 0.8373, LR: 0.0
Epoch 4/5, Batch 0/938, Loss: 0.6676, Accuracy: 0.8391, LR: 0.0
Epoch 4/5, Batch 400/938, Loss: 0.5736, Accuracy: 0.8497, LR: 0.0
Epoch 4/5, Batch 800/938, Loss: 0.4774, Accuracy: 0.8627, LR: 0.0
Epoch 5/5, Batch 0/938, Loss: 0.5586, Accuracy: 0.8646, LR: 0.0
Epoch 5/5, Batch 400/938, Loss: 0.4237, Accuracy: 0.8736, LR: 0.0
Epoch 5/5, Batch 800/938, Loss: 0.6097, Accuracy: 0.8771, LR: 0.0

In [76]:
fig = go.Figure()
fig.add_trace(go.Scatter(y=train_losses, mode='lines', name='', line=dict(color='darkred', width=2)))
fig.update_layout({'plot_bgcolor': 'rgba(255, 255, 255, 1)',})
fig.update_layout(
    xaxis=dict(showgrid=True, gridwidth=1, gridcolor='LightGray'),
    yaxis=dict(showgrid=True, gridwidth=1, gridcolor='LightGray'),
)
fig.update_xaxes(title_text='Optimization Step')
fig.update_yaxes(title_text='CrossEntropy Loss')
fig.update_layout(width=600, height=400, autosize=False)
fig.show()

In [10]:
model

MLP(
  (fc1): Linear(in_features=784, out_features=128, bias=False)
  (fc2): Linear(in_features=128, out_features=128, bias=False)
  (fc3): Linear(in_features=128, out_features=10, bias=False)
)

## Clustering

In [12]:
import numpy as np
import numpy as np
from sklearn.cluster import KMeans
from scipy.sparse.linalg import svds

# bipartite clustering

In [13]:
def spectral_clustering(model, k):
    
    A = model.fc2.weight.detach().cpu().numpy()
    A = np.abs(A)
    
    D_U = np.diag(np.sum(A, axis=1))
    D_V = np.diag(np.sum(A, axis=0))

    D_U_inv_sqrt = np.linalg.inv(np.sqrt(D_U))
    D_V_inv_sqrt = np.linalg.inv(np.sqrt(D_V))

    A_tilde = D_U_inv_sqrt @ A @ D_V_inv_sqrt

    U, Sigma, Vt = svds(A_tilde, k=k)

    kmeans_U = KMeans(n_clusters=k, random_state=42).fit(U)
    kmeans_V = KMeans(n_clusters=k, random_state=42).fit(Vt.T)

    labels_U = kmeans_U.labels_
    labels_V = kmeans_V.labels_

    # convert labels to indices
    cluster_U_indices = defaultdict(list)
    cluster_V_indices = defaultdict(list)
    for i, label in enumerate(labels_U):
        cluster_U_indices[label].append(i)
    for i, label in enumerate(labels_V):
        cluster_V_indices[label].append(i)

    return cluster_U_indices, cluster_V_indices

In [13]:
num_clusters = 2
cluster_U_indices, cluster_V_indices = spectral_clustering(model, num_clusters)

In [14]:
for i in range(num_clusters):
    print(f'Cluster {i} has {len(cluster_U_indices[i])} nodes in U and {len(cluster_V_indices[i])} nodes in V')

Cluster 0 has 58 nodes in U and 63 nodes in V
Cluster 1 has 70 nodes in U and 65 nodes in V


## How good is a cluster?

Now, we need a (possibly information-theoretic) way to evaluate how good this clustering is.

In [16]:
def clustering_goodness_f1(model, cluster_U_indices, cluster_V_indices, num_clusters):
    A = model.fc2.weight.detach().cpu().numpy()
    # square
    A = np.square(A)
    
    # fraction of edges within the same cluster
    intra_cluster_out_sum = 0
    total_out_sum = 0
    for i in range(len(A)):
        for j in range(len(A)):
            total_out_sum += A[i, j]
            same_cluster = False
            for cluster_idx in range(num_clusters):
                if i in cluster_U_indices[cluster_idx] and j in cluster_V_indices[cluster_idx]:
                    same_cluster = True
                    break
            if same_cluster:
                intra_cluster_out_sum += A[i, j]
    
    intra_cluster_fraction = intra_cluster_out_sum / total_out_sum

    return round(intra_cluster_fraction, 3)

In [16]:
clusterability = clustering_goodness_f1(model, cluster_U_indices, cluster_V_indices, num_clusters)
print(f'The model has a clusterability of {clusterability} (out of 1) with {num_clusters} clusters.')

The model has a clusterability of 0.452 (out of 1) with 2 clusters.


This is somewhat okayish? This means the model is by default not very clusterable, which means there's good scope for improvement.

We don't really know if it is going to be possible to do this clustering, but a search on the algorithm space seems easy if we do fix this as our objective measure for clusterability.

### Some other measure of goodness
- information theoretic one (which is more annoying to optimize over)
- L2 norm for per-neuron cross-weights compared with those inside a cluster (could be easier to optimize over)
- the ability to search for human-interpretable circuits (the final goal, but really difficult to even formalize)

In [17]:
# try some more goodness measures

## Do Gradient Updates create Clusters?

The overall idea is this:

As circuits form, neurons (and parameters) that create a particular circuit during learning likely update together.

Thus, maybe parameters that update together should be clustered together. Let's explore how much gradients can tell us.

In [18]:
[len(cluster_U_indices[i]) for i in range(len(cluster_U_indices.keys()))] # clusters in U (the left side)

[58, 70]

In [51]:
if type(grads) is not torch.Tensor:
    grads = torch.stack(list({k: v[0] / v[0].norm() for k, v in grads.items()}.values()))

In [17]:
grads.shape # (step, in_neuron, out_neuron)

torch.Size([4690, 128, 128])

In [25]:
## maybe gradients start getting similar only in the last 25% of the training

In [26]:
# mean and std of gradients for each update step
grads_mean = grads.mean(dim=(1, 2))
grads_std = grads.std(dim=(1, 2))
grads_mean.shape, grads_mean[:5], grads_std.shape, grads_std[-5:]

(torch.Size([4690]),
 tensor([ 0.0007,  0.0004,  0.0005, -0.0004,  0.0002], device='cuda:0'),
 torch.Size([4690]),
 tensor([0.0078, 0.0078, 0.0078, 0.0078, 0.0078], device='cuda:0'))

In [27]:
similarity_matrix = torch.zeros(grads.shape[1], grads.shape[1]).to(device)

for i in range(grads.shape[0]):
    similarity_matrix += torch.mm(grads[i].t(), grads[i])

In [28]:
similarity_matrix /= grads.shape[0]

similarity_matrix.shape

torch.Size([128, 128])

In [29]:
def gradient_based_clustering(similarity_matrix, num_clusters):
    # bipartite clustering
    A = similarity_matrix.detach().cpu().numpy()
    k = num_clusters

    D_U = np.diag(np.sum(A, axis=1))
    D_V = np.diag(np.sum(A, axis=0))

    D_U_inv_sqrt = np.linalg.inv(np.sqrt(D_U))
    D_V_inv_sqrt = np.linalg.inv(np.sqrt(D_V))

    A_tilde = D_U_inv_sqrt @ A @ D_V_inv_sqrt

    U, Sigma, Vt = svds(A_tilde, k=k)

    kmeans_U = KMeans(n_clusters=k, random_state=42).fit(U)
    kmeans_V = KMeans(n_clusters=k, random_state=42).fit(Vt.T)

    labels_U = kmeans_U.labels_
    labels_V = kmeans_V.labels_

    # convert labels to indices
    cluster_U_indices = defaultdict(list)
    cluster_V_indices = defaultdict(list)
    for i, label in enumerate(labels_U):
        cluster_U_indices[label].append(i)
    for i, label in enumerate(labels_V):
        cluster_V_indices[label].append(i)

    return cluster_U_indices, cluster_V_indices

In [30]:
cluster_U_indices, cluster_V_indices = gradient_based_clustering(similarity_matrix, num_clusters)

In [31]:
for i in range(num_clusters):
    print(f'Cluster {i} has {len(cluster_U_indices[i])} nodes in U and {len(cluster_V_indices[i])} nodes in V')

Cluster 0 has 43 nodes in U and 43 nodes in V
Cluster 1 has 85 nodes in U and 85 nodes in V


In [32]:
clustering_goodness_f1(model, cluster_U_indices, cluster_V_indices, num_clusters)

0.547

Alright, this also doesn't give any good clusters based on our "weight interference goodness measure".

## Optimizing for Clusterability

Step 1: train for some time.

In [45]:
model = MLP()
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3)
train_losses = []

randomseed(42)

for epoch in range(1):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())
        if batch_idx % 400 == 0:
            acc = accuracy(model, test_loader)
            print(f'Epoch {epoch+1}/{5}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}, Accuracy: {acc:.4f}')

Epoch 1/5, Batch 0/938, Loss: 2.3259, Accuracy: 0.0631
Epoch 1/5, Batch 400/938, Loss: 2.2052, Accuracy: 0.3871
Epoch 1/5, Batch 800/938, Loss: 1.9979, Accuracy: 0.6055


Step 2: Get clusters.

In [30]:
num_clusters = 2
cluster_U_indices, cluster_V_indices = spectral_clustering(model, num_clusters)

for i in range(num_clusters):
    print(f'Cluster {i} has {len(cluster_U_indices[i])} nodes in U and {len(cluster_V_indices[i])} nodes in V')

Cluster 0 has 61 nodes in U and 52 nodes in V
Cluster 1 has 67 nodes in U and 76 nodes in V


Step 3: Get clusterability.

In [31]:
clusterability = clustering_goodness_f1(model, cluster_U_indices, cluster_V_indices, num_clusters)
print(f'The model has a clusterability of {clusterability} (out of 1) with {num_clusters} clusters.')

The model has a clusterability of 0.459 (out of 1) with 2 clusters.


Step 4: Make a differentiable cluster goodness function.

In [19]:
def cluster_goodness_(model, cluster_U_indices, cluster_V_indices, num_clusters):
    A = model.fc2.weight
    # square
    A = A ** 2
    
    # fraction of edges within the same cluster
    intra_cluster_out_sum = 0
    total_out_sum = torch.sum(A)
    
    for i in range(A.size(0)):
        for j in range(A.size(1)):
            same_cluster = False
            for cluster_idx in range(num_clusters):
                if i in cluster_U_indices[cluster_idx] and j in cluster_V_indices[cluster_idx]:
                    same_cluster = True
                    break
            if same_cluster:
                intra_cluster_out_sum += A[i, j]
    
    intra_cluster_fraction = intra_cluster_out_sum / total_out_sum

    return intra_cluster_fraction

In [34]:
def cluster_goodness_fast(model, cluster_U_indices, cluster_V_indices, num_clusters):
    A = model.fc2.weight ** 2
    mask = torch.zeros_like(A, dtype=torch.bool)
    
    for cluster_idx in range(num_clusters):
        u_indices = torch.tensor(cluster_U_indices[cluster_idx], dtype=torch.long)
        v_indices = torch.tensor(cluster_V_indices[cluster_idx], dtype=torch.long)
        mask[u_indices.unsqueeze(1), v_indices] = True
    
    intra_cluster_out_sum = torch.sum(A[mask])
    total_out_sum = torch.sum(A)
    
    return intra_cluster_out_sum / total_out_sum

In [35]:
clusterability = cluster_goodness_(model, cluster_U_indices, cluster_V_indices, num_clusters)
print(f'The model has a clusterability of {clusterability} (out of 1) with {num_clusters} clusters.')

clusterability = cluster_goodness_fast(model, cluster_U_indices, cluster_V_indices, num_clusters)
print(f'The model has a clusterability of {clusterability} (out of 1) with {num_clusters} clusters.')

The model has a clusterability of 0.45867615938186646 (out of 1) with 2 clusters.
The model has a clusterability of 0.4586756229400635 (out of 1) with 2 clusters.


In [36]:
# try a backward pass
optimizer.zero_grad()
clusterability.backward()
optimizer.zero_grad()

Step 5: Train the rest of the model while including the clusterablitlity loss function.

In [56]:
cluster_losses = []
ce_losses = []
lomda = 3.0
epochs = 10

randomseed(42)

for epoch in range(epochs):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss_ce = criterion(output, target)
        loss_cluster = cluster_goodness_fast(model, cluster_U_indices, cluster_V_indices, num_clusters)
        cluster_losses.append(loss_cluster.item())
        ce_losses.append(loss_ce.item())
        loss = loss_ce - lomda * loss_cluster
        loss.backward()
        optimizer.step()
        if batch_idx % 200 == 0:
            acc = accuracy(model, test_loader)
            print(f'Epoch {epoch+1}/{epochs}, Batch {batch_idx}/{len(train_loader)}, CE loss: {loss.item():.4f}, Clusterability: {loss_cluster.item():.4f}, Accuracy: {acc:.4f}')

Epoch 1/10, Batch 0/938, CE loss: -2.9347, Clusterability: 0.9993, Accuracy: 0.9471
Epoch 1/10, Batch 200/938, CE loss: -2.9391, Clusterability: 0.9993, Accuracy: 0.9469
Epoch 1/10, Batch 400/938, CE loss: -2.7890, Clusterability: 0.9993, Accuracy: 0.9469
Epoch 1/10, Batch 600/938, CE loss: -2.8117, Clusterability: 0.9993, Accuracy: 0.9476
Epoch 1/10, Batch 800/938, CE loss: -2.9150, Clusterability: 0.9993, Accuracy: 0.9472
Epoch 2/10, Batch 0/938, CE loss: -2.8254, Clusterability: 0.9993, Accuracy: 0.9475
Epoch 2/10, Batch 200/938, CE loss: -2.8185, Clusterability: 0.9993, Accuracy: 0.9480
Epoch 2/10, Batch 400/938, CE loss: -2.8282, Clusterability: 0.9993, Accuracy: 0.9476
Epoch 2/10, Batch 600/938, CE loss: -2.8706, Clusterability: 0.9993, Accuracy: 0.9475
Epoch 2/10, Batch 800/938, CE loss: -2.8736, Clusterability: 0.9993, Accuracy: 0.9480
Epoch 3/10, Batch 0/938, CE loss: -2.8492, Clusterability: 0.9993, Accuracy: 0.9483
Epoch 3/10, Batch 200/938, CE loss: -2.8423, Clusterability:

In [50]:
# plot the two losses on side by side plots
fig = make_subplots(rows=1, cols=2, subplot_titles=('Cross Entropy Loss', 'Model Clusterability'))
fig.add_trace(go.Scatter(y=ce_losses, mode='lines', name='', line=dict(color='darkred', width=2)), row=1, col=1)
fig.add_trace(go.Scatter(y=cluster_losses, mode='lines', name='', line=dict(color='darkblue', width=2)), row=1, col=2)
fig.update_layout({'plot_bgcolor': 'rgba(255, 255, 255, 1)',})
# show fine grid lines on both axes on both subplots
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='LightGray', row=1, col=1)
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGray', row=1, col=1)
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='LightGray', row=1, col=2)
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGray', row=1, col=2)

fig.update_xaxes(title_text='Optimization Step', row=1, col=1)
fig.update_yaxes(title_text='Cross Entropy', row=1, col=1)
fig.update_xaxes(title_text='Optimization Step', row=1, col=2)
fig.update_yaxes(title_text='Clusterability', row=1, col=2)
fig.update_layout(width=1000, height=400, autosize=False)
fig.show()

In [None]:
# store the new, cluster-promoting model and the cluster indices
torch.save(model.state_dict(), 'model_cluster.pt')
torch.save(cluster_U_indices, 'cluster_U_indices.pt')
torch.save(cluster_V_indices, 'cluster_V_indices.pt')

In [None]:
# load the model and cluster indices
model = MLP()
model.load_state_dict(torch.load('model_cluster.pt'))
cluster_U_indices = torch.load('cluster_U_indices.pt')
cluster_V_indices = torch.load('cluster_V_indices.pt')

## Studying the clusters!

Yay! I should store the "clustered" model and study/interpret it a bit to see what exactly happened during this clustering-promoted training.