<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Few_Shot_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class ProtoNet(nn.Module):
    def __init__(self, input_dim, embedding_dim):
        super(ProtoNet, self).__init__()
        self.embedding = nn.Sequential(
            nn.Conv2d(input_dim, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(128 * 7 * 7, embedding_dim)
        )

    def forward(self, x):
        return self.embedding(x)

def euclidean_distance(a, b):
    return torch.cdist(a, b)

def compute_prototypes(embeddings, labels, n_classes):
    prototypes = []
    for i in range(n_classes):
        class_embeddings = embeddings[labels == i]
        prototype = class_embeddings.mean(dim=0)
        prototypes.append(prototype)
    return torch.stack(prototypes)

# Example usage
input_dim = 1  # e.g., for grayscale images
embedding_dim = 64
n_classes = 5
model = ProtoNet(input_dim, embedding_dim)

optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

for epoch in range(100):
    optimizer.zero_grad()
    support_set = torch.randn(25, input_dim, 28, 28)  # 5 classes, 5 samples each
    support_labels = torch.tensor([i for i in range(n_classes) for _ in range(5)])
    query_set = torch.randn(10, input_dim, 28, 28)  # 2 samples per class for query
    query_labels = torch.tensor([i for i in range(n_classes) for _ in range(2)])

    support_embeddings = model(support_set)
    query_embeddings = model(query_set)

    prototypes = compute_prototypes(support_embeddings, support_labels, n_classes)
    distances = euclidean_distance(query_embeddings, prototypes)
    loss = criterion(-distances, query_labels)
    loss.backward()
    optimizer.step()
    print(f'Epoch {epoch + 1}, Loss: {loss.item():.4f}')