In [3]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import adjusted_rand_score
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow.keras.backend as K
import os

# Set random seeds for reproducibility
tf.random.set_seed(42)
np.random.seed(42)

# Load and preprocess MNIST data
def load_and_preprocess_mnist():
    # Load MNIST dataset
    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
    
    # Normalize pixel values to range [0, 1]
    x_train = x_train.astype("float32") / 255.0
    x_test = x_test.astype("float32") / 255.0
    
    # Reshape to include channel dimension
    x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
    x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)
    
    return x_train, y_train, x_test, y_test

# Custom sampling layer with reparameterization trick
class Sampling(layers.Layer):
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.random.normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

# Target distribution for clustering optimization
def target_distribution(q):
    weight = q ** 2 / tf.reduce_sum(q, axis=0)
    return tf.transpose(tf.transpose(weight) / tf.reduce_sum(weight, axis=1))

# Define the improved VAE model with clustering capability (VaDE-inspired)
class ClusteringVAE(keras.Model):
    def __init__(self, latent_dim=10, n_clusters=10, beta=1.0, alpha=1.0):
        super(ClusteringVAE, self).__init__()
        self.latent_dim = latent_dim
        self.n_clusters = n_clusters
        self.beta = beta
        self.alpha = alpha  # Weight for clustering loss
        self.clustering_initialized = False
        
        # Encoder network - deeper architecture with residual connections
        encoder_inputs = keras.Input(shape=(28, 28, 1))
        
        # First block
        x = layers.Conv2D(32, 3, activation=None, strides=1, padding="same")(encoder_inputs)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU(0.2)(x)
        x = layers.Conv2D(32, 3, activation=None, strides=1, padding="same")(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU(0.2)(x)
        x = layers.MaxPooling2D(pool_size=2, strides=2, padding='same')(x)
        
        # Second block with residual connection
        skip = layers.Conv2D(64, 1, strides=1, padding="same")(x)
        x = layers.Conv2D(64, 3, activation=None, strides=1, padding="same")(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU(0.2)(x)
        x = layers.Conv2D(64, 3, activation=None, strides=1, padding="same")(x)
        x = layers.BatchNormalization()(x)
        x = layers.Add()([x, skip])
        x = layers.LeakyReLU(0.2)(x)
        x = layers.MaxPooling2D(pool_size=2, strides=2, padding='same')(x)
        
        # Third block with residual connection
        skip = layers.Conv2D(128, 1, strides=1, padding="same")(x)
        x = layers.Conv2D(128, 3, activation=None, strides=1, padding="same")(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU(0.2)(x)
        x = layers.Conv2D(128, 3, activation=None, strides=1, padding="same")(x)
        x = layers.BatchNormalization()(x)
        x = layers.Add()([x, skip])
        x = layers.LeakyReLU(0.2)(x)
        
        # Flatten and dense layers
        x = layers.Flatten()(x)
        x = layers.Dense(256, activation=None)(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU(0.2)(x)
        
        z_mean = layers.Dense(latent_dim, name="z_mean")(x)
        z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
        z = Sampling()([z_mean, z_log_var])
        
        self.encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
        
        # Cluster assignment layer (GMM)
        self.cluster_centers = tf.Variable(
            initial_value=tf.random.normal(shape=(n_clusters, latent_dim)),
            trainable=True, name="cluster_centers"
        )
        
        # Decoder network - deeper with residual connections
        latent_inputs = keras.Input(shape=(latent_dim,))
        
        x = layers.Dense(7 * 7 * 128)(latent_inputs)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU(0.2)(x)
        x = layers.Reshape((7, 7, 128))(x)
        
        # First block with residual connection
        skip = x
        x = layers.Conv2D(128, 3, activation=None, strides=1, padding="same")(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU(0.2)(x)
        x = layers.Conv2D(128, 3, activation=None, strides=1, padding="same")(x)
        x = layers.BatchNormalization()(x)
        x = layers.Add()([x, skip])
        x = layers.LeakyReLU(0.2)(x)
        
        # Upsampling block
        x = layers.Conv2DTranspose(64, 3, strides=2, padding="same")(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU(0.2)(x)
        
        # Second block with residual connection
        skip = layers.Conv2D(64, 1, strides=1, padding="same")(x)
        x = layers.Conv2D(64, 3, activation=None, strides=1, padding="same")(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU(0.2)(x)
        x = layers.Conv2D(64, 3, activation=None, strides=1, padding="same")(x)
        x = layers.BatchNormalization()(x)
        x = layers.Add()([x, skip])
        x = layers.LeakyReLU(0.2)(x)
        
        # Upsampling block
        x = layers.Conv2DTranspose(32, 3, strides=2, padding="same")(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU(0.2)(x)
        
        # Final convolution
        decoder_outputs = layers.Conv2D(1, 3, activation="sigmoid", padding="same")(x)
        
        self.decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
        
    def encode(self, x):
        z_mean, z_log_var, z = self.encoder(x)
        return z
    
    def compute_cluster_assignment(self, z):
        """Compute soft assignment q_ij between latent samples and clusters"""
        # Calculate squared Euclidean distance between latent samples and cluster centers
        q = 1.0 / (1.0 + tf.reduce_sum(
            tf.square(tf.expand_dims(z, axis=1) - tf.expand_dims(self.cluster_centers, axis=0)),
            axis=2) / 1.0)
        q = q / tf.reduce_sum(q, axis=1, keepdims=True)
        return q
        
    def decode(self, z):
        return self.decoder(z)
    
    def call(self, inputs):
        z_mean, z_log_var, z = self.encoder(inputs)
        reconstructed = self.decoder(z)
        # Compute cluster assignments
        q = self.compute_cluster_assignment(z_mean)
        return reconstructed, z_mean, z_log_var, q
    
    def compute_loss(self, x, y=None):
        # Forward pass
        z_mean, z_log_var, z = self.encoder(x)
        x_reconstructed = self.decoder(z)
        
        # Reconstruction loss (binary cross-entropy)
        reconstruction_loss = tf.reduce_mean(
            tf.reduce_sum(
                keras.losses.binary_crossentropy(x, x_reconstructed),
                axis=(1, 2, 3)
            )
        )
        
        # KL divergence loss
        kl_loss = -0.5 * tf.reduce_mean(
            tf.reduce_sum(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var), axis=1)
        )
        
        # Initialize loss dict
        loss_dict = {
            "reconstruction_loss": reconstruction_loss,
            "kl_loss": kl_loss,
        }
        
        # Compute total loss based on initialization state
        if not self.clustering_initialized:
            # Just ELBO loss during initial training
            total_loss = reconstruction_loss + self.beta * kl_loss
            loss_dict["total_loss"] = total_loss
            loss_dict["clustering_loss"] = tf.constant(0.0)
        else:
            # Compute cluster assignment
            q = self.compute_cluster_assignment(z_mean)
            
            # Get target distribution
            p = target_distribution(q)
            
            # Clustering loss (KL divergence between soft assignments and target)
            clustering_loss = tf.reduce_mean(
                tf.reduce_sum(p * tf.math.log(p / q), axis=1)
            )
            
            # Total loss including clustering component
            total_loss = reconstruction_loss + self.beta * kl_loss + self.alpha * clustering_loss
            
            loss_dict["clustering_loss"] = clustering_loss
            loss_dict["total_loss"] = total_loss
            
        return loss_dict
    
    def train_step(self, data):
        if isinstance(data, tuple):
            x = data[0]
        else:
            x = data
            
        with tf.GradientTape() as tape:
            loss_dict = self.compute_loss(x)
            total_loss = loss_dict["total_loss"]
            
        gradients = tape.gradient(total_loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        
        return loss_dict

    def initialize_clustering(self, x, y=None, epochs=10):
        """Initialize cluster centers using k-means on encoder features"""
        print("Extracting features for clustering initialization...")
        z_mean = self.encoder(x)[0].numpy()
        
        # Initialize with k-means
        print("Running k-means to initialize cluster centers...")
        kmeans = KMeans(n_clusters=self.n_clusters, n_init=20, random_state=42)
        y_pred = kmeans.fit_predict(z_mean)
        
        # Set cluster centers to k-means centers
        self.cluster_centers.assign(kmeans.cluster_centers_)
        
        # Mark as initialized
        self.clustering_initialized = True
        
        # Calculate initial ARI if labels provided
        if y is not None:
            ari = adjusted_rand_score(y, y_pred)
            print(f"Initial clustering ARI after k-means: {ari:.4f}")
            
        # Finetune with clustering objective
        if epochs > 0:
            print(f"Fine-tuning clustering for {epochs} epochs...")
            dataset = tf.data.Dataset.from_tensor_slices(x).batch(256)
            
            for epoch in range(epochs):
                losses = {"total_loss": 0.0, "reconstruction_loss": 0.0, 
                         "kl_loss": 0.0, "clustering_loss": 0.0}
                num_batches = 0
                
                for batch in dataset:
                    batch_losses = self.train_step(batch)
                    for k in losses.keys():
                        losses[k] += batch_losses[k]
                    num_batches += 1
                
                # Average losses
                for k in losses.keys():
                    losses[k] /= num_batches
                
                # Evaluate ARI if labels provided
                if y is not None and (epoch + 1) % 2 == 0:
                    # Get updated assignments
                    q = self.compute_cluster_assignment(self.encoder(x)[0])
                    y_pred_updated = tf.argmax(q, axis=1).numpy()
                    ari = adjusted_rand_score(y, y_pred_updated)
                    print(f"Epoch {epoch+1}/{epochs}: total_loss={losses['total_loss']:.4f}, "
                          f"clustering_loss={losses['clustering_loss']:.4f}, ARI={ari:.4f}")
                else:
                    print(f"Epoch {epoch+1}/{epochs}: total_loss={losses['total_loss']:.4f}, "
                          f"clustering_loss={losses['clustering_loss']:.4f}")
        
        return self

# Function to evaluate ARI performance
def evaluate_ari(model, x_test, y_test):
    # Get latent representations
    z_mean = model.encoder(x_test)[0]
    
    # Get cluster assignments
    q = model.compute_cluster_assignment(z_mean)
    cluster_labels = tf.argmax(q, axis=1).numpy()
    
    # Calculate ARI score
    ari_score = adjusted_rand_score(y_test, cluster_labels)
    
    # Get latent representations for visualization
    latent_representations = z_mean.numpy()
    
    return ari_score, latent_representations, cluster_labels

# Function to visualize latent space using t-SNE
def visualize_latent_space(latent_representations, labels, cluster_labels=None, title_suffix=""):
    # Apply t-SNE for dimensionality reduction
    tsne = TSNE(n_components=2, random_state=42)
    latent_tsne = tsne.fit_transform(latent_representations)
    
    plt.figure(figsize=(12, 5))
    
    # Plot by true labels
    plt.subplot(1, 2 if cluster_labels is not None else 1, 1)
    scatter = plt.scatter(latent_tsne[:, 0], latent_tsne[:, 1], c=labels, cmap='tab10', alpha=0.7, s=5)
    plt.colorbar(scatter)
    plt.title(f"Latent Space (True Labels){title_suffix}")
    
    # Plot by cluster labels if available
    if cluster_labels is not None:
        plt.subplot(1, 2, 2)
        scatter = plt.scatter(latent_tsne[:, 0], latent_tsne[:, 1], c=cluster_labels, cmap='tab10', alpha=0.7, s=5)
        plt.colorbar(scatter)
        plt.title(f"Latent Space (Cluster Labels){title_suffix}")
    
    plt.tight_layout()
    plt.savefig(f"latent_space{title_suffix.replace(' ', '_')}.png")
    plt.show()

# Function to visualize reconstructions
def visualize_reconstructions(model, x_test, y_test=None, n=10):
    # Get random samples from each digit class if y_test provided
    if y_test is not None:
        x_sample = []
        for i in range(10):  # For each digit
            idx = np.where(y_test == i)[0]
            if len(idx) > 0:
                selected_idx = np.random.choice(idx, 1)[0]
                x_sample.append(x_test[selected_idx])
        x_sample = np.array(x_sample)
    else:
        # Get random samples
        random_indices = np.random.choice(len(x_test), n, replace=False)
        x_sample = x_test[random_indices]
    
    # Ensure we have exactly n samples
    if len(x_sample) < n:
        x_sample = x_test[np.random.choice(len(x_test), n, replace=False)]
    elif len(x_sample) > n:
        x_sample = x_sample[:n]
    
    # Reconstruct samples
    reconstructed = model.decoder(model.encoder(x_sample)[2])
    
    plt.figure(figsize=(20, 4))
    for i in range(n):
        # Original
        plt.subplot(2, n, i + 1)
        plt.imshow(x_sample[i].reshape(28, 28), cmap='gray')
        plt.axis('off')
        
        # Reconstruction
        plt.subplot(2, n, i + n + 1)
        plt.imshow(reconstructed[i].numpy().reshape(28, 28), cmap='gray')
        plt.axis('off')
    
    plt.tight_layout()
    plt.savefig("reconstructions.png")
    plt.show()

# Function to visualize cluster centers as digit images
def visualize_cluster_centers(model):
    # Decode cluster centers to see what each cluster represents
    decoded_centers = model.decoder(model.cluster_centers)
    
    plt.figure(figsize=(20, 2))
    for i in range(model.n_clusters):
        plt.subplot(1, model.n_clusters, i + 1)
        plt.imshow(decoded_centers[i].numpy().reshape(28, 28), cmap='gray')
        plt.title(f"Cluster {i}")
        plt.axis('off')
    
    plt.tight_layout()
    plt.savefig("cluster_centers.png")
    plt.show()

# Function to visualize loss history
def plot_history(history):
    plt.figure(figsize=(15, 5))
    
    # Plot losses
    plt.subplot(1, 2, 1)
    plt.plot(history['total_loss'], label='Total Loss')
    plt.plot(history['reconstruction_loss'], label='Reconstruction Loss')
    plt.plot(history['kl_loss'], label='KL Loss')
    if 'clustering_loss' in history:
        plt.plot(history['clustering_loss'], label='Clustering Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training Loss')
    
    # Plot ARI
    if 'ari_scores' in history:
        plt.subplot(1, 2, 2)
        epochs = list(range(len(history['ari_scores'])))
        plt.plot(epochs, history['ari_scores'], marker='o')
        plt.xlabel('Evaluation Step')
        plt.ylabel('ARI Score')
        plt.title('ARI Performance')
    
    plt.tight_layout()
    plt.savefig("training_history.png")
    plt.show()

# Main function to run the improved experiment
def run_improved_vae_experiment(latent_dim=20, n_clusters=10, beta=1.0, alpha=1.0, 
                                pretrain_epochs=30, finetune_epochs=20, batch_size=256):
    # Create output directory for results
    os.makedirs("results", exist_ok=True)
    
    # Load data
    x_train, y_train, x_test, y_test = load_and_preprocess_mnist()
    
    # Create model
    vae = ClusteringVAE(latent_dim=latent_dim, n_clusters=n_clusters, beta=beta, alpha=alpha)
    vae.compile(optimizer=keras.optimizers.Adam(learning_rate=0.001))
    
    # Create training dataset
    train_dataset = tf.data.Dataset.from_tensor_slices(x_train)
    train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
    
    # Training loop for pretraining (standard VAE training)
    print("Phase 1: Pretraining VAE...")
    history = {"total_loss": [], "reconstruction_loss": [], "kl_loss": [], "ari_scores": []}
    
    for epoch in range(pretrain_epochs):
        epoch_loss = {"total_loss": 0.0, "reconstruction_loss": 0.0, "kl_loss": 0.0}
        num_batches = 0
        
        for batch in train_dataset:
            loss_dict = vae.train_step(batch)
            
            for key in epoch_loss:
                epoch_loss[key] += loss_dict[key].numpy()
            
            num_batches += 1
        
        # Average losses for the epoch
        for key in epoch_loss:
            epoch_loss[key] /= num_batches
            history[key].append(epoch_loss[key])
        
        # Evaluate ARI every few epochs
        if (epoch + 1) % 5 == 0 or epoch == pretrain_epochs - 1:
            # Temporary K-means for evaluation
            z_mean = vae.encoder(x_test)[0].numpy()
            kmeans = KMeans(n_clusters=n_clusters, random_state=42)
            cluster_labels = kmeans.fit_predict(z_mean)
            ari_score = adjusted_rand_score(y_test, cluster_labels)
            
            history["ari_scores"].append(ari_score)
            print(f"Epoch {epoch+1}/{pretrain_epochs}, Total Loss: {epoch_loss['total_loss']:.4f}, "
                  f"Reconstruction Loss: {epoch_loss['reconstruction_loss']:.4f}, "
                  f"KL Loss: {epoch_loss['kl_loss']:.4f}, ARI: {ari_score:.4f}")
        else:
            print(f"Epoch {epoch+1}/{pretrain_epochs}, Total Loss: {epoch_loss['total_loss']:.4f}, "
                  f"Reconstruction Loss: {epoch_loss['reconstruction_loss']:.4f}, "
                  f"KL Loss: {epoch_loss['kl_loss']:.4f}")
    
    # Visualize pretrained model results
    z_mean = vae.encoder(x_test)[0].numpy()
    kmeans = KMeans(n_clusters=n_clusters, random_state=42)
    cluster_labels = kmeans.fit_predict(z_mean)
    visualize_latent_space(z_mean, y_test, cluster_labels, title_suffix=" (After Pretraining)")
    visualize_reconstructions(vae, x_test, y_test)
    
    # Initialize clustering
    print("\nPhase 2: Initializing clustering...")
    vae.initialize_clustering(x_train, y_train, epochs=5)
    
    # Fine-tuning with clustering objective
    print("\nPhase 3: Fine-tuning with clustering objective...")
    clustering_history = {"total_loss": [], "reconstruction_loss": [], 
                         "kl_loss": [], "clustering_loss": [], "ari_scores": []}
    
    for epoch in range(finetune_epochs):
        epoch_loss = {"total_loss": 0.0, "reconstruction_loss": 0.0, 
                     "kl_loss": 0.0, "clustering_loss": 0.0}
        num_batches = 0
        
        for batch in train_dataset:
            loss_dict = vae.train_step(batch)
            
            for key in epoch_loss:
                epoch_loss[key] += loss_dict[key].numpy()
            
            num_batches += 1
        
        # Average losses for the epoch
        for key in epoch_loss:
            epoch_loss[key] /= num_batches
            clustering_history[key].append(epoch_loss[key])
        
        # Evaluate ARI every few epochs
        if (epoch + 1) % 2 == 0 or epoch == finetune_epochs - 1:
            ari_score, latent_reps, cluster_labels = evaluate_ari(vae, x_test, y_test)
            clustering_history["ari_scores"].append(ari_score)
            
            print(f"Epoch {epoch+1}/{finetune_epochs}, Total Loss: {epoch_loss['total_loss']:.4f}, "
                  f"Reconstruction Loss: {epoch_loss['reconstruction_loss']:.4f}, "
                  f"KL Loss: {epoch_loss['kl_loss']:.4f}, "
                  f"Clustering Loss: {epoch_loss['clustering_loss']:.4f}, "
                  f"ARI: {ari_score:.4f}")
        else:
            print(f"Epoch {epoch+1}/{finetune_epochs}, Total Loss: {epoch_loss['total_loss']:.4f}, "
                  f"Reconstruction Loss: {epoch_loss['reconstruction_loss']:.4f}, "
                  f"KL Loss: {epoch_loss['kl_loss']:.4f}, "
                  f"Clustering Loss: {epoch_loss['clustering_loss']:.4f}")
    
    # Final evaluation
    final_ari, latent_reps, cluster_labels = evaluate_ari(vae, x_test, y_test)
    print(f"\nFinal ARI Score: {final_ari:.4f}")
    
    # Final visualizations
    visualize_latent_space(latent_reps, y_test, cluster_labels, title_suffix=" (Final)")
    visualize_reconstructions(vae, x_test, y_test)
    visualize_cluster_centers(vae)
    
    # Combine history and plot
    combined_history = {}
    for key in history:
        combined_history[key] = history[key] + clustering_history[key] if key in clustering_history else history[key]
    if 'clustering_loss' in clustering_history:
        combined_history['clustering_loss'] = [0] * len(history['total_loss']) + clustering_history['clustering_loss']
    
    plot_history(combined_history)
    
    return vae, combined_history, final_ari

# Run with optimized parameters for high ARI
vae, history, ari = run_improved_vae_experiment(
    latent_dim=20,         # Higher latent dims capture more structure
    n_clusters=10,         # 10 clusters for MNIST (one per digit)
    beta=1.0,              # Standard weight for KL divergence
    alpha=1.5,             # Weight for clustering objective
    pretrain_epochs=30,    # Pretraining epochs
    finetune_epochs=20,    # Finetuning epochs with clustering objective
    batch_size=64 #256         # Larger batch size for better gradient estimates
)

Phase 1: Pretraining VAE...


2025-04-15 21:25:55.910233: W tensorflow/core/framework/op_kernel.cc:1841] OP_REQUIRES failed at reduction_ops_common.h:147 : INVALID_ARGUMENT: Invalid reduction dimension (3 for input with 3 dimension(s)


InvalidArgumentError: {{function_node __wrapped__Sum_device_/job:localhost/replica:0/task:0/device:GPU:0}} Invalid reduction dimension (3 for input with 3 dimension(s) [Op:Sum]