# Omniglot 数据集单样本学习：MAML 算法

### One Shot Learning on Omniglot: MAML Algorithm

In [1]:
import os
import os.path

from PIL import Image
import numpy as np
import torch
import torch.autograd
from torch import nn
from torch import optim
from torch.nn import functional as F
import torch.nn.init as init

## 配置 Configurations

In [2]:
image_size = 28 # 图片尺度大小 size of each image
way_count = 5 # 类别数 number of classes in each task
shot_count = 1 # 支持集样本数 sample number in each support set
query_count = 15 # 查询集样本数 sample number in each query set
train_batch_size = 32 # 每次训练使用的任务数 task number for training
test_batch_size = 200 # 每次测试使用的任务数 task number for testing
train_step_count = 5 # 训练时采用的优化步数 gradient descent steps for training
test_step_count = 10 # 测试时采用的优化步数 gradient descent steps for testing
train_epoch_count = 801 # 训练回合数 total training epoch
test_epoch_interval = 200 # 每训练多少次测试一次 test frequency

## 读取数据  Read Data

请下载以下两个文件并解压到2个不同的文件夹。

Please download and unzip the following two zip files as two different folders.


```
https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip
https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip
```

文件路径作为下列 `get_raw_data()` 函数的首个参数。

The unzipped folders be the first parameters of the following function `get_raw_data()`.

In [3]:
def get_raw_data(directory, image_size): # 载入图片并做预处理
    characters = []
    for char_dir, dirs, file_names in os.walk(directory):
        if [_ for filename in file_names if filename.endswith('.png')]:
            images = []
            for file_name in file_names:
                file_path = os.path.join(char_dir, file_name)
                image = Image.open(file_path).convert('L')
                image = image.resize((image_size, image_size))
                image = np.reshape(image, (1, image_size, image_size))
                image = image / 255.
                images.append(image)
            images = np.array(images)
            characters.append((char_dir, images))
    print('Found {} characters'.format(len(characters)))
    return characters


raw_data = {}
raw_data['train'] = get_raw_data(
        '.\omniglot\images_background', image_size=image_size)
raw_data['test'] = get_raw_data(
        '.\omniglot\mages_evaluation', image_size=image_size)

Found 964 characters
Found 659 characters


随机读取单样本分类任务的函数

Read tasks randomly for the one-shot classification tasks

In [4]:
def get_random_task(raw_data,
        way_count, shot_count, query_count, permute=True):
    s_inputs, s_labels = [], []
    q_inputs, q_labels = [], []
    chars = np.random.choice(len(raw_data), way_count, replace=False)
    for char_index, char in enumerate(chars):
        images = raw_data[char][1]
        image_indices = np.random.choice(images.shape[0],
                shot_count + query_count, replace=False)
        s_inputs.append(images[image_indices[:shot_count]])
        q_inputs.append(images[image_indices[shot_count:]])
        s_labels += [char_index,] * shot_count
        q_labels += [char_index,] * query_count
    s_inputs, s_labels = np.concatenate(s_inputs), np.array(s_labels)
    q_inputs, q_labels = np.concatenate(q_inputs), np.array(q_labels)

    if permute:
        s_perms = np.random.permutation(way_count * shot_count)
        s_inputs, s_labels = s_inputs[s_perms], s_labels[s_perms]
        q_perms = np.random.permutation(way_count * query_count)
        q_inputs, q_labels = q_inputs[q_perms], q_labels[q_perms]
    return s_inputs, s_labels, q_inputs, q_labels


def get_random_tasks(raw_data, batch_size,
        way_count, shot_count, query_count, permute=True):
    tasks = []
    for idx in range(batch_size):
        task = get_random_task(raw_data,
            way_count, shot_count, query_count, permute=True)
        tasks.append(task)
    return tasks

##  神经网络 Neural Network

