In [1]:
import torch 
from torch import nn 
import matplotlib.pyplot as plt
import numpy as np
from torch.nn import functional as F
from tqdm import tqdm
from torchvision.datasets import Omniglot
from torchvision import transforms
from torch.utils.data import DataLoader
from collections import defaultdict
import random
from torch.utils.data import Subset

In [2]:
class Encoder(nn.Module):
    def __init__(self, hidden_dim=64, out_dim=64):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, hidden_dim, kernel_size=3, padding=1),  # 28x28
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(),
            nn.MaxPool2d(2),  # -> 14x14

            nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(),
            nn.MaxPool2d(2),  # -> 7x7

            nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(),
            nn.MaxPool2d(2),  # -> 3x3

            nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))  # -> 1x1
        )
        self.fc = nn.Linear(hidden_dim, out_dim)

    def forward(self, x):
        x = self.encoder(x)          # Shape: (B, hidden_dim, 1, 1)
        x = x.view(x.size(0), -1)    # Flatten to (B, hidden_dim)
        return self.fc(x)           

In [3]:
encoder = Encoder()

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device  

device(type='cuda')

In [5]:
encoder.load_state_dict(torch.load("byol_best_model.pth", map_location=device))

  encoder.load_state_dict(torch.load("byol_best_model.pth", map_location=device))


<All keys matched successfully>

In [6]:
encoder.to(device)
encoder.eval()

Encoder(
  (encoder): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU()
    (11): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (12): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (14

In [7]:
eval_transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

eval_dataset = Omniglot(root='./data', background=True, download=True, transform=eval_transform)
eval_loader = DataLoader(eval_dataset, batch_size=64, shuffle=False)

Files already downloaded and verified


In [8]:


# Build a map from class label → list of sample indices
class_to_indices = defaultdict(list)
for idx, (_, label) in enumerate(eval_dataset):
    class_to_indices[label].append(idx)

# Randomly choose 5 classes
selected_classes = random.sample(list(class_to_indices.keys()), 5)

# For each class, pick up to 20 samples
selected_indices = []
for cls in selected_classes:
    selected_indices.extend(class_to_indices[cls][:20])  # pick 20 samples per class

# Create subset of dataset
subset = Subset(eval_dataset, selected_indices)
subset_loader = DataLoader(subset, batch_size=64, shuffle=False)


In [9]:
all_embeddings = []
all_labels = []

with torch.no_grad():
    for images, labels in subset_loader:
        images = images.to(device)
        embeddings = encoder(images)
        all_embeddings.append(embeddings.cpu().numpy())
        all_labels.append(labels.numpy())

X = np.concatenate(all_embeddings, axis=0)
y = np.concatenate(all_labels, axis=0)

In [10]:
from sklearn.manifold import TSNE

tsne = TSNE(n_components=3, perplexity=30, init='pca', random_state=42)
X_3d = tsne.fit_transform(X)  # shape: (num_samples, 3)


In [11]:
import plotly.express as px
import pandas as pd

# Create a dataframe for easier plotting
df = pd.DataFrame({
    'x': X_3d[:, 0],
    'y': X_3d[:, 1],
    'z': X_3d[:, 2],
    'class': y
})

# Filter only selected class labels (5 classes)
df = df[df['class'].isin(selected_classes)]

# Map original class labels to consistent string labels
label_map = {label: f"Class {label}" for label in selected_classes}
df['class'] = df['class'].map(label_map)

# Create interactive 3D scatter plot
fig = px.scatter_3d(df, x='x', y='y', z='z',
                    color='class',
                    title="3D t-SNE of 5 Omniglot Classes (BYOL Embeddings)",
                    labels={"class": "Omniglot Class"},
                    opacity=0.8)

fig.update_traces(marker=dict(size=5))
fig.show()


In [12]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=42, stratify=y 
)

In [13]:
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(X_train, y_train)
y_pred = knn.predict(X_test)
print(f"KNN Accuracy: {accuracy_score(y_test, y_pred) * 100:.2f}%")

KNN Accuracy: 80.00%


In [14]:
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
svc= SVC(kernel='rbf', random_state=42)
svc.fit(X_train, y_train)
y_pred = svc.predict(X_test)
print(f"SVM Accuracy: {accuracy_score(y_test, y_pred) * 100:.2f}%")


SVM Accuracy: 76.67%


In [15]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score

rf = RandomForestClassifier(n_estimators=1000, random_state=42)
rf.fit(X_train, y_train)
y_pred = rf.predict(X_test)
print(f"Random Forest Accuracy: {accuracy_score(y_test, y_pred) * 100:.2f}%") 

Random Forest Accuracy: 93.33%


In [16]:
from sklearn.manifold import TSNE

# Compute 2D t-SNE
tsne_2d = TSNE(n_components=2, perplexity=30, init='pca', random_state=42)
X_2d = tsne_2d.fit_transform(X)

# Prepare DataFrame for plotting
df_2d = pd.DataFrame({
    'x': X_2d[:, 0],
    'y': X_2d[:, 1],
    'class': y
})
df_2d = df_2d[df_2d['class'].isin(selected_classes)]
df_2d['class'] = df_2d['class'].map(label_map)

import plotly.express as px

fig2d = px.scatter(
    df_2d, x='x', y='y', color='class',
    title="2D t-SNE of 5 Omniglot Classes (BYOL Embeddings)",
    labels={"class": "Omniglot Class"},
    opacity=0.8
)
fig2d.update_traces(marker=dict(size=7))
fig2d.show()