In [75]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import Omniglot
from torchvision.models import resnet18
from tqdm import tqdm
from torchvision import datasets 


In [76]:
class PrototypicalNetworks(nn.Module):
    def __init__(self, backbone: nn.Module):
        super(PrototypicalNetworks, self).__init__()
        self.backbone = backbone

    def forward(
        self,
        support_images: torch.Tensor,
        support_labels: torch.Tensor,
        query_images: torch.Tensor,
    ) -> torch.Tensor:
        """
        Predict query labels using labeled support images.
        """
        # Extract the features of support and query images
        z_support = self.backbone.forward(support_images)
        z_query = self.backbone.forward(query_images)

        # Infer the number of different classes from the labels of the support set
        n_way = len(torch.unique(support_labels))
        # Prototype i is the mean of all instances of features corresponding to labels == i
        z_proto = torch.cat(
            [
                z_support[torch.nonzero(support_labels == label)].mean(0)
                for label in range(n_way)
            ]
        )

        # Compute the euclidean distance from queries to prototypes
        dists = torch.cdist(z_query, z_proto)

        # And here is the super complicated operation to transform those distances into classification scores!
        scores = -dists
        return scores


convolutional_network = resnet18(pretrained=True)
convolutional_network.fc = nn.Flatten()
print(convolutional_network)

model = PrototypicalNetworks(convolutional_network).cuda()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  



In [77]:
model.load_state_dict(torch.load("prototypical_networks_2.pth"))

  model.load_state_dict(torch.load("prototypical_networks_2.pth"))


<All keys matched successfully>

In [78]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_prototypes(filepath, device):
    prototypes = torch.load(filepath, map_location=device)
    print(f"Loaded prototypes from {filepath}")
    return prototypes.to(device)

prototypes = load_prototypes("prototypes_resnet.pth", device)

Loaded prototypes from prototypes_resnet.pth


  prototypes = torch.load(filepath, map_location=device)


In [97]:
from PIL import Image
from torchvision import transforms

# Define the same preprocessing used during training/support
preprocess = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),  # convert to 3 channels
    transforms.Resize((28, 28)),  # or your backbone's expected size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

def prepare_query_image(image_path):
    img = Image.open(image_path).convert("RGB")  # ensure 3 channels
    img_tensor = preprocess(img).unsqueeze(0)    # add batch dimension
    return img_tensor

query_img_tensor = prepare_query_image("val/lichhavi_a/0038.png").to(device)


In [98]:
def infer_single_query(backbone, prototypes, query_image_tensor):
    backbone.eval()
    with torch.no_grad():
        z_query = backbone(query_image_tensor)            # [1, embedding_dim]
        dists = torch.cdist(z_query, prototypes)          # [1, n_way]
        scores = -dists                                    # negative distance as similarity
        pred = torch.argmax(scores, dim=1).item()         # predicted class index
    return pred

predicted_class = infer_single_query(model.backbone, prototypes, query_img_tensor)

print("Predicted class index:", predicted_class)

Predicted class index: 134


In [99]:
import pandas as pd
df= pd.read_csv("index_to_label.csv")
def get_label_from_index(df, index):
    row = df[df['Index'] == index]
    if not row.empty:
        return row['Label Name'].values[0]
    else:
        return "Unknown label"
predicted_label = get_label_from_index(df, predicted_class)
print("Predicted label:", predicted_label)

Predicted label: lichhavi_a
