In [256]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision import models, transforms
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import normalize
from sklearn.metrics.pairwise import cosine_distances
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import accuracy_score, classification_report


In [257]:
train_dir = os.path.join('dataset', 'part_one_dataset', 'train_data')
eval_dir = os.path.join('dataset', 'part_one_dataset', 'eval_data')
train_path = os.path.join(train_dir, '1_train_data.tar.pth')
eval_path = os.path.join(eval_dir, '1_eval_data.tar.pth')

t = torch.load(train_path, weights_only = False)

### Basic LWP Model using Distance Function

In [258]:
import numpy as np
from sklearn.metrics.pairwise import cosine_distances, manhattan_distances

class LWP:
    """Learning Vector Prototypes with configurable distance function"""
    
    DISTANCE_FUNCTIONS = {
        'euclidean': lambda x, y: np.linalg.norm(x - y),
        'cosine': lambda x, y: cosine_distances(x.reshape(1, -1), y.reshape(1, -1))[0][0],
        'manhattan': lambda x, y: manhattan_distances(x.reshape(1, -1), y.reshape(1, -1))[0][0],
        'minkowski': lambda x, y, p=2: np.power(np.sum(np.power(np.abs(x - y), p)), 1/p)
    }
    
    def __init__(self, distance_metric='euclidean', **distance_params):
        """
            distance_params (dict): Additional parameters for the distance function
        """
        self.prototypes = {}
        self.class_counts = {i: 0 for i in range(10)}
        
        if callable(distance_metric):
            self.distance_fn = distance_metric
        elif distance_metric in self.DISTANCE_FUNCTIONS:
            if distance_metric == 'minkowski':
                p = distance_params.get('p', 2)
                self.distance_fn = lambda x, y: self.DISTANCE_FUNCTIONS[distance_metric](x, y, p)
            else:
                self.distance_fn = self.DISTANCE_FUNCTIONS[distance_metric]
        else:
            raise ValueError(f"Unknown distance metric: {distance_metric}. " 
                           f"Available metrics: {list(self.DISTANCE_FUNCTIONS.keys())}")

    def fit(self, features, labels):
        unique_labels = np.unique(labels)
        for label in unique_labels:
            samples = features[labels == label]
            num_samples = len(samples)
            
            if label not in self.prototypes:  # Original condition was: if label not in self.prototypes
                self.prototypes[label] = samples.mean(axis=0)
                self.class_counts[label] = len(samples)
            else:
                self.class_counts[label] += len(samples)
                self.prototypes[label] = (
                    (self.class_counts[label] - num_samples) / self.class_counts[label] * self.prototypes[label] +
                    num_samples / self.class_counts[label] * samples.mean(axis=0)
                )
    
            
    

    def predict(self, features):
        preds = []
        for feature in features:
            distances = {
                label: self.distance_fn(feature, proto)
                for label, proto in self.prototypes.items()
            }
            preds.append(min(distances, key=distances.get))
        return np.array(preds)
    
    def predict_proba(self, features):
        """
        Predict probabilities (normalized distances to prototypes).
        Args:
            features (np.array): Embeddings of the data points.
        Returns:
            np.array: Probabilities for each class.
        """
        prob_list = []
        for feature in features:
            distances = {
                label: self.distance_fn(feature, proto) for label, proto in self.prototypes.items()
            }
            # Convert distances to probabilities
            prob = np.exp(-np.array(list(distances.values())))  # Exponential of negative distances
            prob /= prob.sum()  # Normalize to sum to 1
            prob_list.append(prob)
        return np.vstack(prob_list)

In [259]:
data, targets = t['data'], t['targets'] 
data=data.reshape(data.shape[0], -1)
print(data.shape)

(2500, 3072)


In [260]:
data1, targets1= t['data'], t['targets']
data1=data1.reshape(data1.shape[0], -1)
data1=normalize(data1)
dataloader=DataLoader(data1, batch_size=32, shuffle=False)
lwp_model = LWP(distance_metric='cosine')  # LWP model with cosine distance
lwp_model.fit(data1, targets1) 

In [261]:

print("Prototypes after fit:", lwp_model.prototypes)


