In [1]:
import torch
import os
from torchvision import transforms

In [2]:
train_dir = os.path.join('dataset', 'part_one_dataset', 'train_data')
eval_dir = os.path.join('dataset', 'part_one_dataset', 'eval_data')
save_dir = os.path.join('part_1_vit_embeds')

In [3]:
import numpy as np
from sklearn.metrics.pairwise import cosine_distances, manhattan_distances

class LWP:
    """Learning Vector Prototypes with configurable distance function"""
    
    DISTANCE_FUNCTIONS = {
        'euclidean': lambda x, y: np.linalg.norm(x - y),
        'cosine': lambda x, y: cosine_distances(x.reshape(1, -1), y.reshape(1, -1))[0][0],
        'manhattan': lambda x, y: manhattan_distances(x.reshape(1, -1), y.reshape(1, -1))[0][0],
        'minkowski': lambda x, y, p=2: np.power(np.sum(np.power(np.abs(x - y), p)), 1/p)
    }
    
    def __init__(self, distance_metric='euclidean', **distance_params):
        """
            distance_params (dict): Additional parameters for the distance function
        """
        self.prototypes = {}
        self.class_counts = {i: 0 for i in range(10)}
        
        if callable(distance_metric):
            self.distance_fn = distance_metric
        elif distance_metric in self.DISTANCE_FUNCTIONS:
            if distance_metric == 'minkowski':
                p = distance_params.get('p', 2)
                self.distance_fn = lambda x, y: self.DISTANCE_FUNCTIONS[distance_metric](x, y, p)
            else:
                self.distance_fn = self.DISTANCE_FUNCTIONS[distance_metric]
        else:
            raise ValueError(f"Unknown distance metric: {distance_metric}. " 
                           f"Available metrics: {list(self.DISTANCE_FUNCTIONS.keys())}")

    def fit(self, features, labels):
        unique_labels = np.unique(labels)
        for label in unique_labels:
            samples = features[labels == label]
            num_samples = len(samples)
            
            if label not in self.prototypes:  # Original condition was: if label not in self.prototypes
                self.prototypes[label] = samples.mean(axis=0)
                self.class_counts[label] = len(samples)
            else:
                self.class_counts[label] += len(samples)
                self.prototypes[label] = (
                    (self.class_counts[label] - num_samples) / self.class_counts[label] * self.prototypes[label] +
                    num_samples / self.class_counts[label] * samples.mean(axis=0)
                )

    def predict(self, features):
        preds = []
        for feature in features:
            distances = {
                label: self.distance_fn(feature, proto)
                for label, proto in self.prototypes.items()
            }
            preds.append(min(distances, key=distances.get))
        return np.array(preds)

## Get Eval Data
Evaluating on trainset for now

In [4]:
domains = [{} for _ in range(20)]

for j in range(10):
    
    train_path = os.path.join(train_dir, f'{j+1}_train_data.tar.pth')
    t = torch.load(train_path, weights_only = False)
    
    domains[j]['labels'] = t['targets'] if 'targets' in t else None
    domains[j]['features'] = torch.load(os.path.join(save_dir,f'train_embeds_{j+1}.pt'))

  domains[j]['features'] = torch.load(os.path.join(save_dir,f'train_embeds_{j+1}.pt'))


In [5]:
eval_domains = [{} for _ in range(20)]

for j in range(10):
    
    eval_path = os.path.join(eval_dir, f'{j+1}_eval_data.tar.pth')
    t = torch.load(eval_path, weights_only = False)

    data = t['data'] # both numpy.ndarray
    
    eval_domains[j]['labels'] = t['targets'] if 'targets' in t else None
    eval_domains[j]['features'] = torch.load(os.path.join(save_dir,f'eval_embeds_{j+1}.pt'))

  eval_domains[j]['features'] = torch.load(os.path.join(save_dir,f'eval_embeds_{j+1}.pt'))


## Adding part two dataset here

In [6]:
train_dir = os.path.join('dataset', 'part_two_dataset', 'train_data')
eval_dir = os.path.join('dataset', 'part_two_dataset', 'eval_data')
save_dir = os.path.join('part_2_vit_embeds')

In [7]:
for j in range(10):
    
    train_path = os.path.join(train_dir, f'{j+1}_train_data.tar.pth')
    t = torch.load(train_path, weights_only = False)
    
    domains[j+10]['labels'] = t['targets'] if 'targets' in t else None
    domains[j+10]['features'] = torch.load(os.path.join(save_dir,f'train_embeds_{j+1}.pt'))

  domains[j+10]['features'] = torch.load(os.path.join(save_dir,f'train_embeds_{j+1}.pt'))


