# <center> Omniglot Classification </center>

**Omniglot**是元学习中的经典基准数据集，包含1623个不同的字符类别，每个类别有20个样本。本notebook展示了如何使用MAML (Model-Agnostic Meta-Learning)算法在Omniglot数据集上进行Few-Shot学习。


In [1]:
import time

import pandas as pd
import numpy as np
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
plt.style.use('bmh')

import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
torch.manual_seed(1227)
np.random.seed(1227)

import higher

from support.omniglot_loaders import OmniglotNShot

## 任务配置参数

元学习的核心思想是"学会学习"，即训练一个模型使其能够快速适应新任务。在Few-Shot分类任务中：

- **支持集(Support Set)**：用于快速适应的少量标记样本
- **查询集(Query Set)**：用于评估适应后模型性能的样本
- **任务(Task)**：由N个类别(N-way)组成，每个类别有K个支持样本(K-shot)和若干查询样本

定义Few-Shot学习任务的关键参数：
- `n_way`: 每个任务包含的类别数量（这里设为5类）
- `k_spt`: 支持集中每个类别的样本数量（5-shot学习）
- `k_qry`: 查询集中每个类别的样本数量（用于评估）
- `task_num`: 每个meta-batch中的任务数量（32个任务并行训练）


In [2]:
class Args:
    def __init__(self):
        self.n_way = 5
        self.k_spt = 3
        self.k_qry = 6
        self.task_num = 4
        self.device = 'cpu'

args = Args()

## 数据加载器和神经网络模型构建

### 任务采样机制
`OmniglotNShot`数据加载器实现了元学习中的任务采样逻辑：

1. **任务构建**：从所有可用字符类别中随机选择`n_way`个类别构成一个任务
2. **样本采样**：对每个选中的类别，随机采样`k_shot`个样本作为支持集，`k_query`个样本作为查询集
3. **批次生成**：同时生成`task_num`个不同的任务构成一个meta-batch

### 卷积神经网络架构
构建一个简单的CNN分类器：
- 3层卷积层，每层包含BatchNorm和ReLU激活，后接MaxPool
- 最终通过全连接层输出5个类别的logits
- 使用Adam优化器进行元参数更新


In [3]:
# Set up the Omniglot loader.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

db = OmniglotNShot(
    '/tmp/omniglot-data',
    batchsz=args.task_num,
    n_way=args.n_way,
    k_shot=args.k_spt,
    k_query=args.k_qry,
    imgsz=28,
    device=device,
)

net = nn.Sequential(
    nn.Conv2d(1, 16, 3),
    nn.BatchNorm2d(16, momentum=1, affine=True),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(16, 16, 3),
    nn.BatchNorm2d(16, momentum=1, affine=True),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(16, 16, 3),
    nn.BatchNorm2d(16, momentum=1, affine=True),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(2, 2),
    nn.Flatten(),
    nn.Linear(16, 5)).to(device)

# We will use Adam to (meta-)optimize the initial parameters
# to be adapted.
meta_opt = optim.Adam(net.parameters(), lr=1e-3)

load from omniglot.npy.
DB: train (1200, 20, 1, 28, 28) test (423, 20, 1, 28, 28)


## MAML训练函数 - 元学习的核心实现

### 元学习的双层优化过程

#### 内循环 (Inner Loop) - 任务适应
对于每个采样的任务：
1. **参数复制**：使用`higher.innerloop_ctx`创建网络参数的可微分副本
2. **快速适应**：在支持集上执行`n_inner_iter`次梯度下降步骤
3. **参数更新**：`diffopt.step(spt_loss)`更新任务特定的参数

#### 外循环 (Outer Loop) - 元参数优化
1. **查询集评估**：使用适应后的参数在查询集上计算损失
2. **元梯度计算**：`qry_loss.backward()`计算关于初始参数的梯度
3. **元更新**：`meta_opt.step()`更新元参数，使模型更容易适应新任务


