In [229]:
# Import necessary libraries
from sklearn.cluster import KMeans
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import os
import pickle

In [None]:
directory_path = "results (1)" # Directory containing extracted features

# Function to load all .pkl files from a directory
def load_all_pickles(directory):
    """
    Loads all .pkl files from the specified directory and stores them in a dictionary.

    Args:
    - directory (str): The path to the directory containing .pkl files.

    Returns:
    - dict: A dictionary with filenames as keys and loaded data as values.
    """
    pickle_data = {}  # Initialize dictionary to store data

    # Iterate through files in the specified directory
    for filename in os.listdir(directory):
        if filename.endswith('.pkl'):  # Only consider .pkl files
            file_path = os.path.join(directory, filename)  # Get the full path of the file
            with open(file_path, 'rb') as f:  # Open the file in read-binary mode
                pickle_data[filename] = pickle.load(f)  # Load and store the data
    
    return pickle_data  # Return the dictionary of loaded data

# Load all .pkl files from the directory
data_dict = load_all_pickles(directory_path)


In [223]:
# Extract data from the loaded pickle dictionary and convert them to numpy arrays

# Load 'eval_task2.pkl' and 'train_task2.pkl' from the dictionary
eval_task2_data = data_dict['eval_task2.pkl']
train_task2_data = data_dict['train_task2.pkl']

# Convert the 'train_task2' and 'eval_task2' data to numpy arrays
train_task2_arr = np.array(train_task2_data)
eval_task2_arr = np.array(eval_task2_data)

# Reshape the data arrays to shape (10, 2500, 768) to match the desired structure
train_task2_arr = np.reshape(train_task2_arr, (10, 2500, 1024))
eval_task2_arr = np.reshape(eval_task2_arr, (10, 2500, 1024))

# Load 'eval_task1.pkl' from the dictionary
eval_task1_data = data_dict['eval_task1.pkl']

# Convert 'eval_task1' data to numpy array and reshape it
eval_task1_arr = np.array(eval_task1_data)
eval_task1_arr = np.reshape(eval_task1_arr, (10, 2500, 1024))

In [225]:
# Define the path to the file containing the prototype tensors
path = r"f_10_ptensors.pkl"

# Open the file in read-binary mode and load the prototype tensors using pickle
with open(path, 'rb') as f:
    prototype_tensors = pickle.load(f)

# Create a tensor of prototype labels (assuming 10 classes, hence labels from 0 to 9)
prototype_labels = torch.arange(0, 10)

In [231]:
import torch
import numpy as np

def lwp_classifier(features, prototypes, prototype_labels):
    """
    Classifies input feature vectors based on the nearest prototype using Learning with Prototypes (LWP) method.

    Args:
    - features (torch.Tensor or np.ndarray): The input feature vectors to classify. Shape: [num_samples, feature_dim].
    - prototypes (torch.Tensor or np.ndarray): The prototype vectors for each class. Shape: [num_classes, feature_dim].
    - prototype_labels (torch.Tensor): The class labels corresponding to the prototypes. Shape: [num_classes].

    Returns:
    - torch.Tensor: The predicted labels for the input features. Shape: [num_samples].
    """
    # Ensure inputs are PyTorch tensors, if they are NumPy arrays, convert them
    if isinstance(features, np.ndarray):
        features = torch.tensor(features, dtype=torch.float32)
    if isinstance(prototypes, np.ndarray):
        prototypes = torch.tensor(prototypes, dtype=torch.float32)

    # Compute pairwise distances between features and prototypes using Euclidean distance
    distances = torch.cdist(features, prototypes)

    # Find the index of the prototype that is closest to each feature
    predictions = torch.argmin(distances, dim=1)

    # Return the corresponding prototype labels for the predicted class
    return prototype_labels[predictions]


