In [1]:
# mounting google drive to colab
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm

In [3]:
import easydict

args = easydict.EasyDict({
    'root': 'data', # path to dataset
    'results_root': 'results', # root to store models, loss and accuracies
    'epochs': 20,
    'learning_rate': 0.001,
    'lr_scheduler_step': 20,
    'lr_scheduler_gamma': 0.5,
    'iterations': 100,
    'Nc_train': 60,
    'Ns_train': 5,
    'Nq_train': 5,
    'Nc_test': 5,
    'Ns_test': 5,
    'Nq_test': 15,
    'manual_seed': 423,
    'use_cuda': True
})


device = torch.device('cuda' if torch.cuda.is_available() and args.use_cuda else 'cpu')

In [4]:
device

device(type='cuda')

# Seed

In [5]:
torch.cuda.cudnn_enabled = False

seed = args.manual_seed
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

In [6]:
if not os.path.exists(args.results_root):
  os.makedirs(args.results_root)
if not os.path.exists(args.root):
  os.makedirs(args.root)

# Dataset (Omniglot)

In [7]:
from torchvision.datasets import Omniglot
import torchvision.transforms as transforms

transform = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Resize(28)
                                ])

train_val_dataset = Omniglot(root=args.root+'/train', background=True, transform=transform, download=True)
test_dataset = Omniglot(root=args.root+'/test', background=False, transform=transform, download=True)

Downloading https://raw.githubusercontent.com/brendenlake/omniglot/master/python/images_background.zip to data/train/omniglot-py/images_background.zip


  0%|          | 0/9464212 [00:00<?, ?it/s]

Extracting data/train/omniglot-py/images_background.zip to data/train/omniglot-py
Downloading https://raw.githubusercontent.com/brendenlake/omniglot/master/python/images_evaluation.zip to data/test/omniglot-py/images_evaluation.zip


  0%|          | 0/6462886 [00:00<?, ?it/s]

Extracting data/test/omniglot-py/images_evaluation.zip to data/test/omniglot-py


In [8]:
train_val_dataset[0][1]

0

In [9]:
# get labels for dataset
def get_labels(dataset):
  labels = list()
  for i in range(len(dataset)):
    y = dataset[i][1]
    labels.append(y)
  return labels

In [10]:
train_val_dataset_labels = get_labels(train_val_dataset)

In [69]:
from sklearn.model_selection import train_test_split
train_indices, val_indices = train_test_split((range(len(train_val_dataset_labels))), test_size=0.25, shuffle=False)

In [75]:
train_dataset = torch.utils.data.Subset(train_val_dataset, train_indices)
val_dataset = torch.utils.data.Subset(train_val_dataset, val_indices)

In [76]:
train_dataset_labels = get_labels(train_dataset)
val_dataset_labels = get_labels(val_dataset)
test_dataset_labels = get_labels(dataset=test_dataset)

# BatchSampler

In [None]:
# coding=utf-8
import numpy as np
import torch


class PrototypicalBatchSampler(object):

    def __init__(self, labels, Nc, num_samples, iterations):
        super(PrototypicalBatchSampler, self).__init__()
        self.labels = labels
        self.Nc = Nc
        self.num_samples = num_samples
        self.iterations = iterations

        self.classes, self.counts = np.unique(self.labels, return_counts=True)
        self.classes = torch.LongTensor(self.classes)

        self.index = range(len(self.labels))
        self.indices_by_class = np.empty((len(self.classes), max(self.counts)), dtype=int) * np.nan
        self.indices_by_class = torch.Tensor(self.indices_by_class)
        self.numel_per_class = torch.zeros_like(self.classes)
        for idx, label in enumerate(self.labels):
            class_idx = np.argwhere(self.classes == label).item()
            self.indices_by_class[class_idx, np.where(np.isnan(self.indices_by_class[class_idx]))[0][0]] = idx
            self.numel_per_class[class_idx] += 1

    def __iter__(self):
        for it in range(self.iterations):
            batch_size = self.num_samples * self.Nc
            batch = torch.LongTensor(batch_size)
            selected_class_indices = torch.randperm(len(self.classes))[:self.Nc]
            for i, c in enumerate(self.classes[selected_class_indices]):
                batch_indices = slice(i * self.num_samples, (i + 1) * self.num_samples)
                label_idx = torch.arange(len(self.classes)).long()[self.classes == c].item()
                sample_indices = torch.randperm(self.numel_per_class[label_idx])[:self.num_samples]
                batch[batch_indices] = self.indices_by_class[label_idx][sample_indices]
            batch = batch[torch.randperm(len(batch))]
            yield batch

    def __len__(self):
        return self.iterations

