In [7]:
# Importing the necessary libraries and modules

import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pickle
import os
import numpy as np
from collections import defaultdict

In [None]:
directory_path = "results (1)" # Directory containing extracted features

# Function to load the extracted features stored in .pkl files
def load_all_pickles(directory):
    pickle_data = {}
    for filename in os.listdir(directory):
        if filename.endswith('.pkl'):
            file_path = os.path.join(directory, filename)
            with open(file_path, 'rb') as f:
                pickle_data[filename] = pickle.load(f)
    return pickle_data

# Loading all .pkl files into a dictionary
data_dict = load_all_pickles(directory_path)

# Extracting training and validation data('held out datasets')
train_task1_data = data_dict['train_task1.pkl']
eval_task1_data = data_dict['eval_task1.pkl']

# Converting the lists to numpy arrays
train_task1_arr = np.array(train_task1_data)
eval_task1_arr = np.array(eval_task1_data)
train_task1_arr=np.reshape(train_task1_arr,(10,2500,1024))
eval_task1_arr=np.reshape(eval_task1_arr,(10,2500,1024))

# Loading the labels for D_1
data = torch.load(r"dataset\part_one_dataset\train_data\1_train_data.tar.pth")
labels_list = data['targets']

In [9]:
# Initialize a dictionary to group features by their class labels
class_features = defaultdict(list)

# Group feature embeddings by their labels from the training data
for feature, label in zip(train_task1_arr[0], labels_list):
    # Append each feature to the corresponding label's list
    class_features[label.item()].append(feature)

# Calculate the prototype (mean feature vector) for each class
prototypes = {
    label: torch.mean(
        torch.stack([
            # Convert features to tensors if they are NumPy arrays
            torch.tensor(feature) if isinstance(feature, np.ndarray) else feature 
            for feature in features
        ]),
        dim=0  # Take the mean along the feature axis
    )
    for label, features in class_features.items()  # Iterate over each class and its features
}

# Stack the prototypes into a tensor for efficient computations
prototype_tensors = torch.stack([prototypes[label] for label in sorted(prototypes.keys())])

# Create a tensor of class labels corresponding to the prototypes
prototype_labels = torch.tensor(sorted(prototypes.keys()))

# Display the calculated prototypes and their shape
print("Prototypes calculated for classes:", prototype_labels)
print("Prototype shape:", prototype_tensors.shape)  # Shape: [num_classes, feature_size]

Prototypes calculated for classes: tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
Prototype shape: torch.Size([10, 1024])


In [10]:
def lwp_classifier(features, prototypes, prototype_labels):
    """
    Classifies input features based on the nearest prototype using the Learning with Prototypes (LWP) approach.

    Args:
    - features (torch.Tensor or np.ndarray): Input feature vectors to classify. Shape: [num_samples, feature_dim].
    - prototypes (torch.Tensor or np.ndarray): Prototype vectors for each class. Shape: [num_classes, feature_dim].
    - prototype_labels (torch.Tensor): Class labels corresponding to the prototypes. Shape: [num_classes].

    Returns:
    - predictions (torch.Tensor): Predicted class labels for each input feature. Shape: [num_samples].
    """
    # Ensure inputs are PyTorch tensors
    if isinstance(features, np.ndarray):
        features = torch.tensor(features, dtype=torch.float32)
    if isinstance(prototypes, np.ndarray):
        prototypes = torch.tensor(prototypes, dtype=torch.float32)

    # Compute pairwise distances between features and prototypes
    # distances[i, j] = distance between i-th sample and j-th prototype
    distances = torch.cdist(features, prototypes)

    # Find the index of the closest prototype for each feature
    # predictions[i] = index of the closest prototype for the i-th feature
    predictions = torch.argmin(distances, dim=1)

    # Map prototype indices to their corresponding class labels
    return prototype_labels[predictions]

# Classify training data using the LWP classifier
# `train_task1_arr[0]` contains the feature embeddings for the first task
pred = lwp_classifier(train_task1_arr[0], prototype_tensors, prototype_labels)

# Ensure the predictions are stored as a PyTorch tensor
if not isinstance(pred, torch.Tensor):
    pred = torch.tensor(pred, dtype=torch.int64)

# Convert ground-truth labels to a PyTorch tensor
labels_list = torch.tensor(labels_list, dtype=torch.int64)

# Compute classification accuracy
# (pred == labels_list): Boolean tensor indicating correct predictions
# .float(): Convert to float for mean computation
# .mean(): Fraction of correct predictions
accuracy = (pred == labels_list).float().mean().item() * 100

# Print the classification accuracy on the training data
print(f"Classification accuracy on training data: {accuracy:.2f}%")

Classification accuracy on training data: 98.72%