Prototypes after fit: {0: array([0.01729068, 0.01912501, 0.02102479, ..., 0.01516719, 0.0159201 ,
       0.01601679]), 1: array([0.01837207, 0.01810205, 0.01747006, ..., 0.01770059, 0.01728502,
       0.01624851]), 2: array([0.0164336 , 0.01742426, 0.01454287, ..., 0.01747445, 0.01748331,
       0.01434357]), 3: array([0.01795165, 0.0166975 , 0.01529295, ..., 0.01780526, 0.01648547,
       0.01536131]), 4: array([0.01489708, 0.01576391, 0.01352558, ..., 0.0190007 , 0.01918124,
       0.01478338]), 5: array([0.01479913, 0.01467359, 0.01309324, ..., 0.01706052, 0.01606419,
       0.01416187]), 6: array([0.0163735 , 0.01605917, 0.01306188, ..., 0.01921913, 0.01844561,
       0.01541324]), 7: array([0.01802458, 0.01884868, 0.01836053, ..., 0.01945586, 0.01867545,
       0.01490231]), 8: array([0.01797723, 0.02032782, 0.02269468, ..., 0.01198278, 0.01380034,
       0.01474392]), 9: array([0.02226604, 0.02317505, 0.02398837, ..., 0.01793398, 0.01759059,
       0.0167393 ])}


### Compute cosine similarity and select top samples

In [262]:
def select_top_samples(embeddings, centroids, top_percentage=0.5):
    """
    Select top-k% samples with highest cosine similarity to centroids
    and include top-2 pseudo-labels.
    Args:
        embeddings (np.array): The data embeddings.
        centroids (np.array): Prototypes or centroids for each class.
        top_percentage (float): Percentage of top samples to select (0 < top_percentage <= 1).
    Returns:
        tuple: Top-k% embeddings, top-1 pseudo-labels, and top-2 pseudo-labels.
    """
    if not (0 < top_percentage <= 1):
        raise ValueError("top_percentage must be between 0 and 1.")

    # Compute similarity scores
    similarities = cosine_similarity(embeddings, centroids)

    # Compute the number of samples to select (50% of total embeddings)
    total_samples = embeddings.shape[0]
    top_k = int(total_samples * top_percentage)
    print(f"Selecting top {top_k} samples out of {total_samples} (percentage: {top_percentage * 100}%)...")

    # Find top-k samples with the highest cosine similarity to centroids
    max_similarities = np.max(similarities, axis=1)
    sorted_indices = np.argsort(max_similarities)[::-1]  # Sort by similarity in descending order
    top_indices = sorted_indices[:top_k]  # Select top-k indices

    # Generate pseudo-labels for top-k samples
    top_1_labels = np.argmax(similarities[top_indices], axis=1)  # Top-1 labels
    second_highest_indices = np.argsort(similarities[top_indices], axis=1)[:, -2]  # Top-2 labels
    top_2_labels = second_highest_indices

    # Return top embeddings and their pseudo-labels
    return embeddings[top_indices], top_1_labels, top_2_labels


### Knowledge Distillation based LWP Model