In [4]:
def train(db, net, device, meta_opt, epoch, log):
    net.train()
    n_train_iter = db.x_train.shape[0] // db.batchsz

    for batch_idx in range(n_train_iter):
        start_time = time.time()
        # Sample a batch of support and query images and labels.
        x_spt, y_spt, x_qry, y_qry = db.next()

        task_num, setsz, c_, h, w = x_spt.size()
        querysz = x_qry.size(1)

        # Initialize the inner optimizer to adapt the parameters to
        # the support set.
        n_inner_iter = 3
        inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)

        qry_losses = []
        qry_accs = []
        meta_opt.zero_grad()
        for i in range(task_num):
            with higher.innerloop_ctx(
                net, inner_opt, copy_initial_weights=False
            ) as (fnet, diffopt):
                # Optimize the likelihood of the support set by taking
                # gradient steps w.r.t. the model's parameters.
                # This adapts the model's meta-parameters to the task.
                # higher is able to automatically keep copies of
                # your network's parameters as they are being updated.
                for _ in range(n_inner_iter):
                    spt_logits = fnet(x_spt[i])
                    spt_loss = F.cross_entropy(spt_logits, y_spt[i])
                    diffopt.step(spt_loss)

                # The final set of adapted parameters will induce some
                # final loss and accuracy on the query dataset.
                # These will be used to update the model's meta-parameters.
                qry_logits = fnet(x_qry[i])
                qry_loss = F.cross_entropy(qry_logits, y_qry[i])
                qry_losses.append(qry_loss.detach())
                qry_acc = (qry_logits.argmax(
                    dim=1) == y_qry[i]).sum().item() / querysz
                qry_accs.append(qry_acc)

                # Update the model's meta-parameters to optimize the query
                # losses across all of the tasks sampled in this batch.
                # This unrolls through the gradient steps.
                qry_loss.backward()

        meta_opt.step()
        qry_losses = sum(qry_losses) / task_num
        qry_accs = 100. * sum(qry_accs) / task_num
        i = epoch + float(batch_idx) / n_train_iter
        iter_time = time.time() - start_time
        if batch_idx % 4 == 0:
            print(
                f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
            )

        log.append({
            'epoch': i,
            'loss': qry_losses,
            'acc': qry_accs,
            'mode': 'train',
            'time': time.time(),
        })


## 测试函数 - 评估元学习的泛化能力

### 测试流程
1. **新任务采样**：从测试集中采样全新的、训练时从未见过的任务
2. **快速适应**：使用元训练得到的初始参数，在新任务的支持集上进行少量梯度步骤
3. **泛化评估**：在查询集上评估适应后模型的性能

### 与训练的技术区别
- **无元梯度追踪**：`track_higher_grads=False`，因为测试时不需要更新元参数
- **纯前向推理**：`.detach()`操作确保测试过程不影响元参数
- **真实场景模拟**：每个测试任务都是独立的，模拟实际应用中遇到新任务的情况

In [5]:
def test(db, net, device, epoch, log):
    # Crucially in our testing procedure here, we do *not* fine-tune
    # the model during testing for simplicity.
    # Most research papers using MAML for this task do an extra
    # stage of fine-tuning here that should be added if you are
    # adapting this code for research.
    net.train()
    n_test_iter = db.x_test.shape[0] // db.batchsz

    qry_losses = []
    qry_accs = []

    for batch_idx in range(n_test_iter):
        x_spt, y_spt, x_qry, y_qry = db.next('test')

        task_num, setsz, c_, h, w = x_spt.size()
        querysz = x_qry.size(1)

        n_inner_iter = 5
        inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)

        for i in range(task_num):
            with higher.innerloop_ctx(net, inner_opt, track_higher_grads=False) as (fnet, diffopt):
                # Optimize the likelihood of the support set by taking
                # gradient steps w.r.t. the model's parameters.
                # This adapts the model's meta-parameters to the task.
                for _ in range(n_inner_iter):
                    spt_logits = fnet(x_spt[i])
                    spt_loss = F.cross_entropy(spt_logits, y_spt[i])
                    diffopt.step(spt_loss)

                # The query loss and acc induced by these parameters.
                qry_logits = fnet(x_qry[i]).detach()
                qry_loss = F.cross_entropy(
                    qry_logits, y_qry[i], reduction='none')
                qry_losses.append(qry_loss.detach())
                qry_accs.append(
                    (qry_logits.argmax(dim=1) == y_qry[i]).detach())

    qry_losses = torch.cat(qry_losses).mean().item()
    qry_accs = 100. * torch.cat(qry_accs).float().mean().item()
    print(
        f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}'
    )
    log.append({
        'epoch': epoch + 1,
        'loss': qry_losses,
        'acc': qry_accs,
        'mode': 'test',
        'time': time.time(),
    })



