<a href="https://colab.research.google.com/github/abdullojony/auec_demo/blob/main/auec.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Libraries

In [None]:
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
from torch.autograd import Variable
import torchvision
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
import umap
from sklearn.cluster import KMeans, DBSCAN
from sklearn.metrics import accuracy_score, normalized_mutual_info_score, adjusted_rand_score
import matplotlib.pyplot as plt
from scipy.stats import mode
import random
from collections import Counter

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Configuration

In [None]:
# --- Configuration ---

CONFIG = {
    # General settings
    "DEVICE": 'cuda' if torch.cuda.is_available() else 'cpu',
    "SEED": 42,
    "MNIST_DIR": "./drive/MyDrive/data/mnist/",  # Directory to store MNIST data
    "OUTPUT_DIR": "./drive/MyDrive/output/",    # Directory to store all outputs (models, plots, data)
    "NUM_CLUSTERS": 10,         # Number of clusters (digits for MNIST)

    # Data loading
    "PRETRAIN_BATCH_SIZE": 256,
    "TRAIN_BATCH_SIZE": 256,
    "EVAL_BATCH_SIZE": 60000, # For evaluating on the full dataset

    # Pre-training CAE
    "PRETRAIN_EPOCHS": 10,
    "PRETRAIN_LR": 0.01,
    "PRETRAIN_WEIGHT_DECAY": 1e-5,
    "PRETRAIN_KMEANS_INIT": 'k-means++',

    # Stage I: AUEC Training
    "AUEC_EPOCHS": 10, # Original: 30
    "AUEC_LR": 0.01,
    "AUEC_WEIGHT_DECAY": 1e-5,
    "AUEC_RECONSTRUCTION_WEIGHT": 1.0,
    "AUEC_SEPARABILITY_WEIGHT": 0.0003,

    # Stage II: UMAP
    "UMAP_N_NEIGHBORS": 8,
    "UMAP_N_COMPONENTS": 8,
    "UMAP_MIN_DIST": 0.01,

    # Stage III: Clustering
    # K-Means specific
    "KMEANS_INIT": 'k-means++',
    # MDBSCAN specific
    "MDBSCAN_EPSILON": 0.17,
    "MDBSCAN_MIN_SAMPLES": 12,

    # Plotting
    "PLOT_FIGURE_SIZE": (10, 7),
    "PLOT_SCATTER_FIGURE_SIZE": (8, 8),
    "PLOT_CMAP": 'Spectral',
    "PLOT_SCATTER_S": 5,
}

# Utility Functions

In [None]:
# --- Utility Functions ---

def seed_everything(seed):
    """Seed everything to make the code more reproducible."""
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    print(f"Seeded everything with seed {seed}")

def ensure_dir(directory_path):
    """Ensure that a directory exists, creating it if necessary."""
    if not os.path.exists(directory_path):
        os.makedirs(directory_path)
        print(f"Created directory: {directory_path}")


# Loss Functions

In [None]:
# --- Loss Functions ---

def update_assignments(enc_output, centroids):
    """
    Assigns each data point to the closest centroid.
    Args:
        enc_output (torch.Tensor): Encoded output from the autoencoder.
        centroids (torch.Tensor): Current cluster centroids.
    Returns:
        torch.Tensor: Cluster assignments for each data point.
    """
    distances = torch.cdist(enc_output, centroids)
    assignments = distances.argmin(dim=1)
    return assignments

def update_centroids(enc_output, assignments, num_clusters):
    """
    Updates centroids based on the mean of assigned data points.
    Args:
        enc_output (torch.Tensor): Encoded output.
        assignments (torch.Tensor): Cluster assignments.
        num_clusters (int): Total number of clusters.
    Returns:
        torch.Tensor: Updated centroids.
    """
    centroids = torch.zeros((num_clusters, enc_output.shape[1]), device=enc_output.device)
    for k in range(num_clusters):
        assigned_data = enc_output[assignments == k]
        if assigned_data.size(0) > 0:
            centroids[k] = assigned_data.mean(dim=0)
    return centroids