In [263]:
class KnowledgeDistillationLWP:
    def __init__(self, distance_metric='cosine', alpha=0.5, beta=0.5):
        """
        Knowledge Distillation-based LWP Model with top-2 pseudo-labeling.
        Args:
            distance_metric (str): Distance metric to use for LWP (e.g., cosine).
            alpha (float): Weighting factor for distillation loss.
            beta (float): Weighting factor for top-2 pseudo-label updates.
        """
        self.lwp_model = LWP(distance_metric=distance_metric)
        self.old_model = None  # Placeholder for storing old model
        self.alpha = alpha  # Trade-off between current and old knowledge
        self.beta = beta  # Trade-off between top-1 and top-2 pseudo-label prototypes

    def fit(self, features, labels=None):
        """
        Fit LWP model with knowledge distillation.
        Args:
            features (np.array): Embeddings of the current dataset.
            labels (np.array): Optional true labels (only for D1).
        """
        # Store the current model as the old model before updating
        if self.old_model is None:
            self.old_model = LWP(distance_metric='cosine')
            self.old_model.prototypes = self.lwp_model.prototypes.copy()
            self.old_model.class_counts = self.lwp_model.class_counts.copy()
        
        # Fit the current LWP model to new data
        self.lwp_model.fit(features, labels)
    
    def distillation_loss(self, new_features):
        """
        Compute KL Divergence between old model and current model predictions.
        Args:
            new_features (np.array): Embeddings of the current dataset.
        Returns:
            float: Knowledge distillation loss.
        """
        if self.old_model is None:
            return 0  # No distillation loss for the first dataset
        
        # Predictions from the old model
        old_predictions = self.old_model.predict_proba(new_features)
        # Predictions from the current model
        current_predictions = self.lwp_model.predict_proba(new_features)
        
        # Compute KL divergence
        kl_div = np.sum(old_predictions * np.log((old_predictions + 1e-8) / (current_predictions + 1e-8)), axis=1)
        return kl_div.mean()

    def update_model(self, features):
        """
        Update the LWP model using distillation loss.
        Args:
            features (np.array): Embeddings of the current dataset.
        """
        if features.size == 0:
            print("No features available for update. Skipping distillation...")
            return

        # Check if prototypes exist
        if not self.lwp_model.prototypes:
            print("No prototypes available. Skipping update...")
            return

        try:
            # Retrieve prototypes and compute pseudo-labels
            centroids = np.vstack([proto for _, proto in sorted(self.lwp_model.prototypes.items())])
        except ValueError as e:
            print(f"Error constructing centroids: {e}. Skipping update...")
            return
        # Retrieve prototypes and compute pseudo-labels
        centroids = np.vstack([proto for _, proto in sorted(self.lwp_model.prototypes.items())])
        top_embeddings, top_1_labels, top_2_labels = select_top_samples(features, centroids)
        if top_embeddings.shape[0] != top_1_labels.shape[0]:
            print("Shape mismatch between top_embeddings and top_1_labels. Skipping...")
            return
        # Update prototypes using top-1 and top-2 pseudo-labels
        for label in np.unique(top_1_labels):
            top_1_samples = top_embeddings[top_1_labels == label]
            top_2_samples = top_embeddings[top_2_labels == label]
            
            if len(top_1_samples) > 0:
                new_proto_top1 = top_1_samples.mean(axis=0)
                if len(top_2_samples) > 0:
                    new_proto_top2 = top_2_samples.mean(axis=0)
                    # Combine top-1 and top-2 updates using the beta factor
                    self.lwp_model.prototypes[label] = (
                        (1 - self.beta) * new_proto_top1 + self.beta * new_proto_top2
                    )
                else:
                    self.lwp_model.prototypes[label] = new_proto_top1
        print(f"feature shape (X): {features.shape}")
        if self.old_model:
            for label, proto in self.old_model.prototypes.items():
                print(f"Prototype for class {label}: Shape = {proto.shape}")

        
        # Align embedding dimensions
        features = features.reshape(features.shape[0], -1)
        
        if self.old_model:
            # Ensure prototype dimensions match features
            self.old_model.prototypes = {
                label: proto.reshape(-1, features.shape[1])
                for label, proto in self.old_model.prototypes.items()
        }
                    
        # Compute distillation loss
        distillation_loss = self.distillation_loss(features) if self.old_model else 0.0
        print(f"Distillation Loss: {distillation_loss:.4f}")

        
        # Adjust prototypes based on distillation loss
        for label, proto in self.lwp_model.prototypes.items():
            if label in self.old_model.prototypes:
                self.lwp_model.prototypes[label] = (
                    (1 - self.alpha) * proto + self.alpha * self.old_model.prototypes[label]
                )
            else:
                print(f"Label {label} not found in old model prototypes. Skipping...")

    def predict(self, features):
        """
        Predict pseudo-labels for the given features.
        Args:
            features (np.array): Embeddings of the data points.
        Returns:
            np.array: Predicted labels.
        """
        return self.lwp_model.predict(features)
    
    def predict_proba(self, features):
        """
        Predict probabilities (normalized distances to prototypes).
        Args:
            features (np.array): Embeddings of the data points.
        Returns:
            np.array: Predicted probabilities for each class.
        """
        distances = []
        for feature in features:
            dist = {
                label: self.lwp_model.distance_fn(feature, proto)
                for label, proto in self.lwp_model.prototypes.items()
            }
            # Convert distances to probabilities
            prob = np.exp(-np.array(list(dist.values())))
            prob /= prob.sum()
            distances.append(prob)
        return np.vstack(distances)

