In [1]:
import torch
import os
from torchvision import transforms

In [2]:
train_dir = os.path.join('dataset', 'part_one_dataset', 'train_data')
eval_dir = os.path.join('dataset', 'part_one_dataset', 'eval_data')

In [3]:
train_path = os.path.join(train_dir, '1_train_data.tar.pth')
eval_path = os.path.join(eval_dir, '1_eval_data.tar.pth')

t = torch.load(train_path, weights_only = False)

In [4]:
from torchvision import models  
import torch

# Load a pre-trained ResNet model
resnet =  models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
resnet = torch.nn.Sequential(*list(resnet.children())[:-1])  # Remove the last layer
resnet.eval()  # Set to evaluation mode

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resnet = resnet.to(device)

transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to 224x224 (ResNet input size)
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet normalization
])


  return torch._C._cuda_getDeviceCount() > 0


In [None]:
domains = [{} for _ in range(10)]

for j in range(10):
    
    train_path = os.path.join(train_dir, f'{j+1}_train_data.tar.pth')
    t = torch.load(train_path, weights_only = False)

    data = t['data'] # both numpy.ndarray
    
    domains[j]['labels'] = t['targets'] if 'targets' in t else None
    
    try:
        domains[j]['features']  = torch.load(f'stuff/normalized_train_embeds_{j+1}.pt', map_location = device)
    except: 
        embeds = []
        # Convert to PyTorch tensor
        X_tensor = torch.tensor(data, dtype=torch.float32)  # Convert to tensor
        X_tensor = X_tensor.permute(0, 3, 1, 2)  # Change shape to (2500, 3, 32, 32)

        tensor = X_tensor.float()

        transformed_images = []
        for image in tensor:
            # Convert each image tensor (C, H, W) to PIL Image for transformation
            transformed_image = transform(image)  # Apply the transformations
            transformed_images.append(transformed_image)

        preprocessed_tensor = torch.stack(transformed_images)  # Shape: (2500, 3, 224, 224)

        for i in range(10) : 
            
            preprocessed_batch = preprocessed_tensor[i*250:(i+1)*250]
            preprocessed_batch = preprocessed_batch.to(device)

            # 4. Get the embeddings (feature maps)
            with torch.no_grad():  # Disable gradients for inference
                feature_maps = resnet(preprocessed_batch)  # Shape will be (batch_size, 512, 1, 1)

            # 5. Flatten the feature maps (optional)
            embeddings = feature_maps.view(feature_maps.size(0), -1)  # Flatten to shape (batch_size, embedding_size)

            embeds.append(embeddings)
        
        embeds = torch.vstack(embeds)
        domains[j]['features'] = embeds
        
        torch.save(embeds, f'stuff/train_embeds_{j+1}.pt')

  domains[j]['features']  = torch.load(f'stuff/normalized_train_embeds_{j+1}.pt', map_location = device)


In [6]:
import numpy as np
from sklearn.metrics.pairwise import euclidean_distances

class LWP:
    def __init__(self):
        self.prototypes = {}
        self.class_counts = {i:0 for i in range(10)}
    
    def fit(self, features, labels):
        unique_labels = np.unique(labels)
        
        for label in unique_labels:
            
            samples = features[labels == label]
            num_samples = len(samples)
            
            if label not in self.prototypes:
                self.prototypes[label] = samples.mean(axis=0)
                self.class_counts[label] = len(samples)
            else:
                self.class_counts[label] += len(samples)
                self.prototypes[label] = (self.class_counts[label] - num_samples) / self.class_counts[label] * self.prototypes[label] + num_samples/ self.class_counts[label] * samples.mean(axis=0)
                
            
    def predict(self, features):
        preds = []
        for feature in features:
            distances = {label: np.linalg.norm(feature - proto) for label, proto in self.prototypes.items()}
            preds.append(min(distances, key=distances.get))
        return np.array(preds)

## Get Eval Data
Evaluating on trainset for now

In [8]:
eval_domains = [{} for _ in range(10)]

