In [30]:
import os
import torch
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import accuracy_score, classification_report


In [31]:
train_dir = os.path.join('dataset', 'part_one_dataset', 'train_data')
eval_dir = os.path.join('dataset', 'part_one_dataset', 'eval_data')
train_path = os.path.join(train_dir, '1_train_data.tar.pth')
eval_path = os.path.join(eval_dir, '1_eval_data.tar.pth')

t = torch.load(train_path, weights_only = False)

### Basic LWP Model using Distance Function

In [32]:
import numpy as np
from sklearn.metrics.pairwise import cosine_distances, manhattan_distances

class LWP:
    """Learning Vector Prototypes with configurable distance function"""
    
    DISTANCE_FUNCTIONS = {
        'euclidean': lambda x, y: np.linalg.norm(x - y),
        'cosine': lambda x, y: cosine_distances(x.reshape(1, -1), y.reshape(1, -1))[0][0],
        'manhattan': lambda x, y: manhattan_distances(x.reshape(1, -1), y.reshape(1, -1))[0][0],
        'minkowski': lambda x, y, p=2: np.power(np.sum(np.power(np.abs(x - y), p)), 1/p)
    }
    
    def __init__(self, distance_metric='euclidean', **distance_params):
        """
            distance_params (dict): Additional parameters for the distance function
        """
        self.prototypes = {}
        self.class_counts = {i: 0 for i in range(10)}
        
        if callable(distance_metric):
            self.distance_fn = distance_metric
        elif distance_metric in self.DISTANCE_FUNCTIONS:
            if distance_metric == 'minkowski':
                p = distance_params.get('p', 2)
                self.distance_fn = lambda x, y: self.DISTANCE_FUNCTIONS[distance_metric](x, y, p)
            else:
                self.distance_fn = self.DISTANCE_FUNCTIONS[distance_metric]
        else:
            raise ValueError(f"Unknown distance metric: {distance_metric}. " 
                           f"Available metrics: {list(self.DISTANCE_FUNCTIONS.keys())}")

    def fit(self, features, labels):
        unique_labels = np.unique(labels)
        for label in unique_labels:
            samples = features[labels == label]
            num_samples = len(samples)
            
            if label not in self.prototypes:  # Original condition was: if label not in self.prototypes
                self.prototypes[label] = samples.mean(axis=0)
                self.class_counts[label] = len(samples)
            else:
                self.class_counts[label] += len(samples)
                self.prototypes[label] = (
                    (self.class_counts[label] - num_samples) / self.class_counts[label] * self.prototypes[label] +
                    num_samples / self.class_counts[label] * samples.mean(axis=0)
                )
    
    def predict(self, features):
        preds = []
        for feature in features:
            distances = {
                label: self.distance_fn(feature, proto)
                for label, proto in self.prototypes.items()
            }
            preds.append(min(distances, key=distances.get))
        return np.array(preds)
    
    def predict_proba(self, features):
        """
        Predict probabilities (normalized distances to prototypes).
        Args:
            features (np.array): Embeddings of the data points.
        Returns:
            np.array: Probabilities for each class.
        """
        prob_list = []
        for feature in features:
            # Calculate distances to all prototypes
            
            # for label, proto in enumerate(self.prototypes.values()) :
            #     print('shape of proto is' , proto.shape)
            distances = {
                label: self.distance_fn(feature, proto) for label, proto in enumerate(self.prototypes.values())
            }
            
            # Convert distances to probabilities
            exp_neg_distances = np.exp(-np.array(list(distances.values())))  # Exponential of negative distances
            probabilities = exp_neg_distances / exp_neg_distances.sum()  # Normalize to sum to 1
            
            prob_list.append(probabilities)
        
        return np.vstack(prob_list)


### Compute cosine similarity and select top samples

