In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, Dataset
from datasets import load_dataset
import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score, davies_bouldin_score
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import seaborn as sns
from minisom import MiniSom

torch.manual_seed(42)
np.random.seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_fashion_mnist_dataset(max_samples=10000):
    dataset = load_dataset("zalando-datasets/fashion_mnist", split="train")
    images = np.array([np.array(img["image"]).flatten() for img in dataset]) / 255.0
    labels = np.array([img["label"] for img in dataset]) 
    images = images[:max_samples]
    labels = labels[:max_samples]
    noise = np.random.normal(0, 0.02, images.shape)
    images = np.clip(images + noise, 0, 1)
    return torch.FloatTensor(images), torch.LongTensor(labels)

class ContrastiveDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        anchor_img = self.images[idx]
        anchor_label = self.labels[idx]
        pos_indices = torch.where(self.labels == anchor_label)[0]
        pos_indices = pos_indices[pos_indices != idx] 
        if len(pos_indices) == 0: 
            pos_idx = idx
        else:
            pos_idx = pos_indices[torch.randint(0, len(pos_indices), (1,))]
        pos_img = self.images[pos_idx]
        neg_indices = torch.where(self.labels != anchor_label)[0]
        neg_idx = neg_indices[torch.randint(0, len(neg_indices), (1,))]
        neg_img = self.images[neg_idx]
        if len(neg_indices) > 50:
            random_neg_indices = neg_indices[torch.randint(0, len(neg_indices), (50,))]
            random_neg_samples = self.images[random_neg_indices]
            distances = torch.sum((anchor_img.unsqueeze(0) - random_neg_samples) ** 2, dim=1)
            hard_neg_idx = random_neg_indices[torch.argmin(distances)]
            hard_neg_img = self.images[hard_neg_idx]
            if torch.rand(1) > 0.5:
                neg_img = hard_neg_img
        
        return anchor_img, pos_img, neg_img, anchor_label

class ImprovedDEC(nn.Module):
    def __init__(self, input_dim=784, hidden_dim1=512, hidden_dim2=256, hidden_dim3=128, 
                 hidden_dim4=64, embed_dim=128, n_clusters=10, dropout_rate=0.2):
        super(ImprovedDEC, self).__init__()
        
        # Encoder
        self.encoder_layer1 = nn.Linear(input_dim, hidden_dim1)
        self.encoder_layer2 = nn.Linear(hidden_dim1, hidden_dim2)
        self.encoder_layer3 = nn.Linear(hidden_dim2, hidden_dim3)
        self.encoder_layer4 = nn.Linear(hidden_dim3, hidden_dim4)
        self.encoder_layer5 = nn.Linear(hidden_dim4, embed_dim)
        
        # Layer normalization
        self.layernorm1 = nn.LayerNorm(hidden_dim1)
        self.layernorm2 = nn.LayerNorm(hidden_dim2)
        self.layernorm3 = nn.LayerNorm(hidden_dim3)
        self.layernorm4 = nn.LayerNorm(hidden_dim4)
        
        # Decoder
        self.decoder_layer1 = nn.Linear(embed_dim, hidden_dim4)
        self.decoder_layer2 = nn.Linear(hidden_dim4, hidden_dim3)
        self.decoder_layer3 = nn.Linear(hidden_dim3, hidden_dim2)
        self.decoder_layer4 = nn.Linear(hidden_dim2, hidden_dim1)
        self.decoder_layer5 = nn.Linear(hidden_dim1, input_dim)
        
        self.dropout = nn.Dropout(dropout_rate)
        self.n_clusters = n_clusters
        self.cluster_centers = nn.Parameter(torch.randn(n_clusters, embed_dim))
        
    def encode(self, x):
        x1 = F.relu(self.encoder_layer1(x))
        x1 = self.layernorm1(x1)
        x1 = self.dropout(x1)
        
        x2 = F.relu(self.encoder_layer2(x1))
        x2 = self.layernorm2(x2)
        x2 = self.dropout(x2)
        
        x3 = F.relu(self.encoder_layer3(x2))
        x3 = self.layernorm3(x3)
        x3 = self.dropout(x3)
        
        x4 = F.relu(self.encoder_layer4(x3))
        x4 = self.layernorm4(x4)
        x4 = self.dropout(x4)
        
        z = F.relu(self.encoder_layer5(x4))
        return z
        
    def decode(self, z):
        x4 = F.relu(self.decoder_layer1(z))
        x3 = F.relu(self.decoder_layer2(x4))
        x2 = F.relu(self.decoder_layer3(x3))
        x1 = F.relu(self.decoder_layer4(x2))
        x_hat = torch.sigmoid(self.decoder_layer5(x1))
        return x_hat
        
    def forward(self, x):
        z = self.encode(x)
        x_hat = self.decode(z)
        q = 1.0 / (1.0 + (torch.sum((z.unsqueeze(1) - self.cluster_centers)**2, dim=2) / 2))
        q = q ** 2 
        q = (q.t() / torch.sum(q, dim=1)).t()
        
        return x_hat, q, z

