## Tutorial Code for MAML and ProtoTypical Network
- Most of the codes are from examples from [higher](https://github.com/facebookresearch/higher/blob/main/examples/support/omniglot_loaders.py).
- If there is a bug, please contact lsnfamily02@kaist.ac.kr (이신의)



In [None]:
!pip install higher

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import argparse
import time
import typing

import pandas as pd
import numpy as np
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
plt.style.use('bmh')

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import higher

import torchvision.transforms as transforms
from PIL import Image
import numpy as np

import torch.utils.data as data
import os
import errno


## Dataset and DataLoader for Omniglot.

In [None]:
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# These Omniglot loaders are from Jackie Loong's PyTorch MAML implementation:
#     https://github.com/dragen1860/MAML-Pytorch
#     https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot.py
#     https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglotNShot.py

import  torchvision.transforms as transforms
from PIL import Image
import numpy as np

import torch
import torch.utils.data as data
import os
import errno


class Omniglot(data.Dataset):
    urls = [
        'https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip',
        'https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip'
    ]
    raw_folder = 'raw'
    processed_folder = 'processed'
    training_file = 'training.pt'
    test_file = 'test.pt'

    '''
    The items are (filename,category). The index of all the categories can be found in self.idx_classes
    Args:
    - root: the directory where the dataset will be stored
    - transform: how to transform the input
    - target_transform: how to transform the target
    - download: need to download the dataset
    '''

    def __init__(self, root, transform=None, target_transform=None,
                 download=False):
        self.root = root
        self.transform = transform
        self.target_transform = target_transform

        if not self._check_exists():
            if download:
                self.download()
            else:
                raise RuntimeError('Dataset not found.' + ' You can use download=True to download it')

        self.all_items = find_classes(os.path.join(self.root, self.processed_folder))
        self.idx_classes = index_classes(self.all_items)

    def __getitem__(self, index):
        filename = self.all_items[index][0]
        img = str.join('/', [self.all_items[index][2], filename])

        target = self.idx_classes[self.all_items[index][1]]
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return len(self.all_items)

    def _check_exists(self):
        return os.path.exists(os.path.join(self.root, self.processed_folder, "images_evaluation")) and \
               os.path.exists(os.path.join(self.root, self.processed_folder, "images_background"))

    def download(self):
        from six.moves import urllib
        import zipfile

        if self._check_exists():
            return

        # download files
        try:
            os.makedirs(os.path.join(self.root, self.raw_folder))
            os.makedirs(os.path.join(self.root, self.processed_folder))
        except OSError as e:
            if e.errno == errno.EEXIST:
                pass
            else:
                raise

        for url in self.urls:
            print('== Downloading ' + url)
            data = urllib.request.urlopen(url)
            filename = url.rpartition('/')[2]
            file_path = os.path.join(self.root, self.raw_folder, filename)
            with open(file_path, 'wb') as f:
                f.write(data.read())
            file_processed = os.path.join(self.root, self.processed_folder)
            print("== Unzip from " + file_path + " to " + file_processed)
            zip_ref = zipfile.ZipFile(file_path, 'r')
            zip_ref.extractall(file_processed)
            zip_ref.close()
        print("Download finished.")


def find_classes(root_dir):
    retour = []
    for (root, dirs, files) in os.walk(root_dir):
        for f in files:
            if (f.endswith("png")):
                r = root.split('/')
                lr = len(r)
                retour.append((f, r[lr - 2] + "/" + r[lr - 1], root))
    print("== Found %d items " % len(retour))
    return retour


def index_classes(items):
    idx = {}
    for i in items:
        if i[1] not in idx:
            idx[i[1]] = len(idx)
    print("== Found %d classes" % len(idx))
    return idx


class OmniglotNShot(object):

    def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, device=None):
        """
        Different from mnistNShot, the
        :param root:
        :param batchsz: task num
        :param n_way:
        :param k_shot:
        :param k_qry:
        :param imgsz:
        """

        self.resize = imgsz
        self.device = device
        if not os.path.isfile(os.path.join(root, 'omniglot.npy')):
            # if root/data.npy does not exist, just download it
            self.x = Omniglot(
                root, download=True,
                transform=transforms.Compose(
                    [lambda x: Image.open(x).convert('L'),
                     lambda x: x.resize((imgsz, imgsz)),
                     lambda x: np.reshape(x, (imgsz, imgsz, 1)),
                     lambda x: np.transpose(x, [2, 0, 1]),
                     lambda x: x/255.]),
            )

            temp = dict()  # {label:img1, img2..., 20 imgs, label2: img1, img2,... in total, 1623 label}
            for (img, label) in self.x:
                if label in temp.keys():
                    temp[label].append(img)
                else:
                    temp[label] = [img]

            self.x = []
            for label, imgs in temp.items():  # labels info deserted , each label contains 20imgs
                self.x.append(np.array(imgs))

            # as different class may have different number of imgs
            self.x = np.array(self.x).astype(float)  # [[20 imgs],..., 1623 classes in total]
            # each character contains 20 imgs
            print('data shape:', self.x.shape)  # [1623, 20, 84, 84, 1]
            temp = []  # Free memory
            # save all dataset into npy file.
            np.save(os.path.join(root, 'omniglot.npy'), self.x)
            print('write into omniglot.npy.')
        else:
            # if data.npy exists, just load it.
            self.x = np.load(os.path.join(root, 'omniglot.npy'))
            print('load from omniglot.npy.')

        # [1623, 20, 84, 84, 1]
        # TODO: can not shuffle here, we must keep training and test set distinct!
        self.x_train, self.x_test = self.x[:1200], self.x[1200:]

        # self.normalization()

        self.batchsz = batchsz
        self.n_cls = self.x.shape[0]  # 1623
        self.n_way = n_way  # n way
        self.k_shot = k_shot  # k shot
        self.k_query = k_query  # k query
        assert (k_shot + k_query) <=20

        # save pointer of current read batch in total cache
        self.indexes = {"train": 0, "test": 0}
        self.datasets = {"train": self.x_train, "test": self.x_test}  # original data cached
        print("DB: train", self.x_train.shape, "test", self.x_test.shape)

        self.datasets_cache = {"train": self.load_data_cache(self.datasets["train"]),  # current epoch data cached
                               "test": self.load_data_cache(self.datasets["test"])}

    def normalization(self):
        """
        Normalizes our data, to have a mean of 0 and sdt of 1
        """
        self.mean = np.mean(self.x_train)
        self.std = np.std(self.x_train)
        self.max = np.max(self.x_train)
        self.min = np.min(self.x_train)
        # print("before norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std)
        self.x_train = (self.x_train - self.mean) / self.std
        self.x_test = (self.x_test - self.mean) / self.std

        self.mean = np.mean(self.x_train)
        self.std = np.std(self.x_train)
        self.max = np.max(self.x_train)
        self.min = np.min(self.x_train)

    # print("after norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std)

    def load_data_cache(self, data_pack):
        """
        Collects several batches data for N-shot learning
        :param data_pack: [cls_num, 20, 84, 84, 1]
        :return: A list with [support_set_x, support_set_y, target_x, target_y] ready to be fed to our networks
        """
        #  take 5 way 1 shot as example: 5 * 1
        setsz = self.k_shot * self.n_way
        querysz = self.k_query * self.n_way
        data_cache = []

        # print('preload next 50 caches of batchsz of batch.')
        for sample in range(10):  # num of episodes

            x_spts, y_spts, x_qrys, y_qrys = [], [], [], []
            for i in range(self.batchsz):  # one batch means one set

                x_spt, y_spt, x_qry, y_qry = [], [], [], []
                selected_cls = np.random.choice(data_pack.shape[0], self.n_way, False)

                for j, cur_class in enumerate(selected_cls):

                    selected_img = np.random.choice(20, self.k_shot + self.k_query, False)

                    # meta-training and meta-test
                    x_spt.append(data_pack[cur_class][selected_img[:self.k_shot]])
                    x_qry.append(data_pack[cur_class][selected_img[self.k_shot:]])
                    y_spt.append([j for _ in range(self.k_shot)])
                    y_qry.append([j for _ in range(self.k_query)])

                # shuffle inside a batch
                perm = np.random.permutation(self.n_way * self.k_shot)
                x_spt = np.array(x_spt).reshape(self.n_way * self.k_shot, 1, self.resize, self.resize)[perm]
                y_spt = np.array(y_spt).reshape(self.n_way * self.k_shot)[perm]
                perm = np.random.permutation(self.n_way * self.k_query)
                x_qry = np.array(x_qry).reshape(self.n_way * self.k_query, 1, self.resize, self.resize)[perm]
                y_qry = np.array(y_qry).reshape(self.n_way * self.k_query)[perm]

                # append [sptsz, 1, 84, 84] => [b, setsz, 1, 84, 84]
                x_spts.append(x_spt)
                y_spts.append(y_spt)
                x_qrys.append(x_qry)
                y_qrys.append(y_qry)


            # [b, setsz, 1, 84, 84]
            x_spts = np.array(x_spts).astype(float).reshape(self.batchsz, setsz, 1, self.resize, self.resize)
            y_spts = np.array(y_spts).astype(int).reshape(self.batchsz, setsz)
            # [b, qrysz, 1, 84, 84]
            x_qrys = np.array(x_qrys).astype(float).reshape(self.batchsz, querysz, 1, self.resize, self.resize)
            y_qrys = np.array(y_qrys).astype(int).reshape(self.batchsz, querysz)

            x_spts, y_spts, x_qrys, y_qrys = [
                torch.from_numpy(z).to(self.device) for z in
                [x_spts, y_spts, x_qrys, y_qrys]
            ]
            # convert double to float
            x_spts = x_spts.float()
            x_qrys = x_qrys.float()
            data_cache.append([x_spts, y_spts, x_qrys, y_qrys])

        return data_cache

    def next(self, mode='train'):
        """
        Gets next batch from the dataset with name.
        :param mode: The name of the splitting (one of "train", "val", "test")
        :return:
        """
        # update cache if indexes is larger cached num
        if self.indexes[mode] >= len(self.datasets_cache[mode]):
            self.indexes[mode] = 0
            self.datasets_cache[mode] = self.load_data_cache(self.datasets[mode])

        next_batch = self.datasets_cache[mode][self.indexes[mode]]
        self.indexes[mode] += 1

        return next_batch

