This notebook contains code accompanying our submission to NeurIPS 2024 Science of Deep Learning Workshop.

## 1. Setup 
(run, don't read!)

_Figure 1_

<img src="fig1-modularity.png" width="400">

In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import tqdm
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import MNIST, CIFAR10
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import mutual_info_score
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import plotly.express as px
import plotly.graph_objects as go
import plotly.colors as pc
from plotly.subplots import make_subplots
from IPython.display import clear_output
from collections import defaultdict
from itertools import islice
import random
import time
from pathlib import Path
import math

from sklearn.cluster import KMeans
from scipy.sparse.linalg import svds

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

def randomseed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

### (a) Data

In [2]:
transform = torchvision.transforms.ToTensor()
train_dataset = CIFAR10(root='.', train=True, download=True, transform=transform)
test_dataset = CIFAR10(root='.', train=False, download=True, transform=transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:02<00:00, 72.3MB/s] 


Extracting ./cifar-10-python.tar.gz to .
Files already downloaded and verified


### (b) Models

In [18]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(16 * 16 * 16, 64, bias=False)
        self.fc2 = nn.Linear(64, 64, bias=False)
        self.fc3 = nn.Linear(64, 64, bias=False)
        self.fc4 = nn.Linear(64, 10, bias=False)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = x.view(-1, 16 * 16 * 16)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = self.fc4(x)
        return x

### (c) Evaluation

In [19]:
def accuracy(model, data):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in data:
            outputs = model(images.to(device))
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels.to(device)).sum().item()
    return correct / total

def classwise_accuracy(model, data):
    model.eval()
    correct = defaultdict(int)
    total = defaultdict(int)
    with torch.no_grad():
        for images, labels in data:
            outputs = model(images.to(device))
            _, predicted = torch.max(outputs.data, 1)
            for i in range(len(labels)):
                label = labels[i].item()
                total[label] += 1
                correct[label] += int(predicted[i] == label)
    return [round(correct[i] / total[i], 3) if total[i] > 0 else 0 for i in range(10)]

def clusterability(model, cluster_U_indices, cluster_V_indices):
    num_clusters = len(cluster_U_indices)
    A = model.fc2.weight ** 2
    mask = torch.zeros_like(A, dtype=torch.bool)
    
    for cluster_idx in range(num_clusters):
        u_indices = torch.tensor(cluster_U_indices[cluster_idx], dtype=torch.long)
        v_indices = torch.tensor(cluster_V_indices[cluster_idx], dtype=torch.long)
        mask[u_indices.unsqueeze(1), v_indices] = True
    
    intra_cluster_out_sum = torch.sum(A[mask])
    total_out_sum = torch.sum(A)
    
    return intra_cluster_out_sum / total_out_sum

def clusterability_component(component, num_clusters=4):
    A = component ** 2
    rows, cols = A.shape
    
    # Create contiguous clusters of equal size
    row_cluster_size = rows // num_clusters
    col_cluster_size = cols // num_clusters
    
    cluster_U_indices = {}
    cluster_V_indices = {}
    
    for i in range(num_clusters):
        # Handle the last cluster which might have extra elements
        if i == num_clusters - 1:
            cluster_U_indices[i] = list(range(i * row_cluster_size, rows))
            cluster_V_indices[i] = list(range(i * col_cluster_size, cols))
        else:
            cluster_U_indices[i] = list(range(i * row_cluster_size, (i + 1) * row_cluster_size))
            cluster_V_indices[i] = list(range(i * col_cluster_size, (i + 1) * col_cluster_size))
    
    # Create mask for intra-cluster connections
    mask = torch.zeros_like(A, dtype=torch.bool)
    
    for cluster_idx in range(num_clusters):
        u_indices = torch.tensor(cluster_U_indices[cluster_idx], dtype=torch.long)
        v_indices = torch.tensor(cluster_V_indices[cluster_idx], dtype=torch.long)
        mask[u_indices.unsqueeze(1), v_indices] = True
    
    intra_cluster_out_sum = torch.sum(A[mask])
    total_out_sum = torch.sum(A)
    
    return intra_cluster_out_sum / total_out_sum

## BSGC (Check if NNs trained by CE are not modular by default)

