<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Prototypical_Networks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F  # Add this import
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np

# Define the Prototypical Network model
class ProtoNet(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(ProtoNet, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Sample few-shot task data (e.g., 5 classes, 5 support examples per class, 15 query examples per class)
def generate_data(num_classes, num_support, num_query, input_dim):
    data = np.random.rand(num_classes * (num_support + num_query), input_dim)
    labels = np.array([[i] * (num_support + num_query) for i in range(num_classes)]).flatten()
    return data, labels

# Prepare the data
input_dim = 64
num_classes = 5
num_support = 5
num_query = 15
hidden_dim = 128

data, labels = generate_data(num_classes, num_support, num_query, input_dim)
dataset = TensorDataset(torch.FloatTensor(data), torch.LongTensor(labels))
loader = DataLoader(dataset, batch_size=num_classes * (num_support + num_query), shuffle=True)

# Instantiate the model
model = ProtoNet(input_dim, hidden_dim)

# Define the training loop
def train_protonet(model, loader, num_classes, num_support, num_query, num_epochs, learning_rate):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    for epoch in range(num_epochs):
        for data, labels in loader:
            data, labels = data, labels  # Keep data on CPU
            optimizer.zero_grad()

            # Split data into support and query sets
            support_indices = torch.arange(num_classes * num_support)
            query_indices = torch.arange(num_classes * num_support, num_classes * (num_support + num_query))

            support_set = data[support_indices]
            query_set = data[query_indices]

            # Compute class prototypes
            prototypes = model(support_set).view(num_classes, num_support, -1).mean(1)

            # Compute distances from query set to prototypes
            query_embeddings = model(query_set)
            distances = torch.cdist(query_embeddings, prototypes)

            # Compute loss
            query_labels = labels[query_indices]
            loss = F.cross_entropy(-distances, query_labels)
            loss.backward()
            optimizer.step()

            print(f"Epoch {epoch+1}, Loss: {loss.item()}")

train_protonet(model, loader, num_classes, num_support, num_query, num_epochs=10, learning_rate=0.001)