In [None]:
# mounting google drive to colab
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm

In [None]:
import easydict

args = easydict.EasyDict({
    'root': 'data', # path to dataset
    'results_root': 'results', # root to store models, loss and accuracies
    'epochs': 20,
    'learning_rate': 0.001,
    'lr_scheduler_step': 20, # per 2000iterations, lr *= gamma (1epoch: 100iterations)
    'lr_scheduler_gamma': 0.5,
    'iterations': 100,
    'Nc_train': 60,
    'Ns_train': 5,
    'Nq_train': 5,
    'Nc_test': 5,
    'Ns_test': 5,
    'Nq_test': 15,
    'manual_seed': 423,
    'use_cuda': True
})

device = torch.device('cuda' if torch.cuda.is_available() and args.use_cuda else 'cpu')

In [None]:
device

device(type='cuda')

# Seed

In [None]:
torch.cuda.cudnn_enabled = False

seed = args.manual_seed
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

In [None]:
if not os.path.exists(args.results_root):
  os.makedirs(args.results_root)
if not os.path.exists(args.root):
  os.makedirs(args.root)

# Dataset

In [None]:
from torchvision.datasets import Omniglot
import torchvision.transforms as transforms

transform = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Resize(28)
                                ])

train_val_dataset = Omniglot(root=args.root+'/train', background=True, transform=transform, download=True)
test_dataset = Omniglot(root=args.root+'/test', background=False, transform=transform, download=True)

Downloading https://raw.githubusercontent.com/brendenlake/omniglot/master/python/images_background.zip to data/train/omniglot-py/images_background.zip


  0%|          | 0/9464212 [00:00<?, ?it/s]

Extracting data/train/omniglot-py/images_background.zip to data/train/omniglot-py
Downloading https://raw.githubusercontent.com/brendenlake/omniglot/master/python/images_evaluation.zip to data/test/omniglot-py/images_evaluation.zip


  0%|          | 0/6462886 [00:00<?, ?it/s]

Extracting data/test/omniglot-py/images_evaluation.zip to data/test/omniglot-py


In [None]:
# Check
train_val_dataset[0][1]

0

In [None]:
# Get labels for dataset
def get_labels(dataset):
  labels = list()
  for i in range(len(dataset)):
    labels.append(dataset[i][1])
  return labels

In [None]:
train_val_dataset_labels = get_labels(train_val_dataset)

In [None]:
from sklearn.model_selection import train_test_split

train_indices, val_indices = train_test_split((range(len(train_val_dataset_labels))), test_size=0.25, shuffle=False)

train_dataset = torch.utils.data.Subset(train_val_dataset, train_indices)
val_dataset = torch.utils.data.Subset(train_val_dataset, val_indices)

In [None]:
train_dataset_labels = get_labels(train_dataset)
val_dataset_labels = get_labels(val_dataset)
test_dataset_labels = get_labels(test_dataset)

# BatchSampler

In [None]:
import numpy as np
import torch

class ProtoBatchSampler(object):

  def __init__(self, labels, Nc, num_samples, iterations):
    """
    labels: dataset labels
    Nc: number of classes 
    num_samples = Ns + Nq
    iterations: number of iterations per epoch
    """
    super(ProtoBatchSampler, self).__init__()
    self.labels = labels
    self.Nc = Nc
    self.num_samples = num_samples
    self.iterations = iterations

    self.classes, self.counts_per_class = np.unique(self.labels, return_counts=True)
    self.classes = torch.LongTensor(self.classes)
    self.indices_by_class = torch.Tensor(np.empty((len(self.classes), max(self.counts_per_class))) * np.nan)
    for data_idx, label in enumerate(self.labels):
      row_idx = np.argwhere(self.classes == label).item()
      col_idx = np.where(np.isnan(self.indices_by_class[row_idx]))[0][0].item()
      self.indices_by_class[row_idx, col_idx] = data_idx
  
  def __iter__(self):
    for it in range(self.iterations):
      batch_size = self.Nc * self.num_samples
      batch = torch.LongTensor(batch_size)
      selected_class_indices = torch.randperm(len(self.classes))[:self.Nc]
      for i, c in enumerate(self.classes[selected_class_indices]):
        batch_indices = slice(i*self.num_samples, (i+1)*self.num_samples)
        row_idx = np.argwhere(self.classes == c).item()
        sample_indices = torch.randperm(self.counts_per_class[row_idx])[:self.num_samples]
        batch[batch_indices] = self.indices_by_class[row_idx][sample_indices]
      batch = batch[torch.randperm(batch_size)]
      yield batch
  
  def __len__(self):
    return self.iterations

