## Setup

In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import tqdm
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import mutual_info_score
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import clear_output
from collections import defaultdict
from itertools import islice
import random
import time
from pathlib import Path
import math

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [3]:
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128, bias=False)
        self.fc2 = nn.Linear(128, 128, bias=False)
        self.fc3 = nn.Linear(128, 10, bias=False)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [4]:
def accuracy(model, data):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in data:
            outputs = model(images.to(device))
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels.to(device)).sum().item()
    return correct / total

In [5]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # normalize the data
])

train_dataset = MNIST(root='.', train=True, download=True, transform=transform)
test_dataset = MNIST(root='.', train=False, download=True, transform=transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

In [6]:
model = MLP()
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3)
train_losses = []

In [7]:
# struct to store gradients of fc2 during each update
grads = defaultdict(list)

In [8]:
for epoch in range(5):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        # store fc2 gradients
        for p in model.named_parameters():
            if p[0] == 'fc2.weight':
                grads[batch_idx + (epoch * len(train_loader))].append(p[1].grad.clone())
        optimizer.step()
        train_losses.append(loss.item())
        if batch_idx % 400 == 0:
            acc = accuracy(model, test_loader)
            print(f'Epoch {epoch+1}/{5}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}, Accuracy: {acc:.4f}')

Epoch 1/5, Batch 0/938, Loss: 2.3006, Accuracy: 0.0897
Epoch 1/5, Batch 400/938, Loss: 2.2268, Accuracy: 0.3548
Epoch 1/5, Batch 800/938, Loss: 2.0596, Accuracy: 0.5094
Epoch 2/5, Batch 0/938, Loss: 1.9660, Accuracy: 0.5300
Epoch 2/5, Batch 400/938, Loss: 1.6622, Accuracy: 0.6303
Epoch 2/5, Batch 800/938, Loss: 1.2533, Accuracy: 0.7237
Epoch 3/5, Batch 0/938, Loss: 1.4081, Accuracy: 0.7438
Epoch 3/5, Batch 400/938, Loss: 1.1240, Accuracy: 0.7966
Epoch 3/5, Batch 800/938, Loss: 0.7808, Accuracy: 0.8251
Epoch 4/5, Batch 0/938, Loss: 0.8641, Accuracy: 0.8311
Epoch 4/5, Batch 400/938, Loss: 0.7361, Accuracy: 0.8449
Epoch 4/5, Batch 800/938, Loss: 0.8414, Accuracy: 0.8556
Epoch 5/5, Batch 0/938, Loss: 0.5157, Accuracy: 0.8579
Epoch 5/5, Batch 400/938, Loss: 0.6084, Accuracy: 0.8662
Epoch 5/5, Batch 800/938, Loss: 0.4657, Accuracy: 0.8719


In [9]:
fig = go.Figure()
fig.add_trace(go.Scatter(y=train_losses, mode='lines', name='', line=dict(color='darkred', width=2)))
fig.update_layout({'plot_bgcolor': 'rgba(255, 255, 255, 1)',})
fig.update_layout(
    xaxis=dict(showgrid=True, gridwidth=1, gridcolor='LightGray'),
    yaxis=dict(showgrid=True, gridwidth=1, gridcolor='LightGray'),
)
fig.update_xaxes(title_text='Optimization Step')
fig.update_yaxes(title_text='CrossEntropy Loss')
fig.update_layout(width=600, height=400, autosize=False)
fig.show()

In [10]:
model

MLP(
  (fc1): Linear(in_features=784, out_features=128, bias=False)
  (fc2): Linear(in_features=128, out_features=128, bias=False)
  (fc3): Linear(in_features=128, out_features=10, bias=False)
)

## Clustering

In [11]:
import numpy as np
import numpy as np
from sklearn.cluster import KMeans
from scipy.sparse.linalg import svds

# bipartite clustering

In [12]:
def spectral_clustering(model, k):
    
    A = model.fc2.weight.detach().cpu().numpy()
    A = np.abs(A)
    
    D_U = np.diag(np.sum(A, axis=1))
    D_V = np.diag(np.sum(A, axis=0))

    D_U_inv_sqrt = np.linalg.inv(np.sqrt(D_U))
    D_V_inv_sqrt = np.linalg.inv(np.sqrt(D_V))

    A_tilde = D_U_inv_sqrt @ A @ D_V_inv_sqrt

    U, Sigma, Vt = svds(A_tilde, k=k)

    kmeans_U = KMeans(n_clusters=k, random_state=42).fit(U)
    kmeans_V = KMeans(n_clusters=k, random_state=42).fit(Vt.T)

    labels_U = kmeans_U.labels_
    labels_V = kmeans_V.labels_

    # convert labels to indices
    cluster_U_indices = defaultdict(list)
    cluster_V_indices = defaultdict(list)
    for i, label in enumerate(labels_U):
        cluster_U_indices[label].append(i)
    for i, label in enumerate(labels_V):
        cluster_V_indices[label].append(i)

    return cluster_U_indices, cluster_V_indices

