This notebook contains code accompanying our submission to NeurIPS 2024 Science of Deep Learning Workshop.

In [1]:
pip install plotly

Note: you may need to restart the kernel to use updated packages.


## 1. Setup 
(run, don't read!)

In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import tqdm
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import MNIST, CIFAR10
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import mutual_info_score
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import plotly.express as px
import plotly.graph_objects as go
import plotly.colors as pc
from plotly.subplots import make_subplots
from IPython.display import clear_output
from collections import defaultdict
from itertools import islice
import random
import time
from pathlib import Path
import math

from sklearn.cluster import KMeans
from scipy.sparse.linalg import svds

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

def randomseed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

### (a) Data

In [3]:
dataset = 'CIFAR10' # 'MNIST' or 'CIFAR10'

if dataset == 'MNIST':
    transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
    train_dataset = MNIST(root='.', train=True, download=True, transform=transform)
    test_dataset = MNIST(root='.', train=False, download=True, transform=transform)
elif dataset == 'CIFAR10':
    transform = torchvision.transforms.ToTensor()
    train_dataset = CIFAR10(root='.', train=True, download=True, transform=transform)
    test_dataset = CIFAR10(root='.', train=False, download=True, transform=transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


### (b) Models

In [4]:
# for MNIST

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 64, bias=False)
        self.fc2 = nn.Linear(64, 64, bias=False)
        self.fc3 = nn.Linear(64, 64, bias=False) # added
        self.fc4 = nn.Linear(64, 10, bias=False)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = self.fc4(x)
        return x
    
# for CIFAR10

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(16 * 16 * 16, 64, bias=False)
        self.fc2 = nn.Linear(64, 64, bias=False)
        self.fc3 = nn.Linear(64, 64, bias=False)
        self.fc4 = nn.Linear(64, 10, bias=False)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = x.view(-1, 16 * 16 * 16)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = self.fc4(x)
        return x
    
def new_model(dataset, device):
    if dataset == 'MNIST':
        model = MLP()
    elif dataset == 'CIFAR10':
        model = CNN()
    model = model.to(device)
    return model        

### (c) Evaluation

In [5]:
def accuracy(model, data):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in data:
            outputs = model(images.to(device))
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels.to(device)).sum().item()
    return correct / total

def classwise_accuracy(model, data):
    model.eval()
    correct = defaultdict(int)
    total = defaultdict(int)
    with torch.no_grad():
        for images, labels in data:
            outputs = model(images.to(device))
            _, predicted = torch.max(outputs.data, 1)
            for i in range(len(labels)):
                label = labels[i].item()
                total[label] += 1
                correct[label] += int(predicted[i] == label)
    return [round(correct[i] / total[i], 3) if total[i] > 0 else 0 for i in range(10)]

def clusterability(model, cluster_U_indices, cluster_V_indices):
    num_clusters = len(cluster_U_indices)
    A = model.fc2.weight ** 2
    mask = torch.zeros_like(A, dtype=torch.bool)
    
    for cluster_idx in range(num_clusters):
        u_indices = torch.tensor(cluster_U_indices[cluster_idx], dtype=torch.long)
        v_indices = torch.tensor(cluster_V_indices[cluster_idx], dtype=torch.long)
        mask[u_indices.unsqueeze(1), v_indices] = True
    
    intra_cluster_out_sum = torch.sum(A[mask])
    total_out_sum = torch.sum(A)
    
    return intra_cluster_out_sum / total_out_sum

## Bipartite clustering and multi-layer clusterability-scores

In [6]:
def bipartite_spectral_clustering(similarity_matrix, k, cluster_U_indices=None, cluster_V_indices=None):
    
    A = similarity_matrix.detach().cpu().numpy() # transform from gpu to cpu, and then to numpy
    A = np.abs(A)
    
    # in D_U each diagonal entry is the sum of the elements in the corresponding row of matrix A
    # in D_V each diagonal entry is the sum of the elements in the corresponding column of matrix A
    D_U = np.diag(np.sum(A, axis=1))
    D_V = np.diag(np.sum(A, axis=0))

    # inverse square roots of the sum matrices
    D_U_inv_sqrt = np.linalg.inv(np.sqrt(D_U))
    D_V_inv_sqrt = np.linalg.inv(np.sqrt(D_V))

    # normalized similarity matrix
    A_tilde = D_U_inv_sqrt @ A @ D_V_inv_sqrt

    # singular value decomposition
    # U: left singular vectors of A_tilde
    # Vt: right singular vectors of A_tilde
    U, Sigma, Vt = svds(A_tilde, k=k)

    if cluster_U_indices is None:
        kmeans_U = KMeans(n_clusters=k, random_state=42).fit(U)
        labels_U = kmeans_U.labels_

        # convert labels to indices
        cluster_U_indices = defaultdict(list)
        for i, label in enumerate(labels_U):
            cluster_U_indices[label].append(i)

    if cluster_V_indices is None:
        kmeans_V = KMeans(n_clusters=k, random_state=42).fit(Vt.T)
        labels_V = kmeans_V.labels_

        # convert labels to indices
        cluster_V_indices = defaultdict(list)
        for i, label in enumerate(labels_V):
            cluster_V_indices[label].append(i)

    return cluster_U_indices, cluster_V_indices

In [7]:
def total_clusterability(model, grads, subset_of_layers, num_clusters):

    nr_layers = len(subset_of_layers)
    clusterability_scores = np.zeros((nr_layers, 2))

    cluster_next_U_indices = [None, None]
    for l, layer_nr in enumerate(subset_of_layers):
        # why are we dividing by the norm of the gradient?
        # if type(grads[layer_nr]) is not torch.Tensor:
        layer_grads = torch.stack(list({k: v[0] / v[0].norm() for k, v in grads[layer_nr].items()}.values()))

        # initializing the similarity matrix
        grads_similarity_matrix = torch.zeros(layer_grads.shape[1], layer_grads.shape[1]).to(device)
        # populating the similarity matrix with the dot product of the gradients
        # why?
        for i in range(layer_grads.shape[0]):
            grads_similarity_matrix += torch.mm(layer_grads[i].t(), layer_grads[i])

        # similarity_matrices = [model[f"fc{layer_nr}"].weight, grads_similarity_matrix]
        similarity_matrices = [getattr(model, f"fc{layer_nr}").weight, grads_similarity_matrix]

        for i, similarity_matrix in enumerate(similarity_matrices):
            cluster_U_indices, cluster_V_indices = bipartite_spectral_clustering(similarity_matrix, num_clusters, cluster_next_U_indices[i])
            cl = clusterability(model, cluster_U_indices, cluster_V_indices)
            clusterability_scores[l, i] = cl.item()
            cluster_next_U_indices[i] = cluster_V_indices

    return clusterability_scores