In [None]:
train_sampler = ProtoBatchSampler(labels=train_dataset_labels,
                                         Nc=args.Nc_train, 
                                         num_samples=args.Ns_train + args.Nq_train, 
                                         iterations=args.iterations)
val_sampler = ProtoBatchSampler(labels=val_dataset_labels,
                                       Nc=args.Nc_test,
                                       num_samples=args.Ns_test + args.Nq_test,
                                       iterations=args.iterations)
test_sampler = ProtoBatchSampler(labels=test_dataset_labels, 
                                        Nc=args.Nc_test, 
                                        num_samples=args.Ns_test + args.Nq_test,
                                        iterations=args.iterations)

# DataLoader

In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(dataset=train_dataset,
                              batch_sampler=train_sampler)

val_dataloader = DataLoader(dataset=val_dataset,
                            batch_sampler=val_sampler)

test_dataloader = DataLoader(dataset=test_dataset,
                             batch_sampler=test_sampler)

# Protonet architecture

In [None]:
import torch.nn as nn

def conv_block(in_channels, out_channels):
  return nn.Sequential(
      nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
      nn.BatchNorm2d(out_channels),
      nn.ReLU(),
      nn.MaxPool2d(2)
  )

class Protonet(nn.Module):

  def __init__(self, in_channels=1, out_channels=64):
    super(Protonet, self).__init__()
    self.embedding = nn.Sequential(
        conv_block(in_channels, out_channels),
        conv_block(out_channels, out_channels),
        conv_block(out_channels, out_channels),
        conv_block(out_channels, out_channels)
    )
  
  def forward(self, x):
    out = self.embedding(x)
    out = out.view(x.size(0), -1)
    return out

In [None]:
model = Protonet()
model.to(device)

Protonet(
  (embedding): Sequential(
    (0): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (1): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (2): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (3): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), 

# Loss function

In [None]:
import torch
import torch.nn.functional as F

def euclidean_distance(x, y):
  """
  x, y is either query_embedded or prototypes
  x: [N, D]
  y: [M, D]
  """
  if x.size(1) != y.size(1):
    raise Exception
  N = x.size(0)
  M = y.size(0)
  D = x.size(1)
  # Make both tensors to [N, M, D] shape
  x = x.unsqueeze(1).expand(N, M, D)
  y = y.unsqueeze(0).expand(N, M, D)
  # Calculate euclidean distance between x and y
  dist = torch.pow(x-y, 2).sum(dim=2) # [N, M]
  return dist

def proto_loss(input, target, Ns, Nq):
  """
  input(torch.Tensor): [num_samples, embedded_dim]
  target(torch.LongTensor): [num_samples, ]
  """
  input = input.to('cpu')
  target = target.to('cpu') # class label

  def supportset_select(class_idx):
    return target.eq(class_idx).nonzero()[:Ns].squeeze(1) # 1D tensor
  
  def query_select(class_idx):
    return target.eq(class_idx).nonzero()[Ns:].squeeze(1)

  class_indices = torch.unique(target)
  Nc = len(class_indices)

  supportset_indices = list(map(supportset_select, class_indices))
  prototypes = torch.stack([input[supportset_idx].mean(0) for supportset_idx in supportset_indices])

  query_indices = torch.stack(list(map(query_select, class_indices))).view(-1)
  query_samples = input[query_indices]
  
  dists = euclidean_distance(query_samples, prototypes) # [Nc*Nq, Nc]
  log_p = F.log_softmax(-dists, dim=1).view(Nc, Nq, -1)

  target_indices = torch.arange(Nc)
  target_indices = target_indices.view(Nc, 1, 1)
  target_indices = target_indices.expand(Nc, Nq, 1).long()

  loss = -log_p.gather(2, target_indices).squeeze().view(-1).mean()
  _, predict = log_p.max(2)
  acc = predict.eq(target_indices.squeeze()).float().mean()
  
  return loss, acc



# Optimizer & lr_scheduler

In [None]:
optimizer = torch.optim.Adam(params=model.parameters(), lr=args.learning_rate)

lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer,
                                            gamma=args.lr_scheduler_gamma,
                                            step_size=args.lr_scheduler_step)