## MAML

In [None]:
#!/usr/bin/env python3
#
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This example shows how to use higher to do Model Agnostic Meta Learning (MAML)
for few-shot Omniglot classification.
For more details see the original MAML paper:
https://arxiv.org/abs/1703.03400
This code has been modified from Jackie Loong's PyTorch MAML implementation:
https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot_train.py
Our MAML++ fork and experiments are available at:
https://github.com/bamos/HowToTrainYourMAMLPytorch
"""


def run_maml():
    parser = argparse.ArgumentParser()
    parser.add_argument('--n_way', type=int, help='n way', default=5)
    parser.add_argument(
        '--k_spt', type=int, help='k shot for support set', default=5)
    parser.add_argument(
        '--k_qry', type=int, help='k shot for query set', default=15)
    parser.add_argument(
        '--task_num',
        type=int,
        help='meta batch size, namely task num',
        default=32)
    parser.add_argument('--seed', type=int, help='random seed', default=1004)
    args = parser.parse_args(args=[])

    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)

    # Set up the Omniglot loader.
    device = torch.cuda.current_device()
    db = OmniglotNShot(
        './tmp/omniglot-data',
        batchsz=args.task_num,
        n_way=args.n_way,
        k_shot=args.k_spt,
        k_query=args.k_qry,
        imgsz=28,
        device=device,
    )

    # Create a vanilla PyTorch neural network that will be
    # automatically monkey-patched by higher later.
    # Before higher, models could *not* be created like this
    # and the parameters needed to be manually updated and copied
    # for the updates.
    net = nn.Sequential(
        nn.Conv2d(1, 64, 3),
        nn.BatchNorm2d(64, momentum=1, affine=True),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(2, 2),
        nn.Conv2d(64, 64, 3),
        nn.BatchNorm2d(64, momentum=1, affine=True),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(2, 2),
        nn.Conv2d(64, 64, 3),
        nn.BatchNorm2d(64, momentum=1, affine=True),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(2, 2),
        nn.Flatten(start_dim=1),
        nn.Linear(64, args.n_way)).to(device)

    # We will use Adam to (meta-)optimize the initial parameters
    # to be adapted.
    meta_opt = optim.Adam(net.parameters(), lr=1e-3)
    print("run MAML")
    log = []
    for epoch in range(10):
        train_maml(db, net, device, meta_opt, epoch, log)
        test_maml(db, net, device, epoch, log)

## Meta training loop with support set and query set
- Inner optimization: $\theta^t_i = \theta^t_{i-1} -\alpha \cdot \nabla_{\theta^t_{i-1}}\mathcal{L}(\theta^t_{i-1}, \mathcal{D}^t_s)$ for $i=1, \ldots, k$.
- Outer optimization: 
$\theta^t_\text{init}= \theta^t_{init} - \beta \cdot \nabla_{\theta_k} \mathcal{L}(\theta^t_k;\mathcal{D}^t_s)\prod_{i=1}^kI-\alpha\cdot \nabla_{\theta^t_{i-1}}(\nabla_{\theta^t_{i-1}}\mathcal{L}(\theta^t_{i-1},\mathcal{D}^t_s))$

In [None]:
# meta-train
def train_maml(db, net, device, meta_opt, epoch, log):
    net.train()
    n_train_iter = db.x_train.shape[0] // db.batchsz

    for batch_idx in range(n_train_iter):
        start_time = time.time()
        # Sample a batch of support and query images and labels.
        batch = db.next()
        x_spt, y_spt, x_qry, y_qry = batch

        task_num, setsz, c_, h, w = x_spt.size()
        querysz = x_qry.size(1)

        # Initialize the inner optimizer to adapt the parameters to
        # the support set.
        n_inner_iter = 5
        inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)

        qry_losses = []
        qry_accs = []
        meta_opt.zero_grad()
        for i in range(task_num):
            with higher.innerloop_ctx(
                net, inner_opt, copy_initial_weights=False
            ) as (fnet, diffopt):
                # Optimize the likelihood of the support set by taking
                # gradient steps w.r.t. the model's parameters.
                # This adapts the model's meta-parameters to the task.
                # higher is able to automatically keep copies of
                # your network's parameters as they are being updated.
                for _ in range(n_inner_iter):
                    spt_logits = fnet(x_spt[i])
                    spt_loss = F.cross_entropy(spt_logits, y_spt[i])
                    diffopt.step(spt_loss)

                # The final set of adapted parameters will induce some
                # final loss and accuracy on the query dataset.
                # These will be used to update the model's meta-parameters.
                qry_logits = fnet(x_qry[i])
                qry_loss = F.cross_entropy(qry_logits, y_qry[i])
                qry_losses.append(qry_loss.detach())
                qry_acc = (qry_logits.argmax(
                    dim=1) == y_qry[i]).sum().item() / querysz
                qry_accs.append(qry_acc)

                # Update the model's meta-parameters to optimize the query
                # losses across all of the tasks sampled in this batch.
                # This unrolls through the gradient steps.
                qry_loss /= task_num
                qry_loss.backward()

        meta_opt.step()
        qry_losses = sum(qry_losses) / task_num
        qry_accs = 100. * sum(qry_accs) / task_num
        i = epoch + float(batch_idx) / n_train_iter
        iter_time = time.time() - start_time
        if batch_idx % 4 == 0:
            print(
                f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
            )

        log.append({
            'epoch': i,
            'loss': qry_losses,
            'acc': qry_accs,
            'mode': 'train',
            'time': time.time(),
        })


## Meta-Test 

In [None]:
def test_maml(db, net, device, epoch, log):
    # Crucially in our testing procedure here, we do *not* fine-tune
    # the model during testing for simplicity.
    # Most research papers using MAML for this task do an extra
    # stage of fine-tuning here that should be added if you are
    # adapting this code for research.
    net.train()
    n_test_iter = db.x_test.shape[0] // db.batchsz

    qry_losses = []
    qry_accs = []

    for batch_idx in range(n_test_iter):
        x_spt, y_spt, x_qry, y_qry = db.next('test')


        task_num, setsz, c_, h, w = x_spt.size()
        querysz = x_qry.size(1)

        # TODO: Maybe pull this out into a separate module so it
        # doesn't have to be duplicated between `train` and `test`?
        n_inner_iter = 5
        inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)

        for i in range(task_num):
            with higher.innerloop_ctx(net, inner_opt, track_higher_grads=False) as (fnet, diffopt):
                # Optimize the likelihood of the support set by taking
                # gradient steps w.r.t. the model's parameters.
                # This adapts the model's meta-parameters to the task.
                for _ in range(n_inner_iter):
                    spt_logits = fnet(x_spt[i])
                    spt_loss = F.cross_entropy(spt_logits, y_spt[i])
                    diffopt.step(spt_loss)

                # The query loss and acc induced by these parameters.
                qry_logits = fnet(x_qry[i]).detach()
                qry_loss = F.cross_entropy(
                    qry_logits, y_qry[i], reduction='none')
                qry_losses.append(qry_loss.detach())
                qry_accs.append(
                    (qry_logits.argmax(dim=1) == y_qry[i]).detach())

    qry_losses = torch.cat(qry_losses).mean().item()
    qry_accs = 100. * torch.cat(qry_accs).float().mean().item()
    print(
        f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}'
    )
    log.append({
        'epoch': epoch + 1,
        'loss': qry_losses,
        'acc': qry_accs,
        'mode': 'test',
        'time': time.time(),
    })


## ProtoTypical Network

In [None]:
class ProtoNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = nn.Sequential(
        nn.Conv2d(1, 64, 3),
        nn.BatchNorm2d(64, affine=True),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(2, 2),
        nn.Conv2d(64, 64, 3),
        nn.BatchNorm2d(64, affine=True),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(2, 2),
        nn.Conv2d(64, 64, 3),
        nn.BatchNorm2d(64, affine=True),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(2, 2),
        nn.Flatten(start_dim=1)
        )
    
    
    def get_num_samples(self, labels, num_classes, dtype):
        b = labels.size(0)
        with torch.no_grad():
            ones = torch.ones_like(labels, dtype=dtype)
            num_samples = ones.new_zeros((b, num_classes))
            num_samples.scatter_add_(1, labels, ones)
        return num_samples

    def make_prototypes(self, embeddings, labels, nways):
        B, _, H = embeddings.size()

        num_samples = self.get_num_samples(
            labels=labels, num_classes=nways, dtype=embeddings.dtype)

        num_samples.unsqueeze_(-1)
        num_samples = torch.max(num_samples, torch.ones_like(num_samples))

        prototypes = embeddings.new_zeros((B, nways, H))
        indices = labels.unsqueeze(-1).expand_as(embeddings)
        prototypes.scatter_add_(1, indices, embeddings).div_(num_samples)
        
        return prototypes

    def prototypical_loss(self, prototypes, embeddings, labels):
        # prototypes: [b, n_way, d]
        # embeddigns: [b, n_way*q_shots, d]
        # labels: [b, nway*q_shots]
        # distances: [b, nway*q_shots, nway*q_shots]
        
        sqr_dist = (prototypes.unsqueeze(2) - embeddings.unsqueeze(1)) ** 2
        distances = torch.sum(sqr_dist, dim=-1)
        return F.cross_entropy(-distances, labels)
        
    def get_accuracy(self, prototypes, embeddings, labels):
        sq_distances = torch.sum((prototypes.unsqueeze(
            1) - embeddings.unsqueeze(2)) ** 2, dim=-1)
        _, preds = torch.min(sq_distances, dim=-1)
        return torch.mean(preds.eq(labels).float(), 1).mean(0)
    
    def forward(self, x_spt, y_spt, x_qry, y_qry):
        b, s_size, c, h, w = x_spt.size()
        _, q_size, _, _, _ = x_qry.size()
        
        s_embed = self.embedding(x_spt.view(b*s_size, c, h, w))
        s_embed = s_embed.view(b, s_size, -1)

        q_embed = self.embedding(x_qry.view(b*q_size, c, h, w))
        q_embed = q_embed.view(b, q_size, -1)
        
        # Create the prototypes
        prototypes = self.make_prototypes(embeddings=s_embed, 
                                          labels=y_spt.view(b, s_size), 
                                          nways=5)

        loss = self.prototypical_loss(prototypes=prototypes, 
                                      embeddings=q_embed, 
                                      labels=y_qry.view(b, q_size))

        acc = self.get_accuracy(prototypes=prototypes, 
                                embeddings=q_embed, 
                                labels=y_qry.view(b, q_size))
        return loss, acc               

In [None]:
def run_protonet():
    parser = argparse.ArgumentParser()
    parser.add_argument('--n_way', type=int, help='n way', default=5)
    parser.add_argument(
        '--k_spt', type=int, help='k shot for support set', default=5)
    parser.add_argument(
        '--k_qry', type=int, help='k shot for query set', default=15)
    parser.add_argument(
        '--task_num',
        type=int,
        help='meta batch size, namely task num',
        default=32)
    parser.add_argument('--seed', type=int, help='random seed', default=1004)
    args = parser.parse_args(args=[])

    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)

    # Set up the Omniglot loader.
    device = torch.cuda.current_device()
    db = OmniglotNShot(
        './tmp/omniglot-data',
        batchsz=args.task_num,
        n_way=args.n_way,
        k_shot=args.k_spt,
        k_query=args.k_qry,
        imgsz=28,
        device=device,
    )

    # Create a vanilla PyTorch neural network that will be
    # automatically monkey-patched by higher later.
    # Before higher, models could *not* be created like this
    # and the parameters needed to be manually updated and copied
    # for the updates.
    net = ProtoNet().to(device)
    # We will use Adam to (meta-)optimize the initial parameters
    # to be adapted.
    opt = optim.Adam(net.parameters(), lr=1e-3)
    print("run ProtoTypical Network")
    log = []
    for epoch in range(20):
        train_protonet(db, net, device, opt, epoch, log)
        test_protonet(db, net, device, epoch, log)

In [None]:
# meta-train
def train_protonet(db, net, device, opt, epoch, log):
    net.train()
    n_train_iter = db.x_train.shape[0] // db.batchsz

    for batch_idx in range(n_train_iter):
        start_time = time.time()
        # Sample a batch of support and query images and labels.
        batch = db.next()
        x_spt, y_spt, x_qry, y_qry = batch
        task_num, setsz, c_, h, w = x_spt.size()
        querysz = x_qry.size(1)

        loss, acc = net(x_spt, y_spt, x_qry, y_qry)

        qry_losses = []
        qry_accs = []
        
        net.zero_grad()
        loss.backward()
        opt.step()
        
        qry_accs = 100. * acc
        i = epoch + float(batch_idx) / n_train_iter
        iter_time = time.time() - start_time
        if batch_idx % 4 == 0:
            print(
                f'[Epoch {i:.2f}] Train Loss: {loss.item():.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
            )

        log.append({
            'epoch': i,
            'loss': loss.item(),
            'acc': qry_accs,
            'mode': 'train',
            'time': time.time(),
        })


def test_protonet(db, net, device, epoch, log):
    net.eval()
    n_test_iter = db.x_test.shape[0] // db.batchsz

    qry_losses = []
    qry_accs = []

    for batch_idx in range(n_test_iter):
        x_spt, y_spt, x_qry, y_qry = db.next('test')
        task_num, setsz, c_, h, w = x_spt.size()
        querysz = x_qry.size(1)
        with torch.no_grad():
            loss, acc = net(x_spt, y_spt, x_qry, y_qry)

            qry_losses.append(loss.item())
            qry_accs.append(acc.item())

    qry_losses = np.mean(qry_losses)
    qry_accs = 100. * np.mean(qry_accs)
    print(
        f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}'
    )
    log.append({
        'epoch': epoch + 1,
        'loss': qry_losses,
        'acc': qry_accs,
        'mode': 'test',
        'time': time.time(),
    })

In [None]:
run_protonet()

load from omniglot.npy.
DB: train (1200, 20, 1, 28, 28) test (423, 20, 1, 28, 28)
run ProtoTypical Network
[Epoch 0.00] Train Loss: 1.19 | Acc: 80.38 | Time: 1.76
[Epoch 0.11] Train Loss: 0.58 | Acc: 88.54 | Time: 0.01
[Epoch 0.22] Train Loss: 0.40 | Acc: 92.63 | Time: 0.01
[Epoch 0.32] Train Loss: 0.22 | Acc: 95.25 | Time: 0.02
[Epoch 0.43] Train Loss: 0.31 | Acc: 93.79 | Time: 0.01
[Epoch 0.54] Train Loss: 0.19 | Acc: 95.92 | Time: 0.48
[Epoch 0.65] Train Loss: 0.23 | Acc: 95.00 | Time: 0.02
[Epoch 0.76] Train Loss: 0.25 | Acc: 95.25 | Time: 0.02
[Epoch 0.86] Train Loss: 0.17 | Acc: 95.79 | Time: 0.01
[Epoch 0.97] Train Loss: 0.14 | Acc: 96.83 | Time: 0.01
[Epoch 1.00] Test Loss: 0.33 | Acc: 93.21
[Epoch 1.00] Train Loss: 0.15 | Acc: 96.17 | Time: 0.01
[Epoch 1.11] Train Loss: 0.18 | Acc: 96.13 | Time: 0.01
[Epoch 1.22] Train Loss: 0.14 | Acc: 96.67 | Time: 0.03
[Epoch 1.32] Train Loss: 0.13 | Acc: 97.08 | Time: 0.03
[Epoch 1.43] Train Loss: 0.08 | Acc: 98.17 | Time: 0.01
[Epoch 1.54

In [None]:
run_maml()

load from omniglot.npy.
DB: train (1200, 20, 1, 28, 28) test (423, 20, 1, 28, 28)
run MAML
[Epoch 0.00] Train Loss: 0.01 | Acc: 89.92 | Time: 1.65
[Epoch 0.11] Train Loss: 0.01 | Acc: 93.96 | Time: 2.37
[Epoch 0.22] Train Loss: 0.01 | Acc: 93.75 | Time: 1.43
[Epoch 0.32] Train Loss: 0.01 | Acc: 96.50 | Time: 1.43
[Epoch 0.43] Train Loss: 0.01 | Acc: 96.88 | Time: 1.45
[Epoch 0.54] Train Loss: 0.01 | Acc: 97.04 | Time: 1.74
[Epoch 0.65] Train Loss: 0.01 | Acc: 96.58 | Time: 1.45
[Epoch 0.76] Train Loss: 0.01 | Acc: 96.46 | Time: 1.46
[Epoch 0.86] Train Loss: 0.01 | Acc: 96.62 | Time: 1.88
[Epoch 0.97] Train Loss: 0.01 | Acc: 97.58 | Time: 1.41
[Epoch 1.00] Test Loss: 0.23 | Acc: 95.20
[Epoch 1.00] Train Loss: 0.01 | Acc: 96.87 | Time: 1.43
[Epoch 1.11] Train Loss: 0.01 | Acc: 97.29 | Time: 1.43
[Epoch 1.22] Train Loss: 0.00 | Acc: 96.71 | Time: 1.78
[Epoch 1.32] Train Loss: 0.00 | Acc: 97.25 | Time: 1.44
[Epoch 1.43] Train Loss: 0.00 | Acc: 97.71 | Time: 1.43
[Epoch 1.54] Train Loss: 0.