## 2. Training the Model

In [8]:
dataset, device

('CIFAR10', device(type='cuda', index=0))

In [9]:
subset_of_layers = [2,3]

unclustered_model = new_model(dataset, device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(unclustered_model.parameters(), lr=1e-3)
train_losses = []
# struct to store some gradients
grads = {l: defaultdict(list) for l in subset_of_layers}

In [10]:
path = Path(f'results/{dataset}/')
path.mkdir(parents=True, exist_ok=True)  # Create the directory if it doesn't exist


In [11]:
randomseed(42)
path = Path(f'results/{dataset}/')

for epoch in range(10):
    unclustered_model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = unclustered_model(data)
        loss = criterion(output, target)
        loss.backward()
        # store gradients that are in the subset of layers
        for l in subset_of_layers:
            for p in unclustered_model.named_parameters():      
                if p[0] == f'fc{l}.weight':
                    grads[l][batch_idx + (epoch * len(train_loader))].append(p[1].grad.clone())
                    break
        optimizer.step()
        train_losses.append(loss.item())
    acc = accuracy(unclustered_model, test_loader)
    print(f'Epoch {epoch+1}/{10}, Loss: {loss.item():.4f}, Accuracy: {acc:.4f}')
    # save model
    torch.save(unclustered_model.state_dict(), path / 'unclustered_model.pth')

Epoch 1/10, Loss: 1.6169, Accuracy: 0.4686
Epoch 2/10, Loss: 1.2356, Accuracy: 0.5271
Epoch 3/10, Loss: 0.9296, Accuracy: 0.5546
Epoch 4/10, Loss: 1.1040, Accuracy: 0.5550
Epoch 5/10, Loss: 1.1980, Accuracy: 0.5930
Epoch 6/10, Loss: 0.9542, Accuracy: 0.5916
Epoch 7/10, Loss: 1.1042, Accuracy: 0.6067
Epoch 8/10, Loss: 1.1427, Accuracy: 0.6233
Epoch 9/10, Loss: 1.3096, Accuracy: 0.6230
Epoch 10/10, Loss: 0.7990, Accuracy: 0.6265


For three layers we get the following accuracies:

Epoch 1/10, Loss: 1.5809, Accuracy: 0.4778
Epoch 2/10, Loss: 1.1792, Accuracy: 0.5128
Epoch 3/10, Loss: 0.9798, Accuracy: 0.5593
Epoch 4/10, Loss: 1.1889, Accuracy: 0.5728
Epoch 5/10, Loss: 1.4932, Accuracy: 0.5862
Epoch 6/10, Loss: 0.8826, Accuracy: 0.5967
Epoch 7/10, Loss: 0.8351, Accuracy: 0.6075
Epoch 8/10, Loss: 1.3136, Accuracy: 0.6241
Epoch 9/10, Loss: 1.2181, Accuracy: 0.6328
Epoch 10/10, Loss: 0.6489, Accuracy: 0.6302

In [12]:
fig = go.Figure()
fig.add_trace(go.Scatter(y=train_losses, mode='lines', name='', line=dict(color='darkred', width=2)))
fig.update_layout({'plot_bgcolor': 'rgba(255, 255, 255, 1)',})
fig.update_layout(
    xaxis=dict(showgrid=True, gridwidth=1, gridcolor='LightGray'),
    yaxis=dict(showgrid=True, gridwidth=1, gridcolor='LightGray'),
)
fig.update_xaxes(title_text='Optimization Step')
fig.update_yaxes(title_text='CrossEntropy Loss')
fig.update_layout(width=600, height=400, autosize=False)
fig.show()

In [13]:
path = Path(f'results/{dataset}/')

In [14]:
# load the unclustered model
unclustered_model = new_model(dataset, device)
unclustered_model.load_state_dict(torch.load(path/'unclustered_model.pth'))


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.



<All keys matched successfully>

In [15]:
# if type(grads) is not torch.Tensor:
#     grads = torch.stack(list({k: v[0] / v[0].norm() for k, v in grads.items()}.values()))

# grads_similarity_matrix = torch.zeros(grads.shape[1], grads.shape[1]).to(device)

# for i in range(grads.shape[0]):
#     grads_similarity_matrix += torch.mm(grads[i].t(), grads[i])

# similarity_matrices = [unclustered_model.fc2.weight, grads_similarity_matrix]
# cluster_sizes = [2, 4, 6, 8, 10, 12, 14]
# clusterability_scores = np.zeros((len(similarity_matrices), len(cluster_sizes)))

In [16]:
# for i, similarity_matrix in enumerate(similarity_matrices):
#     for num_clusters in cluster_sizes:
#         cluster_U_indices, cluster_V_indices = bipartite_spectral_clustering(similarity_matrix, num_clusters)
#         cl = clusterability(unclustered_model, cluster_U_indices, cluster_V_indices, num_clusters)
#         clusterability_scores[i, cluster_sizes.index(num_clusters)] = cl.item()

## Clusterability of trained model

In [17]:
cluster_sizes = [2, 4, 6, 8, 10, 12, 14]
clusterability_scores = []
for num_clusters in cluster_sizes:
    clusterability_scores.append(total_clusterability(unclustered_model, grads, subset_of_layers, num_clusters))

clusterability_scores = np.array(clusterability_scores)

In [18]:
print(clusterability_scores.shape)

(7, 2, 2)


In [19]:
# plot

fig = go.Figure()
fig.add_trace(go.Scatter(x=cluster_sizes, y=clusterability_scores.mean(axis=1)[:, 0], mode='lines+markers', name='Weight-based BSGC'))
fig.add_trace(go.Scatter(x=cluster_sizes, y=clusterability_scores.mean(axis=1)[:, 1], mode='lines+markers', name='Gradient-based BSGC'))

fig.update_layout({'plot_bgcolor': 'rgba(255, 255, 255, 1)',})
fig.update_layout(
    xaxis=dict(showgrid=True, gridwidth=1, gridcolor='LightGray'),
    yaxis=dict(showgrid=True, gridwidth=1, gridcolor='LightGray'),
)
# show x and y axis ticks
fig.update_xaxes(tickvals=cluster_sizes)
fig.update_yaxes(tickvals=np.arange(0, 1.1, 0.1))
# show zero gridlines
fig.update_xaxes(showline=True, linewidth=1, linecolor='black', mirror=False)
fig.update_yaxes(showline=True, linewidth=1, linecolor='black', mirror=False)
fig.update_xaxes(title_text='Number of Clusters')
fig.update_yaxes(title_text='Clusterability')
fig.update_layout(width=600, height=400, autosize=False)

# y axis from 0 to 1
fig.update_yaxes(range=[0, 0.9])

# everthing latex font (for research paper)
fig.update_layout(font=dict(family='serif', size=20, color='black'))
fig.update_xaxes(title_font=dict(family='serif', size=20, color='black'))
fig.update_yaxes(title_font=dict(family='serif', size=20, color='black'))
fig.update_xaxes(tickfont=dict(family='serif', size=18, color='black'))
fig.update_yaxes(tickfont=dict(family='serif', size=18, color='black'))
fig.update_xaxes(showline=True, linewidth=1, linecolor='black', mirror=False)
fig.update_yaxes(showline=True, linewidth=1, linecolor='black', mirror=False)

fig.show()

In [20]:
figures_path = Path('figures')

if not figures_path.exists():
    figures_path.mkdir()

In [21]:
pip install -U kaleido

Requirement already up-to-date: kaleido in /home/nandi/anaconda3/lib/python3.8/site-packages (0.2.1)
Note: you may need to restart the kernel to use updated packages.


In [22]:
# save as pdf

fig.write_image(str(figures_path / 'clusterability_scores.pdf'))

This doesn't seem to be good enough for interpretability.

## 3. Optimizing for Modularity

### (a) Train for a few steps.

In [23]:
dataset, device

('CIFAR10', device(type='cuda', index=0))

In [24]:
model = new_model(dataset, device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
train_losses = []
total_epochs = 20
initial_epochs = 5

randomseed(42)

for epoch in range(initial_epochs):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())
    acc = accuracy(model, test_loader)
    print(f'Epoch {epoch+1}/{total_epochs}, Loss: {loss.item():.4f}, Accuracy: {acc:.4f}')

Epoch 1/20, Loss: 1.6872, Accuracy: 0.4439
Epoch 2/20, Loss: 1.1770, Accuracy: 0.5112
Epoch 3/20, Loss: 0.9848, Accuracy: 0.5369
Epoch 4/20, Loss: 1.2573, Accuracy: 0.5756
Epoch 5/20, Loss: 1.4544, Accuracy: 0.5828


### (b) Get Clusters (via BSGC)

In [25]:
num_clusters = 4
similarity_matrix = model.fc2.weight
cluster_U_indices, cluster_V_indices = bipartite_spectral_clustering(similarity_matrix, num_clusters)

for i in range(num_clusters):
    print(f'Cluster {i} has {len(cluster_U_indices[i])} nodes in U and {len(cluster_V_indices[i])} nodes in V')

Cluster 0 has 17 nodes in U and 44 nodes in V
Cluster 1 has 20 nodes in U and 7 nodes in V
Cluster 2 has 16 nodes in U and 11 nodes in V
Cluster 3 has 11 nodes in U and 2 nodes in V


In [27]:
clusterability_score = clusterability(model, cluster_U_indices, cluster_V_indices)
print(f'Clusterability score: {round(clusterability_score.item(), 3)}')

Clusterability score: 0.228


### (c) Train the rest of the model with the Enmeshment Loss

In [29]:
cluster_losses = []
ce_losses = []
lomda = 20.0

randomseed(42)

for epoch in range(total_epochs - initial_epochs):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss_ce = criterion(output, target)
        loss_cluster = clusterability(model, cluster_U_indices, cluster_V_indices)
        cluster_losses.append(loss_cluster.item())
        ce_losses.append(loss_ce.item())
        ###
        loss = loss_ce - lomda * loss_cluster
        ###
        loss.backward()
        optimizer.step()
    acc = accuracy(model, test_loader)
    print(f'Epoch {epoch+1}/{total_epochs}, Loss: {loss.item():.4f}, Clusterability: {loss_cluster.item():.4f}, Accuracy: {acc:.4f}')

Epoch 1/20, Loss: -18.6823, Clusterability: 0.9995, Accuracy: 0.5918
Epoch 2/20, Loss: -18.8673, Clusterability: 0.9998, Accuracy: 0.5949
Epoch 3/20, Loss: -19.0849, Clusterability: 0.9998, Accuracy: 0.5987
Epoch 4/20, Loss: -19.0779, Clusterability: 0.9998, Accuracy: 0.6129
Epoch 5/20, Loss: -18.8423, Clusterability: 0.9999, Accuracy: 0.6044
Epoch 6/20, Loss: -19.4043, Clusterability: 0.9999, Accuracy: 0.6110
Epoch 7/20, Loss: -19.1226, Clusterability: 0.9999, Accuracy: 0.6081
Epoch 8/20, Loss: -18.6821, Clusterability: 0.9999, Accuracy: 0.6154
Epoch 9/20, Loss: -18.8185, Clusterability: 0.9999, Accuracy: 0.6197
Epoch 10/20, Loss: -19.6229, Clusterability: 0.9999, Accuracy: 0.6178
Epoch 11/20, Loss: -19.3434, Clusterability: 0.9999, Accuracy: 0.6295
Epoch 12/20, Loss: -19.5446, Clusterability: 0.9999, Accuracy: 0.6099
Epoch 13/20, Loss: -19.6925, Clusterability: 0.9999, Accuracy: 0.6116
Epoch 14/20, Loss: -19.3498, Clusterability: 0.9999, Accuracy: 0.6065
Epoch 15/20, Loss: -19.3183, 

In [30]:
# plot the two losses on side by side plots
fig = make_subplots(rows=1, cols=2, subplot_titles=('Cross Entropy Loss', 'Model Clusterability'))
fig.add_trace(go.Scatter(y=ce_losses, mode='lines', name='', line=dict(color='darkred', width=2)), row=1, col=1)
fig.add_trace(go.Scatter(y=cluster_losses, mode='lines', name='', line=dict(color='darkblue', width=2)), row=1, col=2)
fig.update_layout({'plot_bgcolor': 'rgba(255, 255, 255, 1)',})
# show fine grid lines on both axes on both subplots
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='LightGray', row=1, col=1)
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGray', row=1, col=1)
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='LightGray', row=1, col=2)
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGray', row=1, col=2)