In [8]:
for j in range(10):
    
    train_path = os.path.join(eval_dir, f'{j+1}_eval_data.tar.pth')
    t = torch.load(train_path, weights_only = False)
    
    eval_domains[j+10]['labels'] = t['targets'] if 'targets' in t else None
    eval_domains[j+10]['features'] = torch.load(os.path.join(save_dir,f'eval_embeds_{j+1}.pt'))

  eval_domains[j+10]['features'] = torch.load(os.path.join(save_dir,f'eval_embeds_{j+1}.pt'))


In [9]:
from sklearn.metrics import accuracy_score
import pandas as pd

model = LWP(distance_metric='cosine')

df = pd.DataFrame()

for idx,domain in enumerate(domains):
    
    x_test = domain['features']
    y_pred = model.predict(x_test) if domain['labels'] is None else domain['labels']
    
    model.fit(x_test, y_pred)
    print(model.class_counts)
    
    scores = []
    for eval_domain in eval_domains[:idx+1]:
        
        features = eval_domain['features']
        labels = eval_domain['labels']
        
        preds = model.predict(features)
        acc = accuracy_score(labels, preds)
        
        scores.append(acc)
    
    df[f'Domain {idx+1}'] = scores + [np.nan] * (len(eval_domains) - len(scores))

{0: 253, 1: 243, 2: 255, 3: 244, 4: 262, 5: 236, 6: 250, 7: 253, 8: 254, 9: 250}
{0: 490, 1: 510, 2: 473, 3: 520, 4: 518, 5: 487, 6: 503, 7: 508, 8: 493, 9: 498}
{0: 760, 1: 758, 2: 676, 3: 772, 4: 795, 5: 770, 6: 763, 7: 728, 8: 744, 9: 734}
{0: 1012, 1: 1025, 2: 871, 3: 1040, 4: 1070, 5: 1026, 6: 1015, 7: 952, 8: 983, 9: 1006}
{0: 1266, 1: 1296, 2: 1066, 3: 1293, 4: 1346, 5: 1303, 6: 1254, 7: 1181, 8: 1221, 9: 1274}
{0: 1513, 1: 1527, 2: 1282, 3: 1551, 4: 1638, 5: 1598, 6: 1495, 7: 1416, 8: 1449, 9: 1531}
{0: 1763, 1: 1770, 2: 1502, 3: 1799, 4: 1928, 5: 1869, 6: 1729, 7: 1636, 8: 1723, 9: 1781}
{0: 2019, 1: 2023, 2: 1740, 3: 2054, 4: 2201, 5: 2159, 6: 1959, 7: 1846, 8: 1958, 9: 2041}
{0: 2269, 1: 2288, 2: 1933, 3: 2313, 4: 2500, 5: 2411, 6: 2190, 7: 2075, 8: 2212, 9: 2309}
{0: 2512, 1: 2529, 2: 2134, 3: 2579, 4: 2781, 5: 2721, 6: 2425, 7: 2294, 8: 2456, 9: 2569}
{0: 2776, 1: 2790, 2: 2334, 3: 2925, 4: 3064, 5: 3009, 6: 2674, 7: 2453, 8: 2684, 9: 2791}
{0: 3245, 1: 2892, 2: 2404, 3: 3

In [10]:
print(df)

    Domain 1  Domain 2  Domain 3  Domain 4  Domain 5  Domain 6  Domain 7  \
0     0.8992    0.8940    0.8912    0.8900    0.8900    0.8880    0.8872   
1        NaN    0.9028    0.9000    0.8996    0.8976    0.8968    0.8960   
2        NaN       NaN    0.9076    0.9056    0.9052    0.9044    0.9028   
3        NaN       NaN       NaN    0.9156    0.9144    0.9144    0.9132   
4        NaN       NaN       NaN       NaN    0.9064    0.9052    0.9040   
5        NaN       NaN       NaN       NaN       NaN    0.9108    0.9112   
6        NaN       NaN       NaN       NaN       NaN       NaN    0.9032   
7        NaN       NaN       NaN       NaN       NaN       NaN       NaN   
8        NaN       NaN       NaN       NaN       NaN       NaN       NaN   
9        NaN       NaN       NaN       NaN       NaN       NaN       NaN   
10       NaN       NaN       NaN       NaN       NaN       NaN       NaN   
11       NaN       NaN       NaN       NaN       NaN       NaN       NaN   
12       NaN