<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Prototypical_Networks_for_Few_Shot_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, datasets

# Custom dataset class for few-shot learning
class FewShotDataset(Dataset):
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = indices

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        image, label = self.dataset[self.indices[idx]]
        return image, label

# Define the Prototypical Network
class PrototypicalNetwork(nn.Module):
    def __init__(self):
        super(PrototypicalNetwork, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),  # Flatten before the linear layers
            nn.Linear(128 * 7 * 7, 256),  # Ensure dimensions match
            nn.ReLU()
        )

    def forward(self, support_set, query):
        # Encode support set and query
        support_embeddings = self.encoder(support_set)
        query_embeddings = self.encoder(query)

        return support_embeddings, query_embeddings

def calculate_prototypes(support_embeddings, support_labels, num_classes=10):
    # Create a tensor to hold the prototypes
    prototypes = torch.zeros(num_classes, support_embeddings.size(1)).to(support_embeddings.device)
    for i in range(num_classes):
        class_embeddings = support_embeddings[support_labels == i]
        if len(class_embeddings) > 0:
            prototypes[i] = class_embeddings.mean(dim=0)
    return prototypes

def pairwise_distances(x, y):
    n = x.size(0)
    m = y.size(0)
    d = x.size(1)
    x = x.unsqueeze(1).expand(n, m, d)
    y = y.unsqueeze(0).expand(n, m, d)
    return torch.pow(x - y, 2).sum(2)

# Training function
def train_prototypical_network(model, support_loader, query_loader, criterion, optimizer, device):
    model.train()
    for support_set, query_set in zip(support_loader, query_loader):
        support_images, support_labels = support_set
        query_images, query_labels = query_set
        support_images, query_images, support_labels, query_labels = support_images.to(device), query_images.to(device), support_labels.to(device), query_labels.to(device)

        # Forward pass
        support_embeddings, query_embeddings = model(support_images, query_images)

        # Compute prototypes
        prototypes = calculate_prototypes(support_embeddings, support_labels)

        # Compute distances
        distances = pairwise_distances(query_embeddings, prototypes)

        # Convert distances to logits
        logits = -distances

        # Compute loss
        loss = criterion(logits, query_labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# Evaluation function
def evaluate_prototypical_network(model, support_loader, query_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for support_set, query_set in zip(support_loader, query_loader):
            support_images, support_labels = support_set
            query_images, query_labels = query_set
            support_images, query_images, support_labels, query_labels = support_images.to(device), query_images.to(device), support_labels.to(device), query_labels.to(device)

            # Forward pass
            support_embeddings, query_embeddings = model(support_images, query_images)

            # Compute prototypes
            prototypes = calculate_prototypes(support_embeddings, support_labels)

            # Compute distances
            distances = pairwise_distances(query_embeddings, prototypes)

            # Convert distances to logits
            logits = -distances

            # Compute accuracy
            _, predicted = torch.max(logits, 1)
            correct += (predicted == query_labels).sum().item()
            total += query_labels.size(0)

    accuracy = correct / total
    return accuracy

# Load and preprocess data
transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

mnist_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
support_indices = list(range(600))  # Example support set indices
query_indices = list(range(600, 1000))  # Example query set indices

support_set = FewShotDataset(mnist_dataset, support_indices)
query_set = FewShotDataset(mnist_dataset, query_indices)

support_loader = DataLoader(support_set, batch_size=32, shuffle=True)
query_loader = DataLoader(query_set, batch_size=32, shuffle=True)

# Define model, criterion, optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = PrototypicalNetwork().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Train the model
train_prototypical_network(model, support_loader, query_loader, criterion, optimizer, device)

# Evaluate the model
accuracy = evaluate_prototypical_network(model, support_loader, query_loader, device)
print(f'Accuracy: {accuracy:.4f}')