In [33]:
def select_top_samples(embeddings, centroids, top_percentage=0.5):
    """
    Select top-k% samples with highest cosine similarity to centroids
    and include top-2 pseudo-labels.
    Args:
        embeddings (np.array): The data embeddings.
        centroids (np.array): Prototypes or centroids for each class.
        top_percentage (float): Percentage of top samples to select (0 < top_percentage <= 1).
    Returns:
        tuple: Top-k% embeddings, top-1 pseudo-labels, and top-2 pseudo-labels.
    """
    if not (0 < top_percentage <= 1):
        raise ValueError("top_percentage must be between 0 and 1.")

    # Compute Euclidean distances
    distances = np.sqrt(((embeddings[:, np.newaxis, :] - centroids) ** 2).sum(axis=2))
    # Convert distances to similarities (higher is better)
    similarities = 1 / (1 + distances)

    # Compute the number of samples to select (50% of total embeddings)
    total_samples = embeddings.shape[0]
    top_k = int(total_samples * top_percentage)
    print(f"Selecting top {top_k} samples out of {total_samples} (percentage: {top_percentage * 100}%)...")

    # Find top-k samples with the highest cosine similarity to centroids
    max_similarities = np.max(similarities, axis=1)
    sorted_indices = np.argsort(max_similarities)[::-1]  # Sort by similarity in descending order
    top_indices = sorted_indices[:top_k]  # Select top-k indices

    # Generate pseudo-labels for top-k samples
    top_1_labels = np.argmax(similarities[top_indices], axis=1)  # Top-1 labels
    second_highest_indices = np.argsort(similarities[top_indices], axis=1)[:, -2]  # Top-2 labels
    top_2_labels = second_highest_indices

    # Return top embeddings and their pseudo-labels
    return embeddings[top_indices], top_1_labels, top_2_labels


### Knowledge Distillation based LWP Model

