In [1]:
import torch
import random
import numpy as np
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import Omniglot
import torchvision.transforms as transforms
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
from torch.autograd import Variable


In [2]:
device = torch.device("cuda")
device

device(type='cuda')

In [3]:
class OmniglotFewShot(Dataset):
    def __init__(self, root, mode="train", transform=None, n_way=5, k_shots=5, n_query=5):
        super(OmniglotFewShot, self).__init__()
        self.root = root
        self.transform = transform
        self.n_way = n_way
        self.k_shots = k_shots
        self.n_query = n_query
        self.mode = mode
        self.data = Omniglot(root=self.root, background=True if self.mode == "train" else False, download=True, transform=self.transform)
        self.indices_by_class = self._create_indices_by_class()

    def _create_indices_by_class(self):
        indices_by_class = {}
        for idx, (_, label) in enumerate(self.data):
            if label not in indices_by_class:
                indices_by_class[label] = []
            indices_by_class[label].append(idx)
        return indices_by_class

    def __len__(self):
        return len(self.indices_by_class)

    def __getitem__(self, _):
        class_indices = np.random.choice(list(self.indices_by_class.keys()), self.n_way, replace=False)
        
        support_set = []
        query_set = []
        support_labels = []
        query_labels = []

        for class_index in class_indices:
            indices = np.random.choice(self.indices_by_class[class_index], self.k_shots + self.n_query, replace=False)
            class_support_set_indices = indices[:self.k_shots]
            class_query_set_indices = indices[self.k_shots:]

            for i in class_support_set_indices:
                image, _ = self.data[i]
                support_set.append(image.unsqueeze(0))  # Add an extra dimension
                support_labels.append(class_index)

            for i in class_query_set_indices:
                image, _ = self.data[i]
                query_set.append(image.unsqueeze(0))  # Add an extra dimension
                query_labels.append(class_index)

        support_set = torch.stack(support_set, dim=0).reshape(self.n_way, self.k_shots, *image.shape)
        query_set = torch.stack(query_set, dim=0).reshape(self.n_way, self.n_query, *image.shape)
        support_labels = torch.tensor(support_labels)
        query_labels = torch.tensor(query_labels)

        support_set = support_set.to(device)
        query_set = query_set.to(device)
        support_labels = support_labels.to(device)
        query_labels = query_labels.to(device)
        

        return support_set, query_set, support_labels, query_labels


In [4]:
class RandomRotation90:
    def __call__(self, img):
        angle = random.choice([0, 90, 180, 270])
        return transforms.functional.rotate(img, angle)

In [5]:
transform = transforms.Compose([
    transforms.Resize((28, 28)),
    RandomRotation90(),
    transforms.ToTensor()
])


