In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import tqdm
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import MNIST, CIFAR10
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import mutual_info_score
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import plotly.express as px
import plotly.graph_objects as go
import plotly.colors as pc
from plotly.subplots import make_subplots
from IPython.display import clear_output
from collections import defaultdict
from itertools import islice
import random
import time
from pathlib import Path
import math

from sklearn.cluster import KMeans
from scipy.sparse.linalg import svds

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

def randomseed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [2]:
dataset = 'CIFAR10' # 'MNIST' or 'CIFAR10'

if dataset == 'MNIST':
    transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
    train_dataset = MNIST(root='.', train=True, download=True, transform=transform)
    test_dataset = MNIST(root='.', train=False, download=True, transform=transform)
elif dataset == 'CIFAR10':
    transform = torchvision.transforms.ToTensor()
    train_dataset = CIFAR10(root='.', train=True, download=True, transform=transform)
    test_dataset = CIFAR10(root='.', train=False, download=True, transform=transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
# for MNIST

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 64, bias=False)
        self.fc2 = nn.Linear(64, 64, bias=False)
        self.fc3 = nn.Linear(64, 64, bias=False) # added
        self.fc4 = nn.Linear(64, 10, bias=False)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = self.fc4(x)
        return x
    
# for CIFAR10

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.fc1 = nn.Linear(64 * 8 * 8, 128)  # Adjusted the input size here
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.bn1(self.conv2(x)))
        x = F.max_pool2d(x, 2)  # Reduces dimensions by half (32x32 -> 16x16)
        x = F.max_pool2d(x, 2)  # Reduces dimensions further (16x16 -> 8x8)
        x = x.view(-1, 64 * 8 * 8)  # Flatten properly
        x = F.relu(self.fc1(x))
        x = self.dropout(F.relu(self.fc2(x)))
        x = self.fc3(x)
        return x
    
def new_model(dataset, device):
    if dataset == 'MNIST':
        model = MLP()
    elif dataset == 'CIFAR10':
        model = CNN()
    model = model.to(device)
    return model        

In [7]:
def accuracy(model, data):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in data:
            outputs = model(images.to(device))
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels.to(device)).sum().item()
    return correct / total

def classwise_accuracy(model, data):
    model.eval()
    correct = defaultdict(int)
    total = defaultdict(int)
    with torch.no_grad():
        for images, labels in data:
            outputs = model(images.to(device))
            _, predicted = torch.max(outputs.data, 1)
            for i in range(len(labels)):
                label = labels[i].item()
                total[label] += 1
                correct[label] += int(predicted[i] == label)
    return [round(correct[i] / total[i], 3) if total[i] > 0 else 0 for i in range(10)]

