In [14]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision import models, transforms
import numpy as np

In [15]:
train_dir = os.path.join('dataset', 'part_one_dataset', 'train_data')
eval_dir = os.path.join('dataset', 'part_one_dataset', 'eval_data')
train_path = os.path.join(train_dir, '1_train_data.tar.pth')
eval_path = os.path.join(eval_dir, '1_eval_data.tar.pth')

save_dir = os.path.join('part_2_vit_embeds')

t = torch.load(train_path, weights_only = False)
t['data'] = torch.load(os.path.join(save_dir,f'train_embeds_{1}.pt'))

  t['data'] = torch.load(os.path.join(save_dir,f'train_embeds_{1}.pt'))


In [3]:
import numpy as np
from sklearn.metrics.pairwise import cosine_distances, manhattan_distances

class LWP:
    """Learning Vector Prototypes with configurable distance function"""
    
    DISTANCE_FUNCTIONS = {
        'euclidean': lambda x, y: np.linalg.norm(x - y),
        'cosine': lambda x, y: cosine_distances(x.reshape(1, -1), y.reshape(1, -1))[0][0],
        'manhattan': lambda x, y: manhattan_distances(x.reshape(1, -1), y.reshape(1, -1))[0][0],
        'minkowski': lambda x, y, p=2: np.power(np.sum(np.power(np.abs(x - y), p)), 1/p)
    }
    
    def __init__(self, distance_metric='euclidean', **distance_params):
        """
            distance_params (dict): Additional parameters for the distance function
        """
        self.prototypes = {}
        self.class_counts = {i: 0 for i in range(10)}
        
        if callable(distance_metric):
            self.distance_fn = distance_metric
        elif distance_metric in self.DISTANCE_FUNCTIONS:
            if distance_metric == 'minkowski':
                p = distance_params.get('p', 2)
                self.distance_fn = lambda x, y: self.DISTANCE_FUNCTIONS[distance_metric](x, y, p)
            else:
                self.distance_fn = self.DISTANCE_FUNCTIONS[distance_metric]
        else:
            raise ValueError(f"Unknown distance metric: {distance_metric}. " 
                           f"Available metrics: {list(self.DISTANCE_FUNCTIONS.keys())}")

    def fit(self, features, labels):
        unique_labels = np.unique(labels)
        for label in unique_labels:
            samples = features[labels == label]
            num_samples = len(samples)
            
            if label not in self.prototypes:  # Original condition was: if label not in self.prototypes
                self.prototypes[label] = samples.mean(axis=0)
                self.class_counts[label] = len(samples)
            else:
                self.class_counts[label] += len(samples)
                self.prototypes[label] = (
                    (self.class_counts[label] - num_samples) / self.class_counts[label] * self.prototypes[label] +
                    num_samples / self.class_counts[label] * samples.mean(axis=0)
                )
    
            
    

    def predict(self, features):
        preds = []
        for feature in features:
            distances = {
                label: self.distance_fn(feature, proto)
                for label, proto in self.prototypes.items()
            }
            preds.append(min(distances, key=distances.get))
        return np.array(preds)

In [17]:
print (max(t['targets']))

9


In [18]:
print ( t['data'].shape)

(2500, 768)


In [19]:
data, targets = t['data'], t['targets'] 
data=data.reshape(data.shape[0], -1)
print(data.shape)

(2500, 768)


In [20]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from sklearn.preprocessing import normalize

In [21]:
data1, targets1= t['data'], t['targets']
data1=data1.reshape(data1.shape[0], -1)
data1=normalize(data1)
dataloader=DataLoader(data1, batch_size=32, shuffle=False)
lwp_model = LWP(distance_metric='cosine')  # LWP model with cosine distance
lwp_model.fit(data1, targets1) 

In [22]:

print("Prototypes after fit:", lwp_model.prototypes)


Prototypes after fit: {0: array([0.01729068, 0.01912501, 0.02102479, ..., 0.01516719, 0.0159201 ,
       0.01601679]), 1: array([0.01837207, 0.01810205, 0.01747006, ..., 0.01770059, 0.01728502,
       0.01624851]), 2: array([0.0164336 , 0.01742426, 0.01454287, ..., 0.01747445, 0.01748331,
       0.01434357]), 3: array([0.01795165, 0.0166975 , 0.01529295, ..., 0.01780526, 0.01648547,
       0.01536131]), 4: array([0.01489708, 0.01576391, 0.01352558, ..., 0.0190007 , 0.01918124,
       0.01478338]), 5: array([0.01479913, 0.01467359, 0.01309324, ..., 0.01706052, 0.01606419,
       0.01416187]), 6: array([0.0163735 , 0.01605917, 0.01306188, ..., 0.01921913, 0.01844561,
       0.01541324]), 7: array([0.01802458, 0.01884868, 0.01836053, ..., 0.01945586, 0.01867545,
       0.01490231]), 8: array([0.01797723, 0.02032782, 0.02269468, ..., 0.01198278, 0.01380034,
       0.01474392]), 9: array([0.02226604, 0.02317505, 0.02398837, ..., 0.01793398, 0.01759059,
       0.0167393 ])}


