<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Few_Shot_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets

# Define the Prototypical Network
class PrototypicalNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(PrototypicalNetwork, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.encoder(x)

# Sampling task with valid rotation indices
def sample_task(dataset, n_way, n_shot, n_query):
    rotation_indices = [0, 1, 2, 3]  # Define rotation classes (aligned with model output_dim)
    support_set = []
    query_set = []
    for c in rotation_indices[:n_way]:  # Limit classes to valid rotation indices
        indices = torch.nonzero(torch.tensor(dataset.targets) == c).squeeze()
        indices = indices[torch.randperm(len(indices))[:n_shot + n_query]]
        support_set.append(indices[:n_shot])
        query_set.append(indices[n_shot:])
    return support_set, query_set, rotation_indices[:n_way]

# Compute prototypes and loss
def compute_prototypes_and_loss(model, dataset, support_set, query_set, classes, output_dim):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Compute prototypes from support set
    prototypes = []
    for indices in support_set:
        support_data = dataset.data[indices].view(-1, 28 * 28).float().to(device) / 255.0
        support_embeddings = model(support_data)
        prototypes.append(support_embeddings.mean(dim=0))
    prototypes = torch.stack(prototypes)

    # Compute query embeddings
    all_query_embeddings = []
    all_query_labels = []
    for indices, class_idx in zip(query_set, classes):
        query_data = dataset.data[indices].view(-1, 28 * 28).float().to(device) / 255.0
        query_embeddings = model(query_data)
        all_query_embeddings.append(query_embeddings)
        all_query_labels.append(torch.full((len(indices),), class_idx, dtype=torch.long).to(device))
    query_embeddings = torch.cat(all_query_embeddings)
    query_labels = torch.cat(all_query_labels)

    # Compute distances and classify
    distances = torch.cdist(query_embeddings, prototypes)
    predicted_labels = torch.argmin(distances, dim=1)

    # Compute loss
    criterion = nn.CrossEntropyLoss()
    loss = criterion(-distances, query_labels)  # Negative distances for logits
    return loss, predicted_labels, query_labels

# Example usage
dataset = datasets.MNIST('./data', train=True, download=True, transform=transforms.ToTensor())
support_set, query_set, classes = sample_task(dataset, n_way=4, n_shot=5, n_query=15)  # Align n_way with valid indices

# Define model and optimizer
model = PrototypicalNetwork(input_dim=784, hidden_dim=256, output_dim=128)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
for epoch in range(10):
    optimizer.zero_grad()
    loss, predictions, query_labels = compute_prototypes_and_loss(
        model, dataset, support_set, query_set, classes, output_dim=128
    )
    loss.backward()
    optimizer.step()
    accuracy = (predictions == query_labels).float().mean().item()
    print(f"Epoch {epoch + 1}, Loss: {loss.item()}, Accuracy: {accuracy:.4f}")

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets

# Define the Prototypical Network
class PrototypicalNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(PrototypicalNetwork, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.encoder(x)

# Sample a few-shot learning task
def sample_task(dataset, n_way, n_shot, n_query):
    rotation_indices = [0, 1, 2, 3]  # Define rotation classes (aligned with model output_dim)
    support_set = []
    query_set = []
    for c in rotation_indices[:n_way]:  # Limit classes to valid rotation indices
        indices = torch.nonzero(dataset.targets == c).squeeze()  # Avoid redundant tensor wrapping
        indices = indices[torch.randperm(len(indices))[:n_shot + n_query]]
        support_set.append(indices[:n_shot])
        query_set.append(indices[n_shot:])
    return support_set, query_set, rotation_indices[:n_way]

# Compute prototypes and classify
def compute_prototypes_and_loss(model, dataset, support_set, query_set, classes, output_dim):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Compute prototypes from support set
    prototypes = []
    for indices in support_set:
        support_data = dataset.data[indices].view(-1, 28 * 28).float().to(device) / 255.0
        support_embeddings = model(support_data)
        prototypes.append(support_embeddings.mean(dim=0))
    prototypes = torch.stack(prototypes)

    # Compute query embeddings
    all_query_embeddings = []
    all_query_labels = []
    for indices, class_idx in zip(query_set, classes):
        query_data = dataset.data[indices].view(-1, 28 * 28).float().to(device) / 255.0
        query_embeddings = model(query_data)
        all_query_embeddings.append(query_embeddings)
        all_query_labels.append(torch.full((len(indices),), class_idx, dtype=torch.long).to(device))
    query_embeddings = torch.cat(all_query_embeddings)
    query_labels = torch.cat(all_query_labels)

    # Compute distances and classify
    distances = torch.cdist(query_embeddings, prototypes)
    predicted_labels = torch.argmin(distances, dim=1)

    # Compute loss
    criterion = nn.CrossEntropyLoss()
    loss = criterion(-distances, query_labels)  # Negative distances for logits
    return loss, predicted_labels, query_labels

# Example usage
dataset = datasets.MNIST('./data', train=True, download=True, transform=transforms.ToTensor())
support_set, query_set, classes = sample_task(dataset, n_way=4, n_shot=5, n_query=15)  # Align n_way with valid indices

# Define model and optimizer
model = PrototypicalNetwork(input_dim=784, hidden_dim=256, output_dim=128)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
for epoch in range(10):
    optimizer.zero_grad()
    loss, predictions, query_labels = compute_prototypes_and_loss(
        model, dataset, support_set, query_set, classes, output_dim=128
    )
    loss.backward()
    optimizer.step()
    accuracy = (predictions == query_labels).float().mean().item()
    print(f"Epoch {epoch + 1}, Loss: {loss.item()}, Accuracy: {accuracy:.4f}")