In [11]:
def update_prototypes(prototypes, prototype_labels, new_features, pseudo_labels, alpha):
    """
    Updates the class prototypes by incorporating new features using a weighted average.

    Args:
    - prototypes (torch.Tensor): Current prototype vectors for each class. Shape: [num_classes, feature_dim].
    - prototype_labels (torch.Tensor): Class labels corresponding to the prototypes. Shape: [num_classes].
    - new_features (torch.Tensor): Feature embeddings of the new data. Shape: [num_samples, feature_dim].
    - pseudo_labels (torch.Tensor): Pseudo labels for the new data points. Shape: [num_samples].
    - alpha (float): Weighting factor to control the influence of the old prototypes.

    Returns:
    - updated_prototypes (torch.Tensor): Updated prototype vectors. Shape: [num_classes, feature_dim].
    """
    # Clone the prototypes to avoid modifying the original tensor
    updated_prototypes = prototypes.clone()

    # Loop over unique class labels
    unique_labels = torch.unique(prototype_labels)
    
    alpha=alpha*2500
    for label in unique_labels:
        # Create a boolean mask for the current label
        mask = (pseudo_labels == label)

        # Extract the features corresponding to the current class label
        new_class_features = new_features[mask]

        # Ensure `new_class_features` is a PyTorch tensor
        if not isinstance(new_class_features, torch.Tensor):
            new_class_features = torch.tensor(new_class_features, dtype=torch.float32)

        # Handle edge case: If no new features exist for the label, skip updating
        if new_class_features.size(0) == 0:
            continue

        # Compute the updated prototype using a weighted average:
        updated_prototypes[label] = (
            alpha * prototypes[label] + new_class_features.sum(dim=0)
        ) / (alpha + mask.sum())

    return updated_prototypes

In [12]:
def evaluate_prototypes(prototypes, prototype_labels, data, labels):
    """
    Evaluates the classification accuracy of prototypes on a given dataset.

    Args:
    - prototypes (torch.Tensor): Prototype vectors for each class. Shape: [num_classes, feature_dim].
    - prototype_labels (torch.Tensor): Class labels corresponding to the prototypes. Shape: [num_classes].
    - data (torch.Tensor or np.ndarray): Feature embeddings to classify. Shape: [num_samples, feature_dim].
    - labels (torch.Tensor or np.ndarray): True labels for the data. Shape: [num_samples].

    Returns:
    - accuracy (float): Classification accuracy as a percentage.
    """
    # Use the LWP classifier to predict labels for the input data
    predictions = lwp_classifier(data, prototypes, prototype_labels)

    # Ensure both predictions and labels are PyTorch tensors
    if isinstance(predictions, np.ndarray):
        predictions = torch.tensor(predictions, dtype=torch.int64)
    if isinstance(labels, np.ndarray):
        labels = torch.tensor(labels, dtype=torch.int64)

    # Calculate the number of correct predictions
    correct = (predictions == labels).sum().item()  # Count where predictions match the labels

    # Calculate the total number of samples
    total = labels.size(0)

    # Compute accuracy as a percentage
    accuracy = 100 * correct / total

    return accuracy

In [13]:
# Iteratively update prototypes and evaluate on held-out datasets
for i in range(1, 11):  # Process datasets D_2 to D_10
    print(f"Processing dataset D_{i}...")
    
    # Load training data for the current dataset
    train_data = train_task1_arr[i - 1]  # Features for the current dataset
    
    # Generate pseudo-labels for the training data using the current prototypes
    pseudo_labels = lwp_classifier(train_data, prototype_tensors, prototype_labels)
    
    # Update prototypes using the new training data and pseudo-labels
    prototype_tensors = update_prototypes(prototype_tensors, prototype_labels, train_data, pseudo_labels, alpha=0.2)
    print(f"Prototypes updated!\nEvaluating updated model f_{i} on all previous held-out datasets...")
    
    # Evaluate the updated prototypes on all previous datasets
    for j in range(1, i + 1):  # Evaluate on D_1 to D_i
        # Load evaluation data for dataset D_j
        val_data = torch.load(f"dataset/part_one_dataset/eval_data/{j}_eval_data.tar.pth")
        val_labels = val_data['targets']  # Ground-truth labels for the validation data
        val_feat = eval_task1_arr[j - 1]  # Feature embeddings for dataset D_j
        
        # Compute accuracy of the updated model f_i on dataset D_j
        accuracy = evaluate_prototypes(prototype_tensors, prototype_labels, val_feat, val_labels)
        print(f"Accuracy of f_{i} on D_{j}: {accuracy:.2f}%")

# Save the final prototype tensors after processing all datasets
with open('f_10_ptensors.pkl', 'wb') as f:
    pickle.dump(prototype_tensors, f)

Processing dataset D_1...
Prototypes updated!
Evaluating updated model f_1 on all previous held-out datasets...
Accuracy of f_1 on D_1: 98.32%
Processing dataset D_2...
Prototypes updated!
Evaluating updated model f_2 on all previous held-out datasets...
Accuracy of f_2 on D_1: 98.36%
Accuracy of f_2 on D_2: 97.84%
Processing dataset D_3...
Prototypes updated!
Evaluating updated model f_3 on all previous held-out datasets...
Accuracy of f_3 on D_1: 98.16%
Accuracy of f_3 on D_2: 97.76%
Accuracy of f_3 on D_3: 98.16%
Processing dataset D_4...
Prototypes updated!
Evaluating updated model f_4 on all previous held-out datasets...
Accuracy of f_4 on D_1: 98.16%
Accuracy of f_4 on D_2: 97.76%
Accuracy of f_4 on D_3: 98.04%
Accuracy of f_4 on D_4: 97.92%
Processing dataset D_5...
Prototypes updated!
Evaluating updated model f_5 on all previous held-out datasets...
Accuracy of f_5 on D_1: 98.20%
Accuracy of f_5 on D_2: 97.68%
Accuracy of f_5 on D_3: 97.96%
Accuracy of f_5 on D_4: 98.00%
Accura