## Setup

In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import tqdm
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import mutual_info_score
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import clear_output
from collections import defaultdict
from itertools import islice
import random
import time
from pathlib import Path
import math

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

def randomseed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128, bias=False)
        self.fc2 = nn.Linear(128, 128, bias=False)
        self.fc3 = nn.Linear(128, 10, bias=False)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [8]:
# load the model and cluster indices
model = MLP()
model.load_state_dict(torch.load('model_cluster.pt'))
cluster_U_indices = torch.load('cluster_U_indices.pt')
cluster_V_indices = torch.load('cluster_V_indices.pt')
model.to(device)

MLP(
  (fc1): Linear(in_features=784, out_features=128, bias=False)
  (fc2): Linear(in_features=128, out_features=128, bias=False)
  (fc3): Linear(in_features=128, out_features=10, bias=False)
)

## Interpreting Clusters

In [25]:
def visualize_fc2_layer(layer):
    weights = layer.weight.detach().cpu().numpy() 
    num_outputs, num_inputs = weights.shape

    node_x = []
    node_y = []
    node_text = []
    edge_x = []
    edge_y = []
    edge_color = []
    edge_width = []

    input_layer_positions = np.linspace(-1, 1, num_inputs)
    output_layer_positions = np.linspace(-1, 1, num_outputs)
    
    for i in range(num_inputs):
        node_x.append(input_layer_positions[i])
        node_y.append(-1)
        node_text.append(f"I{i+1}")

    for j in range(num_outputs):
        node_x.append(output_layer_positions[j])
        node_y.append(1)
        node_text.append(f"O{j+1}")

    # Add edges
    for i in range(num_inputs):
        for j in range(num_outputs):
            weight = weights[j, i]
            color = 'red' if weight < 0 else 'blue'
            thickness = abs(weight) * 1
            edge_x.extend([input_layer_positions[i], output_layer_positions[j]])
            edge_y.extend([-1, 1])
            edge_color.append(color)
            edge_width.append(thickness)

    fig = go.Figure()
    
    fig.add_trace(go.Scatter(
        x=node_x, 
        y=node_y,
        mode='markers+text',
        text=node_text,
        textposition='top center',
        marker=dict(size=10, color='lightgray'),
        showlegend=False
    ))
    
    for i in range(0, len(edge_x), 2):
        fig.add_trace(go.Scatter(
            x=[edge_x[i], edge_x[i+1]],
            y=[edge_y[i], edge_y[i+1]],
            mode='lines',
            line=dict(color=edge_color[i//2], width=edge_width[i//2]),
            showlegend=False
        ))

    # Update layout
    fig.update_layout(
        title="fc2 Layer Visualization",
        xaxis=dict(showgrid=False, zeroline=False),
        yaxis=dict(showgrid=False, zeroline=False, range=[-1.5, 1.5]),
        showlegend=False
    )

    # white background
    fig.update_layout(plot_bgcolor='white')
    # square aspect ratio
    fig.update_layout(
        autosize=False,
        width=1200,
        height=900,
    )

    fig.show()

In [26]:
def reorder_network(model, cluster_U_indices, cluster_V_indices):
    # Retrieve the current weights of the layer
    weights = model.fc2.weight.detach().cpu().numpy()
    
    # Flatten indices from clusters
    U_indices = [idx for cluster in cluster_U_indices.values() for idx in cluster]
    V_indices = [idx for cluster in cluster_V_indices.values() for idx in cluster]
    
    # Create new ordering for the weights
    new_weights = np.zeros_like(weights)
    
    # Reorder rows of weights matrix
    for i, old_idx in enumerate(U_indices):
        new_weights[i, :] = weights[old_idx, :]
    
    # Reorder columns of weights matrix
    reordered_weights = np.zeros_like(weights)
    for j, old_idx in enumerate(V_indices):
        reordered_weights[:, j] = new_weights[:, old_idx]
    
    # Update model with new weights
    model.fc2.weight.data = torch.tensor(reordered_weights, dtype=torch.float32)

    return model

In [None]:
model = reorder_network(model, cluster_U_indices, cluster_V_indices)

In [27]:
visualize_fc2_layer(model.fc2)

Nice!

## Mech. Interp

In [20]:
model

MLP(
  (fc1): Linear(in_features=784, out_features=128, bias=False)
  (fc2): Linear(in_features=128, out_features=128, bias=False)
  (fc3): Linear(in_features=128, out_features=10, bias=False)
)

In [21]:
model.fc3.weight.data.shape

torch.Size([10, 128])

In [28]:
visualize_fc2_layer(model.fc3)