In [None]:
Nc_train = args.Nc_train
num_samples_train = args.Ns_train + args.Nq_train

Nc_test = args.Nc_test
num_samples_test = args.Ns_test + args.Nq_test

In [80]:
train_sampler = PrototypicalBatchSampler(labels=train_dataset_labels,
                                         Nc=Nc_train, 
                                         num_samples=num_samples_train, 
                                         iterations=args.iterations)
val_sampler = PrototypicalBatchSampler(labels=val_dataset_labels,
                                       Nc=Nc_test,
                                       num_samples=num_samples_test,
                                       iterations=args.iterations)
test_sampler = PrototypicalBatchSampler(labels=test_dataset_labels, 
                                        Nc=Nc_test, 
                                        num_samples=num_samples_test,
                                        iterations=args.iterations)

# DataLoader

In [81]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(dataset=train_dataset,
                              batch_sampler=train_sampler)

val_dataloader = DataLoader(dataset=val_dataset,
                            batch_sampler=val_sampler)

test_dataloader = DataLoader(dataset=test_dataset,
                             batch_sampler=test_sampler)

# PrototypicalNet architecture

In [82]:
import torch.nn as nn

def conv_block(in_channels, out_channels):
  return nn.Sequential(
      nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
      nn.BatchNorm2d(out_channels),
      nn.ReLU(),
      nn.MaxPool2d(2)
  )

In [83]:
class ProtoNet(nn.Module):
  def __init__(self, in_channels=1, out_channels=64):
    super(ProtoNet, self).__init__()
    self.embedding = nn.Sequential(
        conv_block(in_channels, out_channels),
        conv_block(out_channels, out_channels),
        conv_block(out_channels, out_channels),
        conv_block(out_channels, out_channels)
    )

  def forward(self, x):
    out = self.embedding(x)
    out = out.view(x.size(0), -1)
    return out

In [84]:
model = ProtoNet()
model.to(device)

ProtoNet(
  (embedding): Sequential(
    (0): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (1): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (2): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (3): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), 

# Optimizer & Scheduler

In [85]:
optimizer = torch.optim.Adam(params=model.parameters(), lr=args.learning_rate)

lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer,
                                            gamma=args.lr_scheduler_gamma,
                                            step_size=args.lr_scheduler_step)

# Loss function

In [86]:
import torch
from torch.nn import functional as F

def euclidean_dist(x, y):
    # x: N x D
    # y: M x D
    n = x.size(0)
    m = y.size(0)
    d = x.size(1)
    if x.size(1) != y.size(1):
        raise Exception

    x = x.unsqueeze(1).expand(n, m, d)
    y = y.unsqueeze(0).expand(n, m, d)

    return torch.pow(x - y, 2).sum(2)


def prototypical_loss(input, target, n_support):
    
    target = target.to('cpu')
    input = input.to('cpu')

    def supp_idxs(c):
        return target.eq(c).nonzero()[:n_support].squeeze(1)

    classes = torch.unique(target)
    n_classes = len(classes)
    n_query = target.eq(classes[0].item()).sum().item() - n_support

    support_idxs = list(map(supp_idxs, classes))

    prototypes = torch.stack([input[idx_list].mean(0) for idx_list in support_idxs])
    query_idxs = torch.stack(list(map(lambda c: target.eq(c).nonzero()[n_support:], classes))).view(-1)

    query_samples = input[query_idxs]
    # print("query_sampeles: ", query_samples)
    # print("prototypes: ", prototypes)
    dists = euclidean_dist(query_samples, prototypes)

    log_p_y = F.log_softmax(-dists, dim=1).view(n_classes, n_query, -1)

    target_inds = torch.arange(0, n_classes)
    target_inds = target_inds.view(n_classes, 1, 1)
    target_inds = target_inds.expand(n_classes, n_query, 1).long()

    loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean()
    _, y_hat = log_p_y.max(2)
    acc_val = y_hat.eq(target_inds.squeeze()).float().mean()

    return loss_val,  acc_val

# Train