def enhanced_triplet_loss(anchor, positive, negative, margin=1.0, beta=0.2):
    pos_dist = torch.sum((anchor - positive) ** 2, dim=1)
    neg_dist = torch.sum((anchor - negative) ** 2, dim=1)
    scaled_pos_dist = (1 + beta) * pos_dist
    scaled_neg_dist = (1 - beta) * neg_dist
    base_loss = torch.clamp(scaled_pos_dist - scaled_neg_dist + margin, min=0.0)
    weight = torch.exp(base_loss)
    weighted_loss = base_loss * weight
    
    return torch.mean(weighted_loss) / torch.mean(weight)

def target_distribution(q):
    weight = q ** 2 / q.sum(0)
    return (weight.t() / weight.sum(1)).t()

def cluster_separation_loss(embeddings, cluster_centers, cluster_assignments):
    batch_size = embeddings.size(0)
    n_clusters = cluster_centers.size(0)
    
    assigned_centers = cluster_centers[cluster_assignments]
    assigned_distances = torch.sum((embeddings - assigned_centers) ** 2, dim=1)
    expanded_embeddings = embeddings.unsqueeze(1).expand(batch_size, n_clusters, -1)
    expanded_centers = cluster_centers.unsqueeze(0).expand(batch_size, n_clusters, -1)
    all_distances = torch.sum((expanded_embeddings - expanded_centers) ** 2, dim=2)
    
    mask = torch.ones_like(all_distances, device=embeddings.device)
    mask.scatter_(1, cluster_assignments.unsqueeze(1), 0)
    
    wrong_distances = all_distances + (1 - mask) * 1e6
    closest_wrong_distances = torch.min(wrong_distances, dim=1)[0]
    margin = 2.0 
    loss = torch.mean(assigned_distances - closest_wrong_distances + margin)
    return torch.clamp(loss, min=0.0)

def between_cluster_variance_loss(embeddings, cluster_assignments, n_clusters):
    cluster_means = []
    for i in range(n_clusters):
        cluster_mask = (cluster_assignments == i)
        if torch.sum(cluster_mask) > 0: 
            cluster_mean = torch.mean(embeddings[cluster_mask], dim=0)
            cluster_means.append(cluster_mean)
    
    if len(cluster_means) <= 1:
        return torch.tensor(0.0, device=embeddings.device)
    
    cluster_means = torch.stack(cluster_means)
    n_valid_clusters = cluster_means.size(0)
    expanded_means1 = cluster_means.unsqueeze(1).expand(n_valid_clusters, n_valid_clusters, -1)
    expanded_means2 = cluster_means.unsqueeze(0).expand(n_valid_clusters, n_valid_clusters, -1)
    pairwise_distances = torch.sum((expanded_means1 - expanded_means2) ** 2, dim=2)
    
    mask = 1.0 - torch.eye(n_valid_clusters, device=embeddings.device)
    
    between_variance = torch.sum(pairwise_distances * mask) / (n_valid_clusters * (n_valid_clusters - 1))
    return -between_variance

