## Clustering

In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import tqdm
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import mutual_info_score
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import clear_output
from collections import defaultdict
from itertools import islice
import random
import time
from pathlib import Path
import math

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [91]:
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128, bias=False)
        self.fc2 = nn.Linear(128, 128, bias=False)
        self.fc3 = nn.Linear(128, 10, bias=False)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [97]:
def accuracy(model, data):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in data:
            outputs = model(images.to(device))
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels.to(device)).sum().item()
    return correct / total

In [98]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # normalize the data
])

train_dataset = MNIST(root='.', train=True, download=True, transform=transform)
test_dataset = MNIST(root='.', train=False, download=True, transform=transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

In [99]:
model = MLP()
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3)
train_losses = []

In [100]:
for epoch in range(5):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())
        if batch_idx % 400 == 0:
            acc = accuracy(model, test_loader)
            print(f'Epoch {epoch+1}/{5}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}, Accuracy: {acc:.4f}')

Epoch 1/5, Batch 0/938, Loss: 2.2910, Accuracy: 0.1211
Epoch 1/5, Batch 400/938, Loss: 2.1990, Accuracy: 0.3780
Epoch 1/5, Batch 800/938, Loss: 2.0073, Accuracy: 0.5130
Epoch 2/5, Batch 0/938, Loss: 1.9565, Accuracy: 0.5420
Epoch 2/5, Batch 400/938, Loss: 1.7545, Accuracy: 0.6396
Epoch 2/5, Batch 800/938, Loss: 1.3991, Accuracy: 0.7526
Epoch 3/5, Batch 0/938, Loss: 1.2213, Accuracy: 0.7728
Epoch 3/5, Batch 400/938, Loss: 0.8774, Accuracy: 0.8091
Epoch 3/5, Batch 800/938, Loss: 0.7369, Accuracy: 0.8320
Epoch 4/5, Batch 0/938, Loss: 0.7429, Accuracy: 0.8351
Epoch 4/5, Batch 400/938, Loss: 0.6001, Accuracy: 0.8506
Epoch 4/5, Batch 800/938, Loss: 0.6224, Accuracy: 0.8591
Epoch 5/5, Batch 0/938, Loss: 0.4617, Accuracy: 0.8617
Epoch 5/5, Batch 400/938, Loss: 0.4582, Accuracy: 0.8702
Epoch 5/5, Batch 800/938, Loss: 0.5737, Accuracy: 0.8776


In [101]:
fig = go.Figure()
fig.add_trace(go.Scatter(y=train_losses, mode='lines', name='', line=dict(color='darkred', width=2)))
fig.update_layout({'plot_bgcolor': 'rgba(255, 255, 255, 1)',})
fig.update_layout(
    xaxis=dict(showgrid=True, gridwidth=1, gridcolor='LightGray'),
    yaxis=dict(showgrid=True, gridwidth=1, gridcolor='LightGray'),
)
fig.update_xaxes(title_text='Optimization Step')
fig.update_yaxes(title_text='CrossEntropy Loss')
fig.update_layout(width=600, height=400, autosize=False)
fig.show()

In [102]:
model

MLP(
  (fc1): Linear(in_features=784, out_features=128, bias=False)
  (fc2): Linear(in_features=128, out_features=128, bias=False)
  (fc3): Linear(in_features=128, out_features=10, bias=False)
)

In [103]:
import numpy as np
import numpy as np
from sklearn.cluster import KMeans
from scipy.sparse.linalg import svds

In [106]:
def spectral_clustering(model, k):
    
    A = model.fc2.weight.detach().cpu().numpy()
    A = np.abs(A)
    
    D_U = np.diag(np.sum(A, axis=1))
    D_V = np.diag(np.sum(A, axis=0))

    D_U_inv_sqrt = np.linalg.inv(np.sqrt(D_U))
    D_V_inv_sqrt = np.linalg.inv(np.sqrt(D_V))

    A_tilde = D_U_inv_sqrt @ A @ D_V_inv_sqrt

    U, Sigma, Vt = svds(A_tilde, k=k)

    kmeans_U = KMeans(n_clusters=k, random_state=42).fit(U)
    kmeans_V = KMeans(n_clusters=k, random_state=42).fit(Vt.T)

    labels_U = kmeans_U.labels_
    labels_V = kmeans_V.labels_

    # convert labels to indices
    cluster_U_indices = defaultdict(list)
    cluster_V_indices = defaultdict(list)
    for i, label in enumerate(labels_U):
        cluster_U_indices[label].append(i)
    for i, label in enumerate(labels_V):
        cluster_V_indices[label].append(i)

    return cluster_U_indices, cluster_V_indices

In [113]:
num_clusters = 2
cluster_U_indices, cluster_V_indices = spectral_clustering(model, num_clusters)

In [114]:
for i in range(num_clusters):
    print(f'Cluster {i} has {len(cluster_U_indices[i])} nodes in U and {len(cluster_V_indices[i])} nodes in V')

Cluster 0 has 65 nodes in U and 80 nodes in V
Cluster 1 has 63 nodes in U and 48 nodes in V


## How good is a cluster?

Now, we need a (possibly information-theoretic) way to evaluate how good this cluster is.