<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Few_Shot_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define the SimpleFewShotModel class
class SimpleFewShotModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SimpleFewShotModel, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)  # First fully connected layer
        self.fc2 = nn.Linear(128, output_dim)  # Second fully connected layer

    def forward(self, x):
        x = torch.relu(self.fc1(x))  # Apply ReLU activation after the first layer
        return self.fc2(x)  # Output layer

# Define the train_few_shot function
def train_few_shot(model, support_data, support_labels, query_data, query_labels, inner_steps=1, lr=0.01):
    optimizer = optim.SGD(model.parameters(), lr=lr)  # Optimizer
    criterion = nn.CrossEntropyLoss()  # Loss function

    for _ in range(inner_steps):
        optimizer.zero_grad()  # Zero the gradients
        output = model(support_data)  # Forward pass with support data
        loss = criterion(output, support_labels)  # Compute the loss
        loss.backward()  # Backpropagate the loss
        optimizer.step()  # Update the model parameters

    with torch.no_grad():
        query_output = model(query_data)  # Forward pass with query data
        query_loss = criterion(query_output, query_labels)  # Compute the query loss
    return query_loss.item()  # Return the query loss

# Example usage
model = SimpleFewShotModel(input_dim=5, output_dim=2)  # Instantiate the model
support_data = torch.randn(10, 5)  # 10 support examples
support_labels = torch.randint(0, 2, (10,))  # Support labels
query_data = torch.randn(5, 5)  # 5 query examples
query_labels = torch.randint(0, 2, (5,))  # Query labels
loss = train_few_shot(model, support_data, support_labels, query_data, query_labels)  # Perform few-shot learning
print("Few-shot learning query loss:", loss)  # Print the query loss