In [87]:
def train(args, train_dataloader, model, optimizer, lr_scheduler, device):
  train_loss = []
  train_acc = []
  val_loss = []
  val_acc = []
  best_acc = 0
  best_state = None

  best_model_path = os.path.join(args.results_root, 'best_model.pth')
  last_model_path = os.path.join(args.results_root, 'last_model.pth')

  for epoch in range(args.epochs):
    print("***** Epoch:{} *****".format(epoch+1))

    # train
    train_iterator = iter(train_dataloader)
    model.train()
    for batch in tqdm(train_iterator):
      input, target = batch
      input, target = input.to(device), target.to(device)
      output = model(input)
      loss, acc = prototypical_loss(output, target=target, n_support=args.Ns_train)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      train_loss.append(loss.item())
      train_acc.append(acc.item())
    epoch_avg_train_loss = np.mean(train_loss[-args.iterations:])
    epoch_avg_train_acc = np.mean(train_acc[-args.iterations:])
    print("Epoch: {} / Avg Train Loss: {} / Avg Train Accuracy: {}".format(epoch+1, epoch_avg_train_loss, epoch_avg_train_acc))
    lr_scheduler.step()

    # validation
    val_iterator = iter(val_dataloader)
    model.eval()
    for batch in tqdm(val_iterator):
      input, target = batch
      input, target = input.to(device), target.to(device)
      output = model(input)
      loss, acc = prototypical_loss(output, target=target, n_support=args.Ns_test)
      val_loss.append(loss.item())
      val_acc.append(acc.item())
    epoch_avg_val_loss = np.mean(val_loss[-args.iterations:])
    epoch_avg_val_acc = np.mean(val_acc[-args.iterations:])
    print("Epoch: {} / Avg Val Loss: {} / Avg Val Accuracy: {}".format(epoch+1, epoch_avg_val_loss, epoch_avg_val_acc))
    if epoch_avg_val_acc > best_acc:
      best_acc = epoch_avg_val_acc
      best_state = model.state_dict()
      torch.save(model.state_dict(), best_model_path)

  torch.save(model.state_dict(), last_model_path)

  return best_state, best_acc, train_loss, train_acc, val_loss, val_acc

In [90]:
 best_state, best_acc, train_loss, train_acc, val_loss, val_acc = train(args=args, train_dataloader=train_dataloader, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, device=device)

***** Epoch:1 *****


100%|██████████| 100/100 [00:33<00:00,  3.03it/s]


Epoch: 1 / Avg Train Loss: 0.0351407729415223 / Avg Train Accuracy: 0.9864333426952362


100%|██████████| 100/100 [00:05<00:00, 19.31it/s]


Epoch: 1 / Avg Val Loss: 0.023662043417086453 / Avg Val Accuracy: 0.9930666702985763
***** Epoch:2 *****


100%|██████████| 100/100 [00:34<00:00,  2.87it/s]


Epoch: 2 / Avg Train Loss: 0.027434692538809032 / Avg Train Accuracy: 0.9895666742324829


100%|██████████| 100/100 [00:05<00:00, 19.15it/s]


Epoch: 2 / Avg Val Loss: 0.018043732257260157 / Avg Val Accuracy: 0.9950666701793671
***** Epoch:3 *****


100%|██████████| 100/100 [00:32<00:00,  3.12it/s]


Epoch: 3 / Avg Train Loss: 0.025317456219345332 / Avg Train Accuracy: 0.9905000078678131


100%|██████████| 100/100 [00:05<00:00, 18.72it/s]


Epoch: 3 / Avg Val Loss: 0.017625825217099936 / Avg Val Accuracy: 0.9948000019788742
***** Epoch:4 *****


100%|██████████| 100/100 [00:36<00:00,  2.73it/s]


Epoch: 4 / Avg Train Loss: 0.02572942486847751 / Avg Train Accuracy: 0.9901000088453293


100%|██████████| 100/100 [00:06<00:00, 14.58it/s]


Epoch: 4 / Avg Val Loss: 0.026342027434074666 / Avg Val Accuracy: 0.9940000033378601
***** Epoch:5 *****


100%|██████████| 100/100 [00:34<00:00,  2.90it/s]


Epoch: 5 / Avg Train Loss: 0.021920030976179988 / Avg Train Accuracy: 0.9909000068902969


100%|██████████| 100/100 [00:05<00:00, 19.14it/s]


Epoch: 5 / Avg Val Loss: 0.021986272270254827 / Avg Val Accuracy: 0.9945333343744278
***** Epoch:6 *****


