<a href="https://colab.research.google.com/github/KyleRoss-rice/tiny-cifar10-experiments/blob/main/challenge_2_prototypical_networks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install learn2learn



In [None]:
import torch
import torch.optim as optim
import torch.nn as nn 
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

import torchvision
from torch.utils.data import Subset
from torchvision import datasets, transforms
import torchvision.models as models

import numpy as np
from numpy.random import RandomState
import matplotlib.pyplot as plt

from copy import deepcopy

import learn2learn as l2l
from learn2learn.data.transforms import NWays, KShots, LoadData, RemapLabels
from learn2learn.vision.datasets import MiniImagenet

#Useful Functions

In [None]:
def pairwise_distances_logits(a, b):                                            #  L2 norm of each query point and the prototype of each class 
    n = a.shape[0]
    
    m = b.shape[0]
   
    logits = -((a.unsqueeze(1).expand(n, m, -1) -b.unsqueeze(0).expand(n, m, -1))**2).sum(dim=2) # (query points x way (class) x embed_dim ) summed over embed_dim
  
    return logits


In [None]:
def accuracy(predictions, targets):
    predictions = predictions.argmax(dim=1).view(targets.shape)
    return (predictions == targets).sum().float() / targets.size(0)


In [None]:
def fast_adapt(model, batch, ways, shot, query_num, metric=None, device=None):
    if metric is None:
        metric = pairwise_distances_logits
    if device is None:
        device = model.device()
    data, labels = batch
    data = data.to(device)
    labels = labels.to(device)
    n_items = shot * ways                                                       # way== class and shot== number of samples per class, total samples would be class x # of samples/class

    # Sort data samples by labels
    sort = torch.sort(labels)                                                   # so that images from the same class are near each other 
    data = data.squeeze(0)[sort.indices].squeeze(0)
    labels = labels.squeeze(0)[sort.indices].squeeze(0)
   
    # Compute support and query embeddings
    embeddings = model(data)                                                    # The model learns an embedding space each image is mapped to a 1x256 tensor 
    
    support_indices = np.zeros(data.size(0), dtype=bool)                        # initialize all False 
   
    selection = np.arange(ways) * (shot + query_num)                            # select some of the data to be part of the support set 
    
    for offset in range(shot):                                                  # Shots from each class 
        support_indices[selection + offset] = True                              # Set only the indices of the selected data (support) to be True
        
    query_indices = torch.from_numpy(~support_indices)                          
    support_indices = torch.from_numpy(support_indices)
    
    support = embeddings[support_indices]                                       # The support set
    
    support = support.reshape(ways, shot, -1).mean(dim=1)                       # A class' prototype representation is the mean vector of the support set in the learned embedding space
                                                                                # Each way (class) will have a prototype, hence the shape here would be way x embed_dim (256)

    
    query = embeddings[query_indices]                                           # The query set 
    
    labels = labels[query_indices].long()

    logits = pairwise_distances_logits(query, support)                          # Calculating the distance between prototype representation (obtained by the model) and query point 
                                                                                # The distances will be of shape query points x ways 
    loss = F.cross_entropy(logits, labels)                                      
    acc = accuracy(logits, labels)
    return loss, acc


# The model

In [None]:
class Convnet(nn.Module):

    def __init__(self, x_dim=3, hid_dim=64, z_dim=64):
        super().__init__()
        self.encoder = l2l.vision.models.ConvBase(output_size=z_dim,
                                                  hidden=hid_dim,
                                                  channels=x_dim,
                                                  max_pool=True)

    def forward(self, x):
        x = self.encoder(x)
        return x.view(x.size(0), -1)


# Data

In [None]:
max_epoch =250

train_shot = 1   # Train shots (number of samples that will be used as support -- used to calculate the prototypes)
valid_shot = 1   # In our experiments we chose to keep the number of shots the same for all (train, validation, and testing)
test_shot = 1