fig.update_xaxes(title_text='Optimization Step', row=1, col=1)
fig.update_yaxes(title_text='Cross Entropy', row=1, col=1)
fig.update_xaxes(title_text='Optimization Step', row=1, col=2)
fig.update_yaxes(title_text='Clusterability', row=1, col=2)
fig.update_layout(width=1000, height=400, autosize=False)
fig.show()

### (d) Store our clusters and the clustered model

In [31]:
path: str = f'results/{dataset}/'
Path(path).mkdir(parents=True, exist_ok=True)

torch.save(model.state_dict(), f'{path}model.pth')
torch.save(cluster_U_indices, f'{path}cluster_U_indices.pth')
torch.save(cluster_V_indices, f'{path}cluster_V_indices.pth')

## 4. Interpreting the Clusters

In [32]:
path: str = f'results/{dataset}/'

# load the model and cluster indices
model = new_model(dataset, device)
model.load_state_dict(torch.load(f'{path}model.pth'))
cluster_U_indices = torch.load(f'{path}cluster_U_indices.pth')
cluster_V_indices = torch.load(f'{path}cluster_V_indices.pth')


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is poss

In [33]:
classwise_accuracies = classwise_accuracy(model, test_loader)
classwise_accuracies

[0.708, 0.64, 0.435, 0.414, 0.549, 0.36, 0.782, 0.655, 0.661, 0.799]

