### Hyperparameter Definition

`k_shot` is the number of samples for each class used in support set.

`q_query` is the number of samples for each class used in query set, the total number in query set is $\text{q\_query} \times \text{n\_way}$. And we have.
$$
\text{k\_shot} + \text{q\_query} \le \text{number of samples for each class}
$$

`meta_batch_size` is the number of tasks in each meta-training/meta-testing iteration.

`num_iterations` is the number of total meta-training iterations, each meta-training iteration updates the model initialization parameters once.

`display_gap` is the step size to visualize the meta-training progress.

In [None]:
train_data_path = './data/Omniglot/images_background/'
test_data_path = './data/Omniglot/images_evaluation/'
n_way = 5
k_shot = 5
q_query = 5
outer_lr = 0.001
inner_lr = 0.04
meta_batch_size = 32
train_inner_step = 1
eval_inner_step = 3
num_iterations = 1000
num_workers = 0
valid_size = 0.2
random_seed = 42
display_gap = 50

### Library Importation & Random Seed Setting

In [None]:
import os
import glob
import torch
import random
import collections
import numpy as np
import torch.optim as optim
import torch.nn.functional as F
from PIL import Image
from tqdm import tqdm
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torch.utils.data.dataset import Dataset
from torchvision.transforms import transforms

import matplotlib.pyplot as plt
%matplotlib inline

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(random_seed)
os.environ['PYTHONHASHSEED'] = str(random_seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

### Dataset Definition

`MAMLDataset` is a general dataset framework for MAML, which leaves two functions to implement for each specified dataset. 

Since the tasks are randomly sampled from the entire class dataset each time, `__len__` basically does not influence the results.
This means that the size of it can be seen as Infinity. Here, we set it as the number of all classes, which is helpful in spilting.

`OmniglotDataset` is a implemented dataset class for Omniglot dataset. 

In [None]:
class MAMLDataset(Dataset):
    def __init__(self, data_path, transform, n_way=5, k_shot=1, q_query=1):

        self.file_list = self.get_file_list(data_path)
        self.n_way = n_way
        self.k_shot = k_shot
        self.q_query = q_query
        self.transform = transform

    def get_file_list(self, data_path):
        raise NotImplementedError('get_file_list function not implemented!')

    def get_one_task_data(self):
        raise NotImplementedError('get_one_task_data function not implemented!')

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, index):
        return self.get_one_task_data()


class OmniglotDataset(MAMLDataset):
    def get_file_list(self, data_path):
        """
        Get a list of all classes.
        Args:
            data_path: Omniglot data path

        Returns: list of all classes

        """
        return [f for f in glob.glob(data_path + "**/character*", recursive=True)]

    def get_one_task_data(self):
        """
        Get ones task maml data, include one batch support images and labels, one batch query images and labels.
        Returns: support_data, query_data

        """
        img_dirs = random.sample(self.file_list, self.n_way)
        support_data = []
        query_data = []

        support_image = []
        support_label = []
        query_image = []
        query_label = []

        for label, img_dir in enumerate(img_dirs):
            img_list = [f for f in glob.glob(img_dir + "**/*.png", recursive=True)]
            images = random.sample(img_list, self.k_shot + self.q_query)

            # Read support set
            for img_path in images[:self.k_shot]:
                image = self.transform(Image.open(img_path))
                image = np.array(image)
                support_data.append((image, label))

            # Read query set
            for img_path in images[self.k_shot:]:
                image = self.transform(Image.open(img_path))
                image = np.array(image)
                query_data.append((image, label))

        # shuffle support set
        random.shuffle(support_data)
        for data in support_data:
            support_image.append(data[0])
            support_label.append(data[1])

        # shuffle query set
        random.shuffle(query_data)
        for data in query_data:
            query_image.append(data[0])
            query_label.append(data[1])

        return np.array(support_image), np.array(support_label), np.array(query_image), np.array(query_label)

### Classifier Model Definition