In [20]:
def bipartite_spectral_clustering(similarity_matrix, k, cluster_U_indices=None, cluster_V_indices=None):
    
    A = similarity_matrix.detach().cpu().numpy() # transform from gpu to cpu, and then to numpy
    A = np.abs(A)
    D_U = np.diag(np.sum(A, axis=1))
    D_V = np.diag(np.sum(A, axis=0))

    D_U_inv_sqrt = np.linalg.inv(np.sqrt(D_U))
    D_V_inv_sqrt = np.linalg.inv(np.sqrt(D_V))

    A_tilde = D_U_inv_sqrt @ A @ D_V_inv_sqrt

    U, Sigma, Vt = svds(A_tilde, k=k)

    if cluster_U_indices is None:
        kmeans_U = KMeans(n_clusters=k, random_state=42).fit(U)
        labels_U = kmeans_U.labels_

        cluster_U_indices = defaultdict(list)
        for i, label in enumerate(labels_U):
            cluster_U_indices[label].append(i)

    if cluster_V_indices is None:
        kmeans_V = KMeans(n_clusters=k, random_state=42).fit(Vt.T)
        labels_V = kmeans_V.labels_

        cluster_V_indices = defaultdict(list)
        for i, label in enumerate(labels_V):
            cluster_V_indices[label].append(i)

    return cluster_U_indices, cluster_V_indices

In [21]:
def total_clusterability(model, grads, subset_of_layers, num_clusters):

    nr_layers = len(subset_of_layers)
    clusterability_scores = np.zeros((nr_layers, 2))

    cluster_next_U_indices = [None, None]
    for l, layer_nr in enumerate(subset_of_layers):
        layer_grads = torch.stack(list({k: v[0] / v[0].norm() for k, v in grads[layer_nr].items()}.values()))
        grads_similarity_matrix = torch.zeros(layer_grads.shape[1], layer_grads.shape[1]).to(device)
        for i in range(layer_grads.shape[0]):
            grads_similarity_matrix += torch.mm(layer_grads[i].t(), layer_grads[i])

        similarity_matrices = [getattr(model, f"fc{layer_nr}").weight, grads_similarity_matrix]

        for i, similarity_matrix in enumerate(similarity_matrices):
            cluster_U_indices, cluster_V_indices = bipartite_spectral_clustering(similarity_matrix, num_clusters, cluster_next_U_indices[i])
            cl = clusterability(model, cluster_U_indices, cluster_V_indices)
            clusterability_scores[l, i] = cl.item()
            cluster_next_U_indices[i] = cluster_V_indices

    return clusterability_scores

## 2. Training the Unclustered Model

In [22]:
dataset = 'CIFAR10'

In [23]:
subset_of_layers = [2,3]

unclustered_model = CNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(unclustered_model.parameters(), lr=1e-3)
train_losses = []
grads = {l: defaultdict(list) for l in subset_of_layers}

In [24]:
path = Path(f'results/{dataset}/')
path.mkdir(parents=True, exist_ok=True)

In [None]:
randomseed(42)
path = Path(f'results/{dataset}/')

for epoch in range(10):
    unclustered_model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = unclustered_model(data)
        loss = criterion(output, target)
        loss.backward()
        # store gradients that are in the subset of layers
        for l in subset_of_layers:
            for p in unclustered_model.named_parameters():
                if p[0] == f'fc{l}.weight':
                    grads[l][batch_idx + (epoch * len(train_loader))].append(p[1].grad.clone())
                    break
        optimizer.step()
        train_losses.append(loss.item())
    acc = accuracy(unclustered_model, test_loader)
    print(f'Epoch {epoch+1}/{10}, Loss: {loss.item():.4f}, Accuracy: {acc:.4f}')
    # save model
    torch.save(unclustered_model.state_dict(), path / 'unclustered_model.pth')

Epoch 1/10, Loss: 1.5023, Accuracy: 0.4658
Epoch 2/10, Loss: 1.4825, Accuracy: 0.5414
Epoch 3/10, Loss: 0.7785, Accuracy: 0.5897
Epoch 4/10, Loss: 1.0224, Accuracy: 0.6220
Epoch 5/10, Loss: 1.3159, Accuracy: 0.6333
Epoch 6/10, Loss: 0.7811, Accuracy: 0.6320
Epoch 7/10, Loss: 0.8052, Accuracy: 0.6444
Epoch 8/10, Loss: 1.1430, Accuracy: 0.6538
Epoch 9/10, Loss: 1.0727, Accuracy: 0.6573
Epoch 10/10, Loss: 0.4991, Accuracy: 0.6537


