What if clusterability during training is actually simpler than we thought it was?

In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import tqdm
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import mutual_info_score
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import plotly.express as px
import plotly.graph_objects as go
import plotly.colors as pc
from plotly.subplots import make_subplots
from IPython.display import clear_output
from collections import defaultdict
from itertools import islice
import random
import time
from pathlib import Path
import math

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

def randomseed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 64, bias=False)
        self.fc2 = nn.Linear(64, 64, bias=False)
        self.fc3 = nn.Linear(64, 10, bias=False)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
def accuracy(model, data):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in data:
            outputs = model(images.to(device))
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels.to(device)).sum().item()
    return correct / total

def classwise_accuracy(model, data):
    model.eval()
    correct = defaultdict(int)
    total = defaultdict(int)
    with torch.no_grad():
        for images, labels in data:
            outputs = model(images.to(device))
            _, predicted = torch.max(outputs.data, 1)
            for i in range(len(labels)):
                label = labels[i].item()
                total[label] += 1
                correct[label] += int(predicted[i] == label)
    return [round(correct[i] / total[i], 3) if total[i] > 0 else 0 for i in range(10)]

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # normalize the data
])

train_dataset = MNIST(root='.', train=True, download=True, transform=transform)
test_dataset = MNIST(root='.', train=False, download=True, transform=transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

def cluster_goodness_fast(model, cluster_U_indices, cluster_V_indices, num_clusters):
    A = model.fc2.weight ** 2
    mask = torch.zeros_like(A, dtype=torch.bool)
    
    for cluster_idx in range(num_clusters):
        u_indices = torch.tensor(cluster_U_indices[cluster_idx], dtype=torch.long)
        v_indices = torch.tensor(cluster_V_indices[cluster_idx], dtype=torch.long)
        mask[u_indices.unsqueeze(1), v_indices] = True
    
    intra_cluster_out_sum = torch.sum(A[mask])
    total_out_sum = torch.sum(A)
    
    return intra_cluster_out_sum / total_out_sum

In [2]:
model = MLP()
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3)
ce_losses = []
cluster_losses = []
lomda = 15.0
epochs = 10

In [3]:
# equally divide fc2 into 4 clusters (64 / 4 = 16) as a dict

num_clusters = 4

cluster_U_indices = {i: list(range(i * 16, (i + 1) * 16)) for i in range(num_clusters)}
cluster_V_indices = {i: list(range(i * 16, (i + 1) * 16)) for i in range(num_clusters)}

In [4]:
randomseed(42)

for epoch in range(epochs):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss_ce = criterion(output, target)
        loss_cluster = cluster_goodness_fast(model, cluster_U_indices, cluster_V_indices, num_clusters)
        cluster_losses.append(loss_cluster.item())
        ce_losses.append(loss_ce.item())
        loss = loss_ce - lomda * loss_cluster
        loss.backward()
        optimizer.step()
        if batch_idx % 200 == 0:
            acc = accuracy(model, test_loader)
            print(f'Epoch {epoch+1}/{epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}, Clusterability: {loss_cluster.item():.4f}, Accuracy: {acc:.4f}')

Epoch 1/10, Batch 0/938, Loss: -1.4736, Clusterability: 0.2513, Accuracy: 0.1165


KeyboardInterrupt: 

In [3]:
path: str = 'checkpoints/'

In [None]:
Path(path).mkdir(parents=True, exist_ok=True)
torch.save(model.state_dict(), path + 'model.pth')
# store the cluster indices
torch.save(cluster_U_indices, path + 'cluster_U_indices.pth')
torch.save(cluster_V_indices, path + 'cluster_V_indices.pth')

In [4]:
# load everything back
model = MLP()
model.load_state_dict(torch.load(path + 'model.pth'))
model.to(device)
cluster_U_indices = torch.load(path + 'cluster_U_indices.pth')
cluster_V_indices = torch.load(path + 'cluster_V_indices.pth')

In [5]:
# print classwise accuracy of the model
classwise_accuracy(model, test_loader)

[0.967, 0.965, 0.859, 0.884, 0.909, 0.816, 0.924, 0.908, 0.854, 0.875]

In [28]:
# get accuracy for each cluster turned off one by one
cluster_off_accuracies = []