In [13]:
num_clusters = 2
cluster_U_indices, cluster_V_indices = spectral_clustering(model, num_clusters)

In [14]:
for i in range(num_clusters):
    print(f'Cluster {i} has {len(cluster_U_indices[i])} nodes in U and {len(cluster_V_indices[i])} nodes in V')

Cluster 0 has 65 nodes in U and 59 nodes in V
Cluster 1 has 63 nodes in U and 69 nodes in V


## How good is a cluster?

Now, we need a (possibly information-theoretic) way to evaluate how good this clustering is.

In [15]:
def clustering_goodness_f1(model, cluster_U_indices, cluster_V_indices, num_clusters):
    A = model.fc2.weight.detach().cpu().numpy()
    # square
    A = np.square(A)
    
    # fraction of edges within the same cluster
    intra_cluster_out_sum = 0
    total_out_sum = 0
    for i in range(len(A)):
        for j in range(len(A)):
            total_out_sum += A[i, j]
            same_cluster = False
            for cluster_idx in range(num_clusters):
                if i in cluster_U_indices[cluster_idx] and j in cluster_V_indices[cluster_idx]:
                    same_cluster = True
                    break
            if same_cluster:
                intra_cluster_out_sum += A[i, j]
    
    intra_cluster_fraction = intra_cluster_out_sum / total_out_sum

    return round(intra_cluster_fraction, 3)

In [17]:
clusterability = clustering_goodness_f1(model, cluster_U_indices, cluster_V_indices, num_clusters)
print(f'The model has a clusterability of {clusterability} (out of 1) with {num_clusters} clusters.')

The model has a clusterability of 0.548 (out of 1) with 2 clusters.


This is somewhat okayish? This means the model is by default not very clusterable, which means there's good scope for improvement.

We don't really know if it is going to be possible to do this clustering, but a search on the algorithm space seems easy if we do fix this as our objective measure for clusterability.

### Some other measure of goodness
- information theoretic one (which is more annoying to optimize over)
- L2 norm for per-neuron cross-weights compared with those inside a cluster (could be easier to optimize over)
- the ability to search for human-interpretable circuits (the final goal, but really difficult to even formalize)

In [18]:
# try some more goodness measures

## Do Gradient Updates create Clusters?

The overall idea is this:

As circuits form, neurons (and parameters) that create a particular circuit during learning likely update together.

Thus, maybe parameters that update together should be clustered together. Let's explore how much gradients can tell us.

In [23]:
[len(cluster_U_indices[i]) for i in range(len(cluster_U_indices.keys()))] # clusters in U (the left side)

[65, 63]

In [24]:
len(grads), len(grads[30]), grads[0][0].shape

(4690, 1, torch.Size([128, 128]))

In [27]:
grads = {k: torch.stack(v) for k, v in grads.items()}

In [28]:
grads = {k: v[0] for k, v in grads.items()}

In [29]:
# normalize gradients
grads = {k: v / v.norm() for k, v in grads.items()}

In [33]:
# make one big tensor
grads_tensor = torch.stack(list(grads.values()))

In [34]:
grads_tensor.shape # (step, in_neuron, out_neuron)

torch.Size([4690, 128, 128])

In [42]:
# mean and std of gradients for each update step
grads_mean = grads_tensor.mean(dim=(1, 2))
grads_std = grads_tensor.std(dim=(1, 2))
grads_mean.shape, grads_mean[:5], grads_std.shape, grads_std[-5:]

(torch.Size([4690]),
 tensor([ 2.4728e-04,  1.3061e-04, -2.4917e-05, -1.3795e-05,  7.5134e-04],
        device='cuda:0'),
 torch.Size([4690]),
 tensor([0.0078, 0.0078, 0.0078, 0.0078, 0.0078], device='cuda:0'))

In [45]:
similarity_matrix = torch.zeros(grads_tensor.shape[1], grads_tensor.shape[1]).to(device)