### Process a dataset using kNN with top-2 pseudo-labels and knowledge distillation.

In [264]:
# Initialize the Knowledge Distillation LWP model
kd_lwp_model = KnowledgeDistillationLWP(alpha=0.5)

In [265]:
def process_dataset_with_knn(embed_dir, dataset_idx, lwp_model, kd_lwp_model, k=5, top_percentage=0.5):
    """
    Process the dataset using kNN and pseudo-labeling.
    
    Args:
        embed_dir (str): Directory containing embeddings.
        dataset_idx (int): Dataset index.
        lwp_model (LWP): LWP model instance.
        kd_lwp_model (KD-LWP): KD-LWP model instance.
        k (int): Number of neighbors for kNN.
        top_percentage (float): Percentage of samples to select for pseudo-labeling (0 < top_percentage <= 1).
    """
    print(f"Processing dataset {dataset_idx} with kNN...")
    # Load embeddings
    embed_path = os.path.join(embed_dir, f'train_embeds_{dataset_idx}.pt')
    embeddings = torch.load(embed_path)
    if embeddings.size == 0:
        print(f"Dataset {dataset_idx} contains no embeddings. Skipping...")
        return

    # Compute centroids (prototypes)
    if not lwp_model.prototypes:
        print(f"No prototypes available in LWP model for dataset {dataset_idx}. Skipping...")
        return
    # compute prototypes(centroids)
    centroids = np.vstack([proto for _, proto in sorted(lwp_model.prototypes.items())])
    # Select top 50% samples based on cosine similarity to centroids
    top_embeddings, top_1_labels, top_2_labels = select_top_samples(embeddings, centroids, top_percentage=top_percentage)
    # Ensure k is not greater than the number of top embeddings
    k = min(k, len(top_embeddings))
    print(f"Using k={k} for kNN classification (based on top {len(top_embeddings)} embeddings)...")
    
    
    # Train kNN on the top samples
    knn = KNeighborsClassifier(n_neighbors=k, metric='euclidean')
    knn.fit(top_embeddings, top_1_labels)

    # Use kNN to assign pseudo-labels for all embeddings
    all_pseudo_labels = knn.predict(embeddings)
    print(f"Dataset {dataset_idx}: Assigned pseudo-labels using kNN")
    lwp_model.fit(embeddings, all_pseudo_labels)
    # Update prototypes using top-2 pseudo-labels
    kd_lwp_model.update_model(embeddings)
    kd_lwp_model.fit(embeddings)
    # # Update prototypes (class centroids) with kNN pseudo-labeled samples
    # for label in np.unique(all_pseudo_labels):
    #     class_embeddings = embeddings[all_pseudo_labels == label]
    #     if class_embeddings.size > 0:
    #         centroid = class_embeddings.mean(axis=0)
    #         lwp_model.prototypes[label] = centroid

    print(f"Dataset {dataset_idx}: Updated prototypes for classes {list(kd_lwp_model.lwp_model.prototypes.keys())}")

In [266]:
# Function to evaluate the model
def evaluate_on_eval_embeddings(embed_dir, dataset_idx, model, ground_truth=None):
    embed_path = os.path.join(embed_dir, f'eval_embeds_{dataset_idx}.pt')
    eval_embeddings = torch.load(embed_path)

    print(f"Evaluating on dataset {dataset_idx}...")
    predicted_labels = model.predict(eval_embeddings)

    if ground_truth is not None:
        accuracy = accuracy_score(ground_truth, predicted_labels)
        report = classification_report(ground_truth, predicted_labels, zero_division=0)
        print(f"Accuracy on eval set {dataset_idx}: {accuracy * 100:.2f}%")
        print("Classification Report:")
        print(report)
        return {"accuracy": accuracy, "report": report}
    else:
        print(f"Predictions on eval set {dataset_idx}: {predicted_labels[:10]}...")
        return {"predicted_labels": predicted_labels}