def wcss_loss(enc_output, assignments, centroids):
    """
    Calculates the Within-Cluster Sum of Squares (WCSS).
    Args:
        enc_output (torch.Tensor): Encoded output.
        assignments (torch.Tensor): Cluster assignments.
        centroids (torch.Tensor): Cluster centroids.
    Returns:
        torch.Tensor: WCSS loss.
    """
    device = enc_output.device

    # Infer number of clusters from assignments
    unique_assigned_clusters = torch.unique(assignments)

    if centroids.shape[0] == 0 or unique_assigned_clusters.size(0) == 0:
        # No centroids, no assignments
        return torch.tensor(0.0, device=device)

    total_wcss = 0.0
    total_points = 0

    for i_idx, cluster_idx_val in enumerate(unique_assigned_clusters):
        # Convert cluster_idx_val to integer
        k = cluster_idx_val.item()

        cluster_points = enc_output[assignments == k]
        if cluster_points.size(0) > 0:
            centroid_k = centroids[k].to(device)
            wcss_k = torch.sum((cluster_points - centroid_k)**2)
            total_wcss += wcss_k
            total_points += cluster_points.size(0)

    return total_wcss / total_points

# Convolutional Autoencoder

In [None]:
# --- Convolutional Autoencoder Model ---

class ConvAutoencoder(nn.Module):
    def __init__(self):
        super(ConvAutoencoder, self).__init__()
        # Input: 1x28x28
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 8, 3, stride=2, padding=1),  # Output: 8x14x14
            nn.ReLU(True),
            nn.Conv2d(8, 16, 3, stride=2, padding=1), # Output: 16x7x7
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),# Output: 32x4x4
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.Conv2d(32, 64, 4, stride=1, padding=0), # Output: 64x1x1
            nn.ReLU(True),
            nn.Flatten(),  # Output: 64
        )

        self.decoder = nn.Sequential(
            nn.Unflatten(1, (64, 1, 1)), # Input: 64 -> Output: 64x1x1
            nn.ConvTranspose2d(64, 32, 4, stride=1, padding=0), # Output: 32x4x4
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 16, 2, stride=2, padding=1, output_padding=1), # Output: 16x7x7
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, 3, stride=2, padding=1, output_padding=1), # Output: 8x14x14
            nn.BatchNorm2d(8),
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 1, 3, stride=2,  padding=1, output_padding=1), # Output: 1x28x28
            nn.Sigmoid() # Using Sigmoid for image pixel values (0-1)
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded

# Data Loading

In [None]:
# --- Data Loading ---

def get_mnist_dataloaders(config):
    """Loads MNIST dataset and returns dataloaders."""
    transform = transforms.ToTensor() # MNIST images are 0-1
    train_dataset = MNIST(config["MNIST_DIR"], train=True, download=True, transform=transform)
    test_dataset = MNIST(config["MNIST_DIR"], train=False, download=True, transform=transform)

    train_dl = DataLoader(train_dataset, batch_size=config["PRETRAIN_BATCH_SIZE"], shuffle=True) # Shuffle for pretraining
    train_dl_auec = DataLoader(train_dataset, batch_size=config["TRAIN_BATCH_SIZE"], shuffle=True) # Shuffle for AUEC

    # Dataloader for full dataset evaluation/embedding extraction
    full_train_dl = DataLoader(train_dataset, batch_size=config["EVAL_BATCH_SIZE"], shuffle=False)

    print(f"MNIST training data: {len(train_dataset)} samples")
    print(f"MNIST test data: {len(test_dataset)} samples")
    return train_dl, train_dl_auec, full_train_dl, train_dataset.targets

# Pre-training

In [None]:
# --- Pre-training Stage ---

