<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Few_Shot_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Define Prototypical Network class
class PrototypicalNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(PrototypicalNetwork, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

    def forward(self, x):
        return self.encoder(x)

# Euclidean distance function
def euclidean_dist(a, b):
    n = a.size(0)
    m = b.size(0)
    a = a.unsqueeze(1).expand(n, m, -1)
    b = b.unsqueeze(0).expand(n, m, -1)
    return torch.pow(a - b, 2).sum(2)

# Example data (using MNIST-like data with 784 features and 10 classes)
support_data = torch.randn(50, 784)  # 50 samples in support set
support_labels = torch.randint(0, 10, (50,))
query_data = torch.randn(30, 784)  # 30 samples in query set
query_labels = torch.randint(0, 10, (30,))

# Create DataLoader for training
support_loader = DataLoader(TensorDataset(support_data, support_labels), batch_size=50)
query_loader = DataLoader(TensorDataset(query_data, query_labels), batch_size=30)

model = PrototypicalNetwork(input_dim=784, hidden_dim=64)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
epochs = 10
for epoch in range(epochs):
    model.train()
    for (support_data, support_labels), (query_data, query_labels) in zip(support_loader, query_loader):
        # Compute support embeddings
        support_embeddings = model(support_data)
        query_embeddings = model(query_data)

        # Compute prototypes for each class
        prototypes = torch.stack([support_embeddings[support_labels == i].mean(0) for i in range(10)])

        # Compute distances between query embeddings and prototypes
        dists = euclidean_dist(query_embeddings, prototypes)

        # Compute the loss
        target = query_labels
        loss = nn.CrossEntropyLoss()(dists, target)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}")

print("Training completed.")