In [1]:
import torch
import random
import numpy as np
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import Omniglot
import torchvision.transforms as transforms
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm

In [2]:
device = torch.device("cpu")
device

device(type='cpu')

In [3]:
class OmniglotFewShot(Dataset):
    def __init__(self, root, mode="train", transform=None, n_way=5, k_shots=5, n_query=5):
        super(OmniglotFewShot, self).__init__()
        self.root = root
        self.transform = transform
        self.n_way = n_way
        self.k_shots = k_shots
        self.n_query = n_query
        self.mode = mode
        self.data = Omniglot(root=self.root, background=True if self.mode == "train" else False, download=True, transform=self.transform)
        self.indices_by_class = self._create_indices_by_class()

    def _create_indices_by_class(self):
        indices_by_class = {}
        for idx, (_, label) in enumerate(self.data):
            if label not in indices_by_class:
                indices_by_class[label] = []
            indices_by_class[label].append(idx)
        return indices_by_class

    def __len__(self):
        return len(self.indices_by_class)

    def __getitem__(self, _):
        class_indices = np.random.choice(list(self.indices_by_class.keys()), self.n_way, replace=False)
        
        support_set = []
        query_set = []
        support_labels = []
        query_labels = []

        for class_index in class_indices:
            indices = np.random.choice(self.indices_by_class[class_index], self.k_shots + self.n_query, replace=False)
            class_support_set_indices = indices[:self.k_shots]
            class_query_set_indices = indices[self.k_shots:]

            for i in class_support_set_indices:
                image, _ = self.data[i]
                support_set.append(image)
                support_labels.append(class_index)

            for i in class_query_set_indices:
                image, _ = self.data[i]
                query_set.append(image)
                query_labels.append(class_index)

        # Convert lists to tensors for PyTorch compatibility
        support_set = torch.stack(support_set)
        query_set = torch.stack(query_set)
        support_labels = torch.tensor(support_labels)
        query_labels = torch.tensor(query_labels)


        support_set = support_set.to(device)
        query_set = query_set.to(device)
        support_labels = support_labels.to(device)
        query_labels = query_labels.to(device)


        return support_set, query_set, support_labels, query_labels


In [4]:
class RandomRotation90:
    def __call__(self, img):
        angle = random.choice([0, 90, 180, 270])
        return transforms.functional.rotate(img, angle)

In [5]:
transform = transforms.Compose([
    transforms.Resize((28, 28)),
    RandomRotation90(),
    transforms.ToTensor()
])


In [6]:
train_dataset = OmniglotFewShot(root='./data', mode='train', transform=transform, n_way=60, k_shots=5, n_query=5)
test_dataset = OmniglotFewShot(root='./data', mode='test', transform=transform, n_way=5, k_shots=5, n_query=15)

Files already downloaded and verified


Files already downloaded and verified


In [7]:
train_loader = DataLoader(train_dataset,shuffle=True)
test_loader = DataLoader(test_dataset,shuffle=False)

