<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Few_Shot_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn

class PrototypicalNetwork(nn.Module):
    def __init__(self, input_dim, n_classes, embedding_dim):
        super(PrototypicalNetwork, self).__init__()
        self.embedding = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, embedding_dim)
        )
        self.n_classes = n_classes

    def forward(self, x):
        return self.embedding(x)

    def predict(self, query, support, support_labels):
        """
        Args:
            query: (num_query_samples, input_dim)
            support: (num_support_samples, input_dim)
            support_labels: (num_support_samples,)
        Returns:
            predictions: (num_query_samples,)
        """
        # Embed query and support
        query_embedded = self.embedding(query)  # (num_query_samples, embedding_dim)
        support_embedded = self.embedding(support)  # (num_support_samples, embedding_dim)

        # Compute prototypes for each class
        prototypes = torch.stack([
            support_embedded[support_labels == c].mean(dim=0) for c in range(self.n_classes)
        ])  # (n_classes, embedding_dim)

        # Compute Euclidean distances
        dists = torch.cdist(query_embedded, prototypes)  # (num_query_samples, n_classes)
        return torch.argmin(dists, dim=1)  # Predicted class indices for queries

# Example usage
model = PrototypicalNetwork(input_dim=784, n_classes=5, embedding_dim=64)

# Example support set with labels
support_set = torch.randn(25, 784)  # 25 samples (5 samples per class, 5 classes)
support_labels = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4])

# Example query set
query_set = torch.randn(10, 784)  # 10 query samples

# Predict classes for the query set
predictions = model.predict(query_set, support_set, support_labels)
print("Predicted Classes:", predictions)