def pretrain_cae(model, train_dl, config, device):
    """Pre-trains the Convolutional Autoencoder."""
    print("\n--- Starting CAE Pre-training ---")
    output_dir = config["OUTPUT_DIR"]

    model.to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=config["PRETRAIN_LR"], weight_decay=config["PRETRAIN_WEIGHT_DECAY"])

    losses_pre = []
    for epoch in range(1, config["PRETRAIN_EPOCHS"] + 1):
        model.train()
        train_loss_pre = 0.0
        encoded_list_epoch = []

        for images, _ in train_dl:
            images = images.to(device)
            encoded, decoded = model(images)
            loss = criterion(decoded, images)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss_pre += loss.item() * images.size(0)
            if epoch == config["PRETRAIN_EPOCHS"]: # Save encoded data from the last epoch
                encoded_list_epoch.append(encoded.detach().cpu().numpy())

        train_loss_pre /= len(train_dl.dataset) # Average loss per sample
        losses_pre.append(train_loss_pre)
        print(f'Pre-train Epoch: {epoch}/{config["PRETRAIN_EPOCHS"]}\tTraining Loss: {train_loss_pre:.6f}')

    # Save pre-trained model weights
    pretrain_model_path = os.path.join(output_dir, f"CAE_Pretrained_Epochs{config['PRETRAIN_EPOCHS']}_LR{config['PRETRAIN_LR']}.pth")
    torch.save(model.state_dict(), pretrain_model_path)
    print(f"Saved pre-trained model to {pretrain_model_path}")

    # Concatenate all encoded outputs from the last epoch
    encoded_data_full = np.concatenate(encoded_list_epoch, axis=0)

    # Apply KMeans to the encoded data to get initial centroids
    print("Applying KMeans to pre-trained embeddings for initial centroids...")
    kmeans = KMeans(n_clusters=config["NUM_CLUSTERS"], init=config["PRETRAIN_KMEANS_INIT"], n_init=10, random_state=config["SEED"])
    kmeans.fit(encoded_data_full)

    initial_centroids = torch.from_numpy(kmeans.cluster_centers_).float().to(device)
    initial_labels = kmeans.labels_

    # Save initial centroids and labels
    centroids_path = os.path.join(output_dir, f"InitialCentroids_KMeans.npy")
    labels_path = os.path.join(output_dir, f"InitialLabels_KMeans.npy")
    np.save(centroids_path, initial_centroids.cpu().numpy())
    np.save(labels_path, initial_labels)
    print(f"Saved initial K-Means centroids to {centroids_path}")

    return initial_centroids, pretrain_model_path

# Stage I: Training of the Model

In [None]:
# --- Stage I: AUEC Training ---

def train_auec(model, train_dl_auec, initial_centroids, config, device, pretrained_model_path=None):
    """Trains the model with AUEC loss (Reconstruction + WCSS)."""
    print("\n--- Starting AUEC Training (Stage I) ---")
    output_dir = config["OUTPUT_DIR"]

    if pretrained_model_path:
        print(f"Loading pre-trained weights from: {pretrained_model_path}")
        model.load_state_dict(torch.load(pretrained_model_path, map_location=device))
    model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=config["AUEC_LR"], weight_decay=config["AUEC_WEIGHT_DECAY"])
    reconstruction_criterion = nn.MSELoss()

    centroids = initial_centroids.clone().to(device)

    losses_r_auec = []
    losses_c_auec = []
    losses_total_auec = []

    for epoch in range(1, config["AUEC_EPOCHS"] + 1):
        model.train()
        total_loss_r = 0
        total_loss_c = 0
        total_loss = 0

        for images, _ in train_dl_auec:
            images = images.to(device)
            enc_auec, dec_auec = model(images)

            # Reconstruction loss
            loss_r = config["AUEC_RECONSTRUCTION_WEIGHT"] * reconstruction_criterion(dec_auec, images)

            # Update assignments (non-differentiable part for WCSS calculation)
            with torch.no_grad(): # Assignments should not affect encoder's gradients directly here
                assignments = update_assignments(enc_auec.detach(), centroids) # Use detached enc_auec for assignment

            # Clustering loss (WCSS)
            # WCSS needs enc_auec (not detached) to flow gradients back to encoder
            loss_c = config["AUEC_SEPARABILITY_WEIGHT"] * wcss_loss(enc_auec, assignments, centroids)

            # Total loss
            current_total_loss = loss_r + loss_c

            optimizer.zero_grad()
            current_total_loss.backward() # Gradients flow from both reconstruction and WCSS(via enc_auec)
            optimizer.step()

            # Update centroids (non-differentiable SGD-like step)
            with torch.no_grad():
                 centroids = update_centroids(enc_auec.detach(), assignments, config["NUM_CLUSTERS"])

            total_loss_r += loss_r.item() * images.size(0)
            total_loss_c += loss_c.item() * images.size(0)
            total_loss += current_total_loss.item() * images.size(0)

        avg_loss_r = total_loss_r / len(train_dl_auec.dataset)
        avg_loss_c = total_loss_c / len(train_dl_auec.dataset)
        avg_total_loss = total_loss / len(train_dl_auec.dataset)

        losses_r_auec.append(avg_loss_r)
        losses_c_auec.append(avg_loss_c)
        losses_total_auec.append(avg_total_loss)

        print(f'AUEC Epoch: {epoch}/{config["AUEC_EPOCHS"]} \t'
              f'Total Loss: {avg_total_loss:.6f} \t'
              f'Recon Loss: {avg_loss_r:.6f} \t'
              f'Cluster Loss: {avg_loss_c:.6f}')

    # Plot loss curve
    plt.figure(figsize=config["PLOT_FIGURE_SIZE"])
    plt.plot(losses_r_auec, label='Reconstruction Loss')
    plt.plot(losses_c_auec, label='Clustering (WCSS) Loss')
    plt.plot(losses_total_auec, label='Total Loss')
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title("AUEC Training Loss Curves")
    plt.legend()
    loss_curve_path = os.path.join(output_dir, f"AUEC_LossCurve_Epochs{config['AUEC_EPOCHS']}_RW{config['AUEC_RECONSTRUCTION_WEIGHT']}_SW{config['AUEC_SEPARABILITY_WEIGHT']}.png")
    plt.savefig(loss_curve_path)
    plt.close()
    print(f"Saved AUEC loss curve to {loss_curve_path}")

    # Save the final AUEC model
    auec_model_path = os.path.join(output_dir, f"AUEC_Model_Final_Epochs{config['AUEC_EPOCHS']}.pth")
    torch.save(model.state_dict(), auec_model_path)
    print(f"Saved final AUEC model to {auec_model_path}")

    return auec_model_path, centroids # Return final centroids as well