In [10]:
# Test LWP fit
print("Testing LWP fit...")
lwp_model = LWP(distance_metric='cosine')
lwp_model.fit(data1, targets1)

# Check prototypes
print("Prototypes initialized:")
for label, prototype in lwp_model.prototypes.items():
    print(f"Class {label}: Prototype shape = {prototype.shape}")


Testing LWP fit...
Prototypes initialized:
Class 0: Prototype shape = (3072,)
Class 1: Prototype shape = (3072,)
Class 2: Prototype shape = (3072,)
Class 3: Prototype shape = (3072,)
Class 4: Prototype shape = (3072,)
Class 5: Prototype shape = (3072,)
Class 6: Prototype shape = (3072,)
Class 7: Prototype shape = (3072,)
Class 8: Prototype shape = (3072,)
Class 9: Prototype shape = (3072,)


In [11]:
print("Features shape:", data1.shape)
print("Labels shape:", targets1.shape)
print("Unique labels:", np.unique(targets1))


Features shape: (2500, 3072)
Labels shape: (2500,)
Unique labels: [0 1 2 3 4 5 6 7 8 9]


In [12]:
import os
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.preprocessing import normalize
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics.pairwise import cosine_distances

In [13]:
print("Prototypes:", lwp_model.prototypes)

  data = torch.load(os.path.join(save_dir,f'train_embeds_{i}.pt'))


Prototypes: {0: array([0.01729068, 0.01912501, 0.02102479, ..., 0.01516719, 0.0159201 ,
       0.01601679]), 1: array([0.01837207, 0.01810205, 0.01747006, ..., 0.01770059, 0.01728502,
       0.01624851]), 2: array([0.0164336 , 0.01742426, 0.01454287, ..., 0.01747445, 0.01748331,
       0.01434357]), 3: array([0.01795165, 0.0166975 , 0.01529295, ..., 0.01780526, 0.01648547,
       0.01536131]), 4: array([0.01489708, 0.01576391, 0.01352558, ..., 0.0190007 , 0.01918124,
       0.01478338]), 5: array([0.01479913, 0.01467359, 0.01309324, ..., 0.01706052, 0.01606419,
       0.01416187]), 6: array([0.0163735 , 0.01605917, 0.01306188, ..., 0.01921913, 0.01844561,
       0.01541324]), 7: array([0.01802458, 0.01884868, 0.01836053, ..., 0.01945586, 0.01867545,
       0.01490231]), 8: array([0.01797723, 0.02032782, 0.02269468, ..., 0.01198278, 0.01380034,
       0.01474392]), 9: array([0.02226604, 0.02317505, 0.02398837, ..., 0.01793398, 0.01759059,
       0.0167393 ])}


In [14]:
prototypes= lwp_model.prototypes
print(prototypes)

{0: array([0.01729068, 0.01912501, 0.02102479, ..., 0.01516719, 0.0159201 ,
       0.01601679]), 1: array([0.01837207, 0.01810205, 0.01747006, ..., 0.01770059, 0.01728502,
       0.01624851]), 2: array([0.0164336 , 0.01742426, 0.01454287, ..., 0.01747445, 0.01748331,
       0.01434357]), 3: array([0.01795165, 0.0166975 , 0.01529295, ..., 0.01780526, 0.01648547,
       0.01536131]), 4: array([0.01489708, 0.01576391, 0.01352558, ..., 0.0190007 , 0.01918124,
       0.01478338]), 5: array([0.01479913, 0.01467359, 0.01309324, ..., 0.01706052, 0.01606419,
       0.01416187]), 6: array([0.0163735 , 0.01605917, 0.01306188, ..., 0.01921913, 0.01844561,
       0.01541324]), 7: array([0.01802458, 0.01884868, 0.01836053, ..., 0.01945586, 0.01867545,
       0.01490231]), 8: array([0.01797723, 0.02032782, 0.02269468, ..., 0.01198278, 0.01380034,
       0.01474392]), 9: array([0.02226604, 0.02317505, 0.02398837, ..., 0.01793398, 0.01759059,
       0.0167393 ])}


In [15]:
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics.pairwise import cosine_similarity
# Function to compute cosine similarity and select top samples
def select_top_samples(embeddings, centroids, top_k=50):
    """
    Select top-k% samples with highest cosine similarity to centroids.
    """
    similarities = cosine_similarity(embeddings, centroids)
    pseudo_labels = np.argmax(similarities, axis=1)
    max_similarities = np.max(similarities, axis=1)

    # Sort by similarity and select top k% samples
    sorted_indices = np.argsort(max_similarities)[::-1]
    top_count = len(sorted_indices) * top_k // 100
    top_indices = sorted_indices[:top_count]

    return embeddings[top_indices], pseudo_labels[top_indices]