for i in range(grads_tensor.shape[0]):
    similarity_matrix += torch.mm(grads_tensor[i].t(), grads_tensor[i])

In [52]:
similarity_matrix /= grads_tensor.shape[0]

similarity_matrix.shape

torch.Size([128, 128])

In [56]:
def gradient_based_clustering(similarity_matrix, num_clusters):
    # bipartite clustering
    A = similarity_matrix.detach().cpu().numpy()
    k = num_clusters

    D_U = np.diag(np.sum(A, axis=1))
    D_V = np.diag(np.sum(A, axis=0))

    D_U_inv_sqrt = np.linalg.inv(np.sqrt(D_U))
    D_V_inv_sqrt = np.linalg.inv(np.sqrt(D_V))

    A_tilde = D_U_inv_sqrt @ A @ D_V_inv_sqrt

    U, Sigma, Vt = svds(A_tilde, k=k)

    kmeans_U = KMeans(n_clusters=k, random_state=42).fit(U)
    kmeans_V = KMeans(n_clusters=k, random_state=42).fit(Vt.T)

    labels_U = kmeans_U.labels_
    labels_V = kmeans_V.labels_

    # convert labels to indices
    cluster_U_indices = defaultdict(list)
    cluster_V_indices = defaultdict(list)
    for i, label in enumerate(labels_U):
        cluster_U_indices[label].append(i)
    for i, label in enumerate(labels_V):
        cluster_V_indices[label].append(i)

    return cluster_U_indices, cluster_V_indices

In [57]:
cluster_U_indices, cluster_V_indices = gradient_based_clustering(similarity_matrix, num_clusters)

In [58]:
for i in range(num_clusters):
    print(f'Cluster {i} has {len(cluster_U_indices[i])} nodes in U and {len(cluster_V_indices[i])} nodes in V')

Cluster 0 has 60 nodes in U and 60 nodes in V
Cluster 1 has 68 nodes in U and 68 nodes in V


In [59]:
clustering_goodness_f1(model, cluster_U_indices, cluster_V_indices, num_clusters)

0.509

Alright, this also doesn't give any good clusters based on our "weight interference goodness measure".

Now, let's explore some of these gradients!

In [60]:
grads_tensor.shape

torch.Size([4690, 128, 128])

In [63]:
fig = go.Figure()
fig.add_trace(go.Scatter(y=grads_tensor[:, 0, 0].cpu().numpy(), mode='lines', name='Input Neuron 0', line=dict(color='darkred', width=2)))
fig.add_trace(go.Scatter(y=grads_tensor[:, 1, 1].cpu().numpy(), mode='lines', name='Output Neuron 1', line=dict(color='darkblue', width=2)))
fig.update_layout({'plot_bgcolor': 'rgba(255, 255, 255, 1)',})
fig.update_layout(
    xaxis=dict(showgrid=True, gridwidth=1, gridcolor='LightGray'),
    yaxis=dict(showgrid=True, gridwidth=1, gridcolor='LightGray'),
)
fig.update_xaxes(title_text='Optimization Step')
fig.update_yaxes(title_text='Normalized Gradient')
fig.update_layout(width=600, height=400, autosize=False)
fig.show()

In [78]:
grads_tensor_input_sum = grads_tensor.sum(dim=2)

grads_tensor_input_sum.shape

torch.Size([4690, 128])

In [102]:
rows, cols = 10, 4
fig = make_subplots(rows=rows, cols=cols, subplot_titles=[f'' for i in range(rows) for j in range(cols)])

colors = px.colors.qualitative.Plotly

for i in range(rows):
    for j in range(cols):
        fig.add_trace(go.Scatter(y=grads_tensor_input_sum[:, i + j].cpu().numpy(), mode='lines', name=f'neuron {i * cols + j}', line=dict(color=colors[(i * cols + j) % 10], width=2)), row=i+1, col=j+1)

fig.update_layout({'plot_bgcolor': 'rgba(255, 255, 255, 1)',})
for i in range(1, rows + 1):
    for j in range(1, cols + 1):
        fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='LightGray', row=i, col=j)
        fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGray', row=i, col=j)
fig.update_xaxes(title_text='step')
fig.update_yaxes(title_text='grad')
fig.update_layout(width=300 * cols, height=180 * rows, autosize=False)
fig.show()

Wow! So there are three main kinds of neurons:
- A. Those that learn a few discrete things and then die out?
- B. Those that update a lot at first and then seem to "settle" to something?
- C. Those that perpetually keep on updating