# Evaluation and Embedding Extraction

In [None]:
# --- Evaluation and Embedding Extraction ---

def get_embeddings(model, full_train_dl, config, device, model_path=None):
    """Extracts embeddings from the trained model for the full training dataset."""
    print("\n--- Extracting Embeddings ---")
    output_dir = config["OUTPUT_DIR"]

    if model_path:
        print(f"Loading model for embedding extraction from: {model_path}")
        model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()

    all_embeddings = []
    with torch.no_grad():
        for images, _ in full_train_dl: # Use the full dataset loader
            images = images.to(device)
            enc, _ = model(images)
            all_embeddings.append(enc.cpu().numpy())

    compressed_embedding = np.concatenate(all_embeddings, axis=0)
    embedding_path = os.path.join(output_dir, "AUEC_CompressedEmbeddings.npy")
    np.save(embedding_path, compressed_embedding)
    print(f"Saved compressed embeddings ({compressed_embedding.shape}) to {embedding_path}")
    return compressed_embedding


# Stage II: UMAP
**we set UMAP number of components to $n_C = 2$ for the ease of visualization, and we also take $n_N = 8$.**

In [None]:
# --- Stage II: UMAP ---

def apply_umap(embeddings, config):
    """Applies UMAP to the embeddings."""
    print("\n--- Applying UMAP (Stage II) ---")
    output_dir = config["OUTPUT_DIR"]

    reducer = umap.UMAP(
        n_neighbors=config["UMAP_N_NEIGHBORS"],
        n_components=config["UMAP_N_COMPONENTS"],
        min_dist=config["UMAP_MIN_DIST"],
        random_state=config["SEED"],
        metric='euclidean' # common metric for image embeddings
    )
    refined_embedding = reducer.fit_transform(embeddings)

    umap_path = os.path.join(output_dir, f"UMAP_RefinedEmbeddings_NComp{config['UMAP_N_COMPONENTS']}_NNeigh{config['UMAP_N_NEIGHBORS']}.npy")
    np.save(umap_path, refined_embedding)
    print(f"Saved UMAP refined embeddings ({refined_embedding.shape}) to {umap_path}")
    return refined_embedding


# Stage III: Clustering

In [None]:
# --- Stage III: Clustering ---