train_way = 10   # 10-way since the CIFAR-FS has 100 classes  ( # of prototypes == # of ways ), higher ways in training increases the accuracy and helps the network find better embeddings -- network generalizes better (reference prototypical networks)
valid_way = 5    #  5-way on CIFAR-10
test_way = 5     # (to be decided on)

train_query = 15 # 15 query points per episode (the more query points we have the more training we are doing) -- similar to what they did in the prototypical networks with mini-imagenet
valid_query = 5  # Since each way/class has 10 images in our subset of 100 samples from CIFAR-10
test_query = 5   # (to decided on)

gpu=True
device = torch.device('cpu')
if gpu and torch.cuda.device_count():
    print("Using gpu")
    torch.cuda.manual_seed(43)
    device = torch.device('cuda')

device

Using gpu


device(type='cuda')

In [None]:
## Test set I : 2000 samples from CIFAR-10 training data
## Test set II: 2000 samples from CIFAR-10 testing data
n_subsets=3                                                                     # 3 different subsets will be used to calculate the mean and std for the testing accuracy
accs =[]
final_accs=[]
torch.cuda.manual_seed(43)
for i in range(n_subsets):
  # normalize the data
  normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])
  transf = transforms.Compose([transforms.ToTensor(), normalize]) 
  

  # download datasets
  cifar_data_val = datasets.CIFAR10(root='datasets', train=True, transform=transf, download=True)
  cifar_data_test = datasets.CIFAR10(root='datasets', train=True, transform=transf, download=True)  # Used to construct test set I
  final_test_data = datasets.CIFAR10(root='datasets', train=False, transform=transf, download=True) # Used to construct test set II
  prng = RandomState(i)
  random_permute = prng.permutation(np.arange(0, 5000))
  indx_val = np.concatenate([np.where(np.array(cifar_data_val.targets) == classe)[0][random_permute[0:10]] for classe in range(0, 10)])
  indx_tst = np.concatenate([np.where(np.array(cifar_data_test.targets) == classe)[0][random_permute[10:210]] for classe in range(0, 10)])
  random_permute = prng.permutation(np.arange(0, 900))
  tst = np.concatenate([np.where(np.array(final_test_data.targets) == classe)[0][random_permute[0:200]] for classe in range(0, 10)])
  
  train_dataset = l2l.vision.datasets.CIFARFS(root='./data', mode='train',download=True,transform=transforms.Compose([transforms.ToTensor()])) 
  valid_dataset= Subset(cifar_data_val, indx_val)   # Validation data (100 samples from CIFAR-10 training data)
  test_dataset=  Subset(cifar_data_test, indx_tst)  # Test set I (2000 samples from CIFAR-10 training data)
  final_test_dataset = Subset(final_test_data,tst)  # Test set II (2000 samples from CIFAR-10 testing data)
  
  # Setting up the meta-training datasets
  train_dataset = l2l.data.MetaDataset(train_dataset)
  train_transforms = [
      NWays(train_dataset, train_way),    # N-way = number of classes
      KShots(train_dataset, train_query + train_shot), # k-shots = number of samples per class, the shot is used in the support set and the query is for validation
      LoadData(train_dataset),
      RemapLabels(train_dataset)
      
  ]
  train_tasks = l2l.data.TaskDataset(train_dataset, task_transforms=train_transforms)
  train_loader = DataLoader(train_tasks, pin_memory=True, shuffle=True)

  valid_dataset = l2l.data.MetaDataset(valid_dataset)
  valid_transforms = [
      NWays(valid_dataset, valid_way),
      KShots(valid_dataset, valid_query + valid_shot),
      LoadData(valid_dataset),
      RemapLabels(valid_dataset)
  ]
  valid_tasks = l2l.data.TaskDataset(valid_dataset,
                                      task_transforms=valid_transforms
                                      ,num_tasks=10)                            
  valid_loader = DataLoader(valid_tasks, pin_memory=True, shuffle=True)

  test_dataset = l2l.data.MetaDataset(test_dataset)
  test_transforms = [
      NWays(test_dataset, test_way),
      KShots(test_dataset,test_query + test_shot),
      LoadData(test_dataset),
      RemapLabels(test_dataset),
  ]
  test_tasks = l2l.data.TaskDataset(test_dataset,
                                    task_transforms=test_transforms,num_tasks=10)
  test_loader = DataLoader(test_tasks, pin_memory=True, shuffle=True)
  
  final_test_dataset = l2l.data.MetaDataset(final_test_dataset)
  final_test_transforms = [
      NWays(final_test_dataset, test_way),
      KShots(final_test_dataset,test_query + test_shot),
      LoadData(final_test_dataset),
      RemapLabels(final_test_dataset),
  ]
  final_test_tasks = l2l.data.TaskDataset(final_test_dataset,
                                    task_transforms=final_test_transforms,num_tasks=10)
  final_test_loader = DataLoader(final_test_tasks, pin_memory=True, shuffle=True)
  torch.cuda.manual_seed(43)
  
  ################################# Training ###################################
  
  model =  Convnet()
  model.to(device)
  optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
  lr_scheduler = torch.optim.lr_scheduler.StepLR(
      optimizer, step_size=20, gamma=0.5)
  train_losses =[]
  train_acc =[]
  valid_losses=[]
  valid_acc=[]

  for epoch in range(1, 21):
    model.train()

    loss_ctr = 0
    n_loss = 0
    n_acc = 0

    for i in range(100):                                                          # 100 tasks -- the tasks are created on the fly
        batch = next(iter(train_loader))

        loss, acc = fast_adapt(model,
                                batch,
                                train_way,
                                train_shot,
                                train_query,
                                metric=pairwise_distances_logits,
                                device=device)
        
        loss_ctr += 1
        n_loss += loss.item()
        n_acc += acc

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    train_losses.append(loss)
    train_acc.append(acc)
    lr_scheduler.step()

    print('epoch {}, train, loss={:.4f} acc={:.4f}'.format(
        epoch, n_loss/loss_ctr, n_acc/loss_ctr))

    model.eval()

    loss_ctr = 0
    n_loss = 0
    n_acc = 0
    for i, batch in enumerate(valid_loader):
        loss, acc = fast_adapt(model,
                                batch,
                                valid_way,
                                valid_shot,
                                valid_query,
                                metric=pairwise_distances_logits,
                                device=device)
        
        loss_ctr += 1
        n_loss += loss.item()
        n_acc += acc
    valid_losses.append(loss)
    valid_acc.append(acc)
    print('epoch {}, val, loss={:.4f} acc={:.4f}'.format(
        epoch, n_loss/loss_ctr, n_acc/loss_ctr))
  model.eval()
  loss_ctr = 0
  n_acc = 0

  for i, batch in enumerate(test_loader, 1):
    loss, acc = fast_adapt(model,
                            batch,
                            test_way,
                            test_shot,
                            test_query,
                            metric=pairwise_distances_logits,
                            device=device)
    loss_ctr += 1
    n_acc += acc
  print(f"The testing accuracy using test set I (2000 CIFAR-10 training samples) is: {np.round(int((n_acc*2000/loss_ctr).cpu().detach().numpy()))}/2000 {n_acc*100/loss_ctr:.2f}% ")
  accs.append((n_acc*100/loss_ctr).item())

  model.eval()
  loss_ctr = 0
  n_acc = 0

  for i, batch in enumerate(final_test_loader, 1):
    loss, acc = fast_adapt(model,
                            batch,
                            test_way,
                            test_shot,
                            test_query,
                            metric=pairwise_distances_logits,
                            device=device)
    loss_ctr += 1
    n_acc += acc
  print(f"The testing accuracy using test set II (2000 CIFAR-10 testing samples) is: {np.round(int((n_acc*2000/loss_ctr).cpu().detach().numpy()))}/2000 {n_acc*100/loss_ctr:.2f}% ")
  final_accs.append((n_acc*100/loss_ctr).item())