In [26]:
path = Path(f'results/{dataset}/')

In [29]:
unclustered_model = CNN().to(device)
unclustered_model.load_state_dict(torch.load(path/'unclustered_model.pth', weights_only=True))

<All keys matched successfully>

## Clusterability of trained model

In [31]:
cluster_sizes = [2, 4, 6, 8, 10, 12, 14]
clusterability_scores = []
for num_clusters in cluster_sizes:
    clusterability_scores.append(total_clusterability(unclustered_model, grads, subset_of_layers, num_clusters))

clusterability_scores = np.array(clusterability_scores)

In [32]:
print(clusterability_scores.shape)

(7, 2, 2)


In [48]:
fig = go.Figure()
fig.add_trace(go.Scatter(x=cluster_sizes, y=clusterability_scores.mean(axis=1)[:, 0], mode='lines+markers', name='Weight-based BSGC', line=dict(color='#CC0000')))
fig.add_trace(go.Scatter(x=cluster_sizes, y=clusterability_scores.mean(axis=1)[:, 1], mode='lines+markers', name='Gradient-based BSGC', line=dict(color='#0000CC')))

# Add random baseline (y = 1/x)
random_baseline = [1/x for x in cluster_sizes]
fig.add_trace(go.Scatter(x=cluster_sizes, y=random_baseline, mode='lines', name='Random Baseline', 
                         line=dict(color='black', width=3, dash='dash')))

fig.update_layout({'plot_bgcolor': 'rgba(255, 255, 255, 1)',})
fig.update_layout(
    xaxis=dict(showgrid=True, gridwidth=1, gridcolor='LightGray'),
    yaxis=dict(showgrid=True, gridwidth=1, gridcolor='LightGray'),
    legend=dict(
        x=0.95,
        y=0.95,
        xanchor='right',
        yanchor='top',
        bgcolor='rgba(255, 255, 255, 0.8)',
        bordercolor='black',
        borderwidth=1
    )
)
fig.update_xaxes(tickvals=cluster_sizes)
fig.update_yaxes(tickvals=np.arange(0, 1.1, 0.1))
fig.update_xaxes(showline=True, linewidth=1, linecolor='black', mirror=False)
fig.update_yaxes(showline=True, linewidth=1, linecolor='black', mirror=False)
fig.update_xaxes(title_text='Number of Clusters (k)')
fig.update_yaxes(title_text='Clusterability')
fig.update_layout(width=500, height=400, autosize=False)
fig.update_yaxes(range=[0, 0.7])
fig.update_xaxes(title_font=dict(family='serif', size=20, color='black'))
fig.update_yaxes(title_font=dict(family='serif', size=20, color='black'))
fig.update_xaxes(tickfont=dict(family='serif', size=18, color='black'))
fig.update_yaxes(tickfont=dict(family='serif', size=18, color='black'))
fig.update_xaxes(showline=True, linewidth=1, linecolor='black', mirror=False)
fig.update_yaxes(showline=True, linewidth=1, linecolor='black', mirror=False)

fig.show()

In [49]:
figures_path = Path('figures')

if not figures_path.exists():
    figures_path.mkdir()

In [50]:
fig.write_image(str(figures_path / 'clusterability_scores.pdf'))

This doesn't seem to be good enough for interpretability.

## 3. Optimizing for Modularity

In [41]:
model = CNN().to(device)
randomseed(42)

In [43]:
num_clusters = 4
similarity_matrix = model.fc2.weight
input_size = similarity_matrix.shape[1]
output_size = similarity_matrix.shape[0]

nodes_per_cluster_U = input_size // num_clusters
nodes_per_cluster_V = output_size // num_clusters

cluster_U_indices = {}
cluster_V_indices = {}