In [8]:
class Embedder(nn.Module):
    def __init__(self):
        super(Embedder, self).__init__()
        self.in_channels = 1
        self.out_channels = 64

        self.encoder = nn.Sequential(
            nn.Conv2d(self.in_channels, self.out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(self.out_channels, self.out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(self.out_channels, self.out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(self.out_channels, self.out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        self.flatten = nn.Flatten()

    def forward(self, x):
        x = self.encoder(x)
        x = self.flatten(x)
        return x


In [9]:
Embedder = Embedder()

In [10]:
Embedder.to(device)

Embedder(
  (encoder): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU()
    (11): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (12): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1

In [11]:
class Protonet(nn.Module):
    def __init__(self, embedder):
        super(Protonet, self).__init__()
        self.embedder = embedder

    def forward(self, support, query, n_way, k_shots):
        """
        Perform the forward pass and compute the prototypes.

        Parameters:
        - support: The support set.
        - query: The query set.
        - n_way: The number of classes (ways).
        - k_shots: The number of examples per class in the support set.

        Returns:
        - query_embeddings: The embeddings of the query set.
        - prototypes: The class prototypes.
        """
        # Embed support and query sets
        print(support.shape)
        support_embeddings = self.embedder(support.squeeze(0))
        query_embeddings = self.embedder(query.squeeze(0))
        
        # Calculate the prototypes for each class in the support set
        # Reshape support embeddings to [n_way, k_shots, embedding_size] and compute mean across k_shots
        prototypes = support_embeddings.view(n_way, k_shots, -1).mean(dim=1)
        
        return query_embeddings, prototypes


In [12]:
model = Protonet(Embedder)

In [20]:
def protonet_loss(query_embeddings, prototypes, query_labels, n_way):
    """
    Computes the Prototypical Networks loss given embeddings of the query set,
    class prototypes, and query labels.

    Parameters:
    - query_embeddings: The embeddings of the query set.
    - prototypes: The class prototypes.
    - query_labels: The labels for the query set.
    - n_way: The number of classes (ways).

    Returns:
    - Loss calculated using Negative Log Likelihood.
    """
    # Calculate the Euclidean distance from each query sample to the prototypes
    distances = torch.cdist(query_embeddings, prototypes) # shape n_instances,n_ways
    
    # Find the index of the nearest centroid for each instance
    nearest_indices = torch.argmin(distances, dim=1) # min_dist_index
    print(query_labels[nearest_indices])
    # Convert query_labels to a 1D tensor
    query_labels = query_labels.view(-1)
    
    # Convert distances to log probabilities
    log_p_y = F.log_softmax(-distances, dim=1)  # Apply softmax and take log
    
    # Compute the negative log likelihood loss
    loss = F.nll_loss(log_p_y, query_labels)
    
    return loss


In [14]:
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [15]:
scheduler = StepLR(optimizer, step_size=2000, gamma=0.5)

In [21]:
num_epochs = 10
total_episodes = 0  # Initialize total episodes


# Training loop
for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    total_loss = 0.0

    # Create tqdm progress bar
    pbar = tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs}', unit='batch')
    
    for support_set, query_set, support_labels, query_labels in pbar:
        optimizer.zero_grad()


        query_embeddings, prototypes = model(support_set, query_set, 60, 5)

        print("------------")
        print(query_labels)
        
        # Forward pass
        loss = protonet_loss(query_embeddings, prototypes, query_labels, 60)
        
        # Backward pass
        loss.backward()
        
        # Optimize
        optimizer.step()

        # Update total loss
        total_loss += loss.item()

        # Update progress bar description with current loss
        pbar.set_postfix({'Loss': loss.item()})
        
        total_episodes += 1  # Increment total episodes
        scheduler.step(total_episodes)  # Update learning rate
        
    # Calculate average loss for the epoch
    avg_loss = total_loss / len(train_loader)
    print(f'Epoch [{epoch + 1}/{num_epochs}], Avg. Loss: {avg_loss:.4f}')

# Training complete
print('Training complete.')


Epoch 1/10:   0%|          | 0/964 [00:00<?, ?batch/s]

torch.Size([1, 300, 1, 28, 28])
------------
tensor([[545, 545, 545, 545, 545, 479, 479, 479, 479, 479, 200, 200, 200, 200,
         200, 703, 703, 703, 703, 703, 904, 904, 904, 904, 904, 374, 374, 374,
         374, 374, 656, 656, 656, 656, 656,  49,  49,  49,  49,  49, 760, 760,
         760, 760, 760, 599, 599, 599, 599, 599, 935, 935, 935, 935, 935, 658,
         658, 658, 658, 658,  87,  87,  87,  87,  87, 349, 349, 349, 349, 349,
         879, 879, 879, 879, 879,  12,  12,  12,  12,  12, 478, 478, 478, 478,
         478, 246, 246, 246, 246, 246, 542, 542, 542, 542, 542, 721, 721, 721,
         721, 721,  14,  14,  14,  14,  14, 445, 445, 445, 445, 445, 594, 594,
         594, 594, 594, 591, 591, 591, 591, 591, 357, 357, 357, 357, 357, 733,
         733, 733, 733, 733, 644, 644, 644, 644, 644, 135, 135, 135, 135, 135,
         170, 170, 170, 170, 170, 175, 175, 175, 175, 175, 632, 632, 632, 632,
         632, 746, 746, 746, 746, 746, 100, 100, 100, 100, 100, 194, 194, 194,
       

Epoch 1/10:   0%|          | 0/964 [00:00<?, ?batch/s]


IndexError: index 10 is out of bounds for dimension 0 with size 1