def main():
    print("Loading Fashion-MNIST dataset...")
    data, labels = load_fashion_mnist_dataset(max_samples=10000)
    model = ImprovedDEC(
        input_dim=784, 
        hidden_dim1=512, 
        hidden_dim2=256, 
        hidden_dim3=128, 
        hidden_dim4=64, 
        embed_dim=128, 
        n_clusters=10,  
        dropout_rate=0.2
    ).to(device)
    
    data = data.to(device)
    labels = labels.to(device)
    
    standard_loader = DataLoader(TensorDataset(data), batch_size=128, shuffle=True)
    contrastive_dataset = ContrastiveDataset(data, labels)
    contrastive_loader = DataLoader(
    contrastive_dataset, 
    batch_size=64, 
    shuffle=True, 
    drop_last=True,
    collate_fn=contrastive_collate_fn
    )
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5
    )
    
    print("\nPretraining autoencoder...")
    model.train()
    best_recon_loss = float('inf')
    patience_counter = 0
    max_patience = 10
    
    for epoch in range(100): 
        total_loss = 0
        for batch in standard_loader:
            batch_data = batch[0].to(device)
            optimizer.zero_grad()
            x_hat, _, z = model(batch_data)
            
            recon_loss = F.mse_loss(x_hat, batch_data)
            
            l2_reg = 0.001 * torch.mean(torch.sum(z**2, dim=1))
            
            loss = recon_loss + l2_reg
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            
        avg_loss = total_loss / len(standard_loader)
        print(f"Pretrain Epoch [{epoch+1}/100], Recon Loss: {avg_loss:.4f}")
        
        if avg_loss < best_recon_loss:
            best_recon_loss = avg_loss
            patience_counter = 0
        else:
            patience_counter += 1
            
        if patience_counter >= max_patience:
            print(f"Early stopping at epoch {epoch+1}")
            break
            
        scheduler.step(avg_loss)
    
    print("\nPretraining with enhanced contrastive loss...")
    for epoch in range(50): 
        total_contrastive_loss = 0
        batch_count = 0
        
        for anchor_img, pos_img, neg_img, _ in contrastive_loader:
            anchor_img = anchor_img.to(device)
            pos_img = pos_img.to(device)
            neg_img = neg_img.to(device)
            
            batch_size = anchor_img.size(0)
            if batch_size <= 1: 
                continue
            anchor_embed = model.encode(anchor_img)
            pos_embed = model.encode(pos_img)
            neg_embed = model.encode(neg_img)
            
            dynamic_margin = 1.0 + 0.02 * epoch 
            loss = enhanced_triplet_loss(
                anchor_embed, pos_embed, neg_embed, 
                margin=dynamic_margin, 
                beta=0.1 + 0.01 * epoch 
            )
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_contrastive_loss += loss.item()
            batch_count += 1
            
        if batch_count > 0:
            print(f"Contrastive Epoch [{epoch+1}/50], Triplet Loss: {total_contrastive_loss/batch_count:.4f}")
    
    print("\nInitializing cluster centers with KMeans...")
    model.eval()
    with torch.no_grad():
        z_init = model.encode(data).cpu().numpy()
    
    scaler = StandardScaler()
    z_init_scaled = scaler.fit_transform(z_init)
    
    best_score = -1
    best_kmeans = None
    for i in range(5): 
        kmeans = KMeans(n_clusters=model.n_clusters, n_init=10, random_state=42+i).fit(z_init_scaled)
        if len(np.unique(kmeans.labels_)) > 1: 
            score = silhouette_score(z_init_scaled, kmeans.labels_)
            if score > best_score:
                best_score = score
                best_kmeans = kmeans
    
    if best_kmeans is None:
        best_kmeans = KMeans(n_clusters=model.n_clusters, n_init=20, random_state=42).fit(z_init_scaled)
    kmeans_centers_scaled = best_kmeans.cluster_centers_
    if hasattr(scaler, 'scale_') and hasattr(scaler, 'mean_'):
        kmeans_centers = kmeans_centers_scaled * scaler.scale_ + scaler.mean_
    else:
        kmeans_centers = kmeans_centers_scaled
    model.cluster_centers.data = torch.tensor(kmeans_centers, dtype=torch.float32).to(device)
    print("\nTraining DEC with enhanced clustering and contrastive loss...")
    max_iter = 300  
    update_interval = 5
    tol = 5e-4 
    y_pred_last = best_kmeans.labels_
    best_silhouette = -1
    best_model_state = None
    
    for iteration in range(max_iter):
        if iteration % update_interval == 0:
            model.eval()
            with torch.no_grad():
                _, q_all, z_final = model(data)
            p = target_distribution(q_all).detach()
            y_pred = q_all.cpu().numpy().argmax(axis=1)
            z_np = z_final.cpu().numpy()
            
            if len(np.unique(y_pred)) > 1:
                current_silhouette = silhouette_score(z_np, y_pred)
                print(f"Iter {iteration}: Silhouette score = {current_silhouette:.4f}")
                
                if current_silhouette > best_silhouette:
                    best_silhouette = current_silhouette
                    best_model_state = model.state_dict().copy()
            
            delta = np.sum(y_pred != y_pred_last) / len(y_pred)
            print(f"Iter {iteration}: Label change rate = {delta:.4f}")
            if delta < tol and iteration > 50: 
                print("Converged.")
                break
            y_pred_last = y_pred
        
        model.train()
        total_loss = 0
        
        for batch in standard_loader:
            batch_data = batch[0].to(device)
            optimizer.zero_grad()
            x_hat, q_batch, z_batch = model(batch_data)
            
            p_batch = p[:batch_data.size(0)]
            cluster_assignments = torch.argmax(q_batch, dim=1)
            recon_loss = F.mse_loss(x_hat, batch_data)
            kl_loss = F.kl_div(q_batch.log(), p_batch.to(device), reduction='batchmean')
            
            sep_loss = cluster_separation_loss(z_batch, model.cluster_centers, cluster_assignments)
            
            bv_loss = between_cluster_variance_loss(z_batch, cluster_assignments, model.n_clusters)
            
            clustering_weight = min(0.01 * (iteration / 30), 0.1)  
            sep_weight = min(0.005 * (iteration / 50), 0.05)
            bv_weight = min(0.005 * (iteration / 50), 0.05)
            
            loss = recon_loss + clustering_weight * kl_loss + sep_weight * sep_loss + bv_weight * bv_loss
            
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        if iteration % 3 == 0:
            contrastive_total_loss = 0
            batch_count = 0
            
            for anchor_img, pos_img, neg_img, _ in contrastive_loader:
                anchor_img = anchor_img.to(device)
                pos_img = pos_img.to(device)
                neg_img = neg_img.to(device)
                
                batch_size = anchor_img.size(0)
                if batch_size <= 1:
                    continue
                
                
                anchor_embed = model.encode(anchor_img)
                pos_embed = model.encode(pos_img)
                neg_embed = model.encode(neg_img)
                
                dynamic_margin = 1.0 + 0.01 * iteration / 10
                contr_loss = enhanced_triplet_loss(
                    anchor_embed, pos_embed, neg_embed, 
                    margin=dynamic_margin,
                    beta=0.2
                )
                
                optimizer.zero_grad()
                contr_loss.backward()
                optimizer.step()
                contrastive_total_loss += contr_loss.item()
                batch_count += 1
                
            if batch_count > 0 and iteration % 10 == 0:
                print(f"Iter {iteration}: Contrastive Loss = {contrastive_total_loss/batch_count:.4f}")
            
        if iteration % update_interval == 0:
            print(f"Iter {iteration}: Loss = {total_loss/len(standard_loader):.4f}")
    
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        print(f"Loaded best model with silhouette score: {best_silhouette:.4f}")
    
    torch.save(model.state_dict(), "improved_dec_model.pt")
    
    model.eval()
    with torch.no_grad():
        _, q_final, z_final = model(data)
        cluster_labels = q_final.cpu().numpy().argmax(axis=1)
        z_np = z_final.cpu().numpy()
    
    print(f"\nFinal Evaluation:")
    print(f"Latent space std: {z_np.std(axis=0).mean():.4f}")
    
    if len(np.unique(cluster_labels)) > 1:
        final_silhouette = silhouette_score(z_np, cluster_labels)
        final_db = davies_bouldin_score(z_np, cluster_labels)
        print(f"DEC Clustering - Silhouette Score: {final_silhouette:.4f}, Davies-Bouldin Index: {final_db:.4f}")
    else:
        print("DEC Clustering - Warning: Only one cluster found")

    visualize_clusters(z_np, cluster_labels, 
                      "Improved DEC Clusters (t-SNE)", 
                      "improved_dec_fashion_mnist_clusters.png",
                      sample_size=2000)  
    
    print("\nTrying different K-Means configurations for comparison...")
    for n_clusters in [5, 8, 10, 12]:
        kmeans = KMeans(n_clusters=n_clusters, n_init=20, random_state=42).fit(z_np)
        kmeans_labels = kmeans.labels_
        if len(np.unique(kmeans_labels)) > 1:
            sil_score = silhouette_score(z_np, kmeans_labels)
            db_score = davies_bouldin_score(z_np, kmeans_labels)
            print(f"K-Means (k={n_clusters}) - Silhouette: {sil_score:.4f}, Davies-Bouldin: {db_score:.4f}")
            
            if n_clusters == 10:  
                visualize_clusters(z_np, kmeans_labels, 
                                  f"K-Means (k={n_clusters}) Clusters", 
                                  f"kmeans_{n_clusters}_fashion_mnist_clusters.png",
                                  sample_size=2000)

