In [None]:
!pip install torch torchvision scikit-learn matplotlib

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score, davies_bouldin_score, adjusted_rand_score
from sklearn.manifold import TSNE
import numpy as np
import matplotlib.pyplot as plt

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

In [None]:
transform = transforms.ToTensor()
train_ds = datasets.MNIST(root='data', train=True, download=True, transform=transform)
test_ds  = datasets.MNIST(root='data', train=False, download=True, transform=transform)

batch_size = 256
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False)

In [None]:
class ConvAutoencoder(nn.Module):
    def __init__(self, latent_dim=10):
        super().__init__()
        # Encoder
        self.enc = nn.Sequential(
            nn.Conv2d(1,32,3,padding=1), nn.ReLU(True), nn.BatchNorm2d(32), nn.MaxPool2d(2),
            nn.Conv2d(32,64,3,padding=1), nn.ReLU(True), nn.BatchNorm2d(64), nn.MaxPool2d(2)
        )
        self.fc_enc = nn.Linear(64*7*7, latent_dim)
        self.fc_dec = nn.Linear(latent_dim, 64*7*7)
        # Decoder
        self.dec = nn.Sequential(
            nn.ConvTranspose2d(64,32,3,stride=2,padding=1,output_padding=1),
              nn.ReLU(True), nn.BatchNorm2d(32),
            nn.ConvTranspose2d(32,1,3,stride=2,padding=1,output_padding=1),
              nn.Sigmoid()
        )
    def forward(self, x):
        x = self.enc(x)                      # → (B,64,7,7)
        x = x.view(x.size(0), -1)            # → (B,64*7*7)
        z = self.fc_enc(x)                   # → (B,latent_dim)
        x = self.fc_dec(z).view(-1,64,7,7)   # → (B,64,7,7)
        out = self.dec(x)                    # → (B,1,28,28)
        return out, z

# Instantiate & count params
latent_dim = 10
ae = ConvAutoencoder(latent_dim).to(device)
print("Trainable params:", sum(p.numel() for p in ae.parameters() if p.requires_grad))

# Pre‐train AE
criterion = nn.MSELoss()
optimizer = optim.Adam(ae.parameters(), lr=1e-3, weight_decay=1e-5)
epochs_ae = 20

for ep in range(epochs_ae):
    ae.train()
    total_loss = 0
    for imgs, _ in train_loader:
        imgs = imgs.to(device)
        optimizer.zero_grad()
        recons, _ = ae(imgs)
        loss = criterion(recons, imgs)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()*imgs.size(0)
    print(f"AE Epoch {ep+1}/{epochs_ae} — Loss: {total_loss/len(train_ds):.4f}")

In [None]:
# Extract latent codes for all training data
ae.eval()
all_z = []
with torch.no_grad():
    for imgs, _ in train_loader:
        imgs = imgs.to(device)
        _, z = ae(imgs)
        all_z.append(z.cpu().numpy())
all_z = np.vstack(all_z)  # shape (60000, latent_dim)

# K-Means to get initial μ
n_clusters = 10
km = KMeans(n_clusters=n_clusters, random_state=0).fit(all_z)
mu = torch.tensor(km.cluster_centers_, dtype=torch.float, device=device)  # (10,latent_dim)

In [None]:
class DEC(nn.Module):
    def __init__(self, conv_encoder, fc_enc, mu, alpha=1.0):
        super().__init__()
        self.conv_encoder = conv_encoder
        self.fc_enc = fc_enc
        self.mu = nn.Parameter(mu)
        self.alpha = alpha

    def forward(self, x):
        h = self.conv_encoder(x)                    # (B,64,7,7)
        h = h.view(h.size(0), -1)                   # (B,64*7*7)
        z = self.fc_enc(h)                          # (B,latent_dim)
        # compute q exactly as before...
        dist = torch.sum((z.unsqueeze(1) - self.mu.unsqueeze(0))**2, dim=2)
        q = (1.0 + dist / self.alpha)**(- (self.alpha+1)/2)
        q = (q.t() / torch.sum(q, dim=1)).t()
        return z, q

# Target distribution P
def target_dist(q):
    weight = q**2 / torch.sum(q, dim=0, keepdim=True)
    return (weight.t() / torch.sum(weight, dim=1)).t()

dec = DEC(ae.enc, ae.fc_enc, mu).to(device)
optimizer_dec = optim.Adam(dec.parameters(), lr=1e-3)

In [None]:
epochs_dec = 30
tol = 0.001
prev_q = None

for ep in range(epochs_dec):
    dec.train()
    total_loss = 0
    # Accumulate Q over all data for convergence check
    all_q = []
    for imgs, _ in train_loader:
        imgs = imgs.to(device)
        _, q = dec(imgs)
        p = target_dist(q).detach()
        loss = torch.nn.functional.kl_div(q.log(), p, reduction='batchmean')
        optimizer_dec.zero_grad()
        loss.backward()
        optimizer_dec.step()
        total_loss += loss.item()*imgs.size(0)
        all_q.append(q.cpu().detach().numpy())
    all_q = np.vstack(all_q)
    # Check change in assignments
    if prev_q is not None:
        diff = np.linalg.norm(all_q - prev_q) / all_q.shape[0]
        if diff < tol:
            print(f"Converged at epoch {ep+1}")
            break
    prev_q = all_q
    print(f"DEC Epoch {ep+1}/{epochs_dec} — KL Loss: {total_loss/len(train_ds):.4f}")

In [None]:
# Extract final embeddings & hard assignments
dec.eval()
embeds, qs = [], []
labels_true = []
with torch.no_grad():
    for imgs, lbls in train_loader:
        imgs = imgs.to(device)
        z, q = dec(imgs)
        embeds.append(z.cpu().numpy())
        qs.append(q.cpu().numpy())
        labels_true.append(lbls.numpy())
embeds = np.vstack(embeds)
qs     = np.vstack(qs)
labels_true = np.concatenate(labels_true)

# Hard cluster = argmax_j q_ij
labels_dec = np.argmax(qs, axis=1)

# Metrics
sil = silhouette_score(embeds, labels_dec)
db  = davies_bouldin_score(embeds, labels_dec)
ari = adjusted_rand_score(labels_true, labels_dec)
print(f"DEC Clustering — Silhouette: {sil:.3f}, DBI: {db:.3f}, ARI: {ari:.3f}")

# t-SNE plot
tsne = TSNE(n_components=2, random_state=0, perplexity=30)
emb2d = tsne.fit_transform(embeds)

plt.figure(figsize=(8,6))
sc = plt.scatter(emb2d[:,0], emb2d[:,1], c=labels_dec, cmap='tab10', s=10)
plt.colorbar(sc, ticks=range(n_clusters))
plt.title('t-SNE of DEC Embeddings')
plt.show()