for j in range(10):
    
    eval_path = os.path.join(eval_dir, f'{j+1}_eval_data.tar.pth')
    t = torch.load(eval_path, weights_only = False)

    data = t['data'] # both numpy.ndarray
    
    eval_domains[j]['labels'] = t['targets'] if 'targets' in t else None
    
    try:
        eval_domains[j]['features']  = torch.load(f'stuff/normalized_eval_embeds_{j+1}.pt', map_location = device)
    except: 
        embeds = []
        # Convert to PyTorch tensor
        X_tensor = torch.tensor(data, dtype=torch.float32)  # Convert to tensor
        X_tensor = X_tensor.permute(0, 3, 1, 2)  # Change shape to (2500, 3, 32, 32)

        tensor = X_tensor.float()

        transformed_images = []
        for image in tensor:
            # Convert each image tensor (C, H, W) to PIL Image for transformation
            transformed_image = transform(image)  # Apply the transformations
            transformed_images.append(transformed_image)

        preprocessed_tensor = torch.stack(transformed_images)  # Shape: (2500, 3, 224, 224)

        for i in range(10) : 
            
            preprocessed_batch = preprocessed_tensor[i*250:(i+1)*250]
            preprocessed_batch = preprocessed_batch.to(device)

            # 4. Get the embeddings (feature maps)
            with torch.no_grad():  # Disable gradients for inference
                feature_maps = resnet(preprocessed_batch)  # Shape will be (batch_size, 512, 1, 1)

            # 5. Flatten the feature maps (optional)
            embeddings = feature_maps.view(feature_maps.size(0), -1)  # Flatten to shape (batch_size, embedding_size)

            embeds.append(embeddings)
        
        embeds = torch.vstack(embeds)
        eval_domains[j]['features'] = embeds
        
        torch.save(embeds, f'stuff/eval_embeds_{j+1}.pt')

  eval_domains[j]['features']  = torch.load(f'stuff/normalized_eval_embeds_{j+1}.pt', map_location = device)


In [16]:
from sklearn.metrics import accuracy_score
import pandas as pd

model = LWP()

df = pd.DataFrame()

for idx,domain in enumerate(domains):
    
    x_test = domain['features']
    y_pred = model.predict(x_test) if domain['labels'] is None else domain['labels']
    
    model.fit(x_test, y_pred)
    print(model.class_counts)
    
    scores = []
    for eval_domain in eval_domains[:idx+1]:
        
        features = eval_domain['features']
        labels = eval_domain['labels']
        
        preds = model.predict(features)
        acc = accuracy_score(labels, preds)
        
        scores.append(acc)
    
    df[f'Domain {idx+1}'] = scores + [np.nan] * (len(eval_domains) - len(scores))

{0: 253, 1: 243, 2: 255, 3: 244, 4: 262, 5: 236, 6: 250, 7: 253, 8: 254, 9: 250}
{0: 422, 1: 522, 2: 559, 3: 278, 4: 651, 5: 317, 6: 600, 7: 438, 8: 509, 9: 704}
{0: 585, 1: 801, 2: 858, 3: 388, 4: 1044, 5: 409, 6: 943, 7: 638, 8: 758, 9: 1076}
{0: 722, 1: 1090, 2: 1141, 3: 564, 4: 1371, 5: 567, 6: 1264, 7: 823, 8: 987, 9: 1471}
{0: 858, 1: 1360, 2: 1421, 3: 815, 4: 1690, 5: 717, 6: 1579, 7: 997, 8: 1215, 9: 1848}
{0: 1019, 1: 1629, 2: 1683, 3: 1119, 4: 2023, 5: 879, 6: 1857, 7: 1187, 8: 1445, 9: 2159}
{0: 1167, 1: 1894, 2: 1946, 3: 1398, 4: 2355, 5: 1040, 6: 2154, 7: 1367, 8: 1701, 9: 2478}
{0: 1303, 1: 2140, 2: 2210, 3: 1720, 4: 2674, 5: 1227, 6: 2464, 7: 1508, 8: 1917, 9: 2837}
{0: 1420, 1: 2385, 2: 2500, 3: 2072, 4: 3024, 5: 1394, 6: 2737, 7: 1684, 8: 2137, 9: 3147}
{0: 1575, 1: 2609, 2: 2801, 3: 2373, 4: 3355, 5: 1575, 6: 3029, 7: 1851, 8: 2361, 9: 3471}


In [17]:
print(df)

   Domain 1  Domain 2  Domain 3  Domain 4  Domain 5  Domain 6  Domain 7  \
0     0.266    0.2440    0.2344    0.2292    0.2276    0.2244    0.2208   
1       NaN    0.2424    0.2352    0.2292    0.2252    0.2232    0.2216   
2       NaN       NaN    0.2360    0.2344    0.2332    0.2328    0.2328   
3       NaN       NaN       NaN    0.2504    0.2440    0.2440    0.2380   
4       NaN       NaN       NaN       NaN    0.2404    0.2388    0.2372   
5       NaN       NaN       NaN       NaN       NaN    0.2236    0.2216   
6       NaN       NaN       NaN       NaN       NaN       NaN    0.2264   
7       NaN       NaN       NaN       NaN       NaN       NaN       NaN   
8       NaN       NaN       NaN       NaN       NaN       NaN       NaN   
9       NaN       NaN       NaN       NaN       NaN       NaN       NaN   

   Domain 8  Domain 9  Domain 10  
0    0.2184    0.2176     0.2192  
1    0.2176    0.2140     0.2136  
2    0.2316    0.2296     0.2292  
3    0.2360    0.2332     0.2324  