def map_cluster_labels_to_true(y_pred_cluster, y_true):
    """Maps cluster labels to true labels based on majority voting."""
    mapped_labels = np.zeros_like(y_pred_cluster)
    unique_pred_labels = np.unique(y_pred_cluster)

    for cluster_id in unique_pred_labels:
        mask = (y_pred_cluster == cluster_id)
        true_labels_in_cluster = y_true[mask]

        if len(true_labels_in_cluster) == 0:
            # This case should ideally not happen if cluster_id comes from unique_pred_labels
            # and y_pred_cluster has same length as y_true.
            # Assign a placeholder or handle as an error/warning.
            # For MNIST, labels are 0-9. -1 can be a noise/unassigned placeholder.
            mapped_label = -1 # Or a label that's out of the typical range.
        else:
            mapped_label = mode(true_labels_in_cluster, keepdims=True)[0][0]
        mapped_labels[mask] = mapped_label
    return mapped_labels

def plot_embedding_space(embedding_2d, labels, title, filepath, config):
    """Plots 2D embedding space."""
    plt.figure(figsize=config["PLOT_SCATTER_FIGURE_SIZE"])
    scatter = plt.scatter(embedding_2d[:, 0], embedding_2d[:, 1], c=labels, cmap=config["PLOT_CMAP"], s=config["PLOT_SCATTER_S"])
    plt.gca().set_aspect('equal', 'datalim')
    plt.colorbar(scatter, boundaries=np.arange(config["NUM_CLUSTERS"] + 1) - 0.5).set_ticks(np.arange(config["NUM_CLUSTERS"]))
    plt.title(title, fontsize=16)
    plt.xlabel("UMAP Component 1")
    plt.ylabel("UMAP Component 2")
    plt.savefig(filepath)
    plt.close()
    print(f"Saved plot: {filepath}")

def evaluate_clustering(y_true, y_pred_mapped, algorithm_name):
    """Calculates and prints clustering metrics."""
    # Filter out any points that were mapped to -1 (e.g. noise or unassigned in mapping step)
    # if this is a possibility from map_cluster_labels_to_true
    valid_indices = y_pred_mapped != -1
    if not np.all(valid_indices): # If there are any -1 labels
        print(f"Note: Evaluating {algorithm_name} on {np.sum(valid_indices)}/{len(y_true)} points (excluding unmapped labels).")

    y_true_eval = y_true[valid_indices]
    y_pred_mapped_eval = y_pred_mapped[valid_indices]

    if len(y_true_eval) == 0: # No points to evaluate
        print(f"{algorithm_name}: No valid points for evaluation.")
        return {"ACC": 0, "NMI": 0, "ARI": 0}

    acc = accuracy_score(y_true_eval, y_pred_mapped_eval)
    nmi = normalized_mutual_info_score(y_true_eval, y_pred_mapped_eval)
    ari = adjusted_rand_score(y_true_eval, y_pred_mapped_eval)

    print(f"\n--- {algorithm_name} Clustering Results ---")
    print(f"Accuracy (ACC): {acc:.4f}")
    print(f"Normalized Mutual Information (NMI): {nmi:.4f}")
    print(f"Adjusted Rand Index (ARI): {ari:.4f}")
    return {"ACC": acc, "NMI": nmi, "ARI": ari}

def run_kmeans_clustering(refined_embedding, y_true, config):
    """Runs K-Means clustering and evaluates."""
    print("\n--- Running K-Means Clustering (Stage III) ---")
    output_dir = config["OUTPUT_DIR"]

    kmeans = KMeans(
        n_clusters=config["NUM_CLUSTERS"],
        init=config["KMEANS_INIT"],
        n_init=10, # Standard value for k-means++
        random_state=config["SEED"]
    )
    y_pred_kmeans_raw = kmeans.fit_predict(refined_embedding)
    y_pred_kmeans_mapped = map_cluster_labels_to_true(y_pred_kmeans_raw, y_true)

    # Save mapped labels
    kmeans_labels_path = os.path.join(output_dir, "KMeans_PredictedLabels_Mapped.npy")
    np.save(kmeans_labels_path, y_pred_kmeans_mapped)
    print(f"Saved K-Means mapped predicted labels to {kmeans_labels_path}")

    # Plotting (if 2D)
    if refined_embedding.shape[1] == 2:
        plot_path = os.path.join(output_dir, "KMeans_UMAP_Plot.png")
        plot_embedding_space(refined_embedding, y_pred_kmeans_mapped, "K-Means Clustering on UMAP Embedding", plot_path, config)

    return evaluate_clustering(y_true, y_pred_kmeans_mapped, "K-Means")


