In [8]:
import torch
import os

In [9]:
train_dir = os.path.join('dataset', 'part_one_dataset', 'train_data')
eval_dir = os.path.join('dataset', 'part_one_dataset', 'eval_data')
one_embeds_dir = os.path.join('part_1_vit_embeds')
two_embeds_dir = os.path.join('part_2_vit_embeds')

In [10]:
domains = [{} for _ in range(20)]

for j in range(10):
    
    train_path = os.path.join(train_dir, f'{j+1}_train_data.tar.pth')
    t = torch.load(train_path, weights_only = False)
    
    domains[j]['labels'] = t['targets'] if 'targets' in t else None
    domains[j]['features'] = torch.load(os.path.join(one_embeds_dir,f'train_embeds_{j+1}.pt'), weights_only = False)
    
eval_domains = [{} for _ in range(20)]

for j in range(10):
    
    eval_path = os.path.join(eval_dir, f'{j+1}_eval_data.tar.pth')
    t = torch.load(eval_path, weights_only = False)

    data = t['data'] # both numpy.ndarray
    
    eval_domains[j]['labels'] = t['targets'] if 'targets' in t else None
    eval_domains[j]['features'] = torch.load(os.path.join(one_embeds_dir,f'eval_embeds_{j+1}.pt'), weights_only = False)
    
for j in range(10):
    
    train_path = os.path.join(train_dir, f'{j+1}_train_data.tar.pth')
    t = torch.load(train_path, weights_only = False)
    
    domains[j+10]['labels'] = t['targets'] if 'targets' in t else None
    domains[j+10]['features'] = torch.load(os.path.join(two_embeds_dir,f'train_embeds_{j+1}.pt'), weights_only = False)
    
for j in range(10):
    
    train_path = os.path.join(eval_dir, f'{j+1}_eval_data.tar.pth')
    t = torch.load(train_path, weights_only = False)
    
    eval_domains[j+10]['labels'] = t['targets'] if 'targets' in t else None
    eval_domains[j+10]['features'] = torch.load(os.path.join(two_embeds_dir,f'train_embeds_{j+1}.pt'), weights_only = False)
    
    

In [11]:
import numpy as np
from sklearn.metrics.pairwise import cosine_distances, manhattan_distances

class LWP:
    """Learning Vector Prototypes with configurable distance function"""
    
    DISTANCE_FUNCTIONS = {
        'euclidean': lambda x, y: np.linalg.norm(x - y),
        'cosine': lambda x, y: cosine_distances(x.reshape(1, -1), y.reshape(1, -1))[0][0],
        'manhattan': lambda x, y: manhattan_distances(x.reshape(1, -1), y.reshape(1, -1))[0][0],
        'minkowski': lambda x, y, p=2: np.power(np.sum(np.power(np.abs(x - y), p)), 1/p)
    }
    
    def __init__(self, distance_metric='euclidean', **distance_params):
        """
            distance_params (dict): Additional parameters for the distance function
        """
        self.prototypes = {}
        self.class_counts = {i: 0 for i in range(10)}
        
        if callable(distance_metric):
            self.distance_fn = distance_metric
        elif distance_metric in self.DISTANCE_FUNCTIONS:
            if distance_metric == 'minkowski':
                p = distance_params.get('p', 2)
                self.distance_fn = lambda x, y: self.DISTANCE_FUNCTIONS[distance_metric](x, y, p)
            else:
                self.distance_fn = self.DISTANCE_FUNCTIONS[distance_metric]
        else:
            raise ValueError(f"Unknown distance metric: {distance_metric}. " 
                           f"Available metrics: {list(self.DISTANCE_FUNCTIONS.keys())}")

    def fit(self, features, labels):
        unique_labels = np.unique(labels)
        for label in unique_labels:
            samples = features[labels == label]
            num_samples = len(samples)
            
            if label not in self.prototypes:  # Original condition was: if label not in self.prototypes
                self.prototypes[label] = samples.mean(axis=0)
                self.class_counts[label] = len(samples)
            else:
                self.class_counts[label] += len(samples)
                self.prototypes[label] = (
                    (self.class_counts[label] - num_samples) / self.class_counts[label] * self.prototypes[label] +
                    num_samples / self.class_counts[label] * samples.mean(axis=0)
                )

    def predict(self, features):
        preds = []
        for feature in features:
            distances = {
                label: self.distance_fn(feature, proto)
                for label, proto in self.prototypes.items()
            }
            preds.append(min(distances, key=distances.get))
        return np.array(preds)

In [12]:
from sklearn.cluster import KMeans

models = []

source_dataset = domains[0]
model = LWP(distance_metric='cosine')
model.fit(source_dataset['features'], source_dataset['labels'])
models.append(model)

for i in range(1, 20):
    features = domains[i]['features']
    
    # Clustering unlabeled data into 10 clusters
    kmeans = KMeans(n_clusters=10, random_state=42)
    clusters = kmeans.fit_predict(features)
    

    # Assign labels to clusters
    cluster_centers = kmeans.cluster_centers_
    predicted_labels = model.predict(cluster_centers)

    # Label unlabeled data
    pseudo_labels = np.array([predicted_labels[cluster] for cluster in clusters])

    model.fit(features, pseudo_labels)
    models.append(model)

In [13]:
from sklearn.metrics import accuracy_score
import pandas as pd

df = pd.DataFrame()

for idx,model in enumerate(models) :
    
    scores = []
    for eval_domain in eval_domains[:idx+1]:
        
        features = eval_domain['features']
        labels = eval_domain['labels']
        
        preds = model.predict(features)
        acc = accuracy_score(labels, preds)
        
        scores.append(acc)
    
    df[f'Domain {idx+1}'] = scores + [np.nan] * (len(eval_domains) - len(scores))

In [14]:
print(df)

    Domain 1  Domain 2  Domain 3  Domain 4  Domain 5  Domain 6  Domain 7  \
0      0.856    0.8560    0.8560    0.8560    0.8560    0.8560    0.8560   
1        NaN    0.8544    0.8544    0.8544    0.8544    0.8544    0.8544   
2        NaN       NaN    0.8620    0.8620    0.8620    0.8620    0.8620   
3        NaN       NaN       NaN    0.8816    0.8816    0.8816    0.8816   
4        NaN       NaN       NaN       NaN    0.8744    0.8744    0.8744   
5        NaN       NaN       NaN       NaN       NaN    0.8828    0.8828   
6        NaN       NaN       NaN       NaN       NaN       NaN    0.8664   
7        NaN       NaN       NaN       NaN       NaN       NaN       NaN   
8        NaN       NaN       NaN       NaN       NaN       NaN       NaN   
9        NaN       NaN       NaN       NaN       NaN       NaN       NaN   
10       NaN       NaN       NaN       NaN       NaN       NaN       NaN   
11       NaN       NaN       NaN       NaN       NaN       NaN       NaN   
12       NaN