<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Few_Shot_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Define the PrototypicalNetwork class
class PrototypicalNetwork(nn.Module):
    def __init__(self, embedding_dim):
        super(PrototypicalNetwork, self).__init__()
        self.embedding = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3),  # First convolutional layer
            nn.ReLU(),  # ReLU activation
            nn.MaxPool2d(2),  # Max pooling
            nn.Conv2d(64, 128, kernel_size=3),  # Second convolutional layer
            nn.ReLU(),  # ReLU activation
            nn.MaxPool2d(2),  # Max pooling
        )
        self.fc = nn.Linear(128 * 5 * 5, embedding_dim)  # Fully connected layer

    def forward(self, x):
        x = self.embedding(x)  # Forward pass through the embedding network
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = self.fc(x)  # Apply fully connected layer
        return x

# Define the function to compute prototypes
def compute_prototypes(embeddings, labels, num_classes):
    prototypes = torch.zeros((num_classes, embeddings.size(1)))  # Initialize prototypes
    for i in range(num_classes):
        class_embeddings = embeddings[labels == i]  # Get embeddings of the current class
        prototypes[i] = class_embeddings.mean(dim=0)  # Compute the mean embedding for the class
    return prototypes

# Example usage
model = PrototypicalNetwork(embedding_dim=64)
input_data = torch.randn(32, 1, 28, 28)  # Example input (batch_size=32, channels=1, height=28, width=28)
labels = torch.randint(0, 5, (32,))  # Example labels (batch_size=32, num_classes=5)
embeddings = model(input_data)  # Compute embeddings
prototypes = compute_prototypes(embeddings, labels, num_classes=5)

# Print the shape of the prototypes
print(prototypes.shape)  # Expected shape: [5, 64]