import torch
from torch.utils.data import DataLoader

def contrastive_collate_fn(batch):
    anchors = []
    positives = []
    negatives = []
    labels = []
    
    for anchor, pos, neg, label in batch:
        anchors.append(anchor.view(-1))
        positives.append(pos.view(-1))
        negatives.append(neg.view(-1))
        labels.append(label)
    
    anchors = torch.stack(anchors)
    positives = torch.stack(positives)
    negatives = torch.stack(negatives)
    labels = torch.stack(labels)
    
    return anchors, positives, negatives, labels
class ContrastiveDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        anchor_img = self.images[idx].view(-1) 
        anchor_label = self.labels[idx]
        pos_indices = torch.where(self.labels == anchor_label)[0]
        pos_indices = pos_indices[pos_indices != idx] 
        if len(pos_indices) == 0: 
            pos_idx = idx
        else:
            pos_idx = pos_indices[torch.randint(0, len(pos_indices), (1,))]
        pos_img = self.images[pos_idx].view(-1)  
        neg_indices = torch.where(self.labels != anchor_label)[0]
        neg_idx = neg_indices[torch.randint(0, len(neg_indices), (1,))]
        neg_img = self.images[neg_idx].view(-1) 
        
        if len(neg_indices) > 50: 
            random_neg_indices = neg_indices[torch.randint(0, len(neg_indices), (50,))]
            random_neg_samples = self.images[random_neg_indices]
            distances = torch.sum((anchor_img.unsqueeze(0) - random_neg_samples.view(len(random_neg_samples), -1)) ** 2, dim=1)
            hard_neg_idx = random_neg_indices[torch.argmin(distances)]
            hard_neg_img = self.images[hard_neg_idx].view(-1) 
            if torch.rand(1) > 0.5:
                neg_img = hard_neg_img
        
        return anchor_img, pos_img, neg_img, anchor_label

