In [None]:
!pip install --upgrade virtualenv
!virtualenv venv
!source venv/bin/activate

Collecting virtualenv
  Downloading virtualenv-20.13.2-py2.py3-none-any.whl (8.7 MB)
[K     |████████████████████████████████| 8.7 MB 4.2 MB/s 
[?25hCollecting platformdirs<3,>=2
  Downloading platformdirs-2.5.1-py3-none-any.whl (14 kB)
Collecting distlib<1,>=0.3.1
  Downloading distlib-0.3.4-py2.py3-none-any.whl (461 kB)
[K     |████████████████████████████████| 461 kB 44.5 MB/s 
Installing collected packages: platformdirs, distlib, virtualenv
Successfully installed distlib-0.3.4 platformdirs-2.5.1 virtualenv-20.13.2
created virtual environment CPython3.7.12.final.0-64 in 1663ms
  creator CPython3Posix(dest=/content/venv, clear=False, no_vcs_ignore=False, global=False)
  seeder FromAppData(download=False, pip=bundle, setuptools=bundle, wheel=bundle, via=copy, app_data_dir=/root/.local/share/virtualenv)
    added seed packages: pip==22.0.3, setuptools==60.9.3, wheel==0.37.1
  activators BashActivator,CShellActivator,FishActivator,NushellActivator,PowerShellActivator,PythonActivator


In [None]:
!pip install torchvision



In [None]:
!pip install torchmeta

Collecting torchmeta
  Downloading torchmeta-1.8.0-py3-none-any.whl (210 kB)
[K     |████████████████████████████████| 210 kB 4.1 MB/s 
Collecting torch<1.10.0,>=1.4.0
  Downloading torch-1.9.1-cp37-cp37m-manylinux1_x86_64.whl (831.4 MB)
[K     |████████████████████████████████| 831.4 MB 6.7 kB/s 
[?25hCollecting ordered-set
  Downloading ordered_set-4.1.0-py3-none-any.whl (7.6 kB)
Collecting torchvision<0.11.0,>=0.5.0
  Downloading torchvision-0.10.1-cp37-cp37m-manylinux1_x86_64.whl (22.1 MB)
[K     |████████████████████████████████| 22.1 MB 11.7 MB/s 
Installing collected packages: torch, torchvision, ordered-set, torchmeta
  Attempting uninstall: torch
    Found existing installation: torch 1.10.0+cu111
    Uninstalling torch-1.10.0+cu111:
      Successfully uninstalled torch-1.10.0+cu111
  Attempting uninstall: torchvision
    Found existing installation: torchvision 0.11.1+cu111
    Uninstalling torchvision-0.11.1+cu111:
      Successfully uninstalled torchvision-0.11.1+cu111


# Model architecture

In [None]:
import torch.nn as nn
from torchmeta.modules import (MetaModule, MetaSequential, MetaConv2d, MetaBatchNorm2d, MetaLinear)

In [None]:
def conv3x3(in_channels, out_channels, **kwargs):
  return MetaSequential(
      MetaConv2d(in_channels, out_channels, kernel_size=3, padding=1, **kwargs),
      MetaBatchNorm2d(out_channels, momentum=1, track_running_stats=False),
      nn.ReLU(),
      nn.MaxPool2d(2)
  )

In [None]:
class ConvNetwork(MetaModule):
  def __init__(self, in_channels, num_classes, hidden_channels=64):
    super(ConvNetwork, self).__init__()

    self.in_channels = in_channels
    self.num_classes = num_classes
    self.hidden_channels = hidden_channels
    
    self.features = MetaSequential(
        conv3x3(in_channels, hidden_channels),
        conv3x3(hidden_channels, hidden_channels),
        conv3x3(hidden_channels, hidden_channels),
        conv3x3(hidden_channels, hidden_channels)
    )

    self.classifier = MetaLinear(hidden_channels, num_classes) # For 28x28 Omniglot dataset

  def forward(self, inputs, params=None):
    features = self.features(inputs, params=self.get_subdict(params, 'features'))
    features = features.view(features.size(0), -1)
    logits = self.classifier(features, params=self.get_subdict(params, 'classifier'))
    return logits

# get_accuracy function

In [None]:
import torch
from collections import OrderedDict

In [None]:
def get_accuracy(logits, targets):
  _, predictions = torch.max(logits, dim=1) # value, index
  return torch.mean(predictions.eq(targets).float())

# Train

In [None]:
import os
from tqdm import tqdm

import torch
import torch.nn.functional as F

from torchmeta.datasets.helpers import omniglot
from torchmeta.utils.data import BatchMetaDataLoader
from torchmeta.utils.gradient_based import gradient_update_parameters

In [None]:
def train(args):
  dataset = omniglot(args.folder, shots=args.num_shots, ways=args.num_ways, shuffle=True, test_shots=15, meta_train=True, download=args.download)
  dataloader = BatchMetaDataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)

  model = ConvNetwork(in_channels=1, num_classes=args.num_ways, hidden_channels=args.hidden_channels)
  model.to(device=args.device)
  model.train()
  
  meta_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

  # Training loop
  with tqdm(dataloader, total=args.num_batches) as pbar:
    for batch_idx, batch in enumerate(pbar):
      model.zero_grad()

      train_inputs, train_targets = batch['train']
      train_inputs = train_inputs.to(device=args.device)
      train_targets = train_targets.to(device=args.device)

      test_inputs, test_targets = batch['test']
      test_inputs = test_inputs.to(device=args.device)
      test_targets = test_targets.to(device=args.device)

      outer_loss = torch.tensor(0., device=args.device)
      accuracy = torch.tensor(0., device=args.device)
      for task_idx, (train_input, train_target, test_input, test_target) in enumerate(zip(train_inputs, train_targets, test_inputs, test_targets)):
        train_logit = model(train_input)
        inner_loss = F.cross_entropy(train_logit, train_target)

        model.zero_grad()
        params = gradient_update_parameters(model, inner_loss, step_size=args.step_size, first_order=args.first_order)

        test_logit = model(test_input, params=params)
        outer_loss += F.cross_entropy(test_logit, test_target)

        with torch.no_grad():
          accuracy += get_accuracy(test_logit, test_target)
        
      outer_loss.div_(args.batch_size)
      accuracy.div_(args.batch_size)

      outer_loss.backward()
      meta_optimizer.step()

      pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item()))

      if batch_idx >= args.num_batches: break
    
  # Saving model
  if args.output_folder is not None:
    filename = os.path.join(args.output_folder, 'maml_omniglot_{0}shot_{1}way.th'.format(args.num_shots, args.num_ways))
    with open(filename, 'wb') as f:
      state_dict = model.state_dict()
      torch.save(state_dict, f)


In [None]:
import easydict

args = easydict.EasyDict({
    'folder': 'data',
    'num_shots': 5,
    'num_ways': 5,
    'first_order': False,
    'step_size': 0.4,
    'hidden_channels': 64,
    'output_folder': None,
    'batch_size': 16,
    'num_batches': 100,
    'num_workers': 1,
    'download': True,
    'use_cuda': True,
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu')
})

In [None]:
train(args)

  "Argument interpolation should be of type InterpolationMode instead of int. "
100%|██████████| 100/100 [01:55<00:00,  1.15s/it, accuracy=0.9267]