for i in range(num_clusters):
    start_idx_U = i * nodes_per_cluster_U
    end_idx_U = start_idx_U + nodes_per_cluster_U if i < num_clusters - 1 else input_size
    cluster_U_indices[i] = list(range(start_idx_U, end_idx_U))
    
    start_idx_V = i * nodes_per_cluster_V
    end_idx_V = start_idx_V + nodes_per_cluster_V if i < num_clusters - 1 else output_size
    cluster_V_indices[i] = list(range(start_idx_V, end_idx_V))

for i in range(num_clusters):
    print(f'Cluster {i} has {len(cluster_U_indices[i])} nodes in U and {len(cluster_V_indices[i])} nodes in V')

Cluster 0 has 16 nodes in U and 16 nodes in V
Cluster 1 has 16 nodes in U and 16 nodes in V
Cluster 2 has 16 nodes in U and 16 nodes in V
Cluster 3 has 16 nodes in U and 16 nodes in V


In [45]:
clusterability_score = clusterability(model, cluster_U_indices, cluster_V_indices)
print(f'Clusterability score: {round(clusterability_score.item(), 3)}')

Clusterability score: 0.239


### (c) Train the model with Clusterability

In [47]:
cluster_losses = []
ce_losses = []
lomda = 20.0
epochs = 10
model = CNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

randomseed(42)

for epoch in range(epochs):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss_ce = criterion(output, target)
        loss_cluster = clusterability(model, cluster_U_indices, cluster_V_indices)
        cluster_losses.append(loss_cluster.item())
        ce_losses.append(loss_ce.item())
        loss = loss_ce - lomda * loss_cluster
        loss.backward()
        optimizer.step()
    acc = accuracy(model, test_loader)
    print(f'Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}, Clusterability: {loss_cluster.item():.4f}, Accuracy: {acc:.4f}')

Epoch 1/10, Loss: -18.3475, Clusterability: 0.9999, Accuracy: 0.4873
Epoch 2/10, Loss: -18.9488, Clusterability: 0.9999, Accuracy: 0.5437
Epoch 3/10, Loss: -18.9767, Clusterability: 0.9999, Accuracy: 0.5910
Epoch 4/10, Loss: -19.0499, Clusterability: 0.9998, Accuracy: 0.6291
Epoch 5/10, Loss: -18.7698, Clusterability: 0.9998, Accuracy: 0.6387
Epoch 6/10, Loss: -18.8655, Clusterability: 0.9998, Accuracy: 0.6344
Epoch 7/10, Loss: -19.3721, Clusterability: 0.9998, Accuracy: 0.6448
Epoch 8/10, Loss: -19.0225, Clusterability: 0.9998, Accuracy: 0.6604
Epoch 9/10, Loss: -18.9060, Clusterability: 0.9998, Accuracy: 0.6519
Epoch 10/10, Loss: -19.3666, Clusterability: 0.9999, Accuracy: 0.6604


### (d) Store our clusters and the clustered model

In [51]:
path: str = f'results/{dataset}/'
Path(path).mkdir(parents=True, exist_ok=True)

torch.save(model.state_dict(), f'{path}model.pth')
torch.save(cluster_U_indices, f'{path}cluster_U_indices.pth')
torch.save(cluster_V_indices, f'{path}cluster_V_indices.pth')

## 4. Interpreting the Clusters

In [53]:
path: str = f'results/{dataset}/'

# load the model and cluster indices
model = CNN().to(device)
model.load_state_dict(torch.load(f'{path}model.pth', weights_only=True))
cluster_U_indices = torch.load(f'{path}cluster_U_indices.pth', weights_only=True)
cluster_V_indices = torch.load(f'{path}cluster_V_indices.pth', weights_only=True)

In [54]:
classwise_accuracies = classwise_accuracy(model, test_loader)
classwise_accuracies

[0.755, 0.794, 0.443, 0.505, 0.635, 0.551, 0.797, 0.697, 0.699, 0.728]

In [55]:
num_clusters = len(cluster_U_indices)
num_clusters

4

### (a) Classwise accuracies with individual clusters turned ON and OFF

In [58]:
num_clusters = len(cluster_U_indices)

classwise_accuracies_off = []
classwise_accuracies_on = []