可以通过一个简单的可视化函数来直观展示模型训练效果:

In [6]:
def plot(log):
    df = pd.DataFrame(log)
    fig, ax = plt.subplots(figsize=(6, 4))
    train_df = df[df['mode'] == 'train']
    test_df = df[df['mode'] == 'test']
    ax.plot(train_df['epoch'], train_df['acc'], label='Train')
    ax.plot(test_df['epoch'], test_df['acc'], label='Test')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Accuracy')
    ax.set_ylim(70, 100)
    fig.legend(ncol=2, loc='lower right')
    fig.tight_layout()
    fname = './figure/maml-accs.png'
    print(f'--- Plotting accuracy to {fname}')
    fig.savefig(fname)
    plt.close(fig)

## 执行元学习训练

以下代码执行完整的MAML训练过程，包含：

1. **元训练阶段**：
   - 从训练集中反复采样新任务
   - 通过双层优化更新元参数
   - 监控训练过程中的损失和准确率

2. **元测试阶段**：
   - 在每个epoch后评估模型的泛化能力
   - 测试模型在全新任务上的快速适应能力

3. **结果可视化**


In [7]:
log = []
for epoch in range(10):
    train(db, net, device, meta_opt, epoch, log)
    test(db, net, device, epoch, log)
    plot(log)

[Epoch 0.00] Train Loss: 0.97 | Acc: 72.50 | Time: 0.11
[Epoch 0.01] Train Loss: 0.92 | Acc: 77.50 | Time: 0.08
[Epoch 0.03] Train Loss: 0.99 | Acc: 69.17 | Time: 0.09
[Epoch 0.04] Train Loss: 1.00 | Acc: 70.00 | Time: 0.08
[Epoch 0.05] Train Loss: 1.01 | Acc: 75.83 | Time: 0.08
[Epoch 0.07] Train Loss: 0.97 | Acc: 78.33 | Time: 0.09
[Epoch 0.08] Train Loss: 0.91 | Acc: 80.00 | Time: 0.10
[Epoch 0.09] Train Loss: 0.99 | Acc: 69.17 | Time: 0.09
[Epoch 0.11] Train Loss: 0.93 | Acc: 73.33 | Time: 0.09
[Epoch 0.12] Train Loss: 0.91 | Acc: 74.17 | Time: 0.08
[Epoch 0.13] Train Loss: 0.90 | Acc: 77.50 | Time: 0.10
[Epoch 0.15] Train Loss: 0.66 | Acc: 95.00 | Time: 0.11
[Epoch 0.16] Train Loss: 0.78 | Acc: 85.83 | Time: 0.13
[Epoch 0.17] Train Loss: 0.75 | Acc: 88.33 | Time: 0.10
[Epoch 0.19] Train Loss: 0.73 | Acc: 85.83 | Time: 0.10
[Epoch 0.20] Train Loss: 0.73 | Acc: 85.83 | Time: 0.11
[Epoch 0.21] Train Loss: 0.73 | Acc: 84.17 | Time: 0.11
[Epoch 0.23] Train Loss: 0.72 | Acc: 80.00 | Tim