for cluster_idx in range(num_clusters):
    model = MLP()
    model.load_state_dict(torch.load(path + 'model.pth'))
    model.to(device)
    
    # turn off the cluster
    for i in cluster_U_indices[cluster_idx]:
        model.fc2.weight.data[i] = 0
    for i in cluster_V_indices[cluster_idx]:
        model.fc2.weight.data[:, i] = 0
    
    acc = accuracy(model, test_loader)
    cluster_off_accuracies.append(acc)

    print(f'Cluster {cluster_idx} turned off, Accuracy: {acc:.4f}')

model.load_state_dict(torch.load(path + 'model.pth'))

Cluster 0 turned off, Accuracy: 0.7925
Cluster 1 turned off, Accuracy: 0.6091
Cluster 2 turned off, Accuracy: 0.7872
Cluster 3 turned off, Accuracy: 0.7198


<All keys matched successfully>

In [29]:
# get accuracy for each cluster turned on one by one
cluster_on_accuracies = []

for cluster_idx in range(num_clusters):
    model = MLP()
    model.load_state_dict(torch.load(path + 'model.pth'))
    model.to(device)
    
    # turn off every cluster except the current one
    for i in range(num_clusters):
        if i != cluster_idx:
            for j in cluster_U_indices[i]:
                model.fc2.weight.data[j] = 0
            for j in cluster_V_indices[i]:
                model.fc2.weight.data[:, j] = 0

    acc = accuracy(model, test_loader)
    cluster_on_accuracies.append(acc)

    print(f'Cluster {cluster_idx} turned on, Accuracy: {acc:.4f}')

model.load_state_dict(torch.load(path + 'model.pth'))

Cluster 0 turned on, Accuracy: 0.3264
Cluster 1 turned on, Accuracy: 0.2674
Cluster 2 turned on, Accuracy: 0.4802
Cluster 3 turned on, Accuracy: 0.2272


<All keys matched successfully>

In [46]:
# plot the class-wise accuracies for each cluster turned off

classwise_accuracies = []

for cluster_idx in tqdm.trange(num_clusters):
    model = MLP()
    model.load_state_dict(torch.load(path + 'model.pth'))
    model.to(device)
    
    # turn off the cluster
    for i in cluster_U_indices[cluster_idx]:
        model.fc2.weight.data[i] = 0
    for i in cluster_V_indices[cluster_idx]:
        model.fc2.weight.data[:, i] = 0
    
    classwise_accuracies.append(classwise_accuracy(model, test_loader))

100%|██████████| 4/4 [00:05<00:00,  1.40s/it]


In [47]:
model.load_state_dict(torch.load(path + 'model.pth'))

<All keys matched successfully>

In [48]:
color_scale = pc.qualitative.G10  # You can use other scales like `pc.sequential.Plasma` or `pc.sequential.Viridis`
colors = color_scale[:num_clusters]  # Ensure we have enough colors for the number of clusters

# Create subplots: one row per cluster
fig = make_subplots(rows=num_clusters, cols=1, shared_xaxes=False, 
                    subplot_titles=[f'Cluster {i}' for i in range(num_clusters)])

for i in range(num_clusters):
    fig.add_trace(go.Bar(
        x=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],
        y=classwise_accuracies[i],
        marker_color=colors[i],
        name=f'Cluster {i}',
    ), row=i+1, col=1)

# white background
fig.update_layout(plot_bgcolor='white')

# gridlines
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGray')

fig.update_layout(height=150 * num_clusters, width=700, title_text='Class-wise accuracy with each cluster turned OFF.')

fig.show()

In [40]:
# plot the class-wise accuracies for each cluster turned on

classwise_accuracies = []

for cluster_idx in tqdm.trange(num_clusters):
    model = MLP()
    model.load_state_dict(torch.load(path + 'model.pth'))
    model.to(device)
    
    # turn off every cluster except the current one
    for i in range(num_clusters):
        if i != cluster_idx:
            for j in cluster_U_indices[i]:
                model.fc2.weight.data[j] = 0
            for j in cluster_V_indices[i]:
                model.fc2.weight.data[:, j] = 0
    
    classwise_accuracies.append(classwise_accuracy(model, test_loader))

model.load_state_dict(torch.load(path + 'model.pth'))

100%|██████████| 4/4 [00:05<00:00,  1.42s/it]


<All keys matched successfully>

In [43]:
# Create subplots: one row per cluster