def run_mdbscan_clustering(refined_embedding, y_true, config):
    """Runs modified DBSCAN clustering and evaluates."""
    print("\n--- Running MDBSCAN Clustering (Stage III) ---")
    output_dir = config["OUTPUT_DIR"]

    dbscan = DBSCAN(eps=config["MDBSCAN_EPSILON"], min_samples=config["MDBSCAN_MIN_SAMPLES"])
    y_pred_dbscan_raw = dbscan.fit_predict(refined_embedding)

    print(f"DBSCAN initial clusters found: {Counter(y_pred_dbscan_raw)}")

    # MDBSCAN modification: Reassign outliers and small clusters
    cluster_counts = Counter(y_pred_dbscan_raw)
    valid_clusters = {k: v for k, v in cluster_counts.items() if k != -1} # Exclude noise points

    y_pred_mdbscan_reassigned = np.copy(y_pred_dbscan_raw) # Start with raw predictions

    if not valid_clusters:
        print("MDBSCAN: No valid non-noise clusters found by DBSCAN. Mapping raw DBSCAN output (including noise).")
        # No reassignment possible if no valid clusters to form centroids
    else:
        num_top_clusters = min(config["NUM_CLUSTERS"], len(valid_clusters))
        largest_cluster_ids = [item[0] for item in sorted(valid_clusters.items(), key=lambda x: x[1], reverse=True)[:num_top_clusters]]

        print(f"MDBSCAN: Top {len(largest_cluster_ids)} largest non-noise clusters selected: {largest_cluster_ids}")

        mdbscan_centroids_list = []
        valid_largest_cluster_ids_for_reassignment = [] # Store IDs of clusters for which centroids could be made

        for cluster_id in largest_cluster_ids:
            cluster_points = refined_embedding[y_pred_dbscan_raw == cluster_id]
            if cluster_points.shape[0] > 0:
                 mdbscan_centroids_list.append(np.mean(cluster_points, axis=0))
                 valid_largest_cluster_ids_for_reassignment.append(cluster_id)
            # else: Should not happen if cluster_id came from Counter based on y_pred_dbscan_raw

        if not mdbscan_centroids_list: # No centroids could be computed
            print("MDBSCAN: No centroids for reassignment (e.g. selected largest clusters were empty, though unlikely). Mapping raw DBSCAN output.")
        else:
            mdbscan_centroids_np = np.array(mdbscan_centroids_list)
            # Reassign points that are noise (-1) or not in one of the `valid_largest_cluster_ids_for_reassignment`
            for i, point_label in enumerate(y_pred_dbscan_raw):
                if point_label == -1 or point_label not in valid_largest_cluster_ids_for_reassignment:
                    point = refined_embedding[i]
                    distances = [np.linalg.norm(point - centroid) for centroid in mdbscan_centroids_np]
                    if distances:
                        nearest_centroid_idx = np.argmin(distances)
                        y_pred_mdbscan_reassigned[i] = valid_largest_cluster_ids_for_reassignment[nearest_centroid_idx]

    # Map the (potentially reassigned) labels to true labels
    y_pred_mdbscan_mapped = map_cluster_labels_to_true(y_pred_mdbscan_reassigned, y_true)


    # Save mapped labels
    mdbscan_labels_path = os.path.join(output_dir, "MDBSCAN_PredictedLabels_Mapped.npy")
    np.save(mdbscan_labels_path, y_pred_mdbscan_mapped)
    print(f"Saved MDBSCAN mapped predicted labels to {mdbscan_labels_path}")

    # Plotting (if 2D)
    if refined_embedding.shape[1] == 2:
        plot_path = os.path.join(output_dir, "MDBSCAN_UMAP_Plot.png")
        plot_embedding_space(refined_embedding, y_pred_mdbscan_mapped, "MDBSCAN Clustering on UMAP Embedding", plot_path, config)

    return evaluate_clustering(y_true, y_pred_mdbscan_mapped, "MDBSCAN")