accs= np.array(accs)
print(f"Acc over {n_subsets} instances of test set I is: {accs.mean():.3f} +- {accs.std():.3f}")
final_accs= np.array(final_accs)
print(f"Acc over {n_subsets} instances of test set II is: {final_accs.mean():.3f} +- {final_accs.std():.3f}")

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
epoch 1, train, loss=2.6878 acc=0.2467
epoch 1, val, loss=4.0647 acc=0.2920
epoch 2, train, loss=2.1222 acc=0.2432
epoch 2, val, loss=3.2027 acc=0.2960
epoch 3, train, loss=2.0954 acc=0.2477
epoch 3, val, loss=3.2423 acc=0.3160
epoch 4, train, loss=2.0465 acc=0.2785
epoch 4, val, loss=2.7669 acc=0.2920
epoch 5, train, loss=2.0617 acc=0.2676
epoch 5, val, loss=2.6553 acc=0.3120
epoch 6, train, loss=2.0538 acc=0.2600
epoch 6, val, loss=2.6660 acc=0.3120
epoch 7, train, loss=2.0314 acc=0.2777
epoch 7, val, loss=2.0929 acc=0.3360
epoch 8, train, loss=1.9750 acc=0.3018
epoch 8, val, loss=2.3334 acc=0.3480
epoch 9, train, loss=1.9762 acc=0.2992
epoch 9, val, loss=2.3298 acc=0.3440
epoch 10, train, loss=1.9597 acc=0.3032
epoch 10, val, loss=2.1469 acc=0.3720
epoch 11, train, loss=1.9751 acc=0.3057
epoch 11, val, loss=2.1219 acc=0.3360
epoch 12, train, loss=1.9436 acc=0.3097
epoch 