def evaluate_prototypes(prototypes, prototype_labels, data, labels):
    """
    Evaluates the accuracy of the prototypes by classifying the input data and comparing predictions to true labels.

    Args:
    - prototypes (torch.Tensor): The prototype vectors for each class. Shape: [num_classes, feature_dim].
    - prototype_labels (torch.Tensor): The class labels corresponding to the prototypes. Shape: [num_classes].
    - data (torch.Tensor or np.ndarray): The feature vectors of the data to classify. Shape: [num_samples, feature_dim].
    - labels (torch.Tensor or np.ndarray): The true labels of the data. Shape: [num_samples].

    Returns:
    - float: The classification accuracy as a percentage.
    """
    # Initialize counters for correct predictions and total samples
    correct = 0
    total = 0

    # Get predictions from the classifier
    predictions = lwp_classifier(data, prototypes, prototype_labels)
    
    # Ensure predictions and labels are PyTorch tensors
    if isinstance(predictions, np.ndarray):
        predictions = torch.tensor(predictions, dtype=torch.int64)
    if isinstance(labels, np.ndarray):
        labels = torch.tensor(labels, dtype=torch.int64)

    # Compare predictions and true labels, and sum the correct predictions
    correct += (predictions == labels).sum().item()

    # Calculate the total number of samples
    total += labels.size(0)

    # Calculate accuracy as a percentage
    accuracy = 100 * correct / total

    return accuracy

In [None]:
def kmeans_clustering_with_prototypes(target_embeddings, prototype_tensors, num_clusters=10):
    """
    Perform K-Means clustering where initial cluster centers are set to the prototype tensors.
    
    Args:
    - target_embeddings (torch.Tensor or np.ndarray): The feature embeddings of the target data.
    - prototype_tensors (torch.Tensor): The current prototype tensors (size [num_classes, feature_size]).
    - num_clusters (int): Number of clusters (10).
    
    Returns:
    - centroids (torch.Tensor): The final centroids of the clusters after K-Means.
    - cluster_labels (numpy.ndarray): The cluster labels assigned to the target embeddings.
    """
    
    # Ensure inputs are numpy arrays (KMeans from sklearn works with numpy)
    if isinstance(target_embeddings, torch.Tensor):
        target_embeddings = target_embeddings.cpu().numpy()
    if isinstance(prototype_tensors, torch.Tensor):
        prototype_tensors = prototype_tensors.cpu().numpy()
    
    # Initialize the KMeans model with the given prototypes as initial cluster centers
    kmeans = KMeans(n_clusters=num_clusters, init=prototype_tensors, n_init=1, random_state=42)
    
    # Fit KMeans to the target data
    cluster_labels = kmeans.fit_predict(target_embeddings)
    
    # Get the final centroids (updated prototypes)
    centroids = torch.tensor(kmeans.cluster_centers_, dtype=torch.float32)
    
    return centroids, cluster_labels


def update_prototypes_with_kmeans(prototypes, prototype_labels, new_features, pseudo_labels, beta=1.0, num_clusters=10):
    """
    Update the prototypes by clustering the new features using K-Means, and then averaging the existing prototypes
    with the new cluster centroids.

    Args:
    - prototypes (torch.Tensor): Current prototypes (size: [num_classes, feature_dim]).
    - prototype_labels (torch.Tensor): Labels corresponding to the prototypes.
    - new_features (torch.Tensor or np.ndarray): New feature vectors to update prototypes with.
    - pseudo_labels (torch.Tensor): Pseudo labels for the new features.
    - beta (float): Weighting factor to balance old prototypes and new cluster centroids.
    - num_clusters (int): Number of clusters for K-Means.

    Returns:
    - updated_prototypes (torch.Tensor): The updated prototypes after averaging with K-Means centroids.
    """
    
    updated_prototypes = prototypes.clone()  # Create a clone to update the prototypes

    # Check if new_features is a tensor or a NumPy array and convert to numpy if necessary
    if isinstance(new_features, torch.Tensor):
        new_features = new_features.numpy()  # Convert to NumPy array if it's a tensor

    # Perform K-Means clustering for the new features
    kmeans = KMeans(n_clusters=num_clusters, random_state=42)
    cluster_labels = kmeans.fit_predict(new_features)  # Cluster new features

    # Get the centroids from K-Means (new prototypes)
    cluster_centroids = kmeans.cluster_centers_  # Shape: [num_clusters, feature_dim]

    for label in prototype_labels:
        # Create a mask for the current label
        mask = (pseudo_labels == label)
        new_class_features = new_features[mask]  # Features corresponding to this class label

        # If no features for this label, skip updating
        if new_class_features.shape[0] == 0:
            continue
        
        # Perform K-Means for the new features belonging to the current class
        # Select features for this class from clusters
        class_features = new_class_features[cluster_labels[mask] == label]

        if class_features.shape[0] == 0:
            continue  # Skip if no features for this class

        # Compute the new prototype by averaging the current prototype with the K-Means centroids for this class
        new_prototype = (beta * prototypes[label] + class_features.mean(axis=0)) / (beta + 1)

        # Update the prototype for the current label
        updated_prototypes[label] = new_prototype

    return updated_prototypes

