<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Few_Shot_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class PrototypicalNetwork(nn.Module):
    def __init__(self, embedding_dim):
        super(PrototypicalNetwork, self).__init__()
        self.fc = nn.Linear(embedding_dim, embedding_dim)

    def forward(self, support_set, query_set):
        support_embeddings = self.fc(support_set)
        query_embeddings = self.fc(query_set)
        support_mean = support_embeddings.mean(dim=1)
        distances = torch.cdist(query_embeddings, support_mean)
        return -distances

def train_prototypical_network(model, support_set, query_set, labels, optimizer):
    model.train()
    optimizer.zero_grad()
    output = model(support_set, query_set)
    loss = F.cross_entropy(output, labels)
    loss.backward()
    optimizer.step()
    return loss.item()

# Example usage
model = PrototypicalNetwork(embedding_dim=64)
optimizer = optim.Adam(model.parameters(), lr=0.001)

support_set = torch.randn(5, 5, 64)  # 5 classes, 5 examples per class, 64-dimensional embeddings
query_set = torch.randn(25, 64)      # 25 query examples
labels = torch.randint(0, 5, (25,))  # 25 labels corresponding to 5 classes

loss = train_prototypical_network(model, support_set, query_set, labels, optimizer)
print(f'Loss: {loss:.4f}')