In [34]:
num_clusters = len(cluster_U_indices)
num_clusters

4

### (a) Classwise accuracies with individual clusters turned ON and OFF

In [35]:
# plot the class-wise accuracies for each cluster turned OFF

num_clusters = len(cluster_U_indices)

classwise_accuracies_off = []
classwise_accuracies_on = []

for cluster_idx in tqdm.trange(num_clusters):
    model = new_model(dataset, device)
    model.load_state_dict(torch.load(path + 'model.pth'))
    model.to(device)
    
    # turn off the cluster
    for i in cluster_U_indices[cluster_idx]:
        model.fc2.weight.data[i] = 0
    for i in cluster_V_indices[cluster_idx]:
        model.fc2.weight.data[:, i] = 0
    
    classwise_accuracies_off.append(classwise_accuracy(model, test_loader))

model.load_state_dict(torch.load(path + 'model.pth'))

# plot the class-wise accuracies for each cluster turned ON

classwise_accuracies = []

for cluster_idx in tqdm.trange(num_clusters):
    model = CNN()
    model.load_state_dict(torch.load(path + 'model.pth'))
    model.to(device)
    
    # turn off every cluster except the current one
    for i in range(num_clusters):
        if i != cluster_idx:
            for j in cluster_U_indices[i]:
                model.fc2.weight.data[j] = 0
            for j in cluster_V_indices[i]:
                model.fc2.weight.data[:, j] = 0
    
    classwise_accuracies_on.append(classwise_accuracy(model, test_loader))

model.load_state_dict(torch.load(path + 'model.pth'))


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.

100%|██████████| 4/4 [00:04<00:00,  1.10s/it]

You are using `torch.load` with `weights_only=False` (the current default value), which uses th

<All keys matched successfully>

In [36]:
def plot_classwise_cluster_perf(num_clusters, classwise_accuracies, type='ON'):

    color_scale = pc.qualitative.G10  # You can use other scales like `pc.sequential.Plasma` or `pc.sequential.Viridis`
    colors = color_scale[:num_clusters]  # Ensure we have enough colors for the number of clusters

    # Create subplots: one row per cluster
    fig = make_subplots(rows=num_clusters // 2, cols=2, shared_xaxes=False, 
                        subplot_titles=[f'Cluster {i} ({type})' for i in range(num_clusters)],
                        horizontal_spacing=0.15, vertical_spacing=0.15)

    for i in range(num_clusters):
        fig.add_trace(go.Bar(
            x=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],
            y=classwise_accuracies[i],
            marker_color=colors[i],
            name=f'Cluster {i}',
        ), row=i // 2 + 1, col=i % 2 + 1)

    # white background
    fig.update_layout(plot_bgcolor='white')

    # gridlines
    fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGray')

    fig.update_layout(height=150 * num_clusters, width=600)

    # hide the legend
    fig.update_layout(showlegend=False)

    # latex font on everything (for research papers)

    # everthing latex font (for research paper)
    fig.update_layout(font=dict(family='serif', size=20, color='black'))
    fig.update_xaxes(title_font=dict(family='serif', size=20, color='black'))
    fig.update_yaxes(title_font=dict(family='serif', size=20, color='black'))
    fig.update_xaxes(tickfont=dict(family='serif', size=18, color='black'))
    fig.update_yaxes(tickfont=dict(family='serif', size=18, color='black'))
    fig.update_xaxes(showline=True, linewidth=1, linecolor='black', mirror=False)
    fig.update_yaxes(showline=True, linewidth=1, linecolor='black', mirror=False)

    # show all x ticks
    fig.update_xaxes(tickvals=np.arange(10))

    # all y-axes have the same range and ticks
    fig.update_yaxes(range=[0, 1], tickvals=np.arange(0, 1.1, 0.25))

    return fig

In [37]:
fig = plot_classwise_cluster_perf(num_clusters, classwise_accuracies_off, type='OFF')
fig.show()

In [38]:
# save as pdf

fig.write_image(str(figures_path / 'classwise_cluster_off_accuracies.pdf'))

In [39]:
fig = plot_classwise_cluster_perf(num_clusters, classwise_accuracies_on, type='ON')
fig.show()

In [40]:
# save as pdf and png (high quality)

fig.write_image(str(figures_path / 'classwise_cluster_on_accuracies.pdf'))
fig.write_image(str(figures_path / 'classwise_cluster_on_accuracies.png'), scale=5)

In [41]:
cifar_labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