In [34]:
class KnowledgeDistillationLWP:
    def __init__(self, lwp_model, distance_metric='euclidean', alpha=0.5, beta=0.5):
        """
        Knowledge Distillation-based LWP Model with top-2 pseudo-labeling.
        Args:
            distance_metric (str): Distance metric to use for LWP (e.g., cosine).
            alpha (float): Weighting factor for distillation loss.
            beta (float): Weighting factor for top-2 pseudo-label updates.
        """
        self.old_model = LWP(distance_metric=distance_metric)
        self.lwp_model = lwp_model
        self.alpha = alpha  # Trade-off between current and old knowledge
        self.beta = beta  # Trade-off between top-1 and top-2 pseudo-label prototypes

    def fit(self, features, labels):
        """
        Fit LWP model with knowledge distillation.
        Args:
            features (np.array): Embeddings of the current dataset.
            labels (np.array): Optional true labels (only for D1).
        """
        # Store the current model as the old model before updating
        self.old_model.class_counts = self.lwp_model.class_counts
        for i in range(10):
            self.old_model.prototypes[i] = self.lwp_model.prototypes[i]
        
        # Fit the current LWP model to new data
        self.lwp_model.fit(features, labels)
        # print("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAa")
        # print(self.old_model.prototypes.keys())
        # print(np.unique(labels))
        # print(self.lwp_model.prototypes.keys())
        # print("Old Model Prototypes after fit:", self.old_model.prototypes)
        print("Class Counts after fit in old:", self.old_model.class_counts)

    
    def distillation_loss(self, new_features):
        """
        Compute KL Divergence between old model and current model predictions.
        Args:
            new_features (np.array): Embeddings of the current dataset.
        Returns:
            float: Knowledge distillation loss.
        """
        if self.old_model is None:
            return 0  # No distillation loss for the first dataset
        
        # Predictions from the old model
        old_predictions = self.old_model.predict_proba(new_features)
        # print(old_predictions)
        # Predictions from the current model
        current_predictions = self.lwp_model.predict_proba(new_features)
        
        # Compute KL divergence
        kl_div = np.sum(old_predictions * np.log((old_predictions + 1e-8) / (current_predictions + 1e-8)), axis=1)
        return kl_div.mean()

    def update_model(self, features):
        """
        Update the LWP model using distillation loss.
        Args:
            features (np.array): Embeddings of the current dataset.
        """
        if features.size == 0:
            print("No features available for update. Skipping distillation...")
            return

        # Check if prototypes exist
        if not self.lwp_model.prototypes:
            print("No prototypes available. Skipping update...")
            return
        # print(features.shape)
        # return
        centroids = np.vstack([proto for proto in self.lwp_model.prototypes.values()])
        top_embeddings, top_1_labels, top_2_labels = select_top_samples(features, centroids)
        # print("why do i exist:", np.unique(top_1_labels))
        if top_embeddings.shape[0] != top_1_labels.shape[0]:
            print("Shape mismatch between top_embeddings and top_1_labels. Skipping...")
            return
        # Update prototypes using top-1 and top-2 pseudo-labels
        iterations = 0
        for label in np.unique(top_1_labels):
            iterations += 1
            top_1_samples = top_embeddings[top_1_labels == label]
            top_2_samples = top_embeddings[top_2_labels == label]
            # print("hi")
            if len(top_1_samples) > 0:
                new_proto_top1 = top_1_samples.mean(axis=0)
                if len(top_2_samples) > 0:
                    new_proto_top2 = top_2_samples.mean(axis=0)
                    # print('shapes are')
                    # print(new_proto_top1.shape)
                    # print(new_proto_top2.shape)
                    # Combine top-1 and top-2 updates using the beta factor
                    self.lwp_model.prototypes[label] = (
                        (1 - self.beta) * new_proto_top1 + self.beta * new_proto_top2
                    )
                else:
                    self.lwp_model.prototypes[label] = new_proto_top1
        # print(f"feature shape (X): {features.shape}")

        if self.old_model:
            # Ensure prototype dimensions match features
            self.old_model.prototypes = {
                label: proto.reshape(-1, features.shape[1])
                for label, proto in self.old_model.prototypes.items()
        }
                    
        # Compute distillation loss
        # print(self.old_model.prototypes)
        distillation_loss = self.distillation_loss(features) if self.old_model.class_counts[0] else 0.0
        print(f"Distillation Loss: {distillation_loss:.4f}")

        
        # Adjust prototypes based on distillation loss
        
        for label, proto in  enumerate(self.lwp_model.prototypes.values()):
            # print("i'm in love with the shape of X", label)
            if label in self.old_model.prototypes:
                self.lwp_model.prototypes[label] = (
                    (1 - self.alpha) * proto + self.alpha * self.old_model.prototypes[label]
                )
            else:
                print(f"Label {label} not found in old model prototypes. Skipping...")

    def predict(self, features):
        """
        Predict pseudo-labels for the given features.
        Args:
            features (np.array): Embeddings of the data points.
        Returns:
            np.array: Predicted labels.
        """
        return self.lwp_model.predict(features)
    
    def predict_proba(self, features):
        """
        Predict probabilities (normalized distances to prototypes).
        Args:
            features (np.array): Embeddings of the data points.
        Returns:
            np.array: Predicted probabilities for each class.
        """
        distances = []
        for feature in features:
            dist = {
                label: self.lwp_model.distance_fn(feature, proto)
                for label, proto in self.lwp_model.prototypes
            }
            # Convert distances to probabilities
            prob = np.exp(-np.array(list(dist.values())))
            prob /= prob.sum()
            distances.append(prob)
        return np.vstack(distances)

### Process a dataset using kNN with top-2 pseudo-labels and knowledge distillation.

In [35]:
# Function to evaluate the model
def evaluate_on_eval_embeddings(embed_dir, dataset_idx, model, ground_truth=None):
    embed_path = os.path.join(embed_dir, f'eval_embeds_{dataset_idx}.pt')
    eval_embeddings = torch.load(embed_path, weights_only=False)

    print(f"Evaluating on dataset {dataset_idx}...")
    predicted_labels = model.predict(eval_embeddings)

    if ground_truth is not None:
        accuracy = accuracy_score(ground_truth, predicted_labels)
        report = classification_report(ground_truth, predicted_labels, zero_division=0)
        print(f"Accuracy on eval set {dataset_idx}: {accuracy * 100:.2f}%")
        print("Classification Report:")
        print(report)
        return {"accuracy": accuracy, "report": report}
    else:
        print(f"Predictions on eval set {dataset_idx}: {predicted_labels[:10]}...")
        return {"predicted_labels": predicted_labels}