100%|██████████| 100/100 [00:34<00:00,  2.87it/s]


Epoch: 6 / Avg Train Loss: 0.02487734501133673 / Avg Train Accuracy: 0.9904666739702225


100%|██████████| 100/100 [00:05<00:00, 18.81it/s]


Epoch: 6 / Avg Val Loss: 0.02764376111201416 / Avg Val Accuracy: 0.9913333356380463
***** Epoch:7 *****


100%|██████████| 100/100 [00:32<00:00,  3.11it/s]


Epoch: 7 / Avg Train Loss: 0.023422658324707298 / Avg Train Accuracy: 0.9916000068187714


100%|██████████| 100/100 [00:05<00:00, 19.05it/s]


Epoch: 7 / Avg Val Loss: 0.019386253766739223 / Avg Val Accuracy: 0.9945333367586136
***** Epoch:8 *****


100%|██████████| 100/100 [00:31<00:00,  3.13it/s]


Epoch: 8 / Avg Train Loss: 0.020886421297909692 / Avg Train Accuracy: 0.9920000076293946


100%|██████████| 100/100 [00:05<00:00, 17.58it/s]


Epoch: 8 / Avg Val Loss: 0.023665253678305086 / Avg Val Accuracy: 0.9942666697502136
***** Epoch:9 *****


100%|██████████| 100/100 [00:34<00:00,  2.90it/s]


Epoch: 9 / Avg Train Loss: 0.022265104046673513 / Avg Train Accuracy: 0.9913333410024643


100%|██████████| 100/100 [00:05<00:00, 18.91it/s]


Epoch: 9 / Avg Val Loss: 0.03633340789250029 / Avg Val Accuracy: 0.9929333364963532
***** Epoch:10 *****


100%|██████████| 100/100 [00:34<00:00,  2.91it/s]


Epoch: 10 / Avg Train Loss: 0.025901515031000598 / Avg Train Accuracy: 0.9900000077486039


100%|██████████| 100/100 [00:05<00:00, 19.33it/s]


Epoch: 10 / Avg Val Loss: 0.02704202052288511 / Avg Val Accuracy: 0.9941333365440369
***** Epoch:11 *****


100%|██████████| 100/100 [00:31<00:00,  3.13it/s]


Epoch: 11 / Avg Train Loss: 0.024657122346106917 / Avg Train Accuracy: 0.9907000064849854


100%|██████████| 100/100 [00:05<00:00, 19.27it/s]


Epoch: 11 / Avg Val Loss: 0.04106119146645877 / Avg Val Accuracy: 0.9926666712760925
***** Epoch:12 *****


100%|██████████| 100/100 [00:31<00:00,  3.13it/s]


Epoch: 12 / Avg Train Loss: 0.023599828825099395 / Avg Train Accuracy: 0.9911333411931992


100%|██████████| 100/100 [00:05<00:00, 19.46it/s]


Epoch: 12 / Avg Val Loss: 0.025247796339148833 / Avg Val Accuracy: 0.9941333371400833
***** Epoch:13 *****


100%|██████████| 100/100 [00:31<00:00,  3.20it/s]


Epoch: 13 / Avg Train Loss: 0.02305776538909413 / Avg Train Accuracy: 0.9913333415985107


100%|██████████| 100/100 [00:05<00:00, 19.51it/s]


Epoch: 13 / Avg Val Loss: 0.030215339121524708 / Avg Val Accuracy: 0.9908000028133392
***** Epoch:14 *****


100%|██████████| 100/100 [00:31<00:00,  3.19it/s]


Epoch: 14 / Avg Train Loss: 0.023088916002307086 / Avg Train Accuracy: 0.9913000082969665


100%|██████████| 100/100 [00:05<00:00, 19.32it/s]


Epoch: 14 / Avg Val Loss: 0.01495821106342046 / Avg Val Accuracy: 0.9949333363771439
***** Epoch:15 *****


100%|██████████| 100/100 [00:31<00:00,  3.17it/s]


Epoch: 15 / Avg Train Loss: 0.02035960180917755 / Avg Train Accuracy: 0.9927666735649109


100%|██████████| 100/100 [00:05<00:00, 19.49it/s]


Epoch: 15 / Avg Val Loss: 0.03371865359893702 / Avg Val Accuracy: 0.9914666712284088
***** Epoch:16 *****


100%|██████████| 100/100 [00:31<00:00,  3.16it/s]