In [5]:
net0 = nn.Sequential(
        nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=0),
        nn.ReLU(),
        nn.BatchNorm2d(num_features=64),
        nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=0),
        nn.ReLU(),
        nn.BatchNorm2d(num_features=64),
        nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=0),
        nn.ReLU(),
        nn.BatchNorm2d(num_features=64),
        nn.Conv2d(64, 64, kernel_size=2, stride=1, padding=0),
        nn.ReLU(),
        nn.BatchNorm2d(num_features=64),
        nn.Flatten(),
        nn.Linear(64, way_count),
        )

初始化网络 Initialize the network

In [6]:
def weights_init(m):
    if type(m) in [nn.Conv2d, nn.ConvTranspose2d]:
        init.kaiming_normal_(m.weight)
        init.constant_(m.bias, 0.)
    elif type(m) in [nn.Linear,]:
        init.kaiming_normal_(m.weight)
        init.constant_(m.bias, 0.)
    elif type(m) in [nn.BatchNorm2d,]:
        init.constant_(m.weight, 1.)
        init.constant_(m.bias, 0.)

net0.apply(weights_init)

Sequential(
  (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(2, 2))
  (1): ReLU()
  (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2))
  (4): ReLU()
  (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2))
  (7): ReLU()
  (8): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (9): Conv2d(64, 64, kernel_size=(2, 2), stride=(1, 1))
  (10): ReLU()
  (11): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (12): Flatten()
  (13): Linear(in_features=64, out_features=5, bias=True)
)

MAML 算法中需要构建梯度下降的计算图并用梯度下降后的网络变量作为网络权重，所以仅仅使用 `nn.Sequential` 构造得到的 `net0` 是不够的。下面的类扩展了 `nn.Sequential` 类，使得网络能根据外部变量进行计算。

Since MAML algorithm calculates gradient descent and uses the resulting variables as the network parameters, the `nn.Sequential` instance `net0` does not suffice to support the algorithm. The following class extends the class `nn.Sequential` so that the network parameter is configurable.

In [7]:
class ConfigurableSequential(nn.Module):
    
    def __init__(self, net):
        super(ConfigurableSequential, self).__init__()
        self.net = net
    
    def forward(self, x, parameters=None, buffers=None, bn_training=True):
        if parameters is None:
            parameters = list(self.net.parameters()) # 所有要优化的变量
        if buffers is None:
            buffers = list(self.net.buffers())  # BatchNorm维护的均值和方差
        
        param_index, buffer_index = 0, 0
        for m in self.net.modules():
            if type(m) == nn.Sequential:
                pass
            elif type(m) in [nn.Conv2d,]:
                weights, bias = parameters[param_index], parameters[param_index + 1]
                param_index += 2
                x = F.conv2d(x, weights, bias, stride=m.stride, padding=m.padding)
            elif type(m) in [nn.Linear,]:
                weights, bias = parameters[param_index], parameters[param_index + 1]
                param_index += 2
                x = F.linear(x, weights, bias)
            elif type(m) in [nn.BatchNorm2d,]:
                weights, bias = parameters[param_index], parameters[param_index + 1]
                param_index += 2
                running_mean, running_var = buffers[buffer_index], buffers[buffer_index+1]
                buffer_index += 3
                x = F.batch_norm(x, running_mean, running_var,
                        weight=weights, bias=bias, training=bn_training)
            elif type(m) in [nn.Flatten,]:
                x = x.view(x.size(0), -1)
            elif type(m) in [nn.ReLU,]:
                x = F.relu(x, inplace=m.inplace)
            else:
                raise NotImplementedError
        return x


net = ConfigurableSequential(net0)

## MAML 算法
MAML Algorithm

In [8]:
criterion = nn.CrossEntropyLoss()
meta_optimizer = optim.Adam(net.parameters())