for cluster_idx in tqdm.trange(num_clusters):
    model = CNN().to(device)
    model.load_state_dict(torch.load(path + 'model.pth', weights_only=True))
    
    for i in cluster_U_indices[cluster_idx]:
        model.fc2.weight.data[i] = 0
    for i in cluster_V_indices[cluster_idx]:
        model.fc2.weight.data[:, i] = 0
    
    classwise_accuracies_off.append(classwise_accuracy(model, test_loader))

model.load_state_dict(torch.load(path + 'model.pth', weights_only=True))

classwise_accuracies = []

for cluster_idx in tqdm.trange(num_clusters):
    model = CNN().to(device)
    model.load_state_dict(torch.load(path + 'model.pth', weights_only=True))
    
    for i in range(num_clusters):
        if i != cluster_idx:
            for j in cluster_U_indices[i]:
                model.fc2.weight.data[j] = 0
            for j in cluster_V_indices[i]:
                model.fc2.weight.data[:, j] = 0
    
    classwise_accuracies_on.append(classwise_accuracy(model, test_loader))

model.load_state_dict(torch.load(path + 'model.pth', weights_only=True))

100%|██████████| 4/4 [00:03<00:00,  1.10it/s]
100%|██████████| 4/4 [00:03<00:00,  1.10it/s]


<All keys matched successfully>

In [59]:
cifar_labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

In [61]:
def plot_classwise_cluster_perf(num_clusters, classwise_accuracies_on, classwise_accuracies_off, cifar_labels=cifar_labels):
    colors = [pc.qualitative.Dark24_r[0], pc.qualitative.Dark24_r[1], pc.qualitative.Dark24_r[15], pc.qualitative.Dark24_r[5]]
    fig = make_subplots(rows=2, cols=4, shared_xaxes=True, shared_yaxes=True, vertical_spacing=0.1, subplot_titles=[f'Cluster {i} (ON)' for i in range(num_clusters)] + [f'Cluster {i} (OFF)' for i in range(num_clusters)], row_heights=[0.15, 0.15])
    
    for i in range(num_clusters):
        fig.add_trace(go.Bar(x=cifar_labels, y=classwise_accuracies_on[i], marker_color=colors[i], name=f'Cluster {i} (ON)'), row=1, col=i+1)
        fig.update_xaxes(tickangle=-90)
        fig.add_trace(go.Bar(x=cifar_labels, y=classwise_accuracies_off[i], marker_color=colors[i], name=f'Cluster {i} (OFF)'), row=2, col=i+1)
    
    fig.update_layout(plot_bgcolor='white')
    fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGray')
    fig.update_layout(height=140 * num_clusters, width=1000)
    fig.update_layout(showlegend=True)
    fig.add_annotation(text="Accuracy", xref="paper", yref="paper", x=-0.1, y=0.5, showarrow=False, font=dict(size=22, family="Computer Modern"), align="center", textangle=-90)
    fig.add_annotation(text="Class", xref="paper", yref="paper", x=0.5, y=-0.35, showarrow=False, font=dict(size=22, family="Computer Modern"), align="center")
    fig.update_layout(font_family="Computer Modern", font_size=20)
    fig.update_yaxes(range=[0, 1], tickvals=np.arange(0, 1, 0.2))
    fig.update_layout(showlegend=False)
    fig.update_xaxes(tickvals=np.arange(10))
    fig.update_layout(margin=dict(t=50, b=140))
    fig.update_layout(font=dict(family='serif', size=20, color='black'))
    fig.update_xaxes(title_font=dict(family='serif', size=20, color='black'))
    fig.update_yaxes(title_font=dict(family='serif', size=20, color='black'))
    fig.update_xaxes(tickfont=dict(family='serif', size=18, color='black'))
    fig.update_yaxes(tickfont=dict(family='serif', size=18, color='black'))
    fig.update_xaxes(showline=True, linewidth=1, linecolor='black', mirror=False)
    return fig

fig = plot_classwise_cluster_perf(num_clusters, classwise_accuracies_on, classwise_accuracies_off)

fig.show()

In [62]:
# save as pdf

fig.write_image(str(figures_path / 'classwise_cluster_all_accuracies.pdf'), scale=5)

In [63]:
path

'results/CIFAR10/'

## 5. Comparison with Unclustered Model