In [267]:
# Directories for embeddings
part_one_embed_dir = 'part_1_vit_embeds'
part_two_embed_dir = 'part_2_vit_embeds'

# Load D1 embeddings and targets
train_embed_path = os.path.join(part_one_embed_dir, 'train_embeds_1.pt')
eval_embed_path = os.path.join(part_one_embed_dir, 'eval_embeds_1.pt')
train_embeddings = torch.load(train_embed_path)
eval_embeddings = torch.load(eval_embed_path)
train_path = os.path.join('dataset', 'part_one_dataset', 'train_data', '1_train_data.tar.pth')
data = torch.load(train_path)
targets = data['targets']

# Initialize and fit LWP model
lwp_model = LWP(distance_metric='cosine')
lwp_model.fit(train_embeddings, targets)

  train_embeddings = torch.load(train_embed_path)
  eval_embeddings = torch.load(eval_embed_path)
  data = torch.load(train_path)


In [268]:
kd_lwp_model = KnowledgeDistillationLWP(alpha=0.5)

# Process D2-D10 with kNN and pseudo-labeling
for i in range(2, 11):
    process_dataset_with_knn(part_one_embed_dir, i, lwp_model, kd_lwp_model, k=5, top_percentage=0.5)

  embeddings = torch.load(embed_path)


Processing dataset 2 with kNN...
Selecting top 1250 samples out of 2500 (percentage: 50.0%)...
Using k=5 for kNN classification (based on top 1250 embeddings)...
Dataset 2: Assigned pseudo-labels using kNN
No prototypes available. Skipping update...
Dataset 2: Updated prototypes for classes [None]
Processing dataset 3 with kNN...
Selecting top 1250 samples out of 2500 (percentage: 50.0%)...
Using k=5 for kNN classification (based on top 1250 embeddings)...


  embeddings = torch.load(embed_path)


Dataset 3: Assigned pseudo-labels using kNN
Selecting top 1250 samples out of 2500 (percentage: 50.0%)...
feature shape (X): (2500, 768)


ValueError: Incompatible dimension for X and Y matrices: X.shape[1] == 768 while Y.shape[1] == 1920000

In [None]:
# Evaluate on D1 (with ground truth)
eval_labels_path = os.path.join('dataset', 'part_one_dataset', 'eval_data', '1_eval_data.tar.pth')
eval_data = torch.load(eval_labels_path)
eval_ground_truth = eval_data['targets']

In [None]:
evaluate_on_eval_embeddings(part_one_embed_dir, dataset_idx=1, model=lwp_model, ground_truth=eval_ground_truth)

In [None]:
# Evaluate on D2-D10
for i in range(2, 11):
    evaluate_on_eval_embeddings(part_one_embed_dir, dataset_idx=i, model=lwp_model)

In [None]:
# Process D11-D20 (unlabeled datasets) with knowledge distillation
for i in range(11, 21):
    train_path = os.path.join(part_two_embed_dir, f'train_embeds_{i}.pt')
    print(f"Processing dataset D{i} from {train_path}...")
    
    # Load train embeddings
    train_embeddings = torch.load(train_path).numpy()
    
    # Perform knowledge distillation-based learning
    kd_lwp_model.update_model(train_embeddings)
    kd_lwp_model.fit(train_embeddings)
    
    # Evaluate or predict on eval set if needed
    eval_path = os.path.join(part_two_embed_dir, f'eval_embeds_{i}.pt')
    eval_embeddings = torch.load(eval_path).numpy()
    predictions = kd_lwp_model.predict(eval_embeddings)
    print(f"Pseudo-labels for eval set of D{i}: {predictions[:10]}")



In [None]:
# Evaluate on D11-D20 (unlabeled datasets)
for i in range(11, 21):
    evaluate_on_eval_embeddings(part_two_embed_dir, dataset_idx=i, model=lwp_model)