# Omniglot dataset check

In [2]:
!pip install --upgrade virtualenv
!virtualenv venv
!source venv/bin/activate

Collecting virtualenv
  Downloading virtualenv-20.13.2-py2.py3-none-any.whl (8.7 MB)
[K     |████████████████████████████████| 8.7 MB 4.1 MB/s 
Collecting distlib<1,>=0.3.1
  Downloading distlib-0.3.4-py2.py3-none-any.whl (461 kB)
[K     |████████████████████████████████| 461 kB 38.2 MB/s 
[?25hCollecting platformdirs<3,>=2
  Downloading platformdirs-2.5.1-py3-none-any.whl (14 kB)
Installing collected packages: platformdirs, distlib, virtualenv
Successfully installed distlib-0.3.4 platformdirs-2.5.1 virtualenv-20.13.2
created virtual environment CPython3.7.12.final.0-64 in 1241ms
  creator CPython3Posix(dest=/content/venv, clear=False, no_vcs_ignore=False, global=False)
  seeder FromAppData(download=False, pip=bundle, setuptools=bundle, wheel=bundle, via=copy, app_data_dir=/root/.local/share/virtualenv)
    added seed packages: pip==22.0.3, setuptools==60.9.3, wheel==0.37.1
  activators BashActivator,CShellActivator,FishActivator,NushellActivator,PowerShellActivator,PythonActivator


In [3]:
!pip install torchmeta

Collecting torchmeta
  Downloading torchmeta-1.8.0-py3-none-any.whl (210 kB)
[?25l[K     |█▋                              | 10 kB 23.7 MB/s eta 0:00:01[K     |███▏                            | 20 kB 25.2 MB/s eta 0:00:01[K     |████▊                           | 30 kB 11.7 MB/s eta 0:00:01[K     |██████▎                         | 40 kB 9.3 MB/s eta 0:00:01[K     |███████▉                        | 51 kB 3.7 MB/s eta 0:00:01[K     |█████████▍                      | 61 kB 4.3 MB/s eta 0:00:01[K     |███████████                     | 71 kB 4.5 MB/s eta 0:00:01[K     |████████████▌                   | 81 kB 4.3 MB/s eta 0:00:01[K     |██████████████                  | 92 kB 4.8 MB/s eta 0:00:01[K     |███████████████▋                | 102 kB 4.2 MB/s eta 0:00:01[K     |█████████████████▏              | 112 kB 4.2 MB/s eta 0:00:01[K     |██████████████████▊             | 122 kB 4.2 MB/s eta 0:00:01[K     |████████████████████▎           | 133 kB 4.2 MB/s eta 0:00:01

In [4]:
!pip install torchvision



In [5]:
from torchmeta.datasets import Omniglot
from torchmeta.transforms import Categorical, ClassSplitter, Rotation
from torchvision.transforms import Compose, Resize, ToTensor
from torchmeta.utils.data import BatchMetaDataLoader

In [6]:
dataset = Omniglot("data",
                   num_classes_per_task=5,
                   transform=Compose([Resize(28), ToTensor()]),
                   target_transform=Categorical(num_classes=5),
                   class_augmentations=[Rotation([90, 180, 270])],
                   meta_train=True,
                   download=True)

Downloading https://raw.githubusercontent.com/brendenlake/omniglot/master/python/images_background.zip to data/omniglot/images_background.zip


  0%|          | 0/9464212 [00:00<?, ?it/s]

Downloading https://raw.githubusercontent.com/brendenlake/omniglot/master/python/images_evaluation.zip to data/omniglot/images_evaluation.zip


  0%|          | 0/6462886 [00:00<?, ?it/s]

In [7]:
dataset = ClassSplitter(dataset, shuffle=True, num_train_per_class=5, num_test_per_class=15)
dataloader = BatchMetaDataLoader(dataset, batch_size=16, num_workers=4)

  cpuset_checked))


In [8]:
for i, batch in enumerate(dataloader):
  if i == 10: break
  train_inputs, train_targets = batch['train']
  test_inputs, test_targets = batch['test']
  print('------'*10)
  print('Batch ID: {0}'.format(i))
  print('Train inputs shape per batch: {0}'.format(train_inputs.shape))
  print('Train targets shape per batch: {0}'.format(train_targets.shape))
  print()
  print('Test inputs shape per batch: {0}'.format(test_inputs.shape))
  print('Test targets shape per batch: {0}'.format(test_targets.shape))

  cpuset_checked))
  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Argument interpolation should be of type InterpolationMode instead of int. "


------------------------------------------------------------
Batch ID: 0
Train inputs shape per batch: torch.Size([16, 25, 1, 28, 28])
Train targets shape per batch: torch.Size([16, 25])

Test inputs shape per batch: torch.Size([16, 75, 1, 28, 28])
Test targets shape per batch: torch.Size([16, 75])
------------------------------------------------------------
Batch ID: 1
Train inputs shape per batch: torch.Size([16, 25, 1, 28, 28])
Train targets shape per batch: torch.Size([16, 25])

Test inputs shape per batch: torch.Size([16, 75, 1, 28, 28])
Test targets shape per batch: torch.Size([16, 75])
------------------------------------------------------------
Batch ID: 2
Train inputs shape per batch: torch.Size([16, 25, 1, 28, 28])
Train targets shape per batch: torch.Size([16, 25])

Test inputs shape per batch: torch.Size([16, 75, 1, 28, 28])
Test targets shape per batch: torch.Size([16, 75])
------------------------------------------------------------
Batch ID: 3
Train inputs shape per batc

# Model architecture for Omniglot dataset

In [9]:
import torch.nn as nn
from collections import OrderedDict
from torchmeta.modules import (MetaModule, MetaConv2d, MetaBatchNorm2d, MetaSequential, MetaLinear)

In [10]:
def conv_block(in_channels, out_channels, **kwargs):
  return MetaSequential(OrderedDict([
                                     ('conv', MetaConv2d(in_channels, out_channels, **kwargs)),
                                     ('norm', MetaBatchNorm2d(out_channels, momentum=1., track_running_stats=False)),
                                     ('relu', nn.ReLU()),
                                     ('pool', nn.MaxPool2d(2))
  ]))

In [11]:
class MetaConvModel(MetaModule):
  def __init__(self, in_channels, num_ways, hidden_size=64, feature_size=64):
    super(MetaConvModel, self).__init__()
    self.in_channels = in_channels
    self.num_ways = num_ways
    self.hidden_size = hidden_size,
    self.feature_size = feature_size
    self.features = MetaSequential(OrderedDict([
                                                ('layer1', conv_block(in_channels, hidden_size, kernel_size=3, stride=1, padding=1, bias=True)),
                                                ('layer2', conv_block(hidden_size, hidden_size, kernel_size=3, stride=1, padding=1, bias=True)),
                                                ('layer3', conv_block(hidden_size, hidden_size, kernel_size=3, stride=1, padding=1, bias=True)),
                                                ('layer4', conv_block(hidden_size, hidden_size, kernel_size=3, stride=1, padding=1, bias=True))
    ]))
    self.classifier = MetaLinear(feature_size, num_ways, bias=True)

  def forward(self, inputs, params=None):
    features = self.features(inputs, params=self.get_subdict(params, 'features'))
    features = features.view((features.size(0), -1))
    logits = self.classifier(features, params=self.get_subdict(params, 'classifier'))
    return logits

In [21]:
def ModelConvOmniglot(num_ways, hidden_size=64):
  return MetaConvModel(1, num_ways, hidden_size=hidden_size, feature_size=hidden_size)

# Meta Learners

In [12]:
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from torchmeta.utils import gradient_update_parameters

In [13]:
def compute_accuracy(logits, targets):
  with torch.no_grad():
    _, predictions = torch.max(logits, dim=1)
    accuracy = torch.mean(predictions.eq(targets).float())
  return accuracy.item()

In [14]:
def tensors_to_device(tensors, device=torch.device('cpu')):
  if isinstance(tensors, torch.Tensor):
    return tensors.to(device=device)
  elif isinstance(tensors, (list, tuple)):
    return type(tensors)(tensors_to_device(tensor, device=device) for tensor in tensors)
  elif isinstance(tensors, (dict, OrderedDict)):
    return type(tensors)([(key, tensors_to_device(tensor, device=device)) for (key, tensor) in tensors.items()])
  else:
    raise NotImplementedError()

In [42]:
class MAML(object):
  def __init__(self, model, optimizer=None, step_size=0.1, num_adaptation_steps=1, scheduler=None, loss_func=F.cross_entropy, device=None):
    self.model = model
    self.optimizer = optimizer
    self.step_size = step_size
    self.num_adaptation_steps = num_adaptation_steps
    self.scheduler = scheduler
    self.loss_func = loss_func
    self.device = device

    if scheduler is not None:
      for group in self.optimizer.param_groups:
        group.setdefault('initial_lr', group['lr'])
      self.scheduler.base_lrs([group['initial_lr'] for group in self.optimizer.param_groups])

  def adapt(self, inputs, targets, num_adaptation_steps=1, step_size=0.1):
    params = None
    # Initialize results['inner losses'] (per task)
    results = {'inner_losses': np.zeros((num_adaptation_steps, ), dtype=np.float32)}
    for step in range(num_adaptation_steps):
      logits = self.model(inputs, params=params)
      inner_loss = self.loss_func(logits, targets)
      results['inner_losses'][step] = inner_loss.item()
      if step == 0:
        results['accuracy_before'] = compute_accuracy(logits, targets)
      self.model.zero_grad()
      params = gradient_update_parameters(self.model, inner_loss, step_size=step_size, params=params)
    return params, results

  # per batch
  def get_outer_loss(self, batch):
    if 'test' not in batch:
      raise RuntimeError('The batch does not contain any test dataset')

    _,  test_targets = batch['test']
    num_tasks = test_targets.size(0) # Here, 16

    # Initialize results
    results = {
        'num_tasks': num_tasks,
        'inner_losses': np.zeros((self.num_adaptation_steps, num_tasks), dtype=np.float32),
        'outer_losses': np.zeros((num_tasks, ), dtype=np.float32),
        'mean_outer_loss': 0.,
        'accuracies_before': np.zeros((num_tasks, ), dtype=np.float32), # Before adaptation
        'accuracies_after': np.zeros((num_tasks, ), dtype=np.float32) # After adaptation
    }

    mean_outer_loss = torch.tensor(0., device=self.device)

    for task_id, (train_inputs, train_targets, test_inputs, test_targets) in enumerate(zip(*batch['train'], *batch['test'])):
      # adaptation_results(per task): {'inner_losses': (num_adaptation_steps, ), 'accuracy_before': compute_accuracy(logits, targets)} 
      params, adaptation_results = self.adapt(train_inputs, train_targets, 
                                              num_adaptation_steps=self.num_adaptation_steps, step_size=self.step_size)
      results['inner_losses'][:, task_id] = adaptation_results['inner_losses']
      results['accuracies_before'][task_id] = adaptation_results['accuracy_before']
      with torch.set_grad_enabled(self.model.training):
        test_logits = self.model(test_inputs, params=params) # param is updated with (train_inputs, train_targets)
        outer_loss = self.loss_func(test_logits, test_targets)
        results['outer_losses'][task_id] = outer_loss.item()
        mean_outer_loss += outer_loss
      results['accuracies_after'][task_id] = compute_accuracy(test_logits, test_targets)
    
    mean_outer_loss.div_(num_tasks)
    results['mean_outer_loss'] = mean_outer_loss.item()

    return mean_outer_loss, results

  def train_iter(self, dataloader, max_batches=500):
    num_batches = 0
    self.model.train()
    model.to(device=self.device)
    while num_batches < max_batches:
      for batch in dataloader:
        if num_batches >= max_batches: break

        if self.scheduler is not None:
          self.scheduler.step(epoch=num_batches)

        self.optimizer.zero_grad()
        batch = tensors_to_device(batch, device=self.device)
        outer_loss, results = self.get_outer_loss(batch)
        yield results

        outer_loss.backward()
        self.optimizer.step()

        num_batches += 1

  def train(self, dataloader, max_batches=500, verbose=True, **kwargs):
    with tqdm(total=max_batches, disable=not verbose, **kwargs) as pbar:
      for results in self.train_iter(dataloader, max_batches=max_batches):
        pbar.update(1)
        postfix = {'loss': '{0:.4f}'.format(results['mean_outer_loss'])}
        if 'accuracies_after' in results:
          postfix['accuracy'] = '{0:.4f}'.format(np.mean(results['accuracies_after']))
        pbar.set_postfix(**postfix)
  
  def evaluate_iter(self, dataloader, max_batches=500):
    num_batches = 0
    self.model.eval()
    model.to(device=self.device)
    while num_batches < max_batches:
      for batch in dataloader:
        if num_batches >= max_batches: break

        batch = tensors_to_device(batch, device=self.device)
        _, results = self.get_outer_loss(batch)
        yield results

        num_batches += 1
  
  def evaluate(self, dataloader, max_batches=500, verbose=True, **kwargs):
    mean_outer_loss, mean_accuracy, count = 0., 0., 0
    with tqdm(total=max_batches, disable=not verbose, **kwargs) as pbar:
      for results in self.evaluate_iter(dataloader, max_batches=max_batches):
        pbar.update(1)
        count += 1
        mean_outer_loss += (results['mean_outer_loss'] - mean_outer_loss) / count # Running average for the batches
        postfix = {'loss': '{0:.4f}'.format(mean_outer_loss)}
        if 'accuracies_after' in results:
          mean_accuracy += (np.mean(results['accuracies_after']) - mean_accuracy) / count # Running average for the batches
          postfix['accuracy'] = '{0:.4f}'.format(mean_accuracy)
        pbar.set_postfix(**postfix)
    
    mean_results = {'mean_outer_loss': mean_outer_loss}
    if 'accuracies_after' in results:
      mean_results['accuracies_after'] = mean_accuracy
      
    return mean_results

# Train

In [43]:
import torch
import math
import os
import time
import json
import logging

In [44]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [45]:
output_folder = '/path/to/data'
if not os.path.exists(output_folder):
  os.makedirs(output_folder)
  logging.debug('Creating output_folder `{0}`'.format(output_folder))

  folder = os.path.join(output_folder, time.strftime('%Y-%m-%d_%H%M%S'))
  os.makedirs(folder)
  logging.debug('Creating folder `{0}`'.format(folder))

  folder = os.path.abspath(folder)
  model_path = os.path.abspath(os.path.join(folder, 'model.th'))

# For Omniglot dataset
batch_size=16
num_workers=4

meta_train_dataset = Omniglot(root="data",
                  num_classes_per_task=5,
                  transform=Compose([Resize(28), ToTensor()]),
                  target_transform=Categorical(num_classes=5),
                  class_augmentations=[Rotation([90, 180, 270])],
                  meta_train=True,
                  download=True)
meta_train_dataset = ClassSplitter(meta_train_dataset, shuffle=True, num_train_per_class=5, num_test_per_class=15)
meta_train_dataloader = BatchMetaDataLoader(dataset, batch_size=batch_size, num_workers=num_workers)

meta_val_dataset = Omniglot(root="data",
                  num_classes_per_task=5,
                  transform=Compose([Resize(28), ToTensor()]),
                  target_transform=Categorical(num_classes=5),
                  class_augmentations=[Rotation([90, 180, 270])],
                  meta_val=True,
                  download=True)
meta_val_dataset = ClassSplitter(meta_train_dataset, shuffle=True, num_train_per_class=5, num_test_per_class=15)
meta_val_dataloader = BatchMetaDataLoader(dataset, batch_size=batch_size, num_workers=num_workers)

meta_train_dataset = Omniglot(root="data",
                  num_classes_per_task=5,
                  transform=Compose([Resize(28), ToTensor()]),
                  target_transform=Categorical(num_classes=5),
                  class_augmentations=[Rotation([90, 180, 270])],
                  meta_train=True,
                  download=True)
meta_train_dataset = ClassSplitter(meta_train_dataset, shuffle=True, num_train_per_class=5, num_test_per_class=15)
meta_train_dataloader = BatchMetaDataLoader(dataset, batch_size=batch_size, num_workers=num_workers)

  cpuset_checked))
  cpuset_checked))
  cpuset_checked))


In [46]:
model = ModelConvOmniglot(num_ways=5, hidden_size=64)
num_batches=100
loss_func = F.cross_entropy
meta_lr = 0.001
meta_optimizer = torch.optim.Adam(model.parameters(), lr=meta_lr)
metalearner = MAML(model=model, optimizer=meta_optimizer, step_size=0.1, num_adaptation_steps=1, 
                   scheduler=None, loss_func=loss_func, device=device)

In [None]:
model.to(device)
best_value = None

num_epochs = 50
for epoch in range(num_epochs):
  metalearner.train(dataloader=meta_train_dataloader, max_batches=num_batches, verbose=True)
  results = metalearner.evaluate(dataloader=meta_val_dataloader, max_batches=num_batches, verbose=True)
  
  if 'accuracies_after' in results:
    if (best_value is None) or (best_value < results['accuracies_after']):
      best_value = results['accuracies_after']
      save_model = True
  elif (best_value is None) or (best_value > results['mean_outer_loss']):
    best_value = results['mean_outer_loss']
    save_model = True
  else:
    save_model = False

  if save_model and (output_folder is not None):
    with open(model_path, 'wb') as f:
      torch.save(model.state_dict(), f)

  cpuset_checked))
  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Argument interpolation should be of type InterpolationMode instead of int. "
100%|██████████| 100/100 [01:42<00:00,  1.03s/it, accuracy=0.9742, loss=0.1229]
  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Argument interpolation should be of type InterpolationMode instead of int. "
100%|██████████| 100/100 [01:25<00:00,  1.16it/s, accuracy=0.9690, loss=0.1371]
  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Argument interpola

# Test

In [None]:
with open(model_path, 'rb') as f:
  best_model = torch.load(f, map_location=device)