In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import numpy as np

# Define the embedding network
class EmbeddingNet(nn.Module):
    def __init__(self):
        super(EmbeddingNet, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 256),
            nn.ReLU(),
            nn.Linear(256, 128)
        )
    
    def forward(self, x):
        return self.net(x)

# Matching Network model
class MatchingNetwork(nn.Module):
    def __init__(self, embedding_net):
        super(MatchingNetwork, self).__init__()
        self.embedding_net = embedding_net
    
    def forward(self, support, queries, n_way, k_shot):
        support_embeddings = self.embedding_net(support)  # (n_way*k_shot, embedding_dim)
        query_embeddings = self.embedding_net(queries)    # (num_queries, embedding_dim)
        
        # Calculate the cosine similarity
        support_embeddings = F.normalize(support_embeddings, p=2, dim=1)
        query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
        
        similarities = torch.matmul(query_embeddings, support_embeddings.T)  # (num_queries, n_way*k_shot)
        
        # Softmax over support set for each query
        similarities = similarities.view(-1, n_way, k_shot)
        similarities = similarities.mean(dim=2)
        return F.log_softmax(similarities, dim=1)

# Example usage
n_way = 5
k_shot = 5
num_queries = 15
input_dim = (1, 28, 28)

# Create random support and query sets
support_set = torch.rand(n_way * k_shot, *input_dim)
query_set = torch.rand(num_queries, *input_dim)
labels = torch.randint(0, n_way, (num_queries,))

embedding_net = EmbeddingNet()
model = MatchingNetwork(embedding_net)

optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training step
model.train()
optimizer.zero_grad()
log_probs = model(support_set, query_set, n_way, k_shot)
loss = F.nll_loss(log_probs, labels)
loss.backward()
optimizer.step()

# Testing step
model.eval()
with torch.no_grad():
    log_probs = model(support_set, query_set, n_way, k_shot)
    pred_labels = torch.argmax(log_probs, dim=1)
    accuracy = (pred_labels == labels).float().mean()
    print(f'Accuracy: {accuracy.item()}')