In [42]:
def plot_classwise_cluster_perf(num_clusters, classwise_accuracies_on, classwise_accuracies_off, cifar_labels=cifar_labels):

    colors = [
        pc.qualitative.Dark24_r[0],
        pc.qualitative.Dark24_r[1],
        pc.qualitative.Dark24_r[15],
        pc.qualitative.Dark24_r[5],
    ]

    # Create subplots: one row per cluster
    fig = make_subplots(rows=2, cols=4, shared_xaxes=True, shared_yaxes=True, vertical_spacing=0.1, subplot_titles=[f'Cluster {i} (ON)' for i in range(num_clusters)] + [f'Cluster {i} (OFF)' for i in range(num_clusters)], row_heights=[0.15, 0.15])

    for i in range(num_clusters):
        fig.add_trace(go.Bar(
            x=cifar_labels,
            y=classwise_accuracies_on[i],
            marker_color=colors[i],
            name=f'Cluster {i} (ON)',
        ), row=1, col=i+1)
        # x tick angle
        fig.update_xaxes(tickangle=-90)

        fig.add_trace(go.Bar(
            x=cifar_labels,
            y=classwise_accuracies_off[i],
            marker_color=colors[i],
            name=f'Cluster {i} (OFF)',
        ), row=2, col=i+1)

    # white background
    fig.update_layout(plot_bgcolor='white')

    # gridlines
    fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGray')

    fig.update_layout(height=140 * num_clusters, width=1000)

    # hide the legend
    fig.update_layout(showlegend=True)

    # show y-axis title once in the exact middle of four subplots
    fig.add_annotation(
        text="Accuracy",
        xref="paper", yref="paper",
        x=-0.1, y=0.5,
        showarrow=False,
        font=dict(size=22, family="Computer Modern"),
        align="center",
        textangle=-90
    )
    # show x-axis title once in the exact middle of four subplots
    fig.add_annotation(
        text="Class",
        xref="paper", yref="paper",
        x=0.5, y=-0.35,
        showarrow=False,
        font=dict(size=22, family="Computer Modern"),
        align="center",
    )


    # latex font on everything (for research papers)

    fig.update_layout(
        font_family="Computer Modern",
        font_size=20,
    )

    # all y-axes have the same range and ticks
    fig.update_yaxes(range=[0, 1], tickvals=np.arange(0, 1, 0.2))

    # no legend
    fig.update_layout(showlegend=False)

    # show all x ticks
    fig.update_xaxes(tickvals=np.arange(10))

    # remove space from top of figure and add some space at the bottom
    fig.update_layout(margin=dict(t=50, b=140))

    # everthing latex font (for research paper)
    fig.update_layout(font=dict(family='serif', size=20, color='black'))
    fig.update_xaxes(title_font=dict(family='serif', size=20, color='black'))
    fig.update_yaxes(title_font=dict(family='serif', size=20, color='black'))
    fig.update_xaxes(tickfont=dict(family='serif', size=18, color='black'))
    fig.update_yaxes(tickfont=dict(family='serif', size=18, color='black'))
    fig.update_xaxes(showline=True, linewidth=1, linecolor='black', mirror=False)
    return fig

fig = plot_classwise_cluster_perf(num_clusters, classwise_accuracies_on, classwise_accuracies_off)
fig.show()

In [43]:
# save as pdf

fig.write_image(str(figures_path / 'classwise_cluster_all_accuracies.pdf'), scale=5)

In [44]:
path

'results/CIFAR10/'

### (b) Layer Visualization