def clusterability(matrix, auto_index=True, cluster_U_indices=None, cluster_V_indices=None, num_clusters=4):

    if auto_index:
        cluster_size = (matrix.shape[0] // num_clusters, matrix.shape[1] // num_clusters)
        cluster_U_indices = {i: list(range(i*cluster_size[0], (i+1)*cluster_size[0])) for i in range(num_clusters)}
        cluster_V_indices = {i: list(range(i*cluster_size[1], (i+1)*cluster_size[1])) for i in range(num_clusters)}

    num_clusters = len(cluster_U_indices)
    A = matrix ** 2
    mask = torch.zeros_like(A, dtype=torch.bool)
    
    for cluster_idx in range(num_clusters):
        u_indices = torch.tensor(cluster_U_indices[cluster_idx], dtype=torch.long)
        v_indices = torch.tensor(cluster_V_indices[cluster_idx], dtype=torch.long)
        mask[u_indices.unsqueeze(1), v_indices] = True
    
    intra_cluster_out_sum = torch.sum(A[mask])
    total_out_sum = torch.sum(A)
    
    return intra_cluster_out_sum / total_out_sum

In [5]:
dataset, device

('CIFAR10', device(type='cuda', index=0))

In [16]:
unclustered_model = new_model(dataset, device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(unclustered_model.parameters(), lr=1e-3)
train_losses = []

In [17]:
randomseed(42)
path = Path(f'checkpoints/{dataset}/')
path.mkdir(parents=True, exist_ok=True)

In [18]:
# print starting accuracy and loss
acc = accuracy(unclustered_model, test_loader)
print(f'Starting Accuracy: {acc:.4f}')

for epoch in range(10):
    unclustered_model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = unclustered_model(data)
        train_loss = criterion(output, target)
        train_losses.append(train_loss.item())

        ## CLUSTERABILITY

        fc1_c = clusterability(unclustered_model.fc1.weight)
        fc2_c = clusterability(unclustered_model.fc2.weight)
        fc3_c = clusterability(unclustered_model.fc3.weight)

        cluster_loss = (fc1_c + fc2_c + fc3_c) / 3

        ## END CLUSTERABILITY

        loss = train_loss - (20 * cluster_loss)

        loss.backward()
        optimizer.step()
    acc = accuracy(unclustered_model, test_loader)
    print(f'Epoch {epoch+1}/{10}, Train Loss: {train_loss.item():.4f}, Accuracy: {acc:.4f}, Clusterability: {cluster_loss.item():.4f}')
    # save model
torch.save(unclustered_model.state_dict(), path / 'fc_clustered_model.pth')

Starting Accuracy: 0.1000
Epoch 1/10, Train Loss: 1.0746, Accuracy: 0.4836, Clusterability: 0.9955
Epoch 2/10, Train Loss: 1.3663, Accuracy: 0.6238, Clusterability: 0.9957
Epoch 3/10, Train Loss: 1.1802, Accuracy: 0.6410, Clusterability: 0.9963
Epoch 4/10, Train Loss: 1.2017, Accuracy: 0.6479, Clusterability: 0.9965
Epoch 5/10, Train Loss: 1.0795, Accuracy: 0.6648, Clusterability: 0.9968
Epoch 6/10, Train Loss: 0.8805, Accuracy: 0.6652, Clusterability: 0.9968
Epoch 7/10, Train Loss: 0.7450, Accuracy: 0.6826, Clusterability: 0.9970
Epoch 8/10, Train Loss: 0.5133, Accuracy: 0.7020, Clusterability: 0.9972
Epoch 9/10, Train Loss: 0.7286, Accuracy: 0.7090, Clusterability: 0.9973
Epoch 10/10, Train Loss: 1.4023, Accuracy: 0.7166, Clusterability: 0.9975


In [10]:
# load model
unclustered_model.load_state_dict(torch.load(path / 'unclustered_model.pth'))

  unclustered_model.load_state_dict(torch.load(path / 'unclustered_model.pth'))


<All keys matched successfully>

In [11]:
unclustered_model.fc1.weight.shape, unclustered_model.fc2.weight.shape, unclustered_model.fc3.weight.shape

(torch.Size([128, 4096]), torch.Size([64, 128]), torch.Size([10, 64]))

In [12]:
# clusterability of unclustered model
clusterability(unclustered_model.fc1.weight), clusterability(unclustered_model.fc2.weight), clusterability(unclustered_model.fc3.weight)

(tensor(0.2404, device='cuda:0', grad_fn=<DivBackward0>),
 tensor(0.2502, device='cuda:0', grad_fn=<DivBackward0>),
 tensor(0.2164, device='cuda:0', grad_fn=<DivBackward0>))

In [29]:
num_clusters = 4
block = unclustered_model.fc1.weight
cluster_size = (block.shape[0] // num_clusters, block.shape[1] // num_clusters)
cluster_U_indices = {i: list(range(i*cluster_size[0], (i+1)*cluster_size[0])) for i in range(num_clusters)}
cluster_V_indices = {i: list(range(i*cluster_size[1], (i+1)*cluster_size[1])) for i in range(num_clusters)}

In [30]:
clusterability(block, cluster_U_indices, cluster_V_indices)

tensor(0.2502, device='cuda:0', grad_fn=<DivBackward0>)

In [31]:
unclustered_model = new_model(dataset, device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(unclustered_model.parameters(), lr=1e-3)
train_losses = []
cluster_losses = []

In [35]:
randomseed(42)
path = Path(f'checkpoints/{dataset}/')
path.mkdir(parents=True, exist_ok=True)
unclustered_model = new_model(dataset, device)

# print starting accuracy and loss
acc = accuracy(unclustered_model, test_loader)
print(f'Starting Accuracy: {acc:.4f}')
# cluster_loss = clusterability(unclustered_model.fc1.weight, cluster_U_indices, cluster_V_indices)
# print(f'Starting Cluster Loss: {cluster_loss:.4f}')

for epoch in range(10):
    unclustered_model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = unclustered_model(data)
        train_loss = criterion(output, target)
        train_losses.append(train_loss.item())
        cluster_loss = clusterability(unclustered_model.fc1.weight, cluster_U_indices, cluster_V_indices)
        cluster_losses.append(cluster_loss.item())
        loss = train_loss - (20 * cluster_loss)
        loss.backward()
        optimizer.step()
    acc = accuracy(unclustered_model, test_loader)
    print(f'Epoch {epoch+1}/{10}, Train Loss: {train_losses[-1].item():.4f}, Accuracy: {acc:.4f}, Cluster Loss: {cluster_losses[-1].item():.4f}')
    # save model
torch.save(unclustered_model.state_dict(), path / 'fc1_clustered_model.pth')

Starting Accuracy: 0.1000


AttributeError: 'float' object has no attribute 'item'