### (b) Circuit discovery search-space reduction

In [64]:
total_params = sum(p.numel() for p in model.parameters())
total_params

273744

In [65]:
(3 * 16 * 3 * 3) + 16 +  4096 * 64 + 64 * 64 + 64 * 10

267328

In [66]:
trimmed_params = (3 * 16 * 3 * 3) + 16 +  4096 * 16 + 16 * 16 + 16 * 10

print(f'Percentage pruned: {(total_params - trimmed_params) / total_params * 100:.2f}%')

Percentage pruned: 75.74%


# The below stuff (pruning) takes a long time

### (c) Circuit Complexity Reduction (CCR)

The reduction in the expected circuit size for a given behavior. Since we restrict circuit-forming to use fixed clusters, we can likely expect simpler circuits? For this we need to implement a circuit discovery method. For now, we'll just use accuracy-based pruning. We can shift to others (ACDC/EAP) later on.

In [67]:
def fast_label_perf(model, x, label):
    with torch.no_grad():
        output = model(x)
        criterion = nn.CrossEntropyLoss()
        target = torch.tensor([label] * x.size(0)).to(device)
        loss = criterion(output, target)
        accuracy = (output.argmax(dim=1) == target).sum().item() / x.size(0)
    return loss, accuracy

In [73]:
import copy

def prune_model(model, x, label, device, verbose=False):
    pruned_model = copy.deepcopy(model)
    
    layers = list(pruned_model.children())
    for layer in reversed(layers):
        if verbose:
            print(f'Pruning layer: {layer}')
        if isinstance(layer, nn.Linear):
            for neuron_idx in tqdm.trange(layer.weight.shape[0]):
                for weight_idx in range(layer.weight.shape[1]):
                    # Create a mask to zero out the weight
                    mask = torch.ones_like(layer.weight)
                    mask[neuron_idx, weight_idx] = 0
                    
                    # Apply the mask
                    original_weight = layer.weight[neuron_idx, weight_idx].item()
                    layer.weight.data[neuron_idx, weight_idx] = 0
                    
                    # Check performance
                    loss_pruned, acc_pruned = fast_label_perf(pruned_model, x, label)
                    loss_original, acc_original = fast_label_perf(model, x, label)
                    
                    # If performance decreases, restore the weight
                    if (loss_pruned - loss_original) > 0:
                        layer.weight.data[neuron_idx, weight_idx] = original_weight

        # fraction pruned
        num_zeros = torch.sum(layer.weight == 0).item()
        total_params = layer.weight.numel()
        if verbose:
            print(f'Fraction pruned: {num_zeros / total_params:.4f}')
                    
    return pruned_model

In [74]:
label = 8

label_data = torch.stack([test_dataset[i][0] for i in range(len(test_dataset)) if test_dataset[i][1] == label])

label_data = label_data.to(device)

label_data.shape

torch.Size([1000, 3, 32, 32])

In [75]:
pruned_model = prune_model(model, label_data, label, device, verbose=True)

Pruning layer: Linear(in_features=64, out_features=10, bias=False)


100%|██████████| 10/10 [00:01<00:00,  6.91it/s]


Fraction pruned: 0.9891
Pruning layer: Linear(in_features=64, out_features=64, bias=False)


100%|██████████| 64/64 [00:09<00:00,  7.06it/s]


Fraction pruned: 0.9929
Pruning layer: Linear(in_features=64, out_features=64, bias=False)


100%|██████████| 64/64 [00:09<00:00,  7.03it/s]


Fraction pruned: 0.9963
Pruning layer: Linear(in_features=4096, out_features=64, bias=False)


100%|██████████| 64/64 [09:27<00:00,  8.86s/it]

Fraction pruned: 0.9953
Pruning layer: Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Fraction pruned: 0.0000
Pruning layer: Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Fraction pruned: 0.0000





In [77]:
torch.save(pruned_model.state_dict(), f'{path}pruned_model.pth')

In [78]:
# comparing with if we had pruned the unclustered model

pruned_model_unclustered = prune_model(unclustered_model, label_data, label, device, verbose=True)

Pruning layer: Linear(in_features=64, out_features=10, bias=False)


100%|██████████| 10/10 [00:01<00:00,  7.13it/s]