In [16]:
def process_dataset_with_knn(embed_dir, dataset_idx, lwp_model, k=5, top_k=50):
    print(f"Processing dataset {dataset_idx} with kNN...")

    # Load embeddings
    embed_path = os.path.join(embed_dir, f'train_embeds_{dataset_idx}.pt')
    embeddings = torch.load(embed_path)

    # Compute pseudo-labels and select top samples
    centroids = np.vstack([proto for _, proto in sorted(lwp_model.prototypes.items())])
    top_embeddings, top_pseudo_labels = select_top_samples(embeddings, centroids, top_k=top_k)

    # Train kNN on the top samples
    knn = KNeighborsClassifier(n_neighbors=k, metric='euclidean')
    knn.fit(top_embeddings, top_pseudo_labels)

    # Use kNN to assign pseudo-labels for all embeddings
    all_pseudo_labels = knn.predict(embeddings)
    print(f"Dataset {dataset_idx}: Assigned pseudo-labels using kNN")

    # Update prototypes (class centroids) with kNN pseudo-labeled samples
    for label in np.unique(all_pseudo_labels):
        class_embeddings = embeddings[all_pseudo_labels == label]
        if class_embeddings.size > 0:
            centroid = class_embeddings.mean(axis=0)
            lwp_model.prototypes[label] = centroid

    print(f"Dataset {dataset_idx}: Updated prototypes for classes {list(lwp_model.prototypes.keys())}")

In [17]:
# Directories for embeddings
part_one_embed_dir = 'part_1_vit_embeds'
part_two_embed_dir = 'part_2_vit_embeds'

# Load D1 embeddings and targets
train_embed_path = os.path.join(part_one_embed_dir, 'train_embeds_1.pt')
eval_embed_path = os.path.join(part_one_embed_dir, 'eval_embeds_1.pt')
train_embeddings = torch.load(train_embed_path)
eval_embeddings = torch.load(eval_embed_path)

  train_embeddings = torch.load(train_embed_path)
  eval_embeddings = torch.load(eval_embed_path)


In [18]:
train_path = os.path.join('dataset', 'part_one_dataset', 'train_data', '1_train_data.tar.pth')
data = torch.load(train_path)
targets = data['targets']

# Initialize and fit LWP model
lwp_model = LWP(distance_metric='cosine')
lwp_model.fit(train_embeddings, targets)

  data = torch.load(train_path)


In [19]:
for i in range(2, 11):
    process_dataset_with_knn(part_one_embed_dir, i, lwp_model, k=5, top_k=50)

  embeddings = torch.load(embed_path)