In [36]:
# Directories for embeddings
part_one_embed_dir = 'part_1_vit_embeds'
part_two_embed_dir = 'part_2_vit_embeds'

# Load D1 embeddings and targets
train_embed_path = os.path.join(part_one_embed_dir, 'train_embeds_1.pt')
eval_embed_path = os.path.join(part_one_embed_dir, 'eval_embeds_1.pt')
train_embeddings = torch.load(train_embed_path, weights_only=False)
eval_embeddings = torch.load(eval_embed_path, weights_only=False)
train_path = os.path.join('dataset', 'part_one_dataset', 'train_data', '1_train_data.tar.pth')
data = torch.load(train_path, weights_only=False)
targets = data['targets']

# Initialize and fit LWP model
lwp_model = LWP(distance_metric='euclidean')
lwp_model.fit(train_embeddings, targets)
# print(train_embeddings.shape, targets.shape)

In [37]:
import copy 

lwp_model = LWP(distance_metric='euclidean')
lwp_model.fit(train_embeddings, targets)

kd_lwp_model = KnowledgeDistillationLWP(copy.deepcopy(lwp_model),alpha=0.5)

# Process D2-D10 with kNN and pseudo-labeling
for i in range(2, 11):
    dataset_idx = i
    k = 5
    embed_dir = part_one_embed_dir
    top_percentage = 0.5
    """
    Process the dataset using kNN and pseudo-labeling.
    
    Args:
        embed_dir (str): Directory containing embeddings.
        dataset_idx (int): Dataset index.
        lwp_model (LWP): LWP model instance.
        kd_lwp_model (KD-LWP): KD-LWP model instance.
        k (int): Number of neighbors for kNN.
        top_percentage (float): Percentage of samples to select for pseudo-labeling (0 < top_percentage <= 1).
    """
    print(f"Processing dataset {dataset_idx} with kNN...")
    # Load embeddings
    embed_path = os.path.join(embed_dir, f'train_embeds_{dataset_idx}.pt')
    embeddings = torch.load(embed_path, weights_only=False)
    # print(f"Loaded embeddings for dataset {dataset_idx}: {embeddings.shape}")
    
    if embeddings.size == 0:
        print(f"Dataset {dataset_idx} contains no embeddings. Skipping...")
        break

    # Compute centroids (prototypes)
    if not lwp_model.prototypes:
        print(f"No prototypes available in LWP model for dataset {dataset_idx}. Skipping...")
        break
    # compute prototypes(centroids)
    centroids = np.vstack([proto for proto in lwp_model.prototypes.values()])
    # Select top 50% samples based on cosine similarity to centroids
    top_embeddings, top_1_labels, top_2_labels = select_top_samples(embeddings, centroids, top_percentage=top_percentage)
    # Ensure k is not greater than the number of top embeddings
    # print(top_embeddings.shape, top_1_labels.shape, top_2_labels.shape)
    k = min(k, len(top_embeddings))
    # print(f"Using k={k} for kNN classification (based on top {len(top_embeddings)} embeddings)...")
    
    # Train kNN on the top samples
    knn = KNeighborsClassifier(n_neighbors=k, metric='euclidean')
    knn.fit(top_embeddings, top_1_labels)

    # Use kNN to assign pseudo-labels for all embeddings
    all_pseudo_labels = knn.predict(embeddings)
    # print(f"Dataset {dataset_idx}: Assigned pseudo-labels using kNN")
    lwp_model.fit(embeddings, all_pseudo_labels)

    # Update prototypes using top-2 pseudo-labels
    kd_lwp_model.update_model(embeddings)
    # print("why so sad: ", kd_lwp_model.lwp_model.prototypes.keys())
    kd_lwp_model.fit(embeddings, all_pseudo_labels)
    # # Update prototypes (class centroids) with kNN pseudo-labeled samples
    # for label in np.unique(all_pseudo_labels):
    #     class_embeddings = embeddings[all_pseudo_labels == label]
    #     if class_embeddings.size > 0:
    #         centroid = class_embeddings.mean(axis=0)
    #         lwp_model.prototypes[label] = centroid

    print(f"Dataset {dataset_idx}: Updated prototypes for classes {list(kd_lwp_model.lwp_model.prototypes.keys())}")