Epoch: 16 / Avg Train Loss: 0.023066647408995777 / Avg Train Accuracy: 0.9910666733980179


100%|██████████| 100/100 [00:05<00:00, 18.44it/s]


Epoch: 16 / Avg Val Loss: 0.03391813380694318 / Avg Val Accuracy: 0.992000002861023
***** Epoch:17 *****


100%|██████████| 100/100 [00:32<00:00,  3.12it/s]


Epoch: 17 / Avg Train Loss: 0.022763706827536225 / Avg Train Accuracy: 0.991200008392334


100%|██████████| 100/100 [00:05<00:00, 18.66it/s]


Epoch: 17 / Avg Val Loss: 0.021377006456449913 / Avg Val Accuracy: 0.9944000029563904
***** Epoch:18 *****


100%|██████████| 100/100 [00:31<00:00,  3.14it/s]


Epoch: 18 / Avg Train Loss: 0.024983593514189124 / Avg Train Accuracy: 0.9907000082731247


100%|██████████| 100/100 [00:05<00:00, 19.11it/s]


Epoch: 18 / Avg Val Loss: 0.03489652028162192 / Avg Val Accuracy: 0.9924000054597855
***** Epoch:19 *****


100%|██████████| 100/100 [00:31<00:00,  3.16it/s]


Epoch: 19 / Avg Train Loss: 0.021078624428482726 / Avg Train Accuracy: 0.9914666736125946


100%|██████████| 100/100 [00:05<00:00, 19.18it/s]


Epoch: 19 / Avg Val Loss: 0.04063123741452464 / Avg Val Accuracy: 0.9921333372592926
***** Epoch:20 *****


100%|██████████| 100/100 [00:31<00:00,  3.15it/s]


Epoch: 20 / Avg Train Loss: 0.01690622205613181 / Avg Train Accuracy: 0.9938000059127807


100%|██████████| 100/100 [00:05<00:00, 19.13it/s]

Epoch: 20 / Avg Val Loss: 0.0208341412601557 / Avg Val Accuracy: 0.9948000025749206





# Test

In [91]:
def test(args, test_dataloader, model, device):
  for epoch in range(5):
    test_loss = list()
    test_acc = list()
    test_iterator = iter(test_dataloader)
    for batch in test_iterator:
        input, target = batch
        input, target = input.to(device), target.to(device)
        output = model(input)
        loss, acc = prototypical_loss(output, target=target, n_support=args.Ns_test)
        test_loss.append(loss.item())
        test_acc.append(acc.item())
    epoch_avg_test_loss = np.mean(test_loss)
    epoch_avg_test_acc = np.mean(test_acc)
    print("Test Epoch: {} / Avg Test Loss: {} / Avg Test Accuracy: {}".format(epoch+1, epoch_avg_test_loss, epoch_avg_test_acc))

In [92]:
print('===== Testing with last model =====')
test(args=args, test_dataloader=test_dataloader, model=model, device=device)

model.load_state_dict(best_state)
print('===== Testing with best model =====')
test(args=args, test_dataloader=test_dataloader, model=model, device=device)

===== Testing with last model =====
Test Epoch: 1 / Avg Test Loss: 0.04859817686585927 / Avg Test Accuracy: 0.9912000054121017
Test Epoch: 2 / Avg Test Loss: 0.03093400157634363 / Avg Test Accuracy: 0.9928000050783158
Test Epoch: 3 / Avg Test Loss: 0.03593346374054543 / Avg Test Accuracy: 0.9920000064373017
Test Epoch: 4 / Avg Test Loss: 0.04853993382797693 / Avg Test Accuracy: 0.990533338189125
Test Epoch: 5 / Avg Test Loss: 0.03914078666101156 / Avg Test Accuracy: 0.9910666728019715
===== Testing with best model =====
Test Epoch: 1 / Avg Test Loss: 0.05165282077998984 / Avg Test Accuracy: 0.9925333392620087
Test Epoch: 2 / Avg Test Loss: 0.043470068675542844 / Avg Test Accuracy: 0.9922666716575622
Test Epoch: 3 / Avg Test Loss: 0.07363565162398611 / Avg Test Accuracy: 0.9898666727542877
Test Epoch: 4 / Avg Test Loss: 0.04484728974145563 / Avg Test Accuracy: 0.9906666713953018
Test Epoch: 5 / Avg Test Loss: 0.04123870913677613 / Avg Test Accuracy: 0.992000002861023