Fraction pruned: 0.9922
Pruning layer: Linear(in_features=64, out_features=64, bias=False)


100%|██████████| 64/64 [00:09<00:00,  7.09it/s]


Fraction pruned: 0.9893
Pruning layer: Linear(in_features=64, out_features=64, bias=False)


100%|██████████| 64/64 [00:09<00:00,  7.10it/s]


Fraction pruned: 0.9858
Pruning layer: Linear(in_features=4096, out_features=64, bias=False)


100%|██████████| 64/64 [09:27<00:00,  8.87s/it]

Fraction pruned: 0.9911
Pruning layer: Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Fraction pruned: 0.0000
Pruning layer: Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Fraction pruned: 0.0000





In [79]:
torch.save(pruned_model_unclustered.state_dict(), f'{path}pruned_model_unclustered.pth')

In [80]:
print(classwise_accuracy(pruned_model, test_loader))
print(classwise_accuracy(pruned_model_unclustered, test_loader))

[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]


In [81]:
def effective_circuit_size(model):
    # fraction of non-zero weights
    total_params = 0
    num_zeros = 0
    for layer in model.children():
        if isinstance(layer, nn.Linear):
            total_params += layer.weight.numel()
            num_zeros += torch.sum(layer.weight == 0).item()
    return round(1 - (num_zeros / total_params), 3)

In [82]:
effective_circuit_size(model), effective_circuit_size(pruned_model), effective_circuit_size(pruned_model_unclustered)

(1.0, 0.005, 0.009)

In [83]:
path

'results/CIFAR10/'

In [84]:
# store the unclustered model
torch.save(unclustered_model.state_dict(), f'{path}unclustered_model.pth')

In [None]:
# effective circuit sizes for pruned models v pruned unclustered models for each label
# NOTE:
# this will take a while to run; better to run this as a script in the background (there's a "pruning.py" for this)
# and then directly load the results

ecs_pruned_all_labels = []
ecs_pruned_unclustered_all_labels = []

for label in tqdm.trange(10):
    label_data = torch.stack([test_dataset[i][0] for i in range(len(test_dataset)) if test_dataset[i][1] == label])
    label_data = label_data.to(device)
    
    pruned_model = prune_model(model, label_data, label, device, verbose=False)
    pruned_model_unclustered = prune_model(unclustered_model, label_data, label, device, verbose=False)
    
    ecs_pruned_all_labels.append(effective_circuit_size(pruned_model))
    ecs_pruned_unclustered_all_labels.append(effective_circuit_size(pruned_model_unclustered))

    # print the effective circuit sizes for each label
    print(f'Label: {label}, ECS (pruned): {ecs_pruned_all_labels[-1]}, ECS (pruned unclustered): {ecs_pruned_unclustered_all_labels[-1]}')

torch.save(ecs_pruned_all_labels, path + 'ecs_pruned_all_labels.pth')
torch.save(ecs_pruned_unclustered_all_labels, path + 'ecs_pruned_unclustered_all_labels.pth')

: 

In [None]:
path = 'results/CIFAR10/'

# load the effective circuit sizes
ecs_pruned_all_labels = torch.load(path + 'ecs_pruned_all_labels.pth')
ecs_pruned_unclustered_all_labels = torch.load(path + 'ecs_pruned_unclustered_all_labels.pth')

: 

In [None]:
ecs_pruned_all_labels, ecs_pruned_unclustered_all_labels

: 

