# Experimentation - Task 2

We tried mainly 3 other approaches, which are described here - 

i) Pseudo-dataset generation and then applying Gaussian Mixture Model using Expectation Maximization

ii)	KL-Divergence minimization between new datasets and pseudo-generated dataset

iii) Loss function minimization (as described in Paper 2 - https://arxiv.org/pdf/2301.10418)

In [2]:
import torch
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import warnings
from sklearn.decomposition import PCA
from scipy.stats import chi2
import matplotlib.pyplot as plt
from sklearn.preprocessing import PowerTransformer
import torchvision.models as models
import torch.nn as nn
from torchvision import transforms
import pickle
from scipy.stats import multivariate_normal
from scipy.special import logsumexp
warnings.filterwarnings("ignore")

## 1) GMM using EM

The preprocessing steps remain the same as before.

In [1]:
class PreprocessingPipeline:
    def __init__(self, n_components=50):
        self.scaler = StandardScaler()
        self.transformer = PowerTransformer(method='yeo-johnson')
        self.pca = PCA(n_components=n_components)
        self.pca_mean = None
        self.pca_std = None

    def fit_transform(self, X_train):
        """
        Fits the preprocessing pipeline on the training set and applies transformations.
        
        Parameters:
        n: ndarray of shape (n_samples, n_features), training data.
        
        Returns:
        - pca_train: Preprocessed training data.
        """
        # Step 1: Standardization
        standardized_train = self.scaler.fit_transform(X_train)

        # Step 2: Yeo-Johnson Transformation
        transformed_train = self.transformer.fit_transform(standardized_train)

        # Step 3: PCA
        pca_train = self.pca.fit_transform(transformed_train)

        # Step 4: Standardize PCA-transformed data
        self.pca_mean = np.mean(pca_train, axis=0)
        self.pca_std = np.std(pca_train, axis=0)
        pca_train_standardized = (pca_train - self.pca_mean) / self.pca_std

        return pca_train_standardized

    def transform(self, X):
        """
        Applies the fitted preprocessing pipeline to new data.

        Parameters:
        - X: ndarray of shape (n_samples, n_features), data to preprocess.
        
        Returns:
        - pca_transformed: Preprocessed data.
        """
        # Step 1: Standardization
        standardized = self.scaler.transform(X)

        # Step 2: Yeo-Johnson Transformation
        transformed = self.transformer.transform(standardized)

        # Step 3: PCA
        pca_transformed = self.pca.transform(transformed)

        # Step 4: Standardize PCA-transformed data using training set stats
        pca_transformed_standardized = (pca_transformed - self.pca_mean) / self.pca_std

        return pca_transformed_standardized

In [3]:
pipeline=PreprocessingPipeline()
X_train=[pipeline.fit_transform(torch.load('extracted_data\X_train_1.pth'))]+[pipeline.transform(torch.load(f'extracted_data\X_train_{i}.pth')) for i in range(2,21)]
X_eval=[pipeline.transform(torch.load(f'extracted_data\X_eval_{i}.pth')) for i in range(1,21)]

This is the modified QDA model including pseudo-dataset generation and expectation maximization algorithm.

In [4]:
import numpy as np
from scipy.stats import multivariate_normal

class QDAClassifier:
    def __init__(self):
        self.class_means = {}
        self.class_covariances = {}
        self.class_priors = {}
        self.class_counts = {}
        self.total_samples = 0
    
    def fit(self, X, y):
        """
        Fits the QDA model to the data.

        Parameters:
        - X: ndarray of shape (n_samples, n_features), training data
        - y: ndarray of shape (n_samples,), class labels
        """
        self.total_samples = X.shape[0]
        classes = np.unique(y)
        
        for c in classes:
            # Get data points belonging to class c
            class_data = X[y == c]
            class_count = class_data.shape[0]
            
            # Compute class-specific statistics
            self.class_means[c] = np.mean(class_data, axis=0)
            self.class_covariances[c] = np.cov(class_data, rowvar=False)
            self.class_priors[c] = class_count / self.total_samples
            self.class_counts[c] = class_count

    def predict(self, X):
        """
        Predict the class labels for a dataset X.
    
        Parameters:
        - X: ndarray of shape (n_samples, n_features), the input data matrix.
    
        Returns:
        - predictions: ndarray of shape (n_samples,), the predicted class labels for each sample.
        """
        n_samples = X.shape[0]
        predictions = np.zeros(n_samples, dtype=int)  # Initialize predictions array
        
        # Get the classes from the keys of class_means
        classes = list(self.class_means.keys())
    
        for i in range(n_samples):
            posteriors = []
            for c in classes:
                # Compute likelihood P(x | y = c) using the multivariate Gaussian PDF
                mean = self.class_means[c]
                cov = self.class_covariances[c]
                prior = self.class_priors[c]
                likelihood = multivariate_normal.pdf(X[i], mean=mean, cov=cov)
    
                # Compute posterior P(y = c | x) = P(x | y = c) * P(y = c)
                posterior = likelihood * prior
                posteriors.append(posterior)
            
            # Predict the class with the highest posterior
            predictions[i] = classes[np.argmax(posteriors)]
        
        return predictions
    
    def get_class_statistics(self):
        """
        Returns the learned statistics for each class.
        
        Returns:
        - class_means: dict of class means
        - class_covariances: dict of class covariance matrices
        - class_priors: dict of class priors
        - class_counts: dict of the number of samples per class
        """
        return {
            'means': self.class_means,
            'covariances': self.class_covariances,
            'priors': self.class_priors,
            'counts': self.class_counts,
        }
    def update(self, X):

    # Step 2: Predict labels
        predicted_labels = self.predict(X)
    
    # Step 3: Update statistics per class
        for c in self.class_means.keys():
            # Get new samples for class c
            new_samples = X[predicted_labels == c]
            new_count = new_samples.shape[0]
            if new_count == 0:
                continue  # No new samples for this class
        
            # Update mean
            current_mean = self.class_means[c]
            current_count = self.class_counts[c]
            new_mean = (current_mean * current_count + new_samples.sum(axis=0)) / (current_count + new_count)
        
            # Update covariance
            current_cov = self.class_covariances[c]
            scatter_current = current_cov * current_count
            scatter_new = np.cov(new_samples, rowvar=False) * new_count
            scatter_updated = scatter_current + scatter_new
            updated_cov = scatter_updated / (current_count + new_count)
        
        # Update priors, counts, etc.
            self.class_means[c] = new_mean
            self.class_covariances[c] = updated_cov
            self.class_counts[c] += new_count
            self.class_priors[c] = self.class_counts[c] / self.total_samples

    # Update total sample count
        self.total_samples += X.shape[0]
    def mod_update(self, X, h):
        predicted_labels = self.predict(X)
        for c in self.class_means.keys():
            # Get new samples for class c
            new_samples = X[predicted_labels == c]
            new_count = new_samples.shape[0]
            if new_count == 0:
                continue  # No new samples for this class
        
            # Update mean
            current_mean = self.class_means[c]
            # current_count = self.class_counts[c]
            new_mean = current_mean *(1-h) + (h)*new_samples.sum(axis=0)/new_count
        
            # Update covariance
            current_cov = self.class_covariances[c]
            scatter_current = current_cov * (1-h)
            scatter_new = np.cov(new_samples, rowvar=False) * h
            scatter_updated = scatter_current + scatter_new
            updated_cov = scatter_updated
        
        # Update priors, counts, etc.
            self.class_means[c] = new_mean
            self.class_covariances[c] = updated_cov
            self.class_priors[c] = self.class_counts[c]*(1-h) + new_count*h
            self.class_counts[c] += new_count
    
    
    def generate_samples(self, num_samples):
        """
        Generate synthetic samples using the learned class distributions.
        
        Parameters:
        - num_samples: int, total number of synthetic samples to generate.
        
        Returns:
        - X_generated: ndarray of shape (num_samples, n_features), the generated samples.
        - y_generated: ndarray of shape (num_samples,), the corresponding class labels.
        """
        # Initialize storage for generated samples and labels
        X_generated = []
        y_generated = []

        # Generate samples for each class based on the prior probabilities
        for c, prior in self.class_priors.items():
            # Number of samples to generate for this class
            class_samples = int(np.round(prior * num_samples))
            
            # Sample from the Gaussian distribution for this class
            mean = self.class_means[c]
            cov = self.class_covariances[c]
            generated = np.random.multivariate_normal(mean, cov, size=class_samples)
            
            # Append to the result
            X_generated.append(generated)
            y_generated.extend([c] * class_samples)

        # Concatenate and shuffle to create the final dataset
        X_generated = np.vstack(X_generated)
        y_generated = np.array(y_generated)
        indices = np.arange(len(y_generated))
        np.random.shuffle(indices)

        return X_generated[indices], y_generated[indices]


import numpy as np
from scipy.stats import multivariate_normal

from scipy.special import logsumexp

class QDAClassifierWithEM(QDAClassifier):
    def expectation_maximization(self, X, max_iter=100, tol=1e-6, regularization=1e-6, min_iter=10):
        """
        Performs the EM algorithm to refine the QDA parameters using unlabeled data.

        Parameters:
        - X: ndarray of shape (n_samples, n_features), unlabeled data
        - max_iter: int, maximum number of EM iterations
        - tol: float, convergence tolerance for parameter changes
        - regularization: float, regularization term for covariance matrices
        - min_iter: int, minimum number of iterations to avoid premature stopping
        """
        X_generated=self.generate_samples(len(X))
        X=np.vstack((X_generated, X))
        # Initialize parameters from the current QDA model
        classes = list(self.class_means.keys())
        k = len(classes)
        n_samples, n_features = X.shape

        # Initialize variables
        responsibilities = np.zeros((n_samples, k))
        log_likelihood = -np.inf
        temp_class_means=self.class_means
        temp_class_counts=self.class_counts
        temp_class_covariances=self.class_covariances
        temp_class_priors=self.class_priors
        for iteration in range(max_iter):
            # Step 1: Expectation (E-step)
            log_probabilities = np.zeros((n_samples, k))
            for i, c in enumerate(classes):
                mean = temp_class_means[c]
                cov = temp_class_covariances[c] + regularization * np.eye(n_features)  # Regularized covariance
                prior = temp_class_priors[c]
                log_probabilities[:, i] = np.log(prior) + multivariate_normal.logpdf(X, mean=mean, cov=cov)

            # Stabilized log-sum-exp for log-likelihood normalization
            log_responsibilities = log_probabilities - logsumexp(log_probabilities, axis=1, keepdims=True)
            responsibilities = np.exp(log_responsibilities)

            # Step 2: Maximization (M-step)
            new_means, new_covariances, new_priors, new_class_counts = {}, {}, {}, {}
            for i, c in enumerate(classes):
                effective_count = responsibilities[:, i].sum()
                new_class_counts[c] = effective_count

                # Update mean
                new_mean = (responsibilities[:, i][:, np.newaxis] * X).sum(axis=0) / effective_count
                new_means[c] = new_mean

                # Update covariance
                centered_X = X - new_mean
                weighted_cov = (responsibilities[:, i][:, np.newaxis, np.newaxis] * 
                                np.einsum('ni,nj->nij', centered_X, centered_X)).sum(axis=0)
                new_covariances[c] = weighted_cov / effective_count + regularization * np.eye(n_features)

                # Update prior
                new_priors[c] = effective_count / n_samples

            # Compute log-likelihood
            new_log_likelihood = logsumexp(log_probabilities).sum()

            # Convergence check
            if iteration >= min_iter and np.abs(new_log_likelihood - log_likelihood) < tol:
                print(f"EM converged at iteration {iteration + 1}")
                break

            log_likelihood = new_log_likelihood
     
            # Update parameters for the next iteration
            temp_class_means = new_means
            temp_class_covariances = new_covariances
            temp_class_priors = new_priors
            temp_class_counts = new_class_counts
        
        for c in classes:
            #Combine counts
            total_count = self.class_counts[c] + new_class_counts[c]

            # Weighted mean update
            self.class_means[c] = (self.class_means[c] * self.class_counts[c] + new_means[c] * new_class_counts[c]) / total_count

            # Weighted covariance update
            centered_self_cov = self.class_covariances[c] * self.class_counts[c]
            centered_new_cov = new_covariances[c] * new_class_counts[c]
    
            scatter_combined = (centered_self_cov + centered_new_cov + (self.class_counts[c] * new_class_counts[c]) / total_count * np.outer(self.class_means[c] - new_means[c], self.class_means[c] - new_means[c]))
            self.class_covariances[c] = scatter_combined / total_count

                # Update prior probabilities
            self.class_priors[c] = total_count/ (self.total_samples + n_samples)

                # Update total counts
            self.class_counts[c] = total_count

        self.total_samples=self.total_samples+n_samples

        print(f"EM completed in {iteration + 1} iterations.")

In [6]:
path = r"C:\Users\ARITRA\Documents\Notebooks\CS771_MiniProject2\dataset\dataset\part_one_dataset"
traindata_1 = torch.load(f"{path}\\train_data\\1_train_data.tar.pth", map_location=torch.device('cpu'))
y_train_1 = traindata_1['targets']
qdaem = QDAClassifierWithEM()
qdaem.fit(X_train[0], y_train_1)
for i in range(1,10):
    qdaem.update(X_train[i])
    print("Done")
pickle.dump(qdaem, open("final_f10.pkl", "wb"))

Done
Done
Done
Done
Done
Done
Done
Done
Done


## 2) KL-Divergence minimization

In [7]:
import numpy as np
from scipy.optimize import minimize

# Define KL divergence for multivariate Gaussians
def kl_divergence(params, mu_p, cov_p, mu_q, cov_q):
    n = mu_p.shape[0]  # Dimensionality
    a = np.diag(params[:n])  # Diagonal scaling matrix
    b = params[n:]  # Offset vector
    
    # Transformed mean and covariance
    mu_k = a @ mu_q + b
    cov_k = a @ cov_q @ a.T

    # KL divergence components
    term1 = np.log(np.linalg.det(cov_k) / np.linalg.det(cov_p))
    term2 = np.trace(np.linalg.inv(cov_k) @ cov_p)
    term3 = (mu_p - mu_k).T @ np.linalg.inv(cov_k) @ (mu_p - mu_k)
    term4 = -n

    return 0.5 * (term1 + term2 + term3 + term4)

# Parameters for 10 classes of p(x) and q(x)
classes = 10
dim = 3  # Dimensionality of the Gaussians

# Generate random means and covariances for illustration
np.random.seed(42)
mu_p_list = [np.random.rand(dim) for _ in range(classes)]
cov_p_list = [np.eye(dim) for _ in range(classes)]  # Identity covariance for simplicity

mu_q_list = [np.random.rand(dim) for _ in range(classes)]
cov_q_list = [np.eye(dim) * 2 for _ in range(classes)]  # Scaled identity covariance

# Placeholder for transformed Gaussian parameters
transformed_gaussians = []

for i in range(classes):
    mu_p = mu_p_list[i]
    cov_p = cov_p_list[i]
    mu_q = mu_q_list[i]
    cov_q = cov_q_list[i]
    
    # Initial guess for parameters
    initial_params = np.concatenate([np.ones(dim), np.zeros(dim)])  # [a_diag, b]
    
    # Minimize KL divergence for the current class
    result = minimize(kl_divergence, initial_params, args=(mu_p, cov_p, mu_q, cov_q))
    
    # Extract optimal parameters
    optimal_params = result.x
    a_optimal = np.diag(optimal_params[:dim])
    b_optimal = optimal_params[dim:]
    
    # Compute transformed Gaussian parameters
    mu_k = a_optimal @ mu_q + b_optimal
    cov_k = a_optimal @ cov_q @ a_optimal.T
    
    # Save the transformed parameters
    transformed_gaussians.append((mu_k, cov_k))

# Output the transformed Gaussian distributions
for i, (mu_k, cov_k) in enumerate(transformed_gaussians):
    print(f"Class {i + 1}:")
    print(f"Transformed Mean (mu_k): {mu_k}")
    print(f"Transformed Covariance (cov_k):\n{cov_k}\n")

Class 1:
Transformed Mean (mu_k): [0.37454019 0.95071425 0.73199401]
Transformed Covariance (cov_k):
[[0.99999974 0.         0.        ]
 [0.         1.00000011 0.        ]
 [0.         0.         1.00000022]]

Class 2:
Transformed Mean (mu_k): [0.59865906 0.15601888 0.15599435]
Transformed Covariance (cov_k):
[[1.0000004  0.         0.        ]
 [0.         1.00000001 0.        ]
 [0.         0.         1.00000037]]

Class 3:
Transformed Mean (mu_k): [0.05808353 0.86617708 0.60111572]
Transformed Covariance (cov_k):
[[1.00000039 0.         0.        ]
 [0.         1.00000059 0.        ]
 [0.         0.         1.00000037]]

Class 4:
Transformed Mean (mu_k): [0.7080738  0.02058475 0.96990983]
Transformed Covariance (cov_k):
[[1.00000417 0.         0.        ]
 [0.         0.99999926 0.        ]
 [0.         0.         1.00000582]]

Class 5:
Transformed Mean (mu_k): [0.8324418  0.21233873 0.18182489]
Transformed Covariance (cov_k):
[[0.99999839 0.         0.        ]
 [0.         1.0000

In [8]:
import pickle
qda = pickle.load(open('f10_qda.pkl','rb'))

In [9]:
X_train_11 = torch.load('extracted_data/X_train_11.pth')
X_eval_11 = torch.load('extracted_data/X_eval_11.pth')
evaldata_11 = torch.load(r"C:\Users\ARITRA\Documents\Notebooks\CS771_MiniProject2\dataset\dataset\part_two_dataset\eval_data\1_eval_data.tar.pth")
y_test_11 = evaldata_11['targets']
pca_train=pipeline.transform(X_train_11)
pca_eval=pipeline.transform(X_eval_11)

In [10]:
y_pseudo_11 = qda.predict(pca_train)

In [11]:
import numpy as np

# Example inputs
# pca_train = np.array(...)  # n x d matrix
# y_pseudo_11 = np.array(...)  # n x 1 vector or n-length array

# Unique classes in y_pseudo_11
classes = np.unique(y_pseudo_11)

# Initialize dictionaries to store results
means = []
covariances = []

for cls in classes:
    # Get indices of the current class
    indices = np.where(y_pseudo_11 == cls)[0]
    
    # Extract rows corresponding to the current class
    class_data = pca_train[indices]
    
    # Compute mean and covariance
    means.append(np.mean(class_data, axis=0))
    covariances.append(np.cov(class_data, rowvar=False))

In [12]:
mu_p_list = qda.get_class_statistics()['means']
cov_p_list = qda.get_class_statistics()['covariances']  # Identity covariance for simplicity

mu_q_list = means
cov_q_list = covariances  # Scaled identity covariance

In [13]:
# Placeholder for transformed Gaussian parameters
classes = 10
dim = 50
transformed_gaussians = []

for i in range(classes):
    mu_p = mu_p_list[i]
    cov_p = cov_p_list[i]
    mu_q = mu_q_list[i]
    cov_q = cov_q_list[i]
    
    # Initial guess for parameters
    initial_params = np.concatenate([np.ones(dim), np.zeros(dim)])  # [a_diag, b]
    
    # Minimize KL divergence for the current class
    result = minimize(kl_divergence, initial_params, args=(mu_p, cov_p, mu_q, cov_q))
    
    # Extract optimal parameters
    optimal_params = result.x
    a_optimal = np.diag(optimal_params[:dim])
    b_optimal = optimal_params[dim:]
    
    # Compute transformed Gaussian parameters
    mu_k = a_optimal @ mu_q + b_optimal
    cov_k = a_optimal @ cov_q @ a_optimal.T
    
    # Save the transformed parameters
    transformed_gaussians.append((mu_k, cov_k))

# Output the transformed Gaussian distributions
for i, (mu_k, cov_k) in enumerate(transformed_gaussians):
    print(f"Class {i + 1}:")
    print(f"Transformed Mean (mu_k): {mu_k}")
    print(f"Transformed Covariance (cov_k):\n{cov_k}\n")

Class 1:
Transformed Mean (mu_k): [ 0.72245117 -1.10639717 -0.47952426  0.53057121 -0.66278662 -0.31387892
  0.45828145 -0.22612007  0.26091339 -0.23146773  0.54051239  0.14684358
  0.33987582  0.25774372 -0.22241453  0.43204617 -0.55152541 -0.30305663
 -0.0772027  -0.09135301  0.26700167 -0.03132013  0.0373942   0.10497126
 -0.17761777  0.11709951 -0.18453042 -0.0618883  -0.22547649  0.26504388
  0.08665083 -0.16907683 -0.20727861  0.05310048 -0.13724183 -0.02825839
  0.16732125 -0.04251067  0.0107204   0.09130389  0.10061499 -0.01385897
 -0.02002014  0.02397484 -0.00154638  0.01654231  0.14574389  0.18598571
 -0.00281289 -0.02428848]
Transformed Covariance (cov_k):
[[ 0.60323397  0.1168835  -0.0734929  ...  0.0431948  -0.00945292
  -0.00689636]
 [ 0.1168835   0.53403591  0.07577374 ...  0.26088746  0.00995005
   0.10506076]
 [-0.0734929   0.07577374  0.90827066 ...  0.18386272  0.0810887
   0.03801341]
 ...
 [ 0.0431948   0.26088746  0.18386272 ...  1.30870178 -0.02543037
  -0.006398

In [14]:
# List to store the transformed data points
transformed_pca_data = []

# Loop through each data point in pca_train
for idx, data_point in enumerate(pca_train):
    # Look up the class of the current data point
    class_label = y_pseudo_11[idx]
    
    # Retrieve optimal transformation parameters for the identified class
    a_optimal = np.diag(np.sqrt(np.diag(transformed_gaussians[class_label][1]) / np.diag(cov_q_list[class_label])))  # Diagonal scaling matrix
    b_optimal = transformed_gaussians[class_label][0] - a_optimal @ mu_q_list[class_label]  # Shift vector to align the means

    # Transform the current data point
    transformed_data_point = (a_optimal @ data_point) + b_optimal

    # Append the transformed data point to the list
    transformed_pca_data.append(transformed_data_point)

# Convert the list to a NumPy array for further processing
transformed_pca_data = np.array(transformed_pca_data)

# Output the transformed data points for verification (optional)
# print(f"Transformed PCA Data (First 5 rows):\n{transformed_pca_data[:5]}")

In [15]:
qda.update(transformed_pca_data)

In [16]:
prediction=qda.predict_modified(pca_eval)
print(np.mean(prediction == y_test_11)*100)

AttributeError: 'QDAClassifier' object has no attribute 'predict_modified'

Accuracy is not upto the mark, though the approach was interesting.

## 3) Loss function minimization

In [None]:
import torch
import glob

# Paths to train and eval datasets
train_files = sorted(glob.glob("X_train_*.pth"))  # Glob pattern for train datasets
eval_files = sorted(glob.glob("X_eval_*.pth"))    # Glob pattern for eval datasets

# Load all train datasets
train_datasets = [torch.load(file) for file in train_files]

# Load all eval datasets
eval_datasets = [torch.load(file) for file in eval_files]

# Store datasets in a dictionary (optional)
datasets = {
    "train": train_datasets,
    "eval": eval_datasets
}

# Print summary
print(f"Loaded {len(train_datasets)} train datasets.")
print(f"Loaded {len(eval_datasets)} eval datasets.")


In [None]:
y_eval_datasets = []
for i in range(1, 11):
    y_eval_datasets.append(torch.load(f"dataset\\dataset\\part_two_dataset\\eval_data\\{i}_eval_data.tar.pth")['targets'])

In [None]:
import torch
import numpy as np
from torch.utils.data import DataLoader, TensorDataset

def compute_pseudo_labels(X, class_means):
    """
    Computes pseudo-labels for the dataset using Euclidean distance.
    """
    # Ensure X is a tensor with consistent dtype
    if isinstance(X, np.ndarray):
        X = torch.tensor(X, dtype=torch.float32)
    elif X.dtype != torch.float32:
        X = X.to(torch.float32)
    
    # Convert class_means dict to a tensor with consistent dtype
    class_means_tensor = torch.stack(
        [torch.tensor(class_means[k], dtype=torch.float32) for k in sorted(class_means.keys())]
    )  # Shape: [K, d]
    
    # Calculate the Euclidean distance between each data point and each class mean
    distances = torch.cdist(X, class_means_tensor)  # Shape: [N, K]
    
    # Get the index of the class with the minimum distance for each data point
    pseudo_labels = torch.argmin(distances, dim=1)  # Shape: [N]
    
    return pseudo_labels

import torch

# Example data
X = train_datasets_pre_processed[10]  # `data` is a tensor of size (2500, 2048)
pseudo_labels = compute_pseudo_labels(train_datasets_pre_processed[10], class_means)  # A tensor of size (2500,)
# `class_means` is now a list of tensors: [tensor(2048), tensor(2048), ..., tensor(2048)]

# Initialize `updated_means` as a list of tensors similar to `class_means`
updated_means = [torch.tensor(class_mean).clone().detach().zero_().requires_grad_(True) for mean, class_mean in class_means.items()]

def loss_function(X, updated_means, class_means, pseudo_labels, lambda_reg=0.1):
    """
    Optimized loss function calculation with soft assignment and regularization of class means.
    
    Parameters:
    - X: torch.Tensor of shape (N, d), dataset with N samples and d features.
    - updated_means: list of tensors, updated class means, each of shape (d,).
    - class_means: list of tensors, original class means, each of shape (d,).
    - pseudo_labels: torch.Tensor of shape (N,), pseudo-labels for each sample.
    - lambda_reg: Scalar, the weight of the regularization term (default is 0.1).
    
    Returns:
    - loss: Scalar torch.Tensor, the computed loss value.
    """
    # Convert class_means and updated_means to tensors
    updated_means_tensor = torch.stack(updated_means)  # Shape: [K, d]
    class_means_tensor = torch.stack([torch.tensor(class_mean).clone().detach().zero_().requires_grad_(True) for mean, class_mean in class_means.items()])  # Shape: [K, d]
    
    # Compute distances between all samples and all class means
    X_expanded = torch.tensor(X).unsqueeze(1)  # Shape: [N, 1, d]
    updated_distances = torch.norm(X_expanded - updated_means_tensor, dim=2)  # Shape: [N, K]
    class_distances = torch.norm(X_expanded - class_means_tensor, dim=2)  # Shape: [N, K]
    
    # Soft assignment (probabilities for each class based on the distances)
    updated_soft_assignment = torch.softmax(-updated_distances, dim=1)  # Shape: [N, K]
    class_soft_assignment = torch.softmax(-class_distances, dim=1)  # Shape: [N, K]
    
    # First term: log(num / den) for updated_means
    updated_num = torch.gather(updated_distances**2, 1, pseudo_labels.unsqueeze(1)).squeeze(1)  # Shape: [N]
    updated_den = (updated_distances**2).sum(dim=1)  # Shape: [N]
    first_term = -torch.log(updated_num / updated_den).sum()

    # Second term: log(num / den) for combined distances
    combined_num = updated_num + torch.gather(class_distances**2, 1, pseudo_labels.unsqueeze(1)).squeeze(1)
    
    # Compute the denominator for the second term
    combined_den = (updated_distances**2).sum(dim=1) + (class_distances**2).sum(dim=1)
    
    # Compute cross-sample pairwise distances
    pairwise_distances = torch.cdist(torch.tensor(X), torch.tensor(X), p=2)  # Shape: [N, N]
    mask = (pseudo_labels.unsqueeze(1) != pseudo_labels.unsqueeze(0))  # Mask for different class pairs
    combined_den += (pairwise_distances**2 * mask).sum(dim=1)
    
    second_term = -torch.log(combined_num / combined_den).sum()

    # Regularization term to keep the means as far apart as possible
    pairwise_means_distances = torch.cdist(updated_means_tensor, updated_means_tensor, p=2)  # Shape: [K, K]
    regularization_term = torch.sum(1.0 / (pairwise_means_distances + 1e-6))  # Avoid division by zero

    # Total loss with regularization
    loss = -first_term - second_term - lambda_reg * regularization_term  # Subtract regularization to maximize distance
    return loss

# Optimizer
optimizer = torch.optim.SGD(updated_means, lr=0.01)

# Training loop
num_epochs = 100  # Number of epochs
for epoch in range(num_epochs):
    optimizer.zero_grad()  # Reset gradients
    loss = loss_function(X, updated_means, class_means, pseudo_labels)  # Compute loss
    loss.backward()  # Backpropagation
    optimizer.step()  # Update `updated_means`

    # Normalize updated_means after each step (optional)
    for i in range(len(updated_means)):
        updated_means[i].data = updated_means[i].data / updated_means[i].norm()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

In [None]:
def predict(X, updated_means):
    """
    Predicts class labels for the given data points using the updated means.
    
    Parameters:
    - X: torch.Tensor of shape (N, d), where N is the number of data points and d is the feature dimension.
    - updated_means: list of torch.Tensor, where each tensor is the mean vector of a class (d-dimensional).
    
    Returns:
    - predictions: torch.Tensor of shape (N,), the predicted class labels.
    """
    # Convert updated_means to a tensor of shape (K, d)
    updated_means_tensor = torch.stack(updated_means)  # Shape: [K, d]
    
    # Compute distances between each data point and each class mean
    distances = torch.cdist(torch.tensor(X), updated_means_tensor)  # Shape: [N, K]
    
    # Predict the class with the minimum distance
    predictions = torch.argmin(distances, dim=1)  # Shape: [N]
    return predictions



def evaluate(X_eval, y_eval, updated_means):
    
    # Get predictions
    predictions = predict(X_eval, updated_means)
    
    # Compare predictions with true labels
    correct = (predictions == y_eval).sum().item()
    total = y_eval.shape[0]  # Use .shape instead of .size
    print(predictions)
    # Compute accuracy
    accuracy = correct / total
    return accuracy


# Evaluate the model
accuracy = evaluate(eval_datasets_pre_processed[10], y_eval_datasets[1], updated_means)
print(f"Accuracy on evaluation data: {accuracy:.2%}")