fig = make_subplots(rows=num_clusters, cols=1, shared_xaxes=False,
                    subplot_titles=[f'Cluster {i}' for i in range(num_clusters)])

for i in range(num_clusters):
    fig.add_trace(go.Bar(
        x=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],
        y=classwise_accuracies[i],
        marker_color=colors[i],
        name=f'Cluster {i}',
    ), row=i+1, col=1)

# white background
fig.update_layout(plot_bgcolor='white')

# gridlines

fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGray')

fig.update_layout(height=150 * num_clusters, width=700, title_text='Class-wise accuracy with each cluster turned ON.')

fig.show()

In [45]:
classwise_accuracy(model, test_loader)

[0.967, 0.965, 0.859, 0.884, 0.909, 0.816, 0.924, 0.908, 0.854, 0.875]

In [11]:
def visualize_layer(layer):
    weights = layer.weight.detach().cpu().numpy() 
    num_outputs, num_inputs = weights.shape

    colors = pc.qualitative.G10

    node_x = []
    node_y = []
    node_text = []
    edge_x = []
    edge_y = []
    edge_color = []
    edge_width = []

    input_layer_positions = np.linspace(-1, 1, num_inputs)
    output_layer_positions = np.linspace(-1, 1, num_outputs)
    
    for i in range(num_inputs):
        node_x.append(input_layer_positions[i])
        node_y.append(-1)
        node_text.append(f"{i+1}")

    for j in range(num_outputs):
        node_x.append(output_layer_positions[j])
        node_y.append(1)
        node_text.append(f"{j+1}")

    # Add edges
    for i in range(num_inputs):
        for j in range(num_outputs):
            weight = weights[j, i]
            color = 'red' if weight < 0 else 'blue'
            # color based on cluster index
            
            thickness = abs(weight) * 3
            edge_x.extend([input_layer_positions[i], output_layer_positions[j]])
            edge_y.extend([-1, 1])
            edge_color.append(color)
            edge_width.append(thickness)

    fig = go.Figure()
    
    fig.add_trace(go.Scatter(
        x=node_x, 
        y=node_y,
        mode='markers+text',
        text=node_text,
        textposition='top center',
        marker=dict(size=10, color='lightgray'),
        showlegend=False
    ))
    
    for i in range(0, len(edge_x), 2):
        fig.add_trace(go.Scatter(
            x=[edge_x[i], edge_x[i+1]],
            y=[edge_y[i], edge_y[i+1]],
            mode='lines',
            line=dict(color=edge_color[i//2], width=edge_width[i//2]),
            showlegend=False
        ))

    # Update layout
    fig.update_layout(
        title="fc2 Layer Visualization",
        xaxis=dict(showgrid=False, zeroline=False),
        yaxis=dict(showgrid=False, zeroline=False, range=[-1.5, 1.5]),
        showlegend=False
    )

    # white background
    fig.update_layout(plot_bgcolor='white')
    # square aspect ratio
    fig.update_layout(
        autosize=False,
        width=1400,
        height=700,
    )

    fig.show()

In [13]:
visualize_layer(model.fc3)

In [14]:
visualize_layer(model.fc2)

Alright. Now how can we ACTUALLY show that these clusters make the model more modular?

## Max-Activating Datapoints

As a measure of aggregated interpretability of a model / layer / cluster.

In [10]:
def max_activating_datapoints(k, model, data_loader, layer, neuron_index):
    # get the activations of the layer for all the data points
    activations = []
    model.eval()
    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            flatten_images = images.view(-1, 28 * 28)
            fc_1_out = torch.relu(model.fc1(flatten_images))
            fc_2_out = torch.relu(model.fc2(fc_1_out))
            fc_3_out = model.fc3(fc_2_out)
            if layer == 1:
                activations.extend(fc_1_out[:, neuron_index].cpu().numpy())
            elif layer == 2:
                activations.extend(fc_2_out[:, neuron_index].cpu().numpy())
            elif layer == 3:
                activations.extend(fc_3_out[:, neuron_index].cpu().numpy())

    activations = np.array(activations)
    indices = np.argsort(activations)[-k:]
    # print(indices)
    # get the images corresponding to the top k activations
    images = []
    labels = []
    for i in indices:
        images.append(data_loader.dataset[i][0])
        labels.append(data_loader.dataset[i][1])
    return images, labels

In [12]:
for neuron_index in range(len(model.fc2.weight)):
    _, labels = max_activating_datapoints(20, model, test_loader, 2, neuron_index)
    print(f"Neuron {neuron_index} is most activated by the following digits: {labels}")

Neuron 0 is most activated by the following digits: [0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Neuron 1 is most activated by the following digits: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Neuron 2 is most activated by the following digits: [4, 6, 0, 0, 0, 6, 0, 0, 6, 2, 0, 0, 6, 0, 0, 6, 6, 6, 0, 0]
Neuron 3 is most activated by the following digits: [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
Neuron 4 is most activated by the following digits: [5, 3, 2, 1, 2, 1, 1, 2, 1, 3, 3, 1, 1, 1, 3, 3, 2, 2, 2, 3]
Neuron 5 is most activated by the following digits: [9, 9, 8, 2, 8, 4, 8, 8, 2, 9, 8, 9, 9, 2, 9, 9, 2, 8, 8, 8]
Neuron 6 is most activated by the following digits: [8, 6, 6, 6, 6, 6, 6, 6, 2, 2, 2, 6, 6, 6, 6, 6, 2, 6, 6, 2]
Neuron 7 is most activated by the following digits: [0, 0, 5, 2, 2, 0, 0, 0, 0, 5, 8, 8, 0, 0, 0, 0, 0, 0, 8, 0]
Neuron 8 is most activated by the following digits: [6, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

In [33]:
def rough_sense_of_interpretability_of_an_mlp_layer(model, test_loader, layer_id, k):
    neuron_interpretability = []
    interpretable_neurons = [0 for _ in range(10)]
    for neuron_index in tqdm.trange(len(model.fc2.weight)):
        _, labels = max_activating_datapoints(k, model, test_loader, layer_id, neuron_index)
        # a rough sense of interpretability is the number of unique classes in the top k activations
        num_unique_classes = len(set(labels))
        neuron_interpretability.append(num_unique_classes / 10)
        interpretable_neurons[num_unique_classes - 1] += 1
    
    return sum(neuron_interpretability) / len(neuron_interpretability), interpretable_neurons

In [34]:
rsoil, dist = rough_sense_of_interpretability_of_an_mlp_layer(model, test_loader, 2, 50)

100%|██████████| 64/64 [01:14<00:00,  1.17s/it]


In [35]:
print(f"Rough sense of interpretability of fc2 layer of a clustered model: {1 - rsoil:.4f}")

Rough sense of interpretability of fc2 layer of a clustered model: 0.6172


This seems pretty interpretable. Now let's do the same for a layer that is trained without the clusterability loss.

In [15]:
mlp_unclustered = MLP()
mlp_unclustered.to(device)

# Train the model without clustering
optimizer = optim.SGD(mlp_unclustered.parameters(), lr=1e-3)
ce_losses_unclustered = []
epochs = 10

randomseed(42)

criterion = nn.CrossEntropyLoss()

In [16]:
for epoch in range(epochs):
    mlp_unclustered.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = mlp_unclustered(data)
        loss = criterion(output, target)
        ce_losses_unclustered.append(loss.item())
        loss.backward()
        optimizer.step()
        if batch_idx % 200 == 0:
            acc = accuracy(mlp_unclustered, test_loader)
            print(f'Epoch {epoch+1}/{epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}, Accuracy: {acc:.4f}')

Epoch 1/10, Batch 0/938, Loss: 2.3010, Accuracy: 0.0915
Epoch 1/10, Batch 200/938, Loss: 2.2730, Accuracy: 0.1896
Epoch 1/10, Batch 400/938, Loss: 2.2251, Accuracy: 0.3262
Epoch 1/10, Batch 600/938, Loss: 2.1291, Accuracy: 0.4512
Epoch 1/10, Batch 800/938, Loss: 1.9879, Accuracy: 0.5176
Epoch 2/10, Batch 0/938, Loss: 1.9585, Accuracy: 0.5661
Epoch 2/10, Batch 200/938, Loss: 1.8125, Accuracy: 0.6224
Epoch 2/10, Batch 400/938, Loss: 1.6536, Accuracy: 0.6647
Epoch 2/10, Batch 600/938, Loss: 1.3344, Accuracy: 0.6977
Epoch 2/10, Batch 800/938, Loss: 1.1758, Accuracy: 0.7333
Epoch 3/10, Batch 0/938, Loss: 1.1345, Accuracy: 0.7504
Epoch 3/10, Batch 200/938, Loss: 1.1083, Accuracy: 0.7760
Epoch 3/10, Batch 400/938, Loss: 0.9748, Accuracy: 0.7985
Epoch 3/10, Batch 600/938, Loss: 0.9295, Accuracy: 0.8118
Epoch 3/10, Batch 800/938, Loss: 0.7305, Accuracy: 0.8190
Epoch 4/10, Batch 0/938, Loss: 0.7625, Accuracy: 0.8261
Epoch 4/10, Batch 200/938, Loss: 0.6998, Accuracy: 0.8352
Epoch 4/10, Batch 400/

Wow, this is crazy! We might not even have a Pareto frontier as we had imagined, since with the same parameters without the clusterablity loss, the model seems to be performing worse (only 90% accuracy compared to 93%?). Anyway it wasn't worse.

In [36]:
rsoil_unclustered, dist_unclustered = rough_sense_of_interpretability_of_an_mlp_layer(mlp_unclustered, test_loader, 2, 50)

100%|██████████| 64/64 [01:18<00:00,  1.22s/it]


In [37]:
print(f"Rough sense of interpretability of fc2 layer of an unclustered model: {1 - rsoil_unclustered:.4f}")

Rough sense of interpretability of fc2 layer of an unclustered model: 0.6312


In [38]:
dist_unclustered, dist

([13, 7, 7, 20, 10, 3, 0, 0, 0, 4], [5, 19, 10, 12, 8, 4, 0, 0, 0, 6])

Hmm.. so based on our rough sense of interpreatablity of an MLP layer, it looks like 

In [24]:
for neuron_index in range(len(mlp_unclustered.fc2.weight)):
    _, labels = max_activating_datapoints(20, mlp_unclustered, test_loader, 2, neuron_index)
    print(f"Neuron {neuron_index} is most activated by the following digits: {labels}")

Neuron 0 is most activated by the following digits: [6, 6, 6, 6, 6, 5, 6, 6, 6, 6, 6, 6, 3, 6, 6, 6, 6, 6, 5, 6]
Neuron 1 is most activated by the following digits: [0, 0, 0, 0, 0, 0, 0, 6, 0, 4, 0, 0, 0, 2, 0, 0, 0, 6, 0, 0]
Neuron 2 is most activated by the following digits: [3, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
Neuron 3 is most activated by the following digits: [6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6]
Neuron 4 is most activated by the following digits: [8, 8, 8, 8, 5, 8, 4, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8]
Neuron 5 is most activated by the following digits: [0, 0, 3, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0]
Neuron 6 is most activated by the following digits: [0, 2, 2, 2, 2, 2, 2, 6, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
Neuron 7 is most activated by the following digits: [7, 7, 7, 7, 7, 9, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7]
Neuron 8 is most activated by the following digits: [2, 4, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2

In [43]:
# aggregate sum of dist till index i
dist_agg = [0] * 10
for i in range(10):
    dist_agg[i] = sum(dist[i:])

dist_unclustered_agg = [0] * 10
for i in range(10):
    dist_unclustered_agg[i] = sum(dist_unclustered[i:])

In [44]:
# plot the distribution of the number of unique classes in the top k activations for each neuron in the fc2 layer of the clustered model against the unclustered model

fig = go.Figure()

# lines for the clustered model, not histogram
fig.add_trace(go.Scatter(
    x=list(range(10)),
    y=dist_agg,
    mode='lines',
    name='Clustered model',
    line=dict(color='darkblue', width=2)
))

# lines for the unclustered model, not histogram
fig.add_trace(go.Scatter(
    x=list(range(10)),
    y=dist_unclustered_agg,
    mode='lines',
    name='Unclustered model',
    line=dict(color='darkred', width=2)
))

fig.update_layout(
    title_text='Distribution of the number of unique classes in the top k activations for each neuron in the fc2 layer',
    xaxis_title_text='Number of unique classes',
    yaxis_title_text='Number of neurons',
    barmode='overlay'
)

fig.update_traces(marker_line_width=0)

# white background
fig.update_layout(plot_bgcolor='white')

# gridlines

fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGray')
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='LightGray')

fig.show()

## Circuit Search-Space Pruning

We look at the reduction (pruning) of the search space of circuit discovery as a measure for interpretability of a given behavior. On our simple example:

In [6]:
total_params = sum(p.numel() for p in model.parameters())

total_params

54912

In [7]:
784 * 64 + 64 * 64 + 64 * 10

54912

In [8]:
trimmed_params = 784 * 16 + 16 * 16 + 16 * 10

print(f'Percentage pruned: {(total_params - trimmed_params) / total_params * 100:.2f}%')

Percentage pruned: 76.40%


In [9]:
model

MLP(
  (fc1): Linear(in_features=784, out_features=64, bias=False)
  (fc2): Linear(in_features=64, out_features=64, bias=False)
  (fc3): Linear(in_features=64, out_features=10, bias=False)
)

While this metric is well-expected to give is gains, it really does not form that good of an argument in our framework. Here's a better attempt.

## Circuit Complexity Reduction (CCR)

The reduction in the expected circuit size for a given behavior. Since we restrict circuit-forming to use fixed clusters, we can likely expect simpler circuits? For this we need to implement a circuit discovery method. For now, we'll just use accuracy-based pruning. We can shift to others (ACDC/EAP) later on.

In [12]:
# Evaluate the model
def evaluate_model(model, test_loader):
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data.view(-1, 784))
            pred = output.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += target.size(0)
    return correct / total

In [33]:
def evaluate_model_for_label(model, test_loader, label):
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            # filter out the data points with the given label
            mask = target == label
            data = data[mask]
            target = target[mask]
            output = model(data.view(-1, 784))
            pred = output.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += target.size(0)
    return correct / total if total > 0 else 0

In [119]:
def fast_label_loss(model, x, label):
    with torch.no_grad():
        output = model(x)
        criterion = nn.CrossEntropyLoss()
        target = torch.tensor([label] * x.size(0)).to(device)
        loss = criterion(output, target)
    return loss

Step 1: Get the label dataset.

In [74]:
label = 8

label_data = torch.stack([test_dataset[i][0] for i in range(len(test_dataset)) if test_dataset[i][1] == label])

label_data = label_data.to(device)

label_data.shape

torch.Size([974, 1, 28, 28])

In [122]:
fast_label_loss(model, label_data, label)

tensor(0.4604, device='cuda:0')

In [111]:
model

MLP(
  (fc1): Linear(in_features=784, out_features=64, bias=False)
  (fc2): Linear(in_features=64, out_features=64, bias=False)
  (fc3): Linear(in_features=64, out_features=10, bias=False)
)

Start with a clone of the original model.

In [138]:
pruned_model = MLP()
pruned_model.load_state_dict(model.state_dict())
pruned_model.to(device)

curr_loss = fast_label_loss(pruned_model, label_data, label)
pruned = 0
not_pruned = 0
threshold = 0.0001

for i in tqdm.trange(model.fc3.weight.size(0)):
    for j in range(model.fc3.weight.size(1)):
        temp = pruned_model.fc3.weight.data[i, j]
        pruned_model.fc3.weight.data[i, j] = 0
        loss = fast_label_loss(pruned_model, label_data, label)
        if loss - curr_loss < threshold:
            pruned += 1
            curr_loss = loss
        else:
            pruned_model.fc3.weight.data[i, j] = temp
            not_pruned += 1

print(f'Pruned: {pruned}, Not pruned: {not_pruned}')

100%|██████████| 10/10 [00:00<00:00, 40.36it/s]

Pruned: 84, Not pruned: 556





In [139]:
curr_loss

tensor(0.1098, device='cuda:0')

In [133]:
classwise_accuracy(pruned_model, test_loader)

[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]

In [103]:
classwise_accuracy(model, test_loader)

[0.967, 0.965, 0.859, 0.884, 0.909, 0.816, 0.924, 0.908, 0.854, 0.875]

In [105]:
pruned, not_pruned

(54714, 198)

## Dissemination

Here's one idea to share this:

We create a new PyTorch model training optimizer that optimizes for clusterability on `nn.Module` and inherits from `torch.optim.Optimizer`.


> torch.optim.Modularity(n_clusters=4, alpha=20)

This should be doable by the end of MATS. 

Can we try finishing this up and writing and submitting it to the NeurIPS workshop on Science of Deep Learning?