In [233]:
# Iterate over datasets D_11 to D_20 (i.e., the next 10 datasets)
for i in range(10):
    # Print the current dataset being processed (D_11 to D_20)
    print(f"Processing dataset D_{i + 11}...")

    # Select the training data for the current dataset (from train_task2_arr)
    train_data = train_task2_arr[i]

    # Generate pseudo-labels for the current training data using the prototypes
    pseudo_labels = lwp_classifier(train_data, prototype_tensors, prototype_labels)

    # Update the prototypes using KMeans clustering and the current training data
    prototype_tensors = update_prototypes_with_kmeans(prototype_tensors, prototype_labels, train_data, pseudo_labels, beta=1)

    # Print that prototypes have been updated
    print(f"Prototypes updated!\nEvaluating updated model f_{i + 11} on all previous held-out datasets...")

    # Evaluate the updated prototypes on datasets D_1 to D_10
    for j in range(1, 11):
        # Load validation data for dataset D_j from eval_data
        val_data = torch.load(f"dataset/part_one_dataset/eval_data/{j}_eval_data.tar.pth")
        val_labels = val_data['targets']
        val_feat = eval_task1_arr[j - 1]  # Features from eval_task1_arr for dataset D_j
        
        # Compute the accuracy of the updated model on this validation dataset
        accuracy = evaluate_prototypes(prototype_tensors, prototype_labels, val_feat, val_labels)
        print(f"Accuracy of f_{i + 11} on D{j}: {accuracy}")

    # Evaluate the updated prototypes on datasets D_11 to D_(i+11)
    for j in range(11, i + 12):
        # Load validation data for dataset D_j from the part_two_dataset
        val_data = torch.load(f"dataset/part_two_dataset/eval_data/{j - 10}_eval_data.tar.pth")
        val_labels = val_data['targets']
        val_feat = eval_task2_arr[j - 11]  # Features from eval_task2_arr for dataset D_j
        
        # Compute the accuracy of the updated model on this validation dataset
        accuracy = evaluate_prototypes(prototype_tensors, prototype_labels, val_feat, val_labels)
        print(f"Accuracy of f_{i + 11} on D{j}: {accuracy}")

Processing dataset D_11...
Prototypes updated!
Evaluating updated model f_11 on all previous held-out datasets...
Accuracy of f_11 on D1: 98.32
Accuracy of f_11 on D2: 97.88
Accuracy of f_11 on D3: 98.36
Accuracy of f_11 on D4: 98.0
Accuracy of f_11 on D5: 98.12
Accuracy of f_11 on D6: 98.36
Accuracy of f_11 on D7: 97.56
Accuracy of f_11 on D8: 97.88
Accuracy of f_11 on D9: 97.92
Accuracy of f_11 on D10: 98.28
Accuracy of f_11 on D11: 90.36
Processing dataset D_12...
Prototypes updated!
Evaluating updated model f_12 on all previous held-out datasets...
Accuracy of f_12 on D1: 98.32
Accuracy of f_12 on D2: 97.88
Accuracy of f_12 on D3: 98.36
Accuracy of f_12 on D4: 98.0
Accuracy of f_12 on D5: 98.12
Accuracy of f_12 on D6: 98.36
Accuracy of f_12 on D7: 97.56
Accuracy of f_12 on D8: 97.88
Accuracy of f_12 on D9: 97.92
Accuracy of f_12 on D10: 98.28
Accuracy of f_12 on D11: 90.36
Accuracy of f_12 on D12: 75.92
Processing dataset D_13...
Prototypes updated!
Evaluating updated model f_13 on