In [None]:
# cifar label names
cifar_labels = ['airplane', 'auto.', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

: 

In [None]:
fig = go.Figure()

# Calculate percentage increase
percentage_increase = [(unclustered - clustered) / unclustered * 100 
                       for clustered, unclustered in zip(ecs_pruned_all_labels, ecs_pruned_unclustered_all_labels)]

fig.add_trace(go.Scatter(
    x=cifar_labels,
    y=ecs_pruned_all_labels,
    mode='lines+markers',
    name='Pruned clustered model',
    line=dict(color='darkblue', width=2),
    fill='tonexty'  # This will shade the area between the two plots
))

fig.add_trace(go.Scatter(
    x=cifar_labels,
    y=ecs_pruned_unclustered_all_labels,
    mode='lines+markers',
    name='Pruned unclustered model',
    line=dict(color='darkred', width=2),
    fill='tonexty'  # This will shade the area between the two plots
))

# thicker markers
fig.update_traces(marker=dict(size=10))

# Add percentage increase annotations
for i, (x, y, pct) in enumerate(zip(cifar_labels, ecs_pruned_unclustered_all_labels, percentage_increase)):
    fig.add_annotation(
        x=x,
        y=y,
        text=f"+{pct:.1f}%",
        showarrow=False,
        yshift=30,
        font=dict(
            size=20,
            color="black"
        ),
    )

# show a huge downward arrow from ecs_pruned_unclustered_all_labels to ecs_pruned_all_labels for each label
for i, (x, y1, y2) in enumerate(zip(cifar_labels, ecs_pruned_unclustered_all_labels, ecs_pruned_all_labels)):
    fig.add_annotation(
        x=x,
        y=(y1 + y2) / 2,
        text="",
        showarrow=True,
        arrowhead=2,
        arrowsize=float(abs(y1 - y2) / max(ecs_pruned_unclustered_all_labels) * 2),
        arrowwidth=2,
        arrowcolor='darkred',
        ax=0,
        ay=-50,
    )

fig.update_layout(
    title_text='.',
    xaxis_title_text='Label',
    yaxis_title_text='Effective Circuit Size (lesser is better)',
    width=800,  # Fixed width
    height=600,  # Fixed height
    plot_bgcolor='white',
    yaxis=dict(range=[0, max(max(ecs_pruned_all_labels), max(ecs_pruned_unclustered_all_labels)) * 1.1]),  # y-axis starts at 0
)

fig.update_xaxes(
    tickmode='linear',
    tick0=0,
    dtick=1,
    showgrid=True,
    gridwidth=1,
    gridcolor='LightGray'
)

fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGray')

# fig width and height
fig.update_layout(
    autosize=False,
    width=1000,
    height=550,
)

# increase font size of x ticks
fig.update_xaxes(tickfont=dict(size=14))
# make them at 45 degree angle
fig.update_xaxes(tickangle=0)

# legend on top
fig.update_layout(legend=dict(
    orientation="h",
    yanchor="bottom",
    y=1.02,
    xanchor="right",
    x=1
))

# latex font for everything (for research papers)

# fig.update_layout(
#     font=dict(
#         family="Computer Modern",
#         size=20,
#         color="Black"
#     )
# )

# everthing latex font (for research paper)
fig.update_layout(font=dict(family='serif', size=20, color='black'))
fig.update_xaxes(title_font=dict(family='serif', size=20, color='black'))
fig.update_yaxes(title_font=dict(family='serif', size=20, color='black'))
fig.update_xaxes(tickfont=dict(family='serif', size=18, color='black'))
fig.update_yaxes(tickfont=dict(family='serif', size=18, color='black'))
fig.update_xaxes(showline=True, linewidth=1, linecolor='black', mirror=False)

fig.show()

: 

Beautiful! 🤌

In [None]:
# save as pdf
figures_path = Path('figures')

fig.write_image(str(figures_path / 'ecs_pruned_all_labels.pdf'))
fig.write_image(str(figures_path / 'ecs_pruned_all_labels.jpg'), scale=5)

: 

In [None]:
print(classwise_accuracy(unclustered_model, test_loader))

: 

In [None]:
print(classwise_accuracy(model, test_loader))

: 

In [None]:
fig = make_subplots(rows=2, cols=5, subplot_titles=[f'Label {i}' for i in range(10)])

for i in range(10):
    img, _ = test_dataset[i]  # Assuming test_dataset returns a tuple (image, label)
    img = img.permute(1, 2, 0).cpu().numpy()  # Change shape from (C, H, W) to (H, W, C)
    img = (img * 255).astype(np.uint8)  # Ensure pixel values are in the range [0, 255]
    
    fig.add_trace(go.Image(z=img), row=(i // 5) + 1, col=(i % 5) + 1)

fig.update_layout(height=600, width=1000, title_text='One image for each label in CIFAR10')
fig.show()

: 

Thank you! Feel free to contact Satvik (zsatvik@gmail.com) for stuff related to this codebase.