# Execution

In [None]:
# Apply seed and create output directory
seed_everything(CONFIG["SEED"])
ensure_dir(CONFIG["OUTPUT_DIR"])
print(f"Using device: {CONFIG['DEVICE']}")

# 1. Load Data
# Using separate Dataloaders for pretraining and AUEC training to allow different batch sizes/shuffling
train_dl_pretrain, train_dl_auec, full_train_dl, train_true_labels = get_mnist_dataloaders(CONFIG)
train_true_labels = train_true_labels.numpy() # For evaluation

# 2. Initialize Model
cae_model = ConvAutoencoder()

# 3. Pre-training
# Set to True to run pre-training, False to load pre-trained model if available
RUN_PRETRAINING = True
pretrain_model_filename = f"CAE_Pretrained_Epochs{CONFIG['PRETRAIN_EPOCHS']}_LR{CONFIG['PRETRAIN_LR']}.pth"
pretrain_model_path = os.path.join(CONFIG["OUTPUT_DIR"], pretrain_model_filename)

initial_centroids_filename = f"InitialCentroids_KMeans.npy"
initial_centroids_path = os.path.join(CONFIG["OUTPUT_DIR"], initial_centroids_filename)


if RUN_PRETRAINING or not os.path.exists(pretrain_model_path) or not os.path.exists(initial_centroids_path):
    print("Running pre-training phase...")
    initial_centroids, saved_pretrain_path = pretrain_cae(cae_model, train_dl_pretrain, CONFIG, CONFIG["DEVICE"])
    pretrain_model_path = saved_pretrain_path # Update path if it was just saved
else:
    print(f"Skipping pre-training. Loading pre-trained model from {pretrain_model_path} and centroids from {initial_centroids_path}")
    # Model state will be loaded in train_auec, just need centroids here
    initial_centroids = torch.from_numpy(np.load(initial_centroids_path)).float().to(CONFIG["DEVICE"])

# 4. Stage I: AUEC Training
RUN_AUEC_TRAINING = True
auec_model_filename = f"AUEC_Model_Final_Epochs{CONFIG['AUEC_EPOCHS']}.pth"
auec_model_path = os.path.join(CONFIG["OUTPUT_DIR"], auec_model_filename)

if RUN_AUEC_TRAINING or not os.path.exists(auec_model_path):
    print("Running AUEC training phase...")
    # Pass pretrain_model_path so AUEC training can load the weights
    final_auec_model_path, final_centroids = train_auec(cae_model, train_dl_auec, initial_centroids, CONFIG, CONFIG["DEVICE"], pretrained_model_path=pretrain_model_path)
    auec_model_path = final_auec_model_path # Update path
else:
    print(f"Skipping AUEC training. Assuming final model exists at {auec_model_path}")
    # If not training, embeddings need to be loaded or generated from this model
    # For simplicity, we'll assume if AUEC training is skipped, embeddings also exist or will be generated next.

# 5. Evaluation: Get Embeddings from AUEC Model
# We always get embeddings after AUEC training or if loading a pre-existing AUEC model
compressed_embeddings_path = os.path.join(CONFIG["OUTPUT_DIR"], "AUEC_CompressedEmbeddings.npy")
if RUN_AUEC_TRAINING or not os.path.exists(compressed_embeddings_path): # Generate if AUEC was run or if file doesn't exist
      compressed_embedding = get_embeddings(cae_model, full_train_dl, CONFIG, CONFIG["DEVICE"], model_path=auec_model_path)
else:
    print(f"Loading existing compressed embeddings from {compressed_embeddings_path}")
    compressed_embedding = np.load(compressed_embeddings_path)


# 6. Stage II: UMAP
refined_embeddings_path = os.path.join(CONFIG["OUTPUT_DIR"], f"UMAP_RefinedEmbeddings_NComp{CONFIG['UMAP_N_COMPONENTS']}_NNeigh{CONFIG['UMAP_N_NEIGHBORS']}.npy")
# Always run UMAP if embeddings were just generated, or if refined embeddings don't exist
if RUN_AUEC_TRAINING or not os.path.exists(refined_embeddings_path) or not os.path.exists(compressed_embeddings_path): # Added check for compressed_embeddings_path for safety
    refined_embedding = apply_umap(compressed_embedding, CONFIG)
