In [1]:
import numpy as np

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import higher

from loader import PNGDataset

In [2]:
class ConvNet(nn.Module):
    def __init__(self, embedding_size=64, w1=64, w2=128, w3=256, dropout_rate=0.2, use_bn=True):
        super(ConvNet, self).__init__()
        self.embedding_size = embedding_size
        
        self.conv1 = nn.Conv2d(3, w1, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(w1, w2, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(w2, w3, kernel_size=3, padding=1)
        
        self.bn1 = nn.BatchNorm2d(w1) if use_bn else nn.Identity()
        self.bn2 = nn.BatchNorm2d(w2) if use_bn else nn.Identity()
        self.bn3 = nn.BatchNorm2d(w3) if use_bn else nn.Identity()
        
        self.pool = nn.MaxPool2d(2, 2)
        
        self.fun = F.relu
        
        self.fc1 = nn.Linear(w3 * 4 * 4, 128)
        self.drop =  nn.Dropout(p=dropout_rate)
        self.fc2 = nn.Linear(128, embedding_size)

        return
    
    def forward(self, x):
        x = self.pool(self.fun(self.bn1(self.conv1(x))))
        x = self.pool(self.fun(self.bn2(self.conv2(x))))
        x = self.pool(self.fun(self.bn3(self.conv3(x))))
        x = self.drop(torch.flatten(x, 1))
        x = self.fun(self.fc1(x))
        x = self.fc2(self.drop(x))
        
        return x

In [3]:
class MAMLModel(nn.Module):
    def __init__(self, model, n_classes=10):
        super(MAMLModel, self).__init__()
        
        self.model = model
        self.n_classes = n_classes
    
    def forward(self, support_set, query_set):        
        # Compute support set embeddings
        support_embeddings = self.model(support_set)
        query_embeddings = self.model(query_set)
        
        # Compute prototypes (mean of support set embeddings)
        prototypes = self.compute_prototypes(support_embeddings)
        
        # Calculate distances between query samples and prototypes
        distances = self.compute_distances(query_embeddings, prototypes)
        
        return distances
    
    def compute_prototypes(self, support_embeddings):
        # Reshape support_embeddings and compute class prototypes
        support_embeddings = support_embeddings.view(self.n_classes, -1, support_embeddings.size(-1))
        prototypes = support_embeddings.mean(dim=1)
        
        return prototypes
    
    def compute_distances(self, query_embeddings, prototypes):
        # Compute squared Euclidean distance between query embeddings and prototypes
        distances = torch.cdist(query_embeddings, prototypes)
        
        return distances

In [4]:
class FewShotDataset(Dataset):
    def __init__(self, dataset, n_classes=10, n_shots=5, n_queries=15, transform=None):
        self.dataset = dataset
        self.n_classes = n_classes
        self.n_shots = n_shots
        self.n_queries = n_queries
        self.transform = transform
        
        # Group images by their labels
        self.class_images = {}
        for i in range(len(dataset)):
            img, label = dataset[i]
            if label not in self.class_images:
                self.class_images[label] = []
            self.class_images[label].append(img)
    
    def __len__(self):
        return 1000  # Number of episodes
    
    def __getitem__(self, index):
        # Randomly sample classes
        selected_classes = np.random.choice(list(self.class_images.keys()), self.n_classes, replace=False)
        
        support_set = []
        query_set = []
        support_labels = []
        query_labels = []
        
        for label in selected_classes:
            images = self.class_images[label]
            np.random.shuffle(images)
            
            # Sample K images for the support set
            support_images = images[:self.n_shots]
            support_set.extend(support_images)
            support_labels.extend([label] * self.n_shots)
            
            # Sample query set (n_queries)
            query_images = images[self.n_shots:self.n_shots + self.n_queries]
            query_set.extend(query_images)
            query_labels.extend([label] * self.n_queries)
        
        # Apply transformations if specified
        if self.transform:
            support_set = [self.transform(img) for img in support_set]
            query_set = [self.transform(img) for img in query_set]
        
        support_set = torch.stack(support_set)
        query_set = torch.stack(query_set)
        
        return support_set, query_set, torch.tensor(support_labels), torch.tensor(query_labels)

In [5]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.4788952171802521, 0.4722793698310852, 0.43047481775283813],
        std=[0.24205632507801056, 0.2382805347442627, 0.25874853134155273]
    )
])

In [6]:
train_dataset = PNGDataset("data/sample/train")

In [7]:
few_shot_dataset = FewShotDataset(train_dataset, n_classes=2, n_shots=5, n_queries=15, transform=transform)
train_loader = DataLoader(few_shot_dataset, batch_size=32, shuffle=True)

In [9]:
model = MAMLModel(ConvNet(embedding_size=64), n_classes=2)
optimizer = optim.Adam(model.parameters(), lr=0.0005)

In [10]:
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for support_set, query_set, support_labels, query_labels in train_loader:    
        # Define meta-objective for MAML
        with higher.innerloop_ctx(model, optimizer, copy_initial_weights=True) as (meta_model, meta_optimizer):
            # Loop over episodes
            loss = 0.0
            for i in range(support_set.size(0)):
                support_batch = support_set[i]
                query_batch = query_set[i]
    
                # Forward pass through the model
                distances = meta_model(support_batch, query_batch)
                
                # Calculate loss for this task (using cross entropy)
                predicted_labels = torch.argmin(distances, dim=1)
                task_loss = F.cross_entropy(distances, query_labels[i])
                loss += task_loss
    
            # Compute gradients and update the model
            meta_optimizer.step(loss)
    
        # Compute the meta-gradient and update the model
        running_loss += loss.item()
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}")
    
    # Evaluate and perform additional steps as necessary

Epoch [1/5], Loss: 21.8740
Epoch [2/5], Loss: 21.8455
Epoch [3/5], Loss: 21.8742
Epoch [4/5], Loss: 21.8552
Epoch [5/5], Loss: 21.8874