# Calculating testing accuracy using a similar approach to the "fast_adapt" function (Sanity check suggested by Dr. Eugene)

In this section we try to calculate the testing accuracy using the following approach.
  
  1- First embed the 100 samples from the CIFAR-10 training data using the trained embedding. 
 
  2- Then we calculate a class' prototypes by averaging the image embedddings
 
  3- The distance between an embedded test image and the 10 prototypes is measured, and the label of closest prototype is assigned to the test image.


  Note: This section is only for sanity checking and was only done on one experiment (10-way 1-shot) for only one subset of test set I

In [None]:
dataloaders = (torch.utils.data.DataLoader(valid_dataset, batch_size=100, shuffle=True),
                    torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False))

In [None]:
model.eval()
data_iter = iter(dataloaders[0])
data, target = next(data_iter)
data, target = data.to(device), target.to(device)
with torch.no_grad():
  embeddings = model(data)                      # Embedding the 100 samples from CIFAR-10

In [None]:
prototypes = torch.empty((10,256),dtype=torch.float32)
for i in range(10):
  prototypes[i,:] = embeddings[np.where(target.cpu()==i)].mean(axis=0)          # A class' prototype is the average of the embeddings of the images in the class
  

In [None]:
model.to("cpu")
device = "cpu"
correct = 0
for id, (x,y) in enumerate(dataloaders[1]):
  x,y = x.to(device), y.to(device)
  prototypes.to(device)
  with torch.no_grad():
    dist = torch.norm(prototypes - model(x), dim=1, p=None)                     # Calculating the distance between the embedded test image and 10 prototypes (10 classes)
    knn = dist.topk(5, largest=False)                                           # Getting the closest 5 prototypes of the test image
    prediction = knn.indices[knn.values.argmin()]                               # The assigned label is for the closest prototype to the test image
    correct += prediction.eq(y.view_as(prediction)).sum().item()                # determining the number of correctly classified test images
print(f"The accuracy is {correct*100/len(dataloaders[1]):.2f}%")

The accuracy is 28.30%