In [9]:
def run(net, tasks, step_count, criterion,
        update_lr=0.4, meta_optimizer=None):
    
    q_corrects = np.zeros(shape=(step_count + 1,))
    q_totals = np.zeros(shape=(step_count + 1,))
    q_losses = np.zeros(shape=(step_count + 1,))
    if meta_optimizer:
        maml_losses = []
    
    for task in tasks:
        s_inputs, s_labels, q_inputs, q_labels = task
        s_inputs = torch.from_numpy(s_inputs).float()
        s_labels = torch.from_numpy(s_labels).long()
        q_inputs = torch.from_numpy(q_inputs).float()
        q_labels = torch.from_numpy(q_labels).long()
        
        variables = list(net.parameters())
        for k in range(step_count + 1):
            if k:
                # 梯度下降算法 gradient descent
                s_logits = net(s_inputs, variables)
                s_loss = criterion(s_logits, s_labels)
                grads = torch.autograd.grad(s_loss, variables)
                variables = [v - update_lr * g for g, v in zip(grads, variables)]

            with torch.no_grad():
                # 计算损失 calculate loss
                q_logits = net(q_inputs, variables)
                q_loss = criterion(q_logits, q_labels)
                q_losses[k] += q_loss.item()
                
                # 计算准确度 calculate accurate
                q_preds = F.softmax(q_logits, dim=1).argmax(dim=1)
                q_correct = torch.eq(q_preds, q_labels).sum().item()  # convert to numpy
                q_corrects[k] += q_correct
                q_totals[k] += q_labels.shape[0]
        
        if meta_optimizer:
            # 计算元学习损失 calculate loss for MAML
            q_logits = net(q_inputs, variables)
            maml_loss = criterion(q_logits, q_labels)
            maml_losses.append(maml_loss)
    
    if meta_optimizer: # 更新 MAML update
        maml_loss = torch.mean(torch.stack(maml_losses))
        meta_optimizer.zero_grad()
        maml_loss.backward()
        meta_optimizer.step()
    
    q_accs = q_corrects / q_totals
    q_losses = q_losses / len(tasks)
    return q_accs, q_losses

训练与测试 Train & test

In [10]:
for epoch in range(train_epoch_count):

    # 训练 train
    train_tasks = get_random_tasks(raw_data['train'], train_batch_size,
            way_count, shot_count, query_count)
    train_accs, _ = run(net, train_tasks, step_count=train_step_count,
            criterion=criterion, meta_optimizer=meta_optimizer)
    if epoch % 50 == 0: # 训练多次后输出一次训练结果
        print('epoch {} : train accuracy {}'.format(epoch, train_accs))

    # 测试 test
    if epoch % test_epoch_interval == 0: # 训练多次后测试一次
        test_tasks = get_random_tasks(raw_data['test'], test_batch_size,
                way_count, shot_count, query_count)
        test_accs, _ = run(net, train_tasks, step_count=test_step_count,
                criterion=criterion)
        print('epoch {} : test accuracy {}'.format(epoch, test_accs))

epoch 0 : train accuracy [0.18958333 0.30875    0.39791667 0.41541667 0.41458333 0.41541667]
epoch 0 : test accuracy [0.195      0.3225     0.40625    0.40958333 0.40916667 0.41
 0.41041667 0.41041667 0.41041667 0.41166667 0.41166667]
epoch 50 : train accuracy [0.20833333 0.55541667 0.59083333 0.6025     0.60416667 0.6075    ]
epoch 100 : train accuracy [0.1825     0.69291667 0.7275     0.72791667 0.72875    0.73      ]
epoch 150 : train accuracy [0.17416667 0.755      0.77291667 0.78166667 0.78208333 0.78291667]
epoch 200 : train accuracy [0.19958333 0.79708333 0.82625    0.82791667 0.82916667 0.83041667]
epoch 200 : test accuracy [0.2025     0.79958333 0.82666667 0.83333333 0.835      0.83541667
 0.83583333 0.83625    0.83541667 0.83541667 0.83541667]
epoch 250 : train accuracy [0.22875    0.86333333 0.8775     0.88041667 0.88291667 0.88208333]
epoch 300 : train accuracy [0.17458333 0.83625    0.84166667 0.84416667 0.84541667 0.84666667]
epoch 350 : train accuracy [0.19125    0.86291

经过训练，测试集上的准确率由 40% 左右提升到超过 90%。

After training, the accuracy in test tasks increases from ~40% to >90%.