This defines the base classifer for MAML, which can be replaced by any gradient-based classifier. But it is important to note that, one special function `functional_forward` should be defined according to the property of MAML.

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(ConvBlock, self).__init__()
        self.conv2d = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.bn = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU()
        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = self.conv2d(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.max_pool(x)
        return x


def ConvBlockFunction(input, w, b, w_bn, b_bn):
    x = F.conv2d(input, w, b, padding=1)
    x = F.batch_norm(x, running_mean=None, running_var=None, weight=w_bn, bias=b_bn, training=True)
    x = F.relu(x)
    output = F.max_pool2d(x, kernel_size=2, stride=2)

    return output


class Classifier(nn.Module):
    def __init__(self, in_ch, n_way):
        super(Classifier, self).__init__()
        self.conv1 = ConvBlock(in_ch, 64)
        self.conv2 = ConvBlock(64, 64)
        self.conv3 = ConvBlock(64, 64)
        self.conv4 = ConvBlock(64, 64)
        self.logits = nn.Linear(64, n_way)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = x.view(x.shape[0], -1)
        x = self.logits(x)
        return x

    def functional_forward(self, x, params):
        for block in [1, 2, 3, 4]:
            x = ConvBlockFunction(
                x,
                params[f"conv{block}.conv2d.weight"],
                params[f"conv{block}.conv2d.bias"],
                params.get(f"conv{block}.bn.weight"),
                params.get(f"conv{block}.bn.bias"),
            )
        x = x.view(x.shape[0], -1)
        x = F.linear(x, params["logits.weight"], params["logits.bias"])
        return x

### Helper Functions (1): get dataset & spilt train and valid

`train_transform` and `test_transform` is used for data augmentation. You can change this in your implementation as you want.

In [None]:
train_transform = transforms.Compose([
    transforms.Resize(size=28),
    # transforms.RandomHorizontalFlip(),
    # transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
])

test_trasfrom = transforms.Compose([
    transforms.Resize(size=28),
    transforms.ToTensor(),
])

def get_dataset(
        train_data_path,
        test_data_path,
        n_way,
        k_shot,
        q_query
):
    """
    Get maml dataset.
    Args:
        args: ArgumentParser

    Returns: dataset
    """
    train_dataset = OmniglotDataset(train_data_path, 
                                    train_transform,
                                    n_way, 
                                    k_shot, 
                                    q_query)
    
    valid_dataset = OmniglotDataset(train_data_path, 
                                    test_trasfrom,
                                    n_way, 
                                    k_shot, 
                                    q_query)

    test_dataset = OmniglotDataset(test_data_path, 
                                   test_trasfrom,
                                   n_way, 
                                   k_shot, 
                                   q_query)
    
    train_dataset, valid_dataset = spilt_train_valid(train_dataset, 
                                                     valid_dataset, 
                                                     valid_size)

    return train_dataset, valid_dataset, test_dataset


def spilt_train_valid(train_dataset, valid_dataset, valid_set_size):
    """
    Spilt train dataset into train and valid dataset according to the given size.
    Args:
        train_dataset: original train dataset
        valid_dataset: valid dataset to put into
        valid_set_size: given size in terms of proportion
    
    Returns: spilted train and valid datasets
    """
    valid_set_size = int(valid_set_size * len(train_dataset))
    train_set_size = len(train_dataset) - valid_set_size

    file_list = train_dataset.file_list
    random.shuffle(file_list)
    
    train_dataset.file_list = file_list[:train_set_size]
    valid_dataset.file_list = file_list[train_set_size:]

    return train_dataset, valid_dataset

### Helper Function (2): train one meta-batch

In [None]:
def maml_train(model, 
               support_images,
               support_labels,
               query_images,
               query_labels, 
               inner_step, 
               inner_lr,
               optimizer, 
               loss_fn,
               is_train=True):
    """
    Train the model using MAML method.
    Args:
        model: Any model
        support_images: several task support images
        support_labels: several  support labels
        query_images: several query images
        query_labels: several query labels
        inner_step: support data training step
        inner_lr: inner
        optimizer: optimizer
        is_train: whether train

    Returns: meta loss, meta accuracy
    """
    meta_loss = []
    meta_acc = []

    # Get support set and query set data for one train task
    for support_image, support_label, query_image, query_label \
        in zip(support_images, support_labels, query_images, query_labels):

        fast_weights = collections.OrderedDict(model.named_parameters())
        for _ in range(inner_step):
            # Update weight
            support_logit = model.functional_forward(support_image, fast_weights)
            support_loss = loss_fn(support_logit, support_label)
            grads = torch.autograd.grad(support_loss, 
                                        fast_weights.values(), 
                                        create_graph=True)
            fast_weights = collections.OrderedDict((name, param - inner_lr * grads)
                                                   for ((name, param), grads) 
                                                   in zip(fast_weights.items(), grads))

        # Use trained weight to get query loss
        query_logit = model.functional_forward(query_image, fast_weights)
        query_prediction = torch.max(query_logit, dim=1)[1]

        query_loss = loss_fn(query_logit, query_label)
        query_acc = torch.eq(query_label, query_prediction).sum() / len(query_label)

        meta_loss.append(query_loss)
        meta_acc.append(query_acc.data.cpu().numpy())

    meta_loss = torch.stack(meta_loss).mean()
    meta_acc = np.mean(meta_acc)

    if is_train:
        optimizer.zero_grad()
        meta_loss.backward()
        optimizer.step()

    return meta_loss, meta_acc

### Data Loading

In [None]:
train_tasks, valid_tasks, test_tasks = get_dataset(train_data_path,
                                                   test_data_path,
                                                   n_way,
                                                   k_shot,
                                                   q_query)

train_loader = DataLoader(train_tasks, batch_size=meta_batch_size, 
                            shuffle=True, drop_last=True,  num_workers=num_workers)

valid_loader = DataLoader(valid_tasks, batch_size=meta_batch_size, 
                            shuffle=True, drop_last=True, num_workers=num_workers)

test_loader = DataLoader(test_tasks, batch_size=meta_batch_size, 
                            shuffle=False, drop_last=True, num_workers=num_workers)

### Model & Optimizer Initialization

In [None]:
model = Classifier(in_ch=1, n_way=n_way)
model.to(device)
optimizer = optim.Adam(model.parameters(), outer_lr)
loss_fn = nn.CrossEntropyLoss().to(device)

### MAML Training & Testing

In [None]:
valid_best_acc = 0
train_acc = []
valid_acc = []
train_loss = []
valid_loss = []

train_iter = iter(train_loader)
valid_iter = iter(valid_loader)

for iteration in range(1, num_iterations+1):

    # ========================= train model =====================
    model.train()
    try:
        support_images, support_labels, query_images, query_labels = next(train_iter)
    except StopIteration:
        train_iter = iter(train_loader)
        support_images, support_labels, query_images, query_labels = next(train_iter)

    # Get support set and query set data for one meta-batch (several tasks)
    support_images = support_images.float().to(device)
    support_labels = support_labels.long().to(device)
    query_images = query_images.float().to(device)
    query_labels = query_labels.long().to(device)

    # Train init-paras on one meta-batch and get the corresponding 
    # average evaluation (query) loss and acc among these training tasks
    loss, acc = maml_train(model, 
                           support_images, 
                           support_labels, 
                           query_images, 
                           query_labels,
                           train_inner_step, 
                           inner_lr,
                           optimizer, 
                           loss_fn, 
                           is_train=True)
    
    train_loss.append(loss.item())
    train_acc.append(acc)

    if iteration == 1 or iteration % display_gap == 0:
        print('======================== Iteration: {} ========================'.format(iteration))
        print('Meta Train Loss: {:.3f}, Meta Train Acc: {:.2f}%'.format(loss, 100 * acc))
    
    # ====================== validate model ====================
    model.eval()
    try:
        support_images, support_labels, query_images, query_labels = next(valid_iter)
    except StopIteration:
        valid_iter = iter(valid_loader)
        support_images, support_labels, query_images, query_labels = next(valid_iter)
    
    support_images = support_images.float().to(device)
    support_labels = support_labels.long().to(device)
    query_images = query_images.float().to(device)
    query_labels = query_labels.long().to(device)

    loss, acc = maml_train(model, 
                            support_images, 
                            support_labels, 
                            query_images, 
                            query_labels,
                            train_inner_step, 
                            inner_lr,
                            optimizer, 
                            loss_fn, 
                            is_train=False)

    valid_loss.append(loss.item())
    valid_acc.append(acc)

    if iteration == 1 or iteration % display_gap == 0:
        print('Meta Valid Loss: {:.3f}, Meta Valid Acc: {:.2f}%'.format(loss, 100 * acc))
        print('=============================================================='.format(iteration))

        # ========================= plot ==========================
        plt.figure(figsize=(12, 4))
        plt.subplot(121)
        plt.plot(train_acc, '-o', label="train acc")
        plt.plot(valid_acc, '-o', label="valid acc")
        plt.xlabel('Trainin iteration')
        plt.ylabel('Accuracy')
        plt.title("Accuracy Curve by Iteration")
        plt.legend()
        plt.subplot(122)
        plt.plot(train_loss, '-o', label="train loss")
        plt.plot(valid_loss, '-o', label="valid loss")
        plt.xlabel('Trainin iteration')
        plt.ylabel('Loss')
        plt.title("Loss Curve by Iteration")
        plt.legend()
        plt.suptitle("Omniglot-{}way-{}shot".format(n_way, k_shot))
        plt.show()

    # ========================= save model =====================
    if np.mean(acc) > valid_best_acc:
        print('Validation accuracy improved ({:.2f}% --> {:.2f}%).'.format(100 * valid_best_acc, 100 * acc))
        valid_best_acc = np.mean(acc)
        torch.save(model.state_dict(), 'maml-para.pt')

In [None]:
# ====================== evaluate model ====================
model.load_state_dict(torch.load('maml-para.pt'))
test_acc = []
test_loss = []

test_bar = tqdm(test_loader)
model.eval()
for support_images, support_labels, query_images, query_labels in test_bar:
    test_bar.set_description("Testing")

    support_images = support_images.float().to(device)
    support_labels = support_labels.long().to(device)
    query_images = query_images.float().to(device)
    query_labels = query_labels.long().to(device)

    loss, acc = maml_train(model, 
                           support_images, 
                           support_labels, 
                           query_images, 
                           query_labels,
                           eval_inner_step, 
                           inner_lr,
                           optimizer, 
                           loss_fn, 
                           is_train=False)
    test_loss.append(loss.item())
    test_acc.append(acc)

test_loss = np.mean(test_loss)
test_acc = np.mean(test_acc)
print('Meta Test Loss: {:.3f}, Meta Test Acc: {:.2f}%'.format(test_loss, 100 * test_acc))