In [1]:
!pip install torch torchvision numpy matplotlib




In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import Omniglot
from torch.utils.data import DataLoader, Dataset
import numpy as np
import random


In [16]:
class FewShotDataset(Dataset):
    def __init__(self, dataset, n_way, k_shot, q_query, transform=None):
        self.dataset = dataset
        self.n_way = n_way
        self.k_shot = k_shot
        self.q_query = q_query
        self.transform = transform

        self.data_by_class = self._organize_by_class()

    def _organize_by_class(self):
        data_by_class = {}
        for i, (img, label) in enumerate(self.dataset):
            if label not in data_by_class:
                data_by_class[label] = []
            data_by_class[label].append(img)
        return data_by_class

    def __getitem__(self, idx):

        sampled_classes = random.sample(list(self.data_by_class.keys()), min(self.n_way, len(self.data_by_class)))

        support_set, query_set, labels = [], [], []

        for class_index, class_id in enumerate(sampled_classes):
            class_images = self.data_by_class[class_id]
            total_images = len(class_images)
            num_samples = min(total_images, self.k_shot + self.q_query)

            sampled_images = random.sample(class_images, num_samples)
            support_images = sampled_images[:self.k_shot]
            query_images = sampled_images[self.k_shot:]

            if self.transform:
                support_images = [self.transform(img) for img in support_images]
                query_images = [self.transform(img) for img in query_images]

            support_set.extend(support_images)
            query_set.extend(query_images)
            labels.extend([class_index] * len(query_images))

        support_set = torch.stack(support_set)
        query_set = torch.stack(query_set)
        labels = torch.tensor(labels, dtype=torch.long)

        return support_set, query_set, labels

    def __len__(self):
        return len(self.data_by_class)


In [17]:
class PrototypicalNetwork(nn.Module):
    def __init__(self, input_size, embedding_size):
        super(PrototypicalNetwork, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(input_size, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.fc = nn.Linear(128 * 7 * 7, embedding_size)

    def forward(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

    def compute_prototypes(self, support_set, support_labels):
        unique_labels = torch.unique(support_labels)
        prototypes = []
        for label in unique_labels:
            class_samples = support_set[support_labels == label]
            prototypes.append(class_samples.mean(dim=0))
        return torch.stack(prototypes)

    def classify(self, query_set, prototypes):
        distances = torch.cdist(query_set, prototypes)
        return distances.argmin(dim=1)


In [18]:
def train_prototypical_network(model, data_loader, optimizer, criterion, n_way, k_shot):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.train()
    total_loss, total_accuracy = 0.0, 0.0

    for batch_idx, (support_set, query_set, labels) in enumerate(data_loader):

        support_set = support_set.view(-1, 1, 28, 28).to(device)
        query_set = query_set.view(-1, 1, 28, 28).to(device)
        labels = labels.to(device)

        embeddings = model(torch.cat([support_set, query_set], dim=0))
        support_embeddings, query_embeddings = embeddings[:support_set.size(0)], embeddings[support_set.size(0):]

        prototypes = support_embeddings.view(n_way, k_shot, -1).mean(dim=1)

        distances = torch.cdist(query_embeddings, prototypes)
        logits = -distances

        labels = labels.view(-1)

        loss = criterion(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        _, predicted = logits.max(1)
        accuracy = (predicted == labels).float().mean().item()

        total_loss += loss.item()
        total_accuracy += accuracy

    return total_loss / len(data_loader), total_accuracy / len(data_loader)


In [19]:
if __name__ == "__main__":


  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  print(f"Using device: {device}")


  transform = transforms.Compose([
  transforms.Grayscale(),
  transforms.Resize((28, 28)),
  transforms.ToTensor()

])

  omniglot_train = Omniglot(root="./data", background=True, download=True, transform=None)
  omniglot_test = Omniglot(root="./data", background=False, download=True, transform=None)

  train_dataset = FewShotDataset(omniglot_train, n_way=5, k_shot=5, q_query=5, transform=transform)
  train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)


  model = PrototypicalNetwork(input_size=1, embedding_size=64).to(device)
  optimizer = optim.Adam(model.parameters(), lr=0.001)
  criterion = nn.CrossEntropyLoss()
  epochs = 10
  for epoch in range(epochs):
      loss, accuracy = train_prototypical_network(model, train_loader, optimizer, criterion, n_way=5, k_shot=5)
      print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss:.4f}, Accuracy: {accuracy:.4f}")




Using device: cuda
Files already downloaded and verified
Files already downloaded and verified
Epoch 1/10, Loss: 0.1819, Accuracy: 0.9375
Epoch 2/10, Loss: 0.0802, Accuracy: 0.9732
Epoch 3/10, Loss: 0.0537, Accuracy: 0.9819
Epoch 4/10, Loss: 0.0440, Accuracy: 0.9852
Epoch 5/10, Loss: 0.0371, Accuracy: 0.9870
Epoch 6/10, Loss: 0.0355, Accuracy: 0.9873
Epoch 7/10, Loss: 0.0315, Accuracy: 0.9884
Epoch 8/10, Loss: 0.0289, Accuracy: 0.9900
Epoch 9/10, Loss: 0.0265, Accuracy: 0.9913
Epoch 10/10, Loss: 0.0266, Accuracy: 0.9912
