In [41]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import Omniglot
from torchvision.models import resnet18
from tqdm import tqdm
from torchvision import datasets 

from easyfsl.samplers import TaskSampler
from easyfsl.utils import plot_images, sliding_average

In [42]:
class PrototypicalNetworks(nn.Module):
    def __init__(self, backbone: nn.Module):
        super(PrototypicalNetworks, self).__init__()
        self.backbone = backbone

    def forward(
        self,
        support_images: torch.Tensor,
        support_labels: torch.Tensor,
        query_images: torch.Tensor,
    ) -> torch.Tensor:
        """
        Predict query labels using labeled support images.
        """
        # Extract the features of support and query images
        z_support = self.backbone.forward(support_images)
        z_query = self.backbone.forward(query_images)

        # Infer the number of different classes from the labels of the support set
        n_way = len(torch.unique(support_labels))
        # Prototype i is the mean of all instances of features corresponding to labels == i
        z_proto = torch.cat(
            [
                z_support[torch.nonzero(support_labels == label)].mean(0)
                for label in range(n_way)
            ]
        )

        # Compute the euclidean distance from queries to prototypes
        dists = torch.cdist(z_query, z_proto)

        # And here is the super complicated operation to transform those distances into classification scores!
        scores = -dists
        return scores


convolutional_network = resnet18(pretrained=True)
convolutional_network.fc = nn.Flatten()
print(convolutional_network)

model = PrototypicalNetworks(convolutional_network).cuda()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [43]:
model.load_state_dict(torch.load("prototypical_networks_2.pth"))

  model.load_state_dict(torch.load("prototypical_networks_2.pth"))


<All keys matched successfully>

In [44]:
image_size = 28
transform=transforms.Compose(
    [
        transforms.Grayscale(num_output_channels=3),
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ]
)

In [45]:
def get_labels(self):
    return [label for _, label in self.samples]

In [46]:
# Parameters
N_WAY = 327
N_SHOT = 10
N_QUERY = 0
N_EVALUATION_TASKS = 1

# Properly assign get_labels as a method (not the result)
def get_labels(self):
    return [label for _, label in self.samples]

# Create val_set and val_loader
val_set = datasets.ImageFolder("val", transform=transform)

# Assign get_labels method to val_set
val_set.get_labels = get_labels.__get__(val_set)

# Create sampler and loader
test_sampler = TaskSampler(
    val_set,
    n_way=N_WAY,
    n_shot=N_SHOT,
    n_query=N_QUERY,
    n_tasks=N_EVALUATION_TASKS
)

val_loader = DataLoader(
    val_set,
    batch_sampler=test_sampler,
    collate_fn=test_sampler.episodic_collate_fn,
)

In [47]:
import torch

def compute_prototypes(backbone, support_images, support_labels, n_way):
    backbone.eval()
    with torch.no_grad():
        z_support = backbone(support_images)  # [num_support, embedding_dim]

    prototypes = []
    for c in range(n_way):
        class_embeddings = z_support[support_labels == c]
        proto = class_embeddings.mean(0)
        prototypes.append(proto)

    prototypes = torch.stack(prototypes)  # [n_way, embedding_dim]
    return prototypes

def save_prototypes(prototypes, filepath):
    torch.save(prototypes.cpu(), filepath)
    print(f"Prototypes saved to {filepath}")

In [48]:
# Get one episodic batch from val_loader
batch = next(iter(val_loader))

# Extract support set images and labels from the tuple
support_images = batch[0].to(model.backbone.conv1.weight.device)  # shape: [N_WAY * N_SHOT, C, H, W]
support_labels = batch[1].to(model.backbone.conv1.weight.device)  # shape: [N_WAY * N_SHOT]

# Convert grayscale (1 channel) images to 3 channels for ResNet
if support_images.shape[1] == 1:
	support_images = support_images.repeat(1, 3, 1, 1)

print("Support images shape:", support_images.shape)
print("Support labels shape:", support_labels.shape)
print("Unique classes in support labels:", torch.unique(support_labels))

# Compute prototypes by averaging embeddings per class
prototypes = compute_prototypes(model.backbone, support_images, support_labels, n_way=N_WAY)

print("Prototypes shape:", prototypes.shape)

Support images shape: torch.Size([3270, 3, 28, 28])
Support labels shape: torch.Size([3270])
Unique classes in support labels: tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
        140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,
        154, 15

In [49]:
save_prototypes(prototypes, "prototypes_resnet.pth")

Prototypes saved to prototypes_resnet.pth


In [50]:
import json
import csv

# The fifth element of the batch tuple contains the list of label indices in order
index_to_label = {idx: label for idx, label in enumerate(batch[4])}
# Print mapping from index to label name
for idx, label_idx in index_to_label.items():
    label_name = val_set.classes[label_idx]
    print(f"Index {idx}: Label {label_idx} -> {label_name}")
    # Save mapping as CSV instead of printing
    with open("index_to_label.csv", "w", encoding="utf-8", newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(["Index", "Label Index", "Label Name"])
        for idx, label_idx in index_to_label.items():
            label_name = val_set.classes[label_idx]
            writer.writerow([idx, label_idx, label_name])
with open("index_to_label.txt", "w", encoding="utf-8") as f:
    # Create a dictionary mapping index to label name
    index_to_label_name = {idx: val_set.classes[label_idx] for idx, label_idx in index_to_label.items()}
    # Save as JSON
    json.dump(index_to_label_name, f, ensure_ascii=False, indent=2)

Index 0: Label 0 -> bhujimol-a
Index 1: Label 1 -> bhujimol-aa
Index 2: Label 2 -> bhujimol-ah
Index 3: Label 3 -> bhujimol-ai
Index 4: Label 4 -> bhujimol-am
Index 5: Label 5 -> bhujimol-au
Index 6: Label 6 -> bhujimol-ba
Index 7: Label 7 -> bhujimol-bha
Index 8: Label 8 -> bhujimol-ca
Index 9: Label 9 -> bhujimol-cha
Index 10: Label 10 -> bhujimol-da
Index 11: Label 11 -> bhujimol-daa
Index 12: Label 12 -> bhujimol-dha
Index 13: Label 13 -> bhujimol-dhaa
Index 14: Label 14 -> bhujimol-e
Index 15: Label 15 -> bhujimol-ga
Index 16: Label 16 -> bhujimol-gha
Index 17: Label 17 -> bhujimol-gja
Index 18: Label 18 -> bhujimol-ha
Index 19: Label 19 -> bhujimol-i
Index 20: Label 20 -> bhujimol-ii
Index 21: Label 21 -> bhujimol-ja
Index 22: Label 22 -> bhujimol-jha
Index 23: Label 23 -> bhujimol-ka
Index 24: Label 24 -> bhujimol-kha
Index 25: Label 25 -> bhujimol-ksa
Index 26: Label 26 -> bhujimol-la
Index 27: Label 27 -> bhujimol-lr
Index 28: Label 28 -> bhujimol-lrr
Index 29: Label 29 -> bhu