def create_dataloaders(data, labels):
    standard_loader = DataLoader(TensorDataset(data), batch_size=128, shuffle=True)
    
    contrastive_dataset = ContrastiveDataset(data, labels)
    contrastive_loader = DataLoader(
        contrastive_dataset, 
        batch_size=64, 
        shuffle=True, 
        drop_last=True,
        collate_fn=contrastive_collate_fn
    )
    
    return standard_loader, contrastive_loader

if __name__ == "__main__":
    main()

Loading Fashion-MNIST dataset...

Pretraining autoencoder...
Pretrain Epoch [1/100], Recon Loss: 0.1063
Pretrain Epoch [2/100], Recon Loss: 0.0823
Pretrain Epoch [3/100], Recon Loss: 0.0687
Pretrain Epoch [4/100], Recon Loss: 0.0583
Pretrain Epoch [5/100], Recon Loss: 0.0485
Pretrain Epoch [6/100], Recon Loss: 0.0447
Pretrain Epoch [7/100], Recon Loss: 0.0428
Pretrain Epoch [8/100], Recon Loss: 0.0411
Pretrain Epoch [9/100], Recon Loss: 0.0401
Pretrain Epoch [10/100], Recon Loss: 0.0391
Pretrain Epoch [11/100], Recon Loss: 0.0387
Pretrain Epoch [12/100], Recon Loss: 0.0376
Pretrain Epoch [13/100], Recon Loss: 0.0374
Pretrain Epoch [14/100], Recon Loss: 0.0368
Pretrain Epoch [15/100], Recon Loss: 0.0361
Pretrain Epoch [16/100], Recon Loss: 0.0356
Pretrain Epoch [17/100], Recon Loss: 0.0350
Pretrain Epoch [18/100], Recon Loss: 0.0343
Pretrain Epoch [19/100], Recon Loss: 0.0339
Pretrain Epoch [20/100], Recon Loss: 0.0335
Pretrain Epoch [21/100], Recon Loss: 0.0332
Pretrain Epoch [22/100],