In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from datasets import load_dataset
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score, normalized_mutual_info_score, adjusted_rand_score
from sklearn.metrics import davies_bouldin_score, calinski_harabasz_score  # Fixed import
from sklearn.manifold import TSNE
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split  # Added for validation split

# Set seeds and device
torch.manual_seed(42)
np.random.seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

data = load_dataset("mnist")

# Preprocess
train_images = np.array([np.array(img).reshape(28, 28) for img in data["train"]["image"]]) / 255.0
test_images = np.array([np.array(img).reshape(28, 28) for img in data["test"]["image"]]) / 255.0
train_labels = np.array(data["train"]["label"])
test_labels = np.array(data["test"]["label"])

# Add channel dimension
train_tensor = torch.tensor(train_images[:, None, :, :], dtype=torch.float32)
test_tensor = torch.tensor(test_images[:, None, :, :], dtype=torch.float32)

# Split training data into train and validation
train_idx, val_idx = train_test_split(range(len(train_tensor)), test_size=0.2, random_state=42)
train_subset = train_tensor[train_idx]
val_tensor = train_tensor[val_idx]

# Dataset
batch_size = 128
train_loader = DataLoader(TensorDataset(train_subset), batch_size=batch_size, shuffle=True)
val_loader = DataLoader(TensorDataset(val_tensor), batch_size=batch_size, shuffle=False)

# Model
class TripletAutoencoderCNN(nn.Module):
    def __init__(self, latent_dim=64):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 3, stride=2, padding=1),  # 14x14
            nn.BatchNorm2d(32), nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),  # 7x7
            nn.BatchNorm2d(64), nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64*7*7, latent_dim)
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 64*7*7), nn.ReLU(),
            nn.Unflatten(1, (64, 7, 7)),
            nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(32), nn.ReLU(),
            nn.ConvTranspose2d(32, 1, 3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        z = self.encoder(x)
        z = F.normalize(z, p=2, dim=1)  # Normalize latent vectors
        x_hat = self.decoder(z)
        return x_hat, z

model = TripletAutoencoderCNN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
recon_loss = nn.MSELoss()

# Triplet mining (unsupervised)
def mine_triplets(data_tensor, model, n_samples=10000):
    model.eval()
    with torch.no_grad():
        _, embeddings = model(data_tensor.to(device))
        embeddings = embeddings.cpu().numpy()
    triplets = []
    for i in range(n_samples):
        anchor = embeddings[i]
        dists = np.linalg.norm(embeddings - anchor, axis=1)
        pos_idx = np.argmin(np.where(dists == 0, np.inf, dists))
        neg_pool = np.where(dists > 0.5)[0]
        if len(neg_pool) == 0:
            continue
        neg_idx = neg_pool[np.argmax(dists[neg_pool])]
        triplets.append((data_tensor[i], data_tensor[pos_idx], data_tensor[neg_idx]))
    return triplets

# Triplet Loss
def triplet_loss(a, p, n, margin=1.0):
    return F.relu(torch.norm(a - p, dim=1) - torch.norm(a - n, dim=1) + margin).mean()

# Phase 1: Train autoencoder
for epoch in range(12):
    model.train()
    total_train_loss = 0
    for (x,) in train_loader:
        x = x.to(device)
        x_hat, z = model(x)
        loss = recon_loss(x_hat, x)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()

    # Validation step
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for (x,) in val_loader:
            x = x.to(device)
            x_hat, _ = model(x)
            total_val_loss += recon_loss(x_hat, x).item()
    print(f"Epoch {epoch+1}, Train Recon Loss: {total_train_loss/len(train_loader):.4f}, Val Recon Loss: {total_val_loss/len(val_loader):.4f}")

# Phase 2: Triplet training
triplets = mine_triplets(train_subset[:20000], model)  # Use train_subset
val_triplets = mine_triplets(val_tensor[:5000], model, n_samples=2000)  # Validation triplets
for epoch in range(5):
    model.train()
    total_train_loss = 0
    for (a, p, n) in triplets:
        a, p, n = a.unsqueeze(0).to(device), p.unsqueeze(0).to(device), n.unsqueeze(0).to(device)
        _, a_z = model(a)
        _, p_z = model(p)
        _, n_z = model(n)
        loss = triplet_loss(a_z, p_z, n_z)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()

    # Validation step
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for (a, p, n) in val_triplets:
            a, p, n = a.unsqueeze(0).to(device), p.unsqueeze(0).to(device), n.unsqueeze(0).to(device)
            _, a_z = model(a)
            _, p_z = model(p)
            _, n_z = model(n)
            total_val_loss += triplet_loss(a_z, p_z, n_z).item()
    print(f"Triplet Epoch {epoch+1}, Train Triplet Loss: {total_train_loss/len(triplets):.4f}, Val Triplet Loss: {total_val_loss/len(val_triplets):.4f}")

# Evaluate
model.eval()
with torch.no_grad():
    _, z_test = model(test_tensor.to(device))
    z_test = z_test.cpu().numpy()

kmeans = KMeans(n_clusters=10, random_state=42).fit(z_test)
preds = kmeans.labels_

# In evaluation section, after KMeans
print("\nEvaluation on MNIST:")
print(f"Silhouette Score: {silhouette_score(z_test, preds):.4f}")
print(f"Davies-Bouldin Index: {davies_bouldin_score(z_test, preds):.4f}")
print(f"Calinski-Harabasz Index: {calinski_harabasz_score(z_test, preds):.4f}")
print(f"NMI: {normalized_mutual_info_score(test_labels, preds):.4f}")
print(f"ARI: {adjusted_rand_score(test_labels, preds):.4f}")

# t-SNE
# t-SNE visualization (full test set)
tsne = TSNE(n_components=2, random_state=42)
z_2d = tsne.fit_transform(z_test)
sns.scatterplot(x=z_2d[:,0], y=z_2d[:,1], hue=preds, palette="tab10", s=10)
plt.title("t-SNE of Clustered Embeddings (Full Test Set)")
plt.show()
# -------------------------------
# Parameter Count
# -------------------------------
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

total_params = count_parameters(model)
print(f"\nTotal Trainable Parameters: {total_params}")