# Train

In [None]:
def train(model, dataloader, epoch):
  model.train()

  train_loss = []
  train_acc = []
  log_interval = 20
  
  for it, (inputs, targets) in enumerate(dataloader):
    optimizer.zero_grad()

    inputs = inputs.to(device)
    targets = targets.to(device)
    
    outs = model(inputs)
    loss, acc = proto_loss(outs, targets, args.Ns_train, args.Nq_train)
    train_loss.append(loss.item())
    train_acc.append(acc.item())

    loss.backward()
    optimizer.step()

    if it % log_interval == 0 and it > 0:
      print(f"| epoch {epoch:3d} | {it:5d}/{len(dataloader):5d} iteration | loss {loss:8.3f} | accuracy {acc:8.3f}")

  train_loss_epoch = np.mean(train_loss)
  train_acc_epoch = np.mean(train_acc)

  lr_scheduler.step()


# Validate

In [None]:
def validate(model, dataloader):
  with torch.no_grad():
    model.eval()

    val_loss = []
    val_acc = []
    log_interval = 20

    for it, (inputs, targets) in enumerate(dataloader):
      inputs = inputs.to(device)
      targets = targets.to(device)

      outs = model(inputs)
      loss, acc = proto_loss(outs, targets, args.Ns_test, args.Nq_test)
      val_loss.append(loss.item())
      val_acc.append(acc.item())

      if it % log_interval == 0 and it > 0:
        print(f"| {it:5d}/{len(dataloader):5d} iteration | loss {loss:8.3f} | accuracy {acc:8.3f}")
    
    val_loss_epoch = np.mean(val_loss)
    val_acc_epoch = np.mean(val_acc)
    return val_loss_epoch, val_acc_epoch

In [None]:
best_acc = 0
best_state = None
for epoch in range(1, args.epochs + 1):
  print(f"========== Epoch: {epoch} ==========")
  print("Train")
  train(model, train_dataloader, epoch)
  print("Validation")
  val_loss_epoch, val_acc_epoch = validate(model, val_dataloader)

  print('-' * 70)
  if best_acc < val_acc_epoch:
    best_acc = val_acc_epoch
    best_state = model.state_dict()
    print(f"best validation accuracy {best_acc:8.3f}")
    torch.save(model.state_dict(), f"{args.results_root}/best_model.pth")
  
  print(f"| end of epoch {epoch:3d} | best accuracy {best_acc:8.3f}")
  
  print('=' * 70)

Train
| epoch   1 |    20/  100 iteration | loss    0.562 | accuracy    0.860
| epoch   1 |    40/  100 iteration | loss    0.278 | accuracy    0.923
| epoch   1 |    60/  100 iteration | loss    0.322 | accuracy    0.890
| epoch   1 |    80/  100 iteration | loss    0.322 | accuracy    0.900
Validation
|    20/  100 iteration | loss    0.201 | accuracy    0.907
|    40/  100 iteration | loss    0.013 | accuracy    1.000
|    60/  100 iteration | loss    0.004 | accuracy    1.000
|    80/  100 iteration | loss    0.088 | accuracy    0.960
----------------------------------------------------------------------
best validation accuracy    0.976
| end of epoch   1 | best accuracy    0.976
Train
| epoch   2 |    20/  100 iteration | loss    0.159 | accuracy    0.943
| epoch   2 |    40/  100 iteration | loss    0.156 | accuracy    0.940
| epoch   2 |    60/  100 iteration | loss    0.272 | accuracy    0.903
| epoch   2 |    80/  100 iteration | loss    0.211 | accuracy    0.933
Validation
|

In [None]:
model.load_state_dict(best_state)

test_loss, test_acc = validate(model, test_dataloader)

print(f"Test loss: {test_loss:8.3f} | Test acc: {test_acc:8.3f}")

|    20/  100 iteration | loss    0.000 | accuracy    1.000
|    40/  100 iteration | loss    0.002 | accuracy    1.000
|    60/  100 iteration | loss    0.007 | accuracy    1.000
|    80/  100 iteration | loss    0.000 | accuracy    1.000
Test loss:    0.050 | Test acc:    0.992
