<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Prototypical_Networks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import Omniglot
from torchvision import transforms
import numpy as np

# Define the Prototypical Network
class ProtoNet(nn.Module):
    def __init__(self, input_dim):
        super(ProtoNet, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(input_dim, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

    def forward(self, x):
        return self.encoder(x).view(x.size(0), -1)

def euclidean_dist(x, y):
    return ((x.unsqueeze(1) - y.unsqueeze(0))**2).sum(2)

def proto_loss(prototypes, embeddings, labels):
    dists = euclidean_dist(embeddings, prototypes)
    return nn.CrossEntropyLoss()(dists, labels)

class FewShotDataset(Dataset):
    def __init__(self, dataset, n_way, k_shot):
        self.dataset = dataset
        self.n_way = n_way
        self.k_shot = k_shot
        self.indices_by_class = self._get_indices_by_class()

    def _get_indices_by_class(self):
        indices_by_class = {}
        for idx in range(len(self.dataset)):
            _, label = self.dataset[idx]
            if label not in indices_by_class:
                indices_by_class[label] = []
            indices_by_class[label].append(idx)
        return indices_by_class

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        class_indices = list(self.indices_by_class.keys())
        selected_classes = np.random.choice(class_indices, self.n_way, replace=False)
        support_set = []
        query_set = []
        support_labels = []
        query_labels = []

        for i, cls in enumerate(selected_classes):
            indices = np.random.choice(self.indices_by_class[cls], 2 * self.k_shot, replace=False)
            support_indices, query_indices = indices[:self.k_shot], indices[self.k_shot:]
            support_set.extend(support_indices)
            query_set.extend(query_indices)
            support_labels.extend([i] * self.k_shot)
            query_labels.extend([i] * self.k_shot)

        support_images = torch.stack([self.dataset[idx][0] for idx in support_set])
        query_images = torch.stack([self.dataset[idx][0] for idx in query_set])
        support_labels = torch.tensor(support_labels)
        query_labels = torch.tensor(query_labels)

        return support_images, support_labels, query_images, query_labels

# Prepare the Omniglot dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = Omniglot(root='data', background=True, transform=transform, download=True)
few_shot_dataset = FewShotDataset(train_dataset, n_way=5, k_shot=5)
train_loader = DataLoader(few_shot_dataset, batch_size=1, shuffle=True)

model = ProtoNet(input_dim=1)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
for epoch in range(10):
    for batch in train_loader:
        support_images, support_labels, query_images, query_labels = batch

        # Flatten the batch dimension
        support_images = support_images.squeeze(0)
        query_images = query_images.squeeze(0)

        # Concatenate support and query images
        combined_images = torch.cat([support_images, query_images])

        embeddings = model(combined_images)

        # Calculate prototypes
        support_embeddings = embeddings[:support_images.size(0)]
        prototypes = torch.stack([support_embeddings[support_labels.squeeze(0) == i].mean(0) for i in range(support_labels.max().item() + 1)])

        # Calculate query embeddings
        query_embeddings = embeddings[support_images.size(0):]

        # Calculate loss
        loss = proto_loss(prototypes, query_embeddings, query_labels.squeeze(0))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f'Epoch {epoch+1}, Loss: {loss.item():.4f}')