Note: This is not currently working (there's a bug), but skipping since it is also probably not that important.

In [45]:
path = Path(f'results/{dataset}/')

model = new_model(dataset, device)
model.load_state_dict(torch.load(f'{path}/model.pth'))
cluster_U_indices = torch.load(f'{path}/cluster_U_indices.pth')
cluster_V_indices = torch.load(f'{path}/cluster_V_indices.pth')


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is poss

In [46]:
classwise_accuracy(model, test_loader)

[0.708, 0.64, 0.435, 0.414, 0.549, 0.36, 0.782, 0.655, 0.661, 0.799]

In [47]:
cluster_U_indices.keys(), cluster_V_indices.keys(), len(cluster_U_indices)

(dict_keys([2, 3, 0, 1]), dict_keys([1, 0, 2, 3]), 4)

In [48]:
def visualize_layer(layer):
    weights = layer.weight.detach().cpu().numpy() 
    num_outputs, num_inputs = weights.shape

    colors = pc.qualitative.G10

    node_x = []
    node_y = []
    node_text = []
    edge_x = []
    edge_y = []
    edge_color = []
    edge_width = []

    input_layer_positions = np.linspace(-1, 1, num_inputs)
    output_layer_positions = np.linspace(-1, 1, num_outputs)
    
    for i in range(num_inputs):
        node_x.append(input_layer_positions[i])
        node_y.append(-1)
        # node_text.append(f"{i+1}")
        node_text.append(".")

    for j in range(num_outputs):
        node_x.append(output_layer_positions[j])
        node_y.append(1)
        # node_text.append(f"{j+1}")
        node_text.append(".")

    # Add edges
    for i in range(num_inputs):
        for j in range(num_outputs):
            weight = weights[j, i]
            color = 'red' if weight < 0 else 'blue'
            # color based on cluster index
            
            thickness = abs(weight)
            edge_x.extend([input_layer_positions[i], output_layer_positions[j]])
            edge_y.extend([-1, 1])
            edge_color.append(color)
            edge_width.append(thickness)
            # edge_width.append(1)

    fig = go.Figure()
    
    fig.add_trace(go.Scatter(
        x=node_x, 
        y=node_y,
        mode='markers+text',
        text=node_text,
        textposition='top center',
        marker=dict(size=10, color='lightgray'),
        showlegend=False
    ))

    # also on bottom center
    fig.add_trace(go.Scatter(
        x=node_x, 
        y=[-1]*len(node_y),
        mode='markers+text',
        text=node_text,
        textposition='bottom center',
        marker=dict(size=10, color='lightgray'),
        showlegend=False
    ))
    
    for i in range(0, len(edge_x), 2):
        fig.add_trace(go.Scatter(
            x=[edge_x[i], edge_x[i+1]],
            y=[edge_y[i], edge_y[i+1]],
            mode='lines',
            line=dict(color=edge_color[i//2], width=edge_width[i//2]),
            showlegend=False
        ))

    # Update layout
    fig.update_layout(
        title=".",
        xaxis=dict(showgrid=False, zeroline=False),
        yaxis=dict(showgrid=False, zeroline=False, range=[-1.5, 1.5]),
        showlegend=False
    )

    # white background
    fig.update_layout(plot_bgcolor='white')
    # square aspect ratio
    fig.update_layout(
        autosize=False,
        width=1200,
        height=700,
    )

    # remove axis ticks
    fig.update_xaxes(showticklabels=False)
    fig.update_yaxes(showticklabels=False)

    return fig

In [49]:
old_layer = model.fc2
old_layer.weight.shape

torch.Size([64, 64])

In [50]:
new_layer = nn.Linear(64, 64, bias=False)
new_weights = torch.zeros_like(new_layer.weight)
new_weights.shape

torch.Size([64, 64])

In [51]:
flat_U, flat_V = [], []
for i in range(len(cluster_U_indices)):
    flat_U.extend(cluster_U_indices[i])
    flat_V.extend(cluster_V_indices[i])

for i in range(len(flat_U)):
    for j in range(len(flat_V)):
        index_i = flat_U.index(i)
        index_j = flat_V.index(j)
        new_weights[index_j, index_i] = old_layer.weight[i, j]

new_layer.weight.data = new_weights

In [52]:
fig = visualize_layer(new_layer)
fig.show()

In [53]:
# save exact figure as shown in the notebook in high quality

fig.write_image(str(figures_path / 'clustered_layer.pdf'), scale=5)

In [54]:
# similarly rearrange fc3 layer
old_layer = model.fc3
old_layer.weight.shape

torch.Size([64, 64])

In [55]:
new_layer = nn.Linear(64, 10, bias=False)
new_weights = torch.zeros_like(new_layer.weight)
new_weights.shape

torch.Size([10, 64])

In [56]:
# here, we only have clusters on the input side

for i in range(len(flat_U)):
    for j in range(10):
        new_weights[j, i] = old_layer.weight[j, flat_U[i]]

new_layer.weight.data = new_weights

In [57]:
visualize_layer(new_layer).show()

### (c) Max-Activating Datapoints: Neuron Semanticity

This turns out to be (expectedly, kinda) not that helpful.

In [58]:
def max_activating_datapoints(k, model, data_loader, layer, neuron_index):
    # get the activations of the layer for all the data points
    activations = []
    model.eval()
    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            flatten_images = images.view(-1, 28 * 28) if dataset == 'MNIST' else images
            flatten_images = F.max_pool2d(torch.relu(model.conv1(images)), 2).view(-1, 16 * 16 * 16) if dataset == 'CIFAR10' else flatten_images
            fc_1_out = torch.relu(model.fc1(flatten_images))
            fc_2_out = torch.relu(model.fc2(fc_1_out))
            fc_3_out = model.fc3(fc_2_out)
            if layer == 1:
                activations.extend(fc_1_out[:, neuron_index].cpu().numpy())
            elif layer == 2:
                activations.extend(fc_2_out[:, neuron_index].cpu().numpy())
            elif layer == 3:
                activations.extend(fc_3_out[:, neuron_index].cpu().numpy())

    activations = np.array(activations)
    indices = np.argsort(activations)[-k:]
    # print(indices)
    # get the images corresponding to the top k activations
    images = []
    labels = []
    for i in indices:
        images.append(data_loader.dataset[i][0])
        labels.append(data_loader.dataset[i][1])
    return images, labels

In [59]:
for neuron_index in range(len(model.fc2.weight))[:10]:
    _, labels = max_activating_datapoints(20, model, test_loader, 2, neuron_index)
    print(f"Neuron {neuron_index} is most activated by the following digits: {labels}")

Neuron 0 is most activated by the following digits: [9, 9, 1, 3, 6, 1, 1, 5, 3, 9, 0, 9, 0, 0, 0, 9, 9, 9, 9, 9]
Neuron 1 is most activated by the following digits: [9, 7, 0, 2, 5, 3, 6, 4, 7, 0, 2, 7, 8, 7, 2, 5, 0, 0, 3, 7]
Neuron 2 is most activated by the following digits: [5, 5, 5, 5, 5, 1, 7, 3, 3, 5, 5, 9, 5, 5, 3, 6, 3, 5, 3, 5]
Neuron 3 is most activated by the following digits: [8, 1, 0, 1, 1, 1, 1, 1, 8, 8, 8, 1, 1, 8, 0, 8, 8, 0, 8, 8]
Neuron 4 is most activated by the following digits: [1, 1, 1, 1, 1, 1, 1, 1, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
Neuron 5 is most activated by the following digits: [6, 5, 5, 3, 5, 5, 0, 1, 7, 7, 3, 6, 7, 1, 7, 3, 5, 2, 7, 2]
Neuron 6 is most activated by the following digits: [7, 5, 3, 5, 5, 5, 2, 5, 5, 5, 2, 3, 5, 5, 5, 5, 5, 5, 5, 5]
Neuron 7 is most activated by the following digits: [1, 1, 1, 8, 8, 8, 8, 1, 9, 8, 1, 1, 1, 1, 8, 1, 1, 1, 1, 1]
Neuron 8 is most activated by the following digits: [7, 2, 5, 7, 7, 7, 5, 7, 7, 7, 7, 7, 5, 7, 7

In [60]:
def neuron_label_semanticity(model, test_loader, layer_id, k):
    neuron_semanticity = []
    n_semantic_neurons = [0 for _ in range(10)]
    for neuron_index in tqdm.trange(len(model.fc2.weight)):
        _, labels = max_activating_datapoints(k, model, test_loader, layer_id, neuron_index)
        # a rough sense of interpretability is the number of unique classes in the top k activations
        num_unique_classes = len(set(labels))
        neuron_semanticity.append(num_unique_classes / 10)
        n_semantic_neurons[num_unique_classes - 1] += 1
    
    return sum(neuron_semanticity) / len(neuron_semanticity), n_semantic_neurons

In [61]:
rsoil, dist = neuron_label_semanticity(model, test_loader, 2, 50)

  0%|          | 0/64 [00:00<?, ?it/s]

100%|██████████| 64/64 [00:55<00:00,  1.15it/s]


## 5. Comparison with Unclustered Model

### (a) Neuron Label Semanticity

In [62]:
rsoil_unclustered, dist_unclustered = neuron_label_semanticity(unclustered_model, test_loader, 2, 50)

100%|██████████| 64/64 [00:52<00:00,  1.22it/s]


In [63]:
# aggregate sum of dist till index i
dist_agg = [0] * 10
for i in range(10):
    dist_agg[i] = sum(dist[i:])

dist_unclustered_agg = [0] * 10
for i in range(10):
    dist_unclustered_agg[i] = sum(dist_unclustered[i:])

In [64]:
# plot the distribution of the number of unique classes in the top k activations for each neuron in the fc2 layer of the clustered model against the unclustered model

fig = go.Figure()

# lines for the clustered model, not histogram
fig.add_trace(go.Scatter(
    x=list(range(10)),
    y=dist_agg,
    mode='lines',
    name='Clustered model',
    line=dict(color='darkblue', width=2)
))

# lines for the unclustered model, not histogram
fig.add_trace(go.Scatter(
    x=list(range(10)),
    y=dist_unclustered_agg,
    mode='lines',
    name='Unclustered model',
    line=dict(color='darkred', width=2)
))

fig.update_layout(
    title_text='Distribution of the number of unique classes in the top k activations for each neuron in the fc2 layer',
    xaxis_title_text='Number of unique classes',
    yaxis_title_text='Number of neurons',
    barmode='overlay'
)

fig.update_traces(marker_line_width=0)

# white background
fig.update_layout(plot_bgcolor='white')

# gridlines

fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGray')
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='LightGray')

fig.show()

### (b) Circuit discovery search-space reduction

In [65]:
total_params = sum(p.numel() for p in model.parameters())
total_params

271424

In [66]:
(3 * 16 * 3 * 3) + 16 +  4096 * 64 + 64 * 64 + 64 * 10

267328

In [67]:
trimmed_params = (3 * 16 * 3 * 3) + 16 +  4096 * 16 + 16 * 16 + 16 * 10

print(f'Percentage pruned: {(total_params - trimmed_params) / total_params * 100:.2f}%')

Percentage pruned: 75.54%


# The below stuff (pruning) takes a long time

### (c) Circuit Complexity Reduction (CCR)

The reduction in the expected circuit size for a given behavior. Since we restrict circuit-forming to use fixed clusters, we can likely expect simpler circuits? For this we need to implement a circuit discovery method. For now, we'll just use accuracy-based pruning. We can shift to others (ACDC/EAP) later on.

In [68]:
def fast_label_perf(model, x, label):
    with torch.no_grad():
        output = model(x)
        criterion = nn.CrossEntropyLoss()
        target = torch.tensor([label] * x.size(0)).to(device)
        loss = criterion(output, target)
        accuracy = (output.argmax(dim=1) == target).sum().item() / x.size(0)
    return loss, accuracy

In [69]:
import copy

def prune_model(model, x, label, device, verbose=False):
    pruned_model = copy.deepcopy(model)
    
    layers = list(pruned_model.children())
    for layer in reversed(layers):
        if verbose:
            print(f'Pruning layer: {layer}')
        if isinstance(layer, nn.Linear):
            for neuron_idx in tqdm.trange(layer.weight.shape[0]):
                for weight_idx in range(layer.weight.shape[1]):
                    # Create a mask to zero out the weight
                    mask = torch.ones_like(layer.weight)
                    mask[neuron_idx, weight_idx] = 0
                    
                    # Apply the mask
                    original_weight = layer.weight[neuron_idx, weight_idx].item()
                    layer.weight.data[neuron_idx, weight_idx] = 0
                    
                    # Check performance
                    loss_pruned, acc_pruned = fast_label_perf(pruned_model, x, label)
                    loss_original, acc_original = fast_label_perf(model, x, label)
                    
                    # If performance decreases, restore the weight
                    if (loss_pruned - loss_original) > 0:
                        layer.weight.data[neuron_idx, weight_idx] = original_weight

        # fraction pruned
        num_zeros = torch.sum(layer.weight == 0).item()
        total_params = layer.weight.numel()
        if verbose:
            print(f'Fraction pruned: {num_zeros / total_params:.4f}')
                    
    return pruned_model

In [70]:
label = 8

label_data = torch.stack([test_dataset[i][0] for i in range(len(test_dataset)) if test_dataset[i][1] == label])

label_data = label_data.to(device)

label_data.shape

torch.Size([1000, 3, 32, 32])

In [71]:
# pruned_model = prune_model(model, label_data, label, device, verbose=True)

In [72]:
# torch.save(unclustered_model.state_dict(), path / 'pruned_model.pth')

In [73]:
# comparing with if we had pruned the unclustered model

pruned_model_unclustered = prune_model(unclustered_model, label_data, label, device, verbose=True)

  0%|          | 0/10 [00:00<?, ?it/s]

Pruning layer: Linear(in_features=64, out_features=10, bias=False)


100%|██████████| 10/10 [00:32<00:00,  3.24s/it]
  0%|          | 0/64 [00:00<?, ?it/s]

Fraction pruned: 0.9891
Pruning layer: Linear(in_features=64, out_features=64, bias=False)


 12%|█▎        | 8/64 [00:27<03:11,  3.42s/it]

KeyboardInterrupt: 

In [None]:
print(classwise_accuracy(pruned_model, test_loader))
print(classwise_accuracy(pruned_model_unclustered, test_loader))

: 

In [None]:
def effective_circuit_size(model):
    # fraction of non-zero weights
    total_params = 0
    num_zeros = 0
    for layer in model.children():
        if isinstance(layer, nn.Linear):
            total_params += layer.weight.numel()
            num_zeros += torch.sum(layer.weight == 0).item()
    return round(1 - (num_zeros / total_params), 3)

: 

In [None]:
effective_circuit_size(model), effective_circuit_size(pruned_model), effective_circuit_size(pruned_model_unclustered)

: 

In [None]:
path

: 

In [None]:
# store the unclustered model
torch.save(unclustered_model.state_dict(), f'{path}unclustered_model.pth')

: 

In [None]:
# effective circuit sizes for pruned models v pruned unclustered models for each label
# NOTE:
# this will take a while to run; better to run this as a script in the background (there's a "pruning.py" for this)
# and then directly load the results

ecs_pruned_all_labels = []
ecs_pruned_unclustered_all_labels = []

for label in tqdm.trange(10):
    label_data = torch.stack([test_dataset[i][0] for i in range(len(test_dataset)) if test_dataset[i][1] == label])
    label_data = label_data.to(device)
    
    pruned_model = prune_model(model, label_data, label, device, verbose=False)
    pruned_model_unclustered = prune_model(unclustered_model, label_data, label, device, verbose=False)
    
    ecs_pruned_all_labels.append(effective_circuit_size(pruned_model))
    ecs_pruned_unclustered_all_labels.append(effective_circuit_size(pruned_model_unclustered))

    # print the effective circuit sizes for each label
    print(f'Label: {label}, ECS (pruned): {ecs_pruned_all_labels[-1]}, ECS (pruned unclustered): {ecs_pruned_unclustered_all_labels[-1]}')

torch.save(ecs_pruned_all_labels, path + 'ecs_pruned_all_labels.pth')
torch.save(ecs_pruned_unclustered_all_labels, path + 'ecs_pruned_unclustered_all_labels.pth')

: 

In [None]:
path = 'results/CIFAR10/'

# load the effective circuit sizes
ecs_pruned_all_labels = torch.load(path + 'ecs_pruned_all_labels.pth')
ecs_pruned_unclustered_all_labels = torch.load(path + 'ecs_pruned_unclustered_all_labels.pth')

: 

In [None]:
ecs_pruned_all_labels, ecs_pruned_unclustered_all_labels

: 

In [None]:
# cifar label names
cifar_labels = ['airplane', 'auto.', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

: 

In [None]:
fig = go.Figure()

# Calculate percentage increase
percentage_increase = [(unclustered - clustered) / unclustered * 100 
                       for clustered, unclustered in zip(ecs_pruned_all_labels, ecs_pruned_unclustered_all_labels)]

fig.add_trace(go.Scatter(
    x=cifar_labels,
    y=ecs_pruned_all_labels,
    mode='lines+markers',
    name='Pruned clustered model',
    line=dict(color='darkblue', width=2),
    fill='tonexty'  # This will shade the area between the two plots
))

fig.add_trace(go.Scatter(
    x=cifar_labels,
    y=ecs_pruned_unclustered_all_labels,
    mode='lines+markers',
    name='Pruned unclustered model',
    line=dict(color='darkred', width=2),
    fill='tonexty'  # This will shade the area between the two plots
))

# thicker markers
fig.update_traces(marker=dict(size=10))

# Add percentage increase annotations
for i, (x, y, pct) in enumerate(zip(cifar_labels, ecs_pruned_unclustered_all_labels, percentage_increase)):
    fig.add_annotation(
        x=x,
        y=y,
        text=f"+{pct:.1f}%",
        showarrow=False,
        yshift=30,
        font=dict(
            size=20,
            color="black"
        ),
    )

# show a huge downward arrow from ecs_pruned_unclustered_all_labels to ecs_pruned_all_labels for each label
for i, (x, y1, y2) in enumerate(zip(cifar_labels, ecs_pruned_unclustered_all_labels, ecs_pruned_all_labels)):
    fig.add_annotation(
        x=x,
        y=(y1 + y2) / 2,
        text="",
        showarrow=True,
        arrowhead=2,
        arrowsize=float(abs(y1 - y2) / max(ecs_pruned_unclustered_all_labels) * 2),
        arrowwidth=2,
        arrowcolor='darkred',
        ax=0,
        ay=-50,
    )

fig.update_layout(
    title_text='.',
    xaxis_title_text='Label',
    yaxis_title_text='Effective Circuit Size (lesser is better)',
    width=800,  # Fixed width
    height=600,  # Fixed height
    plot_bgcolor='white',
    yaxis=dict(range=[0, max(max(ecs_pruned_all_labels), max(ecs_pruned_unclustered_all_labels)) * 1.1]),  # y-axis starts at 0
)

fig.update_xaxes(
    tickmode='linear',
    tick0=0,
    dtick=1,
    showgrid=True,
    gridwidth=1,
    gridcolor='LightGray'
)

fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGray')

# fig width and height
fig.update_layout(
    autosize=False,
    width=1000,
    height=550,
)

# increase font size of x ticks
fig.update_xaxes(tickfont=dict(size=14))
# make them at 45 degree angle
fig.update_xaxes(tickangle=0)

# legend on top
fig.update_layout(legend=dict(
    orientation="h",
    yanchor="bottom",
    y=1.02,
    xanchor="right",
    x=1
))

# latex font for everything (for research papers)

# fig.update_layout(
#     font=dict(
#         family="Computer Modern",
#         size=20,
#         color="Black"
#     )
# )

# everthing latex font (for research paper)
fig.update_layout(font=dict(family='serif', size=20, color='black'))
fig.update_xaxes(title_font=dict(family='serif', size=20, color='black'))
fig.update_yaxes(title_font=dict(family='serif', size=20, color='black'))
fig.update_xaxes(tickfont=dict(family='serif', size=18, color='black'))
fig.update_yaxes(tickfont=dict(family='serif', size=18, color='black'))
fig.update_xaxes(showline=True, linewidth=1, linecolor='black', mirror=False)

fig.show()

: 

Beautiful! 🤌

In [None]:
# save as pdf
figures_path = Path('figures')

fig.write_image(str(figures_path / 'ecs_pruned_all_labels.pdf'))
fig.write_image(str(figures_path / 'ecs_pruned_all_labels.jpg'), scale=5)

: 

In [None]:
print(classwise_accuracy(unclustered_model, test_loader))

: 

In [None]:
print(classwise_accuracy(model, test_loader))

: 

In [None]:
fig = make_subplots(rows=2, cols=5, subplot_titles=[f'Label {i}' for i in range(10)])

for i in range(10):
    img, _ = test_dataset[i]  # Assuming test_dataset returns a tuple (image, label)
    img = img.permute(1, 2, 0).cpu().numpy()  # Change shape from (C, H, W) to (H, W, C)
    img = (img * 255).astype(np.uint8)  # Ensure pixel values are in the range [0, 255]
    
    fig.add_trace(go.Image(z=img), row=(i // 5) + 1, col=(i % 5) + 1)

fig.update_layout(height=600, width=1000, title_text='One image for each label in CIFAR10')
fig.show()

: 

Thank you! Feel free to contact Satvik (zsatvik@gmail.com) for stuff related to this codebase.