Processing dataset 2 with kNN...
Selecting top 1250 samples out of 2500 (percentage: 50.0%)...
Selecting top 1250 samples out of 2500 (percentage: 50.0%)...
Distillation Loss: 0.0000
Label 0 not found in old model prototypes. Skipping...
Label 1 not found in old model prototypes. Skipping...
Label 2 not found in old model prototypes. Skipping...
Label 3 not found in old model prototypes. Skipping...
Label 4 not found in old model prototypes. Skipping...
Label 5 not found in old model prototypes. Skipping...
Label 6 not found in old model prototypes. Skipping...
Label 7 not found in old model prototypes. Skipping...
Label 8 not found in old model prototypes. Skipping...
Label 9 not found in old model prototypes. Skipping...
Class Counts after fit in old: {0: 475, 1: 503, 2: 461, 3: 493, 4: 538, 5: 421, 6: 518, 7: 561, 8: 506, 9: 524}
Dataset 2: Updated prototypes for classes [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Processing dataset 3 with kNN...
Selecting top 1250 samples out of 2500 (percentag

In [38]:
# Evaluate on D1 (with ground truth)
eval_labels_path = os.path.join('dataset', 'part_one_dataset', 'eval_data', '1_eval_data.tar.pth')
eval_data = torch.load(eval_labels_path)
eval_ground_truth = eval_data['targets']

  eval_data = torch.load(eval_labels_path)


In [39]:
evaluate_on_eval_embeddings(part_one_embed_dir, dataset_idx=1, model=lwp_model, ground_truth=eval_ground_truth)

Evaluating on dataset 1...
Accuracy on eval set 1: 89.16%
Classification Report:
              precision    recall  f1-score   support

           0       0.91      0.88      0.90       252
           1       0.93      0.94      0.94       217
           2       0.96      0.80      0.87       264
           3       0.85      0.82      0.83       242
           4       0.78      0.91      0.84       257
           5       0.82      0.90      0.86       252
           6       0.95      0.93      0.94       269
           7       0.90      0.86      0.88       233
           8       0.94      0.93      0.93       266
           9       0.91      0.95      0.93       248

    accuracy                           0.89      2500
   macro avg       0.90      0.89      0.89      2500
weighted avg       0.90      0.89      0.89      2500



{'accuracy': 0.8916,
 'report': '              precision    recall  f1-score   support\n\n           0       0.91      0.88      0.90       252\n           1       0.93      0.94      0.94       217\n           2       0.96      0.80      0.87       264\n           3       0.85      0.82      0.83       242\n           4       0.78      0.91      0.84       257\n           5       0.82      0.90      0.86       252\n           6       0.95      0.93      0.94       269\n           7       0.90      0.86      0.88       233\n           8       0.94      0.93      0.93       266\n           9       0.91      0.95      0.93       248\n\n    accuracy                           0.89      2500\n   macro avg       0.90      0.89      0.89      2500\nweighted avg       0.90      0.89      0.89      2500\n'}

In [40]:
# Evaluate on D2-D10
for i in range(2, 11):
    t = torch.load(os.path.join('dataset', 'part_one_dataset', 'eval_data', f'{i}_eval_data.tar.pth') ,weights_only=False)
    evaluate_on_eval_embeddings(part_one_embed_dir, dataset_idx=i, model=lwp_model, ground_truth = t['targets']) 

Evaluating on dataset 2...
Accuracy on eval set 2: 89.80%
Classification Report:
              precision    recall  f1-score   support

           0       0.93      0.90      0.92       271
           1       0.95      0.95      0.95       261
           2       0.94      0.81      0.87       267
           3       0.88      0.83      0.85       258
           4       0.79      0.92      0.85       236
           5       0.78      0.89      0.83       212
           6       0.96      0.92      0.94       250
           7       0.88      0.88      0.88       252
           8       0.95      0.95      0.95       258
           9       0.91      0.95      0.93       235

    accuracy                           0.90      2500
   macro avg       0.90      0.90      0.90      2500
weighted avg       0.90      0.90      0.90      2500

Evaluating on dataset 3...
Accuracy on eval set 3: 91.00%
Classification Report:
              precision    recall  f1-score   support

           0       0.91 

In [41]:
# Process D11-D20 (unlabeled datasets) with knowledge distillation
for i in range(11, 21):
    train_path = os.path.join(part_two_embed_dir, f'train_embeds_{i-10}.pt')
    print(f"Processing dataset D{i} from {train_path}...")
    
    # Load train embeddings
    train_embeddings = torch.load(train_path, weights_only=False)
    
    # Perform knowledge distillation-based learning
    kd_lwp_model.update_model(train_embeddings)
    
    #FIXME: Uncomment this line to update the model with the new data, where to get the pseudo labels from?
    # kd_lwp_model.fit(train_embeddings)
    
    # Evaluate or predict on eval set if needed
    eval_path = os.path.join(part_two_embed_dir, f'eval_embeds_{i-10}.pt')
    eval_embeddings = torch.load(eval_path, weights_only=False)
    predictions = kd_lwp_model.predict(eval_embeddings)
    print(f"Pseudo-labels for eval set of D{i}: {predictions[:10]}")


Processing dataset D11 from part_2_vit_embeds/train_embeds_1.pt...
Selecting top 1250 samples out of 2500 (percentage: 50.0%)...
Distillation Loss: 0.0187
Pseudo-labels for eval set of D11: [4 4 8 8 1 7 3 8 2 2]
Processing dataset D12 from part_2_vit_embeds/train_embeds_2.pt...
Selecting top 1250 samples out of 2500 (percentage: 50.0%)...
Distillation Loss: 0.1400
Pseudo-labels for eval set of D12: [4 4 8 8 9 7 5 0 4 4]
Processing dataset D13 from part_2_vit_embeds/train_embeds_3.pt...
Selecting top 1250 samples out of 2500 (percentage: 50.0%)...
Distillation Loss: 0.0333
Pseudo-labels for eval set of D13: [4 4 8 8 1 7 3 2 2 6]
Processing dataset D14 from part_2_vit_embeds/train_embeds_4.pt...
Selecting top 1250 samples out of 2500 (percentage: 50.0%)...
Distillation Loss: 0.0228
Pseudo-labels for eval set of D14: [4 4 8 8 1 7 3 0 6 6]
Processing dataset D15 from part_2_vit_embeds/train_embeds_5.pt...
Selecting top 1250 samples out of 2500 (percentage: 50.0%)...
Distillation Loss: 0.03

In [42]:
# Evaluate on D11-D20 (unlabeled datasets)
for i in range(11, 21):
    t = torch.load(os.path.join('dataset', 'part_two_dataset', 'eval_data', f'{i-10}_eval_data.tar.pth') ,weights_only=False)
    evaluate_on_eval_embeddings(part_two_embed_dir, dataset_idx=i-10, model=lwp_model, ground_truth=t['targets'])

Evaluating on dataset 1...
Accuracy on eval set 1: 81.76%
Classification Report:
              precision    recall  f1-score   support

           0       0.85      0.84      0.84       238
           1       0.91      0.89      0.90       262
           2       0.87      0.78      0.82       256
           3       0.64      0.73      0.68       255
           4       0.74      0.83      0.78       268
           5       0.69      0.80      0.74       246
           6       0.84      0.83      0.83       225
           7       0.90      0.75      0.82       259
           8       0.93      0.85      0.89       236
           9       0.91      0.88      0.90       255

    accuracy                           0.82      2500
   macro avg       0.83      0.82      0.82      2500
weighted avg       0.83      0.82      0.82      2500

Evaluating on dataset 2...
Accuracy on eval set 2: 66.36%
Classification Report:
              precision    recall  f1-score   support

           0       0.49 