else:
    print(f"Loading existing UMAP refined embeddings from {refined_embeddings_path}")
    refined_embedding = np.load(refined_embeddings_path)

# Plot true labels on UMAP embedding
if refined_embedding.shape[1] == 2:
    true_labels_plot_path = os.path.join(CONFIG["OUTPUT_DIR"], "UMAP_TrueLabels_Plot.png")
    plot_embedding_space(refined_embedding, train_true_labels, "UMAP Embedding with True Labels", true_labels_plot_path, CONFIG)

# 7. Stage III: Clustering
# K-Means
kmeans_results = run_kmeans_clustering(refined_embedding, train_true_labels, CONFIG)

# MDBSCAN
mdbscan_results = run_mdbscan_clustering(refined_embedding, train_true_labels, CONFIG)

print("\n--- All Stages Complete ---")
print("K-Means Results:", kmeans_results)
print("MDBSCAN Results:", mdbscan_results)

Seeded everything with seed 42
Created directory: ./drive/MyDrive/output/
Using device: cuda


100%|██████████| 9.91M/9.91M [00:01<00:00, 5.09MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 135kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.27MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.59MB/s]


MNIST training data: 60000 samples
MNIST test data: 10000 samples
Running pre-training phase...

--- Starting CAE Pre-training ---
Pre-train Epoch: 1/10	Training Loss: 0.041054
Pre-train Epoch: 2/10	Training Loss: 0.012996
Pre-train Epoch: 3/10	Training Loss: 0.011119
Pre-train Epoch: 4/10	Training Loss: 0.010500
Pre-train Epoch: 5/10	Training Loss: 0.010015
Pre-train Epoch: 6/10	Training Loss: 0.009779
Pre-train Epoch: 7/10	Training Loss: 0.009533
Pre-train Epoch: 8/10	Training Loss: 0.009383
Pre-train Epoch: 9/10	Training Loss: 0.009362
Pre-train Epoch: 10/10	Training Loss: 0.009196
Saved pre-trained model to ./drive/MyDrive/output/CAE_Pretrained_Epochs10_LR0.01.pth
Applying KMeans to pre-trained embeddings for initial centroids...
Saved initial K-Means centroids to ./drive/MyDrive/output/InitialCentroids_KMeans.npy
Running AUEC training phase...

--- Starting AUEC Training (Stage I) ---
Loading pre-trained weights from: ./drive/MyDrive/output/CAE_Pretrained_Epochs10_LR0.01.pth
AUEC 

  warn(


Saved UMAP refined embeddings ((60000, 8)) to ./drive/MyDrive/output/UMAP_RefinedEmbeddings_NComp8_NNeigh8.npy

--- Running K-Means Clustering (Stage III) ---
Saved K-Means mapped predicted labels to ./drive/MyDrive/output/KMeans_PredictedLabels_Mapped.npy

--- K-Means Clustering Results ---
Accuracy (ACC): 0.9751
Normalized Mutual Information (NMI): 0.9353
Adjusted Rand Index (ARI): 0.9459

--- Running MDBSCAN Clustering (Stage III) ---
DBSCAN initial clusters found: Counter({np.int64(3): 6487, np.int64(4): 5975, np.int64(6): 5972, np.int64(1): 5916, np.int64(7): 5873, np.int64(5): 5872, np.int64(9): 5778, np.int64(2): 5635, np.int64(8): 5468, np.int64(0): 5303, np.int64(-1): 765, np.int64(10): 685, np.int64(11): 49, np.int64(12): 25, np.int64(18): 25, np.int64(16): 22, np.int64(14): 21, np.int64(23): 19, np.int64(21): 17, np.int64(20): 15, np.int64(24): 14, np.int64(22): 14, np.int64(13): 13, np.int64(15): 13, np.int64(17): 12, np.int64(19): 12})
MDBSCAN: Top 10 largest non-noise clu