Processing dataset 2 with kNN...
Dataset 2: Assigned pseudo-labels using kNN
Dataset 2: Updated prototypes for classes [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Processing dataset 3 with kNN...


  embeddings = torch.load(embed_path)


Dataset 3: Assigned pseudo-labels using kNN
Dataset 3: Updated prototypes for classes [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Processing dataset 4 with kNN...


  embeddings = torch.load(embed_path)


Dataset 4: Assigned pseudo-labels using kNN
Dataset 4: Updated prototypes for classes [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Processing dataset 5 with kNN...


  embeddings = torch.load(embed_path)


Dataset 5: Assigned pseudo-labels using kNN
Dataset 5: Updated prototypes for classes [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Processing dataset 6 with kNN...


  embeddings = torch.load(embed_path)


Dataset 6: Assigned pseudo-labels using kNN
Dataset 6: Updated prototypes for classes [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Processing dataset 7 with kNN...


  embeddings = torch.load(embed_path)


Dataset 7: Assigned pseudo-labels using kNN
Dataset 7: Updated prototypes for classes [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Processing dataset 8 with kNN...


  embeddings = torch.load(embed_path)


Dataset 8: Assigned pseudo-labels using kNN
Dataset 8: Updated prototypes for classes [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Processing dataset 9 with kNN...


  embeddings = torch.load(embed_path)


Dataset 9: Assigned pseudo-labels using kNN
Dataset 9: Updated prototypes for classes [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Processing dataset 10 with kNN...


  embeddings = torch.load(embed_path)


Dataset 10: Assigned pseudo-labels using kNN
Dataset 10: Updated prototypes for classes [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


In [20]:
for i in range(11, 21):
    process_dataset_with_knn(part_two_embed_dir, i-10, lwp_model, k=5, top_k=50)

Processing dataset 1 with kNN...


  embeddings = torch.load(embed_path)


Dataset 1: Assigned pseudo-labels using kNN
Dataset 1: Updated prototypes for classes [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Processing dataset 2 with kNN...


  embeddings = torch.load(embed_path)


Dataset 2: Assigned pseudo-labels using kNN
Dataset 2: Updated prototypes for classes [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Processing dataset 3 with kNN...


  embeddings = torch.load(embed_path)


Dataset 3: Assigned pseudo-labels using kNN
Dataset 3: Updated prototypes for classes [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Processing dataset 4 with kNN...


  embeddings = torch.load(embed_path)


Dataset 4: Assigned pseudo-labels using kNN
Dataset 4: Updated prototypes for classes [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Processing dataset 5 with kNN...


  embeddings = torch.load(embed_path)


Dataset 5: Assigned pseudo-labels using kNN
Dataset 5: Updated prototypes for classes [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Processing dataset 6 with kNN...


  embeddings = torch.load(embed_path)


Dataset 6: Assigned pseudo-labels using kNN
Dataset 6: Updated prototypes for classes [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Processing dataset 7 with kNN...


  embeddings = torch.load(embed_path)


Dataset 7: Assigned pseudo-labels using kNN
Dataset 7: Updated prototypes for classes [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Processing dataset 8 with kNN...


  embeddings = torch.load(embed_path)


Dataset 8: Assigned pseudo-labels using kNN
Dataset 8: Updated prototypes for classes [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Processing dataset 9 with kNN...


  embeddings = torch.load(embed_path)


Dataset 9: Assigned pseudo-labels using kNN
Dataset 9: Updated prototypes for classes [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Processing dataset 10 with kNN...


  embeddings = torch.load(embed_path)


Dataset 10: Assigned pseudo-labels using kNN
Dataset 10: Updated prototypes for classes [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


In [21]:
from sklearn.metrics import accuracy_score, classification_report
import numpy as np
import os
import torch

# Function to evaluate on eval embeddings
def evaluate_on_eval_embeddings(embed_dir, dataset_idx, model, ground_truth=None):
    """
    Evaluate the trained model using eval embeddings.
    Args:
        embed_dir (str): Path to the directory containing eval embeddings.
        dataset_idx (int): Dataset index (e.g., 1 for D1, 2 for D2, etc.).
        model (LWP or kNN): Trained model for predictions.
        ground_truth (np.array): True labels for eval set (if available).
    Returns:
        dict: Evaluation metrics (if ground_truth provided).
    """
    # Load eval embeddings
    embed_path = os.path.join(embed_dir, f'eval_embeds_{dataset_idx}.pt')
    eval_embeddings = torch.load(embed_path)

    # Predict labels using the trained model
    print(f"Evaluating on dataset {dataset_idx}...")
    predicted_labels = model.predict(eval_embeddings)
    
    if 'part_1' in embed_dir:
        datum = torch.load(os.path.join('dataset', 'part_one_dataset', 'eval_data', f'{dataset_idx}_eval_data.tar.pth'))
        ground_truth = datum['targets']
    elif 'part_2' in embed_dir:
        datum = torch.load(os.path.join('dataset', 'part_two_dataset', 'eval_data', f'{dataset_idx}_eval_data.tar.pth'))
        ground_truth = datum['targets']
        
    if ground_truth is not None:
        # Compute evaluation metrics if ground truth is available
        accuracy = accuracy_score(ground_truth, predicted_labels)
        report = classification_report(ground_truth, predicted_labels, zero_division=0)
        print(f"Accuracy on eval set {dataset_idx}: {accuracy * 100:.2f}%")
        print("Classification Report:")
        print(report)
        return {"accuracy": accuracy, "report": report}
    else:
        # No ground truth available
        print(f"Predictions on eval set {dataset_idx}: {predicted_labels[:10]}...")
        return {"predicted_labels": predicted_labels}

# Directories for eval embeddings
part_one_eval_dir = 'part_1_vit_embeds'
part_two_eval_dir = 'part_2_vit_embeds'

# Evaluate on D1 (with ground truth)
eval_labels_path = os.path.join('dataset', 'part_one_dataset', 'eval_data', '1_eval_data.tar.pth')
eval_data = torch.load(eval_labels_path)
eval_ground_truth = eval_data['targets']

# Evaluate on D1 using LWP model
evaluate_on_eval_embeddings(part_one_eval_dir, dataset_idx=1, model=lwp_model, ground_truth=eval_ground_truth)

# Evaluate on D2-D10 (unlabeled datasets)
for i in range(2, 11):
    huh = evaluate_on_eval_embeddings(part_one_eval_dir, dataset_idx=i, model=lwp_model)
    print(huh)


Evaluating on dataset 1...


  eval_data = torch.load(eval_labels_path)
  eval_embeddings = torch.load(embed_path)
  datum = torch.load(os.path.join('dataset', 'part_one_dataset', 'eval_data', f'{dataset_idx}_eval_data.tar.pth'))
  eval_embeddings = torch.load(embed_path)


Accuracy on eval set 1: 86.12%
Classification Report:
              precision    recall  f1-score   support

           0       0.91      0.87      0.89       252
           1       0.90      0.97      0.93       217
           2       0.99      0.64      0.77       264
           3       0.73      0.82      0.77       242
           4       0.69      0.92      0.79       257
           5       0.86      0.80      0.83       252
           6       0.97      0.88      0.92       269
           7       0.84      0.88      0.86       233
           8       0.95      0.93      0.94       266
           9       0.90      0.94      0.92       248

    accuracy                           0.86      2500
   macro avg       0.87      0.86      0.86      2500
weighted avg       0.87      0.86      0.86      2500

Evaluating on dataset 2...


  datum = torch.load(os.path.join('dataset', 'part_one_dataset', 'eval_data', f'{dataset_idx}_eval_data.tar.pth'))
  eval_embeddings = torch.load(embed_path)


Accuracy on eval set 2: 87.52%
Classification Report:
              precision    recall  f1-score   support

           0       0.93      0.92      0.92       271
           1       0.94      0.96      0.95       261
           2       0.97      0.69      0.80       267
           3       0.77      0.85      0.81       258
           4       0.71      0.91      0.80       236
           5       0.82      0.76      0.79       212
           6       0.97      0.86      0.91       250
           7       0.84      0.90      0.87       252
           8       0.96      0.95      0.96       258
           9       0.91      0.94      0.93       235

    accuracy                           0.88      2500
   macro avg       0.88      0.87      0.87      2500
weighted avg       0.88      0.88      0.88      2500

{'accuracy': 0.8752, 'report': '              precision    recall  f1-score   support\n\n           0       0.93      0.92      0.92       271\n           1       0.94      0.96      0.95

  datum = torch.load(os.path.join('dataset', 'part_one_dataset', 'eval_data', f'{dataset_idx}_eval_data.tar.pth'))
  eval_embeddings = torch.load(embed_path)


Accuracy on eval set 3: 87.12%
Classification Report:
              precision    recall  f1-score   support

           0       0.88      0.88      0.88       231
           1       0.95      0.93      0.94       263
           2       0.98      0.66      0.79       275
           3       0.77      0.81      0.79       236
           4       0.74      0.92      0.82       266
           5       0.83      0.83      0.83       232
           6       0.96      0.92      0.94       241
           7       0.84      0.89      0.86       267
           8       0.96      0.93      0.94       241
           9       0.88      0.96      0.92       248

    accuracy                           0.87      2500
   macro avg       0.88      0.87      0.87      2500
weighted avg       0.88      0.87      0.87      2500

{'accuracy': 0.8712, 'report': '              precision    recall  f1-score   support\n\n           0       0.88      0.88      0.88       231\n           1       0.95      0.93      0.94

  datum = torch.load(os.path.join('dataset', 'part_one_dataset', 'eval_data', f'{dataset_idx}_eval_data.tar.pth'))
  eval_embeddings = torch.load(embed_path)


Accuracy on eval set 4: 88.64%
Classification Report:
              precision    recall  f1-score   support

           0       0.89      0.88      0.89       231
           1       0.95      0.99      0.97       234
           2       0.99      0.73      0.84       247
           3       0.77      0.83      0.80       241
           4       0.78      0.95      0.85       261
           5       0.84      0.75      0.80       273
           6       0.98      0.93      0.96       251
           7       0.85      0.91      0.88       263
           8       0.95      0.93      0.94       251
           9       0.93      0.96      0.94       248

    accuracy                           0.89      2500
   macro avg       0.89      0.89      0.89      2500
weighted avg       0.89      0.89      0.89      2500

{'accuracy': 0.8864, 'report': '              precision    recall  f1-score   support\n\n           0       0.89      0.88      0.89       231\n           1       0.95      0.99      0.97

  datum = torch.load(os.path.join('dataset', 'part_one_dataset', 'eval_data', f'{dataset_idx}_eval_data.tar.pth'))
  eval_embeddings = torch.load(embed_path)


Accuracy on eval set 5: 87.44%
Classification Report:
              precision    recall  f1-score   support

           0       0.94      0.91      0.93       256
           1       0.96      0.97      0.96       265
           2       0.98      0.60      0.75       239
           3       0.77      0.84      0.80       251
           4       0.66      0.93      0.78       227
           5       0.88      0.81      0.84       246
           6       0.97      0.87      0.92       256
           7       0.83      0.87      0.85       261
           8       0.94      0.94      0.94       249
           9       0.92      0.98      0.95       250

    accuracy                           0.87      2500
   macro avg       0.89      0.87      0.87      2500
weighted avg       0.89      0.87      0.87      2500

{'accuracy': 0.8744, 'report': '              precision    recall  f1-score   support\n\n           0       0.94      0.91      0.93       256\n           1       0.96      0.97      0.96

  datum = torch.load(os.path.join('dataset', 'part_one_dataset', 'eval_data', f'{dataset_idx}_eval_data.tar.pth'))
  eval_embeddings = torch.load(embed_path)


Accuracy on eval set 6: 88.16%
Classification Report:
              precision    recall  f1-score   support

           0       0.90      0.87      0.89       245
           1       0.95      0.97      0.96       276
           2       0.99      0.70      0.82       235
           3       0.72      0.84      0.78       226
           4       0.78      0.96      0.86       257
           5       0.87      0.75      0.80       271
           6       0.97      0.90      0.93       235
           7       0.87      0.91      0.89       257
           8       0.94      0.95      0.95       234
           9       0.91      0.95      0.93       264

    accuracy                           0.88      2500
   macro avg       0.89      0.88      0.88      2500
weighted avg       0.89      0.88      0.88      2500

{'accuracy': 0.8816, 'report': '              precision    recall  f1-score   support\n\n           0       0.90      0.87      0.89       245\n           1       0.95      0.97      0.96

  datum = torch.load(os.path.join('dataset', 'part_one_dataset', 'eval_data', f'{dataset_idx}_eval_data.tar.pth'))
  eval_embeddings = torch.load(embed_path)


Accuracy on eval set 7: 87.04%
Classification Report:
              precision    recall  f1-score   support

           0       0.88      0.89      0.89       245
           1       0.94      0.96      0.95       239
           2       0.98      0.67      0.80       251
           3       0.73      0.85      0.78       247
           4       0.72      0.93      0.81       268
           5       0.89      0.76      0.82       272
           6       0.96      0.91      0.93       266
           7       0.83      0.86      0.85       232
           8       0.96      0.92      0.94       238
           9       0.92      0.96      0.94       242

    accuracy                           0.87      2500
   macro avg       0.88      0.87      0.87      2500
weighted avg       0.88      0.87      0.87      2500

{'accuracy': 0.8704, 'report': '              precision    recall  f1-score   support\n\n           0       0.88      0.89      0.89       245\n           1       0.94      0.96      0.95

  datum = torch.load(os.path.join('dataset', 'part_one_dataset', 'eval_data', f'{dataset_idx}_eval_data.tar.pth'))
  eval_embeddings = torch.load(embed_path)


Accuracy on eval set 8: 87.16%
Classification Report:
              precision    recall  f1-score   support

           0       0.91      0.89      0.90       267
           1       0.92      0.96      0.94       238
           2       0.97      0.64      0.77       238
           3       0.74      0.82      0.78       248
           4       0.71      0.88      0.79       247
           5       0.89      0.84      0.86       251
           6       0.97      0.91      0.94       253
           7       0.83      0.88      0.85       257
           8       0.94      0.93      0.93       264
           9       0.92      0.96      0.94       237

    accuracy                           0.87      2500
   macro avg       0.88      0.87      0.87      2500
weighted avg       0.88      0.87      0.87      2500

{'accuracy': 0.8716, 'report': '              precision    recall  f1-score   support\n\n           0       0.91      0.89      0.90       267\n           1       0.92      0.96      0.94

  datum = torch.load(os.path.join('dataset', 'part_one_dataset', 'eval_data', f'{dataset_idx}_eval_data.tar.pth'))
  eval_embeddings = torch.load(embed_path)


Accuracy on eval set 9: 86.92%
Classification Report:
              precision    recall  f1-score   support

           0       0.94      0.90      0.92       268
           1       0.95      0.96      0.96       261
           2       0.97      0.66      0.78       236
           3       0.73      0.82      0.77       253
           4       0.70      0.94      0.81       254
           5       0.84      0.75      0.79       270
           6       0.99      0.87      0.92       250
           7       0.85      0.89      0.87       246
           8       0.94      0.96      0.95       218
           9       0.91      0.95      0.93       244

    accuracy                           0.87      2500
   macro avg       0.88      0.87      0.87      2500
weighted avg       0.88      0.87      0.87      2500

{'accuracy': 0.8692, 'report': '              precision    recall  f1-score   support\n\n           0       0.94      0.90      0.92       268\n           1       0.95      0.96      0.96

  datum = torch.load(os.path.join('dataset', 'part_one_dataset', 'eval_data', f'{dataset_idx}_eval_data.tar.pth'))


In [22]:
# Evaluate on D11-D20 (unlabeled datasets with different distribution)
for i in range(11, 21):
    huh = evaluate_on_eval_embeddings(part_two_eval_dir, dataset_idx=i-10, model=lwp_model)
    print(huh)

  eval_embeddings = torch.load(embed_path)


Evaluating on dataset 1...


  datum = torch.load(os.path.join('dataset', 'part_two_dataset', 'eval_data', f'{dataset_idx}_eval_data.tar.pth'))
  eval_embeddings = torch.load(embed_path)


Accuracy on eval set 1: 78.72%
Classification Report:
              precision    recall  f1-score   support

           0       0.81      0.86      0.83       238
           1       0.87      0.90      0.89       262
           2       0.94      0.60      0.73       256
           3       0.51      0.77      0.61       255
           4       0.67      0.87      0.75       268
           5       0.76      0.71      0.73       246
           6       0.89      0.73      0.80       225
           7       0.88      0.78      0.83       259
           8       0.95      0.82      0.88       236
           9       0.92      0.84      0.87       255

    accuracy                           0.79      2500
   macro avg       0.82      0.79      0.79      2500
weighted avg       0.82      0.79      0.79      2500

{'accuracy': 0.7872, 'report': '              precision    recall  f1-score   support\n\n           0       0.81      0.86      0.83       238\n           1       0.87      0.90      0.89

  datum = torch.load(os.path.join('dataset', 'part_two_dataset', 'eval_data', f'{dataset_idx}_eval_data.tar.pth'))
  eval_embeddings = torch.load(embed_path)


Accuracy on eval set 2: 66.04%
Classification Report:
              precision    recall  f1-score   support

           0       0.47      0.82      0.60       238
           1       0.94      0.43      0.59       262
           2       0.97      0.29      0.45       256
           3       0.47      0.55      0.51       255
           4       0.67      0.81      0.74       268
           5       0.65      0.73      0.69       246
           6       0.71      0.83      0.77       225
           7       0.78      0.74      0.76       259
           8       0.77      0.81      0.79       236
           9       0.68      0.63      0.65       255

    accuracy                           0.66      2500
   macro avg       0.71      0.66      0.65      2500
weighted avg       0.71      0.66      0.65      2500

{'accuracy': 0.6604, 'report': '              precision    recall  f1-score   support\n\n           0       0.47      0.82      0.60       238\n           1       0.94      0.43      0.59

  datum = torch.load(os.path.join('dataset', 'part_two_dataset', 'eval_data', f'{dataset_idx}_eval_data.tar.pth'))
  eval_embeddings = torch.load(embed_path)


Accuracy on eval set 3: 82.68%
Classification Report:
              precision    recall  f1-score   support

           0       0.81      0.85      0.83       238
           1       0.91      0.91      0.91       262
           2       0.97      0.58      0.73       256
           3       0.65      0.72      0.68       255
           4       0.69      0.91      0.79       268
           5       0.81      0.75      0.78       246
           6       0.89      0.86      0.88       225
           7       0.89      0.88      0.88       259
           8       0.92      0.89      0.90       236
           9       0.86      0.92      0.89       255

    accuracy                           0.83      2500
   macro avg       0.84      0.83      0.83      2500
weighted avg       0.84      0.83      0.83      2500

{'accuracy': 0.8268, 'report': '              precision    recall  f1-score   support\n\n           0       0.81      0.85      0.83       238\n           1       0.91      0.91      0.91

  datum = torch.load(os.path.join('dataset', 'part_two_dataset', 'eval_data', f'{dataset_idx}_eval_data.tar.pth'))
  eval_embeddings = torch.load(embed_path)


Accuracy on eval set 4: 86.40%
Classification Report:
              precision    recall  f1-score   support

           0       0.84      0.88      0.86       238
           1       0.95      0.95      0.95       262
           2       0.99      0.67      0.80       256
           3       0.74      0.81      0.77       255
           4       0.71      0.92      0.80       268
           5       0.84      0.80      0.82       246
           6       0.93      0.88      0.91       225
           7       0.89      0.87      0.88       259
           8       0.94      0.92      0.93       236
           9       0.92      0.94      0.93       255

    accuracy                           0.86      2500
   macro avg       0.88      0.86      0.87      2500
weighted avg       0.87      0.86      0.86      2500

{'accuracy': 0.864, 'report': '              precision    recall  f1-score   support\n\n           0       0.84      0.88      0.86       238\n           1       0.95      0.95      0.95 

  datum = torch.load(os.path.join('dataset', 'part_two_dataset', 'eval_data', f'{dataset_idx}_eval_data.tar.pth'))
  eval_embeddings = torch.load(embed_path)


Accuracy on eval set 5: 87.20%
Classification Report:
              precision    recall  f1-score   support

           0       0.88      0.88      0.88       238
           1       0.95      0.96      0.96       262
           2       0.98      0.69      0.81       256
           3       0.75      0.79      0.77       255
           4       0.75      0.91      0.82       268
           5       0.84      0.80      0.82       246
           6       0.96      0.90      0.93       225
           7       0.85      0.91      0.88       259
           8       0.95      0.93      0.94       236
           9       0.90      0.96      0.93       255

    accuracy                           0.87      2500
   macro avg       0.88      0.87      0.87      2500
weighted avg       0.88      0.87      0.87      2500

{'accuracy': 0.872, 'report': '              precision    recall  f1-score   support\n\n           0       0.88      0.88      0.88       238\n           1       0.95      0.96      0.96 

  datum = torch.load(os.path.join('dataset', 'part_two_dataset', 'eval_data', f'{dataset_idx}_eval_data.tar.pth'))
  eval_embeddings = torch.load(embed_path)


Accuracy on eval set 6: 79.80%
Classification Report:
              precision    recall  f1-score   support

           0       0.71      0.88      0.79       238
           1       0.86      0.94      0.90       262
           2       0.99      0.52      0.69       256
           3       0.60      0.75      0.67       255
           4       0.69      0.84      0.76       268
           5       0.85      0.71      0.77       246
           6       0.78      0.87      0.82       225
           7       0.86      0.84      0.85       259
           8       0.97      0.75      0.85       236
           9       0.89      0.88      0.89       255

    accuracy                           0.80      2500
   macro avg       0.82      0.80      0.80      2500
weighted avg       0.82      0.80      0.80      2500

{'accuracy': 0.798, 'report': '              precision    recall  f1-score   support\n\n           0       0.71      0.88      0.79       238\n           1       0.86      0.94      0.90 

  datum = torch.load(os.path.join('dataset', 'part_two_dataset', 'eval_data', f'{dataset_idx}_eval_data.tar.pth'))
  eval_embeddings = torch.load(embed_path)


Accuracy on eval set 7: 80.72%
Classification Report:
              precision    recall  f1-score   support

           0       0.79      0.84      0.82       238
           1       0.89      0.94      0.92       262
           2       0.92      0.65      0.76       256
           3       0.54      0.83      0.65       255
           4       0.68      0.88      0.77       268
           5       0.84      0.69      0.76       246
           6       0.92      0.80      0.86       225
           7       0.94      0.76      0.84       259
           8       0.95      0.81      0.88       236
           9       0.90      0.86      0.88       255

    accuracy                           0.81      2500
   macro avg       0.84      0.81      0.81      2500
weighted avg       0.84      0.81      0.81      2500

{'accuracy': 0.8072, 'report': '              precision    recall  f1-score   support\n\n           0       0.79      0.84      0.82       238\n           1       0.89      0.94      0.92

  datum = torch.load(os.path.join('dataset', 'part_two_dataset', 'eval_data', f'{dataset_idx}_eval_data.tar.pth'))
  eval_embeddings = torch.load(embed_path)


Accuracy on eval set 8: 80.92%
Classification Report:
              precision    recall  f1-score   support

           0       0.84      0.80      0.82       238
           1       0.96      0.87      0.91       262
           2       0.97      0.53      0.68       256
           3       0.70      0.66      0.68       255
           4       0.60      0.92      0.73       268
           5       0.78      0.76      0.77       246
           6       0.82      0.90      0.86       225
           7       0.88      0.83      0.86       259
           8       0.90      0.89      0.90       236
           9       0.86      0.93      0.89       255

    accuracy                           0.81      2500
   macro avg       0.83      0.81      0.81      2500
weighted avg       0.83      0.81      0.81      2500

{'accuracy': 0.8092, 'report': '              precision    recall  f1-score   support\n\n           0       0.84      0.80      0.82       238\n           1       0.96      0.87      0.91

  datum = torch.load(os.path.join('dataset', 'part_two_dataset', 'eval_data', f'{dataset_idx}_eval_data.tar.pth'))
  eval_embeddings = torch.load(embed_path)


Accuracy on eval set 9: 67.84%
Classification Report:
              precision    recall  f1-score   support

           0       0.65      0.71      0.68       238
           1       0.87      0.76      0.81       262
           2       1.00      0.30      0.46       256
           3       0.47      0.66      0.55       255
           4       0.75      0.59      0.66       268
           5       0.75      0.62      0.68       246
           6       0.89      0.45      0.60       225
           7       0.68      0.81      0.74       259
           8       0.68      0.92      0.78       236
           9       0.59      0.95      0.73       255

    accuracy                           0.68      2500
   macro avg       0.73      0.68      0.67      2500
weighted avg       0.73      0.68      0.67      2500

{'accuracy': 0.6784, 'report': '              precision    recall  f1-score   support\n\n           0       0.65      0.71      0.68       238\n           1       0.87      0.76      0.81

  datum = torch.load(os.path.join('dataset', 'part_two_dataset', 'eval_data', f'{dataset_idx}_eval_data.tar.pth'))


In [23]:

# # Train LWP model on unlabeled 

# for i in range(2,11):
#     train_path=os.path.join(train_dir, f'{i}_train_data.tar.pth')
#     print(f"Processing dataset {i} from {train_path}")

#     dataset=torch.load(train_path, weights_only=False)
#     data= dataset['data']
#     data=data.reshape(-1,3072)
#     print(f"Reshaped dataset shape: {data.shape}")  # Should be (N, 3072)

#     data = normalize(data.reshape(data.shape[0],-1))  # Normalize data (important for distance calculations)
#     # Prepare DataLoader for the dataset
#     tensor_dataset = torch.tensor(data, dtype=torch.float32)
#     dataloader = DataLoader(data, batch_size=32, shuffle=False)
#     confidence_scores = []
#     embeddings = []
#     predictions = []
#     for batch in dataloader:
#         print(f"Batch shape: {batch.shape}")  # Should be (batch_size, 3072)

#     with torch.no_grad():
#         for batch in dataloader:

#             inputs = batch.numpy()
#             print(inputs.shape)
#             batch_predictions = []
#             batch_distances = []
#             print("Prototypes:", lwp_model.prototypes)

#             # Predict pseudo-labels and calculate distances to prototypes
#             for sample in inputs:
#                 print(f"Sample shape before flattening: {sample.shape}")
#                 sample = sample.flatten()  # Ensure the sample is a 1D vector
#                 print(f"Sample shape after flattening: {sample.shape}")
#                 dist_to_prototypes = {label: lwp_model.distance_fn(sample, proto) for label, proto in lwp_model.prototypes.items()}
#                 closest_label = min(dist_to_prototypes, key=dist_to_prototypes.get)
#                 closest_distance = dist_to_prototypes[closest_label]
#                 batch_predictions.append(closest_label)
#                 batch_distances.append(closest_distance)

#             predictions.extend(batch_predictions)
#             confidence = 1 / (1 + np.array(batch_distances))  # Convert distances to confidence scores
#             confidence_scores.extend(confidence.tolist())
#             embeddings.extend(inputs)

#     embeddings = np.array(embeddings)

#     # Step 4: Select top 50% most confident samples
#     sorted_indices = np.argsort(confidence_scores)[::-1]  # Sort by confidence scores (descending)
#     top_50_percent_indices = sorted_indices[:len(sorted_indices) // 2]

#     top_50_embeddings = embeddings[top_50_percent_indices]
#     top_50_predictions = np.array(predictions)[top_50_percent_indices]

#     # Step 5: Construct class centroids
#     class_centroids = {}
#     for label in np.unique(top_50_predictions):
#         class_embeddings = top_50_embeddings[top_50_predictions == label]
#         if class_embeddings.size > 0:
#             centroid = class_embeddings.mean(axis=0)
#             class_centroids[label] = centroid

#     # Print centroids
#     print(f"Class centroids calculated for dataset {i}:")
#     for label, centroid in class_centroids.items():
#         print(f"Class {label}: Centroid = {centroid[:5]}...")
   