In [6]:
class Embedder(nn.Module):
    def __init__(self):
        super(Embedder, self).__init__()
        self.in_channels = 1
        self.out_channels = 64

        self.encoder = nn.Sequential(
            nn.Conv2d(self.in_channels, self.out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(self.out_channels, self.out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(self.out_channels, self.out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(self.out_channels, self.out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        self.flatten = nn.Flatten()

    def forward(self, x):
        x = self.encoder(x)
        x = self.flatten(x)
        return x


In [7]:
Embedder = Embedder()

In [8]:
Embedder.to(device)

Embedder(
  (encoder): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU()
    (11): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (12): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1

In [9]:
class Protonet(nn.Module):
    def __init__(self, embedder):
        super(Protonet, self).__init__()
        self.embedder = embedder

    def forward(self, support, query):
        """
        Perform the forward pass and compute the prototypes.

        Parameters:
        - support: The support set.
        - query: The query set.
        - n_way: The number of classes (ways).
        - k_shots: The number of examples per class in the support set.

        Returns:
        - query_embeddings: The embeddings of the query set.
        - prototypes: The class prototypes.
        """
        # Embed support and query sets
        xs = support.squeeze(0)
        xq = query.squeeze(0)

        n_class  =xs.shape[0]
        n_support =xs.shape[1]
        n_query = xq.shape[1]

        
        target_inds = torch.arange(0, n_class).view(n_class, 1, 1).expand(n_class, n_query, 1).long()
        target_inds = Variable(target_inds, requires_grad=False).to(device)
        x = torch.cat([xs.view(n_class * n_support, *xs.size()[2:]),
                       xq.view(n_class * n_query, *xq.size()[2:])], 0)

        z = self.embedder.forward(x)
        z_dim = z.size(-1)

        z_proto = z[:n_class*n_support].view(n_class, n_support, z_dim).mean(1)
        zq = z[n_class*n_support:]


        dists = torch.cdist(zq, z_proto)

        log_p_y = F.log_softmax(-dists, dim=1).view(n_class, n_query, -1)

        loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean()

        _, y_hat = log_p_y.max(2)
        acc_val = torch.eq(y_hat, target_inds.squeeze()).float().mean()

        return loss_val, {
            'loss': loss_val.item(),
            'acc': acc_val.item()
        }


In [10]:
model = Protonet(Embedder)

In [11]:
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [12]:
scheduler = StepLR(optimizer, step_size=2000, gamma=0.5)

In [13]:
def train(train_loader, test_loader, best_val_loss = float('inf'), best_acc=-float('inf')):
    num_epochs = 1000
    total_episodes = 0  
    patience = 3  
    wait = 0  
    for epoch in range(num_epochs):
        model.train()  
        total_loss = 0.0
        total_acc = 0.0
        
        train_pbar = tqdm(train_loader, desc=f'Training Epoch {epoch + 1}/{num_epochs}', unit='batch')
        
        for support_set, query_set, support_labels, query_labels in train_pbar:
            optimizer.zero_grad()
            loss, results = model(support_set, query_set)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            total_acc += results['acc']

            train_pbar.set_postfix({'Loss': loss.item(), 'Accuracy': results['acc']})
            
            total_episodes += 1  
            scheduler.step(total_episodes)  
            
        avg_loss = total_loss / len(train_loader)
        avg_acc = total_acc / len(train_loader)
        print(f'Training Epoch [{epoch + 1}/{num_epochs}], Avg. Loss: {avg_loss:.4f}, Avg. Accuracy: {avg_acc:.4f}')
        
        # Validation loop
        model.eval()  
        total_val_loss = 0.0
        total_val_acc = 0.0
        
        val_pbar = tqdm(test_loader, desc=f'Validation Epoch {epoch + 1}/{num_epochs}', unit='batch')
        
        with torch.no_grad():
            for support_set, query_set, support_labels, query_labels in val_pbar:
                loss, results = model(support_set, query_set)
                total_val_loss += loss.item()
                total_val_acc += results['acc']

                val_pbar.set_postfix({'Loss': loss.item(), 'Accuracy': results['acc']})
        
        avg_val_loss = total_val_loss / len(test_loader)
        avg_val_acc = total_val_acc / len(test_loader)
        print(f'Validation Epoch [{epoch + 1}/{num_epochs}], Avg. Loss: {avg_val_loss:.4f}, Avg. Accuracy: {avg_val_acc:.4f}')
        
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_acc = avg_val_acc
            wait = 0  
        else:
            wait += 1
            if wait >= patience:
                print(f'Validation loss did not improve for {patience} epochs. Early stopping...')
                break
        
    # Training complete
    print('Training complete.')
    return best_val_loss, best_acc


In [14]:
train_dataset = OmniglotFewShot(root='./data', mode='train', transform=transform, n_way=60, k_shots=5, n_query=5)
train_loader = DataLoader(train_dataset,shuffle=True)

Files already downloaded and verified


In [15]:
test_dataset = OmniglotFewShot(root='./data', mode='test', transform=transform, n_way=5, k_shots=5, n_query=15)
test_loader = DataLoader(test_dataset,shuffle=True)

Files already downloaded and verified


In [16]:
print(f"5 way 5 shot results {train(train_loader, test_loader)}")

Training Epoch 1/1000: 100%|██████████| 964/964 [03:20<00:00,  4.81batch/s, Loss=0.291, Accuracy=0.913]


Training Epoch [1/1000], Avg. Loss: 0.5608, Avg. Accuracy: 0.8498


Validation Epoch 1/1000: 100%|██████████| 659/659 [00:21<00:00, 30.78batch/s, Loss=0.104, Accuracy=0.987] 


Validation Epoch [1/1000], Avg. Loss: 0.0733, Avg. Accuracy: 0.9783


Training Epoch 2/1000: 100%|██████████| 964/964 [03:06<00:00,  5.18batch/s, Loss=0.165, Accuracy=0.933] 


Training Epoch [2/1000], Avg. Loss: 0.2116, Avg. Accuracy: 0.9385


Validation Epoch 2/1000: 100%|██████████| 659/659 [00:20<00:00, 32.82batch/s, Loss=0.0134, Accuracy=1]    


Validation Epoch [2/1000], Avg. Loss: 0.0611, Avg. Accuracy: 0.9809


Training Epoch 3/1000: 100%|██████████| 964/964 [03:04<00:00,  5.21batch/s, Loss=0.196, Accuracy=0.923] 


Training Epoch [3/1000], Avg. Loss: 0.1558, Avg. Accuracy: 0.9527


Validation Epoch 3/1000: 100%|██████████| 659/659 [00:19<00:00, 34.50batch/s, Loss=0.0318, Accuracy=0.973]


Validation Epoch [3/1000], Avg. Loss: 0.0521, Avg. Accuracy: 0.9837


Training Epoch 4/1000: 100%|██████████| 964/964 [03:03<00:00,  5.24batch/s, Loss=0.104, Accuracy=0.97]  


Training Epoch [4/1000], Avg. Loss: 0.1348, Avg. Accuracy: 0.9584


Validation Epoch 4/1000: 100%|██████████| 659/659 [00:18<00:00, 35.12batch/s, Loss=0.123, Accuracy=0.973] 


Validation Epoch [4/1000], Avg. Loss: 0.0572, Avg. Accuracy: 0.9829


Training Epoch 5/1000: 100%|██████████| 964/964 [03:02<00:00,  5.27batch/s, Loss=0.194, Accuracy=0.933] 


Training Epoch [5/1000], Avg. Loss: 0.1156, Avg. Accuracy: 0.9638


Validation Epoch 5/1000: 100%|██████████| 659/659 [00:18<00:00, 35.11batch/s, Loss=0.179, Accuracy=0.973] 


Validation Epoch [5/1000], Avg. Loss: 0.0506, Avg. Accuracy: 0.9846


Training Epoch 6/1000: 100%|██████████| 964/964 [03:02<00:00,  5.27batch/s, Loss=0.104, Accuracy=0.963] 


Training Epoch [6/1000], Avg. Loss: 0.1047, Avg. Accuracy: 0.9670


Validation Epoch 6/1000: 100%|██████████| 659/659 [00:18<00:00, 35.37batch/s, Loss=0.101, Accuracy=0.947] 


Validation Epoch [6/1000], Avg. Loss: 0.0471, Avg. Accuracy: 0.9856


Training Epoch 7/1000: 100%|██████████| 964/964 [03:03<00:00,  5.26batch/s, Loss=0.153, Accuracy=0.96]  


Training Epoch [7/1000], Avg. Loss: 0.0966, Avg. Accuracy: 0.9693


Validation Epoch 7/1000: 100%|██████████| 659/659 [00:18<00:00, 35.08batch/s, Loss=0.0977, Accuracy=0.987]


Validation Epoch [7/1000], Avg. Loss: 0.0496, Avg. Accuracy: 0.9854


Training Epoch 8/1000: 100%|██████████| 964/964 [03:03<00:00,  5.25batch/s, Loss=0.145, Accuracy=0.96]  


Training Epoch [8/1000], Avg. Loss: 0.0914, Avg. Accuracy: 0.9705


Validation Epoch 8/1000: 100%|██████████| 659/659 [00:18<00:00, 35.11batch/s, Loss=0.0132, Accuracy=1]    


Validation Epoch [8/1000], Avg. Loss: 0.0462, Avg. Accuracy: 0.9859


Training Epoch 9/1000: 100%|██████████| 964/964 [03:03<00:00,  5.26batch/s, Loss=0.088, Accuracy=0.973] 


Training Epoch [9/1000], Avg. Loss: 0.0875, Avg. Accuracy: 0.9715


Validation Epoch 9/1000: 100%|██████████| 659/659 [00:18<00:00, 35.08batch/s, Loss=0.0029, Accuracy=1]    


Validation Epoch [9/1000], Avg. Loss: 0.0495, Avg. Accuracy: 0.9847


Training Epoch 10/1000: 100%|██████████| 964/964 [03:01<00:00,  5.30batch/s, Loss=0.0727, Accuracy=0.987]


Training Epoch [10/1000], Avg. Loss: 0.0845, Avg. Accuracy: 0.9723


Validation Epoch 10/1000: 100%|██████████| 659/659 [00:18<00:00, 35.39batch/s, Loss=0.0572, Accuracy=0.987] 


Validation Epoch [10/1000], Avg. Loss: 0.0525, Avg. Accuracy: 0.9844


Training Epoch 11/1000: 100%|██████████| 964/964 [03:01<00:00,  5.30batch/s, Loss=0.121, Accuracy=0.953] 


Training Epoch [11/1000], Avg. Loss: 0.0824, Avg. Accuracy: 0.9731


Validation Epoch 11/1000: 100%|██████████| 659/659 [00:18<00:00, 35.56batch/s, Loss=0.0076, Accuracy=1]    

Validation Epoch [11/1000], Avg. Loss: 0.0469, Avg. Accuracy: 0.9855
Validation loss did not improve for 3 epochs. Early stopping...
Training complete.
5 way 5 shot results (0.04615281028638108, 0.9858978368986359)





In [17]:
def validate(model, val_loader):
    model.eval()  
    total_val_loss = 0.0
    total_val_acc = 0.0
    num_batches = len(val_loader)
    
    # Create tqdm progress bar for validation
    val_pbar = tqdm(val_loader, desc='Validation', unit='batch')
    
    with torch.no_grad():
        for support_set, query_set, support_labels, query_labels in val_pbar:
            loss, results = model(support_set, query_set)
            total_val_loss += loss.item()
            total_val_acc += results['acc']

            val_pbar.set_postfix({'Loss': loss.item(), 'Accuracy': results['acc']})
    
    # Calculate average loss and accuracy for the validation set
    avg_val_loss = total_val_loss / num_batches
    avg_val_acc = total_val_acc / num_batches
    
    return avg_val_loss, avg_val_acc


In [18]:
test_dataset = OmniglotFewShot(root='./data', mode='test', transform=transform, n_way=5, k_shots=1, n_query=15)
test_loader = DataLoader(test_dataset,shuffle=True)
print(f"5 way 1 shot results {validate(model, test_loader)}")

Files already downloaded and verified


Validation: 100%|██████████| 659/659 [00:14<00:00, 44.58batch/s, Loss=0.0776, Accuracy=0.987]

5 way 1 shot results (0.16194710969971388, 0.9468487814995877)





In [19]:
test_dataset = OmniglotFewShot(root='./data', mode='test', transform=transform, n_way=20, k_shots=1, n_query=15)
test_loader = DataLoader(test_dataset,shuffle=True)
print(f"20 way 1 shot results {validate(model, test_loader)}")

Files already downloaded and verified


Validation: 100%|██████████| 659/659 [00:53<00:00, 12.26batch/s, Loss=0.603, Accuracy=0.787]

20 way 1 shot results (0.5185531473214058, 0.8467931410068807)





In [20]:
test_dataset = OmniglotFewShot(root='./data', mode='test', transform=transform, n_way=20, k_shots=5, n_query=15)
test_loader = DataLoader(test_dataset,shuffle=True)
print(f"20 way 5 shot results {validate(model, test_loader)}")

Files already downloaded and verified


Validation: 100%|██████████| 659/659 [01:07<00:00,  9.78batch/s, Loss=0.0976, Accuracy=0.98] 

20 way 5 shot results (0.15649693870241435, 0.9539555158347389)



