In [50]:
import torch 
from torch import nn 
import matplotlib.pyplot as plt
import numpy as np
from torch.nn import functional as F
from tqdm import tqdm
from torchvision.datasets import Omniglot
from torchvision import transforms
from torch.utils.data import DataLoader
from collections import defaultdict
import random
from torch.utils.data import Subset
from torchvision.models import resnet18

In [51]:
class Encoder(nn.Module):
    def __init__(self, pretrained=True, out_dim=128):
        super(Encoder, self).__init__()
        base_model = resnet18(pretrained=pretrained)

        # Modify first conv layer to accept 1-channel input instead of 3
        original_conv1 = base_model.conv1
        base_model.conv1 = nn.Conv2d(
            in_channels=1,
            out_channels=original_conv1.out_channels,
            kernel_size=original_conv1.kernel_size,
            stride=original_conv1.stride,
            padding=original_conv1.padding,
            bias=original_conv1.bias is not None
        )

        # Optional: initialize weights of new conv1
        if pretrained:
            with torch.no_grad():
                # Copy weights from RGB to 1-channel (average over input channels)
                base_model.conv1.weight = nn.Parameter(
                    base_model.conv1.weight.mean(dim=1, keepdim=True)
                )

        self.features = nn.Sequential(*list(base_model.children())[:-1])  # up to avgpool
        self.flatten = nn.Flatten()
        self.projection = nn.Linear(512, out_dim)

    def forward(self, x):
        x = self.features(x)      # (B, 512, 1, 1)
        x = self.flatten(x)       # (B, 512)
        x = self.projection(x)    # (B, 128)
        return x


In [52]:
encoder = Encoder()


The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.


Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.



In [53]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device  

device(type='cuda')

In [54]:
encoder.load_state_dict(torch.load("best_model.pth", map_location=device))


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.



<All keys matched successfully>

In [55]:
encoder.to(device)
encoder.eval()

Encoder(
  (features): Sequential(
    (0): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True

In [56]:
from torchvision import datasets
eval_transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((105, 105)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
dataset_path = "./dataset"
eval_dataset = datasets.ImageFolder(root=dataset_path, transform=eval_transform)
#eval_dataset = Omniglot(root='./data', background=True, download=True, transform=eval_transform)
eval_loader = DataLoader(eval_dataset, batch_size=64, shuffle=False)

In [57]:


# Build a map from class label → list of sample indices
class_to_indices = defaultdict(list)
for idx, (_, label) in enumerate(eval_dataset):
    class_to_indices[label].append(idx)

# Randomly choose 5 classes
selected_classes = random.sample(list(class_to_indices.keys()), 5)

# For each class, pick up to 20 samples
selected_indices = []
for cls in selected_classes:
    selected_indices.extend(class_to_indices[cls][:20])  # pick 20 samples per class

# Create subset of dataset
subset = Subset(eval_dataset, selected_indices)
subset_loader = DataLoader(subset, batch_size=64, shuffle=False)


In [58]:
all_embeddings = []
all_labels = []

with torch.no_grad():
    for images, labels in subset_loader:
        images = images.to(device)
        embeddings = encoder(images)
        all_embeddings.append(embeddings.cpu().numpy())
        all_labels.append(labels.numpy())

X = np.concatenate(all_embeddings, axis=0)
y = np.concatenate(all_labels, axis=0)

In [59]:
from sklearn.manifold import TSNE

tsne = TSNE(n_components=3, perplexity=30, init='pca', random_state=42)
X_3d = tsne.fit_transform(X)  # shape: (num_samples, 3)


In [60]:
import plotly.express as px
import pandas as pd

# Create a dataframe for easier plotting
df = pd.DataFrame({
    'x': X_3d[:, 0],
    'y': X_3d[:, 1],
    'z': X_3d[:, 2],
    'class': y
})

# Filter only selected class labels (5 classes)
df = df[df['class'].isin(selected_classes)]

# Map original class labels to consistent string labels
label_map = {label: f"Class {label}" for label in selected_classes}
df['class'] = df['class'].map(label_map)

# Create interactive 3D scatter plot
fig = px.scatter_3d(df, x='x', y='y', z='z',
                    color='class',
                    title="3D t-SNE of 5 Omniglot Classes (BYOL Embeddings)",
                    labels={"class": "Omniglot Class"},
                    opacity=0.8)

fig.update_traces(marker=dict(size=5))
fig.show()


In [61]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=42, stratify=y 
)

In [62]:
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(X_train, y_train)
y_pred = knn.predict(X_test)
print(f"KNN Accuracy: {accuracy_score(y_test, y_pred) * 100:.2f}%")

KNN Accuracy: 66.67%


In [63]:
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
svc= SVC(kernel='rbf', random_state=42)
svc.fit(X_train, y_train)
y_pred = svc.predict(X_test)
print(f"SVM Accuracy: {accuracy_score(y_test, y_pred) * 100:.2f}%")


SVM Accuracy: 60.00%


In [64]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score

rf = RandomForestClassifier(n_estimators=1000, random_state=42)
rf.fit(X_train, y_train)
y_pred = rf.predict(X_test)
print(f"Random Forest Accuracy: {accuracy_score(y_test, y_pred) * 100:.2f}%") 

Random Forest Accuracy: 90.00%


In [65]:
from sklearn.manifold import TSNE

# Compute 2D t-SNE
tsne_2d = TSNE(n_components=2, perplexity=30, init='pca', random_state=42)
X_2d = tsne_2d.fit_transform(X)

# Prepare DataFrame for plotting
df_2d = pd.DataFrame({
    'x': X_2d[:, 0],
    'y': X_2d[:, 1],
    'class': y
})
df_2d = df_2d[df_2d['class'].isin(selected_classes)]
df_2d['class'] = df_2d['class'].map(label_map)

import plotly.express as px

fig2d = px.scatter(
    df_2d, x='x', y='y', color='class',
    title="2D t-SNE of 5 Classes (BYOL Embeddings)",
    labels={"class": "Script Class"},
    opacity=0.8
)
fig2d.update_traces(marker=dict(size=7))
fig2d.show()