Meta Learning : Learn how to learn
Machine Learning = Looking for a function
Step 1: Function with unknown
Step 2: Define loss function
Step 3: Optimization

Learning algorithm:Hand-crafted
机器自己学出来的就是learnable components

Meta Learning是跨任务学习Across-task Training,Machine Learning是Within-task Testing

In [1]:
workspace_dir = '/kaggle/working/'

# Download dataset
!wget https://www.dropbox.com/s/pqeym3n4jly5e89/Omniglot.tar.gz?dl=1 \
    -O "{workspace_dir}/Omniglot.tar.gz"
!wget https://www.dropbox.com/s/nlvokertmksfc42/Omniglot-test.tar.gz?dl=1 \
    -O "{workspace_dir}/Omniglot-test.tar.gz"

# Use `tar' command to decompress
!tar -zxf "{workspace_dir}/Omniglot.tar.gz" -C "{workspace_dir}/"
!tar -zxf "{workspace_dir}/Omniglot-test.tar.gz" -C "{workspace_dir}/"

--2025-07-03 07:52:26--  https://www.dropbox.com/s/pqeym3n4jly5e89/Omniglot.tar.gz?dl=1
Resolving www.dropbox.com (www.dropbox.com)... 162.125.3.18, 2620:100:6018:18::a27d:312
Connecting to www.dropbox.com (www.dropbox.com)|162.125.3.18|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://www.dropbox.com/scl/fi/peae3bis6c9i96zsmmbzc/Omniglot.tar.gz?rlkey=v9ljhktg1wiy3x9otdz3p7k8c&dl=1 [following]
--2025-07-03 07:52:26--  https://www.dropbox.com/scl/fi/peae3bis6c9i96zsmmbzc/Omniglot.tar.gz?rlkey=v9ljhktg1wiy3x9otdz3p7k8c&dl=1
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://ucc8b4901ceb6b316a8bbf2bb94c.dl.dropboxusercontent.com/cd/0/inline/Csx3YKOvieFHZWQeYexq-yQSQXgGPQ8RkBrJn9CRnG0U-lDcuiFOYMt-ZQTbN1RP6l8xzEOD9VQ54O-BaWPISvo-IQAYFXmOO-F6C3RzYMb4N81nMl4yDaIn-HMNfF0H0J-_4ZHtqMMLIRxxDVwrfIlt/file?dl=1# [following]
--2025-07-03 07:52:26--  https://ucc8b4901ceb6b316a8bbf2bb94c.dl.dropbox

In [2]:
# Import modules we need
import glob, random
from collections import OrderedDict

import numpy as np
from tqdm.auto import tqdm

import torch, torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms

from PIL import Image
from IPython.display import display

# Check device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"DEVICE = {device}")

# Fix random seeds
random_seed = 0
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(random_seed)

DEVICE = cuda


In [4]:
def ConvBlock(in_ch: int, out_ch: int):
    return nn.Sequential(
        nn.Conv2d(in_ch, out_ch, 3, padding=1),
        nn.BatchNorm2d(out_ch),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
    )


def ConvBlockFunction(x, w, b, w_bn, b_bn):
    x = F.conv2d(x, w, b, padding=1)
    x = F.batch_norm(
        x, running_mean=None, running_var=None, weight=w_bn, bias=b_bn, training=True
    )
    x = F.relu(x)
    x = F.max_pool2d(x, kernel_size=2, stride=2)
    return x

In [6]:
class Classifier(nn.Module):
    def __init__(self, in_ch, k_way):
        super(Classifier, self).__init__()
        self.conv1 = ConvBlock(in_ch, 64)
        self.conv2 = ConvBlock(64, 64)
        self.conv3 = ConvBlock(64, 64)
        self.conv4 = ConvBlock(64, 64)
        self.logits = nn.Linear(64, k_way)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = x.view(x.shape[0], -1)
        x = self.logits(x)
        return x

    def functional_forward(self, x, params):
        """
        Arguments:
        x: input images [batch, 1, 28, 28]
        params: model parameters,
                i.e. weights and biases of convolution
                     and weights and biases of
                                   batch normalization
                type is an OrderedDict

        Arguments:
        x: input images [batch, 1, 28, 28]
        params: The model parameters,
                i.e. weights and biases of convolution
                     and batch normalization layers
                It's an `OrderedDict`
        """
        for block in [1, 2, 3, 4]:
            x = ConvBlockFunction(
                x,
                params[f"conv{block}.0.weight"],
                params[f"conv{block}.0.bias"],
                params.get(f"conv{block}.1.weight"),
                params.get(f"conv{block}.1.bias"),
            )
        x = x.view(x.shape[0], -1)
        x = F.linear(x, params["logits.weight"], params["logits.bias"])
        return x

In [7]:
def create_label(n_way, k_shot):
    return torch.arange(n_way).repeat_interleave(k_shot).long()


# Try to create labels for 5-way 2-shot setting
create_label(5, 2)

tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4])

In [8]:
def calculate_accuracy(logits, labels):
    """utility function for accuracy calculation"""
    acc = np.asarray(
        [(torch.argmax(logits, -1).cpu().numpy() == labels.cpu().numpy())]
    ).mean()
    return acc

In [9]:
# Dataset for train and val
class Omniglot(Dataset):
    def __init__(self, data_dir, k_shot, q_query, task_num=None):
        self.file_list = [
            f for f in glob.glob(data_dir + "**/character*", recursive=True)
        ]
        # limit task number if task_num is set
        if task_num is not None:
            self.file_list = self.file_list[: min(len(self.file_list), task_num)]
        self.transform = transforms.Compose([transforms.ToTensor()])
        self.n = k_shot + q_query

    def __getitem__(self, idx):
        # For random sampling the characters we want.
        img_path = self.file_list[idx]
        img_list = [f for f in glob.glob(img_path + "**/*.png", recursive=True)]
        img_list.sort()
        
        sample = np.arange(len(img_list))
        np.random.shuffle(sample)
        
        # `k_shot + q_query` examples for each character
        imgs = [self.transform(Image.open(img_list[idx])) for idx in sample[:self.n]]
        imgs = torch.stack(imgs)
        return imgs

    def __len__(self):
        return len(self.file_list)

In [10]:
def BaseSolver(
    model,
    optimizer,
    x,
    n_way,
    k_shot,
    q_query,
    loss_fn,
    inner_train_step=1,
    inner_lr=0.4,
    train=True,
    return_labels=False,
):
    criterion, task_loss, task_acc = loss_fn, [], []
    labels = []

    for meta_batch in x:
        # Get data
        support_set = meta_batch[: n_way * k_shot]
        query_set = meta_batch[n_way * k_shot :]

        if train:
            """ training loop """
            # Use the support set to calculate loss
            labels = create_label(n_way, k_shot).to(device)
            logits = model.forward(support_set)
            loss = criterion(logits, labels)

            task_loss.append(loss)
            task_acc.append(calculate_accuracy(logits, labels))
        else:
            """ validation / testing loop """
            # First update model with support set images for `inner_train_step` steps
            fast_weights = OrderedDict(model.named_parameters())


            for inner_step in range(inner_train_step):
                # Simply training
                train_label = create_label(n_way, k_shot).to(device)
                logits = model.functional_forward(support_set, fast_weights)
                loss = criterion(logits, train_label)

                grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)
                # Perform SGD
                fast_weights = OrderedDict(
                    (name, param - inner_lr * grad)
                    for ((name, param), grad) in zip(fast_weights.items(), grads)
                )

            if not return_labels:
                """ validation """
                val_label = create_label(n_way, q_query).to(device)

                logits = model.functional_forward(query_set, fast_weights)
                loss = criterion(logits, val_label)
                task_loss.append(loss)
                task_acc.append(calculate_accuracy(logits, val_label))
            else:
                """ testing """
                logits = model.functional_forward(query_set, fast_weights)
                labels.extend(torch.argmax(logits, -1).cpu().numpy())

    if return_labels:
        return labels

    batch_loss = torch.stack(task_loss).mean()
    task_acc = np.mean(task_acc)

    if train:
        # Update model
        model.train()
        optimizer.zero_grad()
        batch_loss.backward()
        optimizer.step()

    return batch_loss, task_acc

In [11]:
def MetaSolver(
    model,
    optimizer,
    x,
    n_way,
    k_shot,
    q_query,
    loss_fn,
    inner_train_step=1,
    inner_lr=0.4,
    train=True,
    return_labels=False
):
    criterion, task_loss, task_acc = loss_fn, [], []
    labels = []

    for meta_batch in x:
        # Get data
        support_set = meta_batch[: n_way * k_shot]
        query_set = meta_batch[n_way * k_shot :]

        # Copy the params for inner loop
        fast_weights = OrderedDict(model.named_parameters())

        ### ---------- INNER TRAIN LOOP ---------- ###
        for inner_step in range(inner_train_step):
            # Simply training
            train_label = create_label(n_way, k_shot).to(device)
            logits = model.functional_forward(support_set, fast_weights)
            loss = criterion(logits, train_label)
            # Inner gradients update! vvvvvvvvvvvvvvvvvvvv #
            """ Inner Loop Update """
            # TODO: Finish the inner loop update rule
            raise NotImplementedError
            # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ #

        ### ---------- INNER VALID LOOP ---------- ###
        if not return_labels:
            """ training / validation """
            val_label = create_label(n_way, q_query).to(device)

            # Collect gradients for outer loop
            logits = model.functional_forward(query_set, fast_weights)
            loss = criterion(logits, val_label)
            task_loss.append(loss)
            task_acc.append(calculate_accuracy(logits, val_label))
        else:
            """ testing """
            logits = model.functional_forward(query_set, fast_weights)
            labels.extend(torch.argmax(logits, -1).cpu().numpy())

    if return_labels:
        return labels

    # Update outer loop
    model.train()
    optimizer.zero_grad()

    meta_batch_loss = torch.stack(task_loss).mean()
    if train:
        """ Outer Loop Update """
        # TODO: Finish the outer loop update
        raise NotimplementedError

    task_acc = np.mean(task_acc)
    return meta_batch_loss, task_acc

In [12]:
n_way = 5
k_shot = 1
q_query = 1
train_inner_train_step = 1
val_inner_train_step = 3
inner_lr = 0.4
meta_lr = 0.001
meta_batch_size = 32
max_epoch = 30
eval_batches = 20
train_data_path = "/kaggle/working/Omniglot/images_background"

In [14]:
def dataloader_init(datasets, shuffle=True, num_workers=2):
    train_set, val_set = datasets
    train_loader = DataLoader(
        train_set,
        # The "batch_size" here is not \
        #    the meta batch size, but  \
        #    how many different        \
        #    characters in a task,     \
        #    i.e. the "n_way" in       \
        #    few-shot classification.
        batch_size=n_way,
        num_workers=num_workers,
        shuffle=shuffle,
        drop_last=True,
    )
    val_loader = DataLoader(
        val_set, batch_size=n_way, num_workers=num_workers, shuffle=shuffle, drop_last=True
    )

    train_iter = iter(train_loader)
    val_iter = iter(val_loader)
    return (train_loader, val_loader), (train_iter, val_iter)

In [16]:
def model_init():
    meta_model = Classifier(1, n_way).to(device)
    optimizer = torch.optim.Adam(meta_model.parameters(), lr=meta_lr)
    loss_fn = nn.CrossEntropyLoss().to(device)
    return meta_model, optimizer, loss_fn

In [17]:
def get_meta_batch(meta_batch_size, k_shot, q_query, data_loader, iterator):
    data = []
    for _ in range(meta_batch_size):
        try:
            # a "task_data" tensor is representing \
            #     the data of a task, with size of \
            #     [n_way, k_shot+q_query, 1, 28, 28]
            task_data = next(iterator)
        except StopIteration:
            iterator = iter(data_loader)
            task_data = next(iterator)
        train_data = task_data[:, :k_shot].reshape(-1, 1, 28, 28)
        val_data = task_data[:, k_shot:].reshape(-1, 1, 28, 28)
        task_data = torch.cat((train_data, val_data), 0)
        data.append(task_data)
    return torch.stack(data).to(device), iterator

In [21]:
# 设置求解器类型（基础版或元学习版）
solver = 'base'  # 可选 'base' 或 'meta'
meta_model, optimizer, loss_fn = model_init()

# 根据求解器类型初始化数据集和加载器
if solver == 'base':
    max_epoch = 5  # 基础求解器只需要5个epoch
    Solver = BaseSolver
    
    # 创建数据集并检查实际长度
    dataset = Omniglot(train_data_path, k_shot, q_query, task_num=10)
    actual_length = len(dataset)
    print(f"数据集实际长度: {actual_length}")  # 调试信息
    
    # 确保训练集和验证集分割正确
    if actual_length >= 10:  # 如果数据集足够大
        train_size = 8  # 训练集大小（可调整）
        val_size = 2    # 验证集大小
    else:  # 如果数据集比预期小
        train_size = int(0.8 * actual_length)  # 80%训练
        val_size = actual_length - train_size  # 20%验证
    
    # 随机分割数据集
    train_set, val_set = torch.utils.data.random_split(dataset, [train_size, val_size])
    
    # 初始化数据加载器
    (train_loader, val_loader), (train_iter, val_iter) = dataloader_init(
        (train_set, val_set), 
        shuffle=False  # 不打乱顺序
    )

elif solver == 'meta':
    Solver = MetaSolver
    dataset = Omniglot(train_data_path, k_shot, q_query)
    train_split = int(0.8 * len(dataset))  # 80%训练
    val_split = len(dataset) - train_split  # 20%验证
    train_set, val_set = torch.utils.data.random_split(dataset, [train_split, val_split])
    (train_loader, val_loader), (train_iter, val_iter) = dataloader_init((train_set, val_set))
else:
    raise NotImplementedError("不支持的求解器类型")

# 主训练循环
for epoch in range(max_epoch):
    print(f"\n第 {epoch + 1} 轮训练")
    train_meta_loss = []
    train_acc = []
    
    # 使用tqdm显示进度条
    for step in tqdm(range(max(1, len(train_loader) // meta_batch_size)), desc="训练进度"):
        try:
            # 获取一个元批量的数据
            x, train_iter = get_meta_batch(
                meta_batch_size, 
                k_shot, 
                q_query, 
                train_loader, 
                train_iter
            )
            
            # 使用求解器训练
            meta_loss, acc = Solver(
                meta_model,
                optimizer,
                x,
                n_way,
                k_shot,
                q_query,
                loss_fn, 
                inner_train_step=train_inner_train_step
            )
            
            train_meta_loss.append(meta_loss.item())
            train_acc.append(acc)
            
        except StopIteration:
            print("\n警告: 数据迭代器已耗尽，重新初始化...")
            train_iter = iter(train_loader)
            continue
    
    # 打印训练结果
    print("  平均损失: %.3f" % np.mean(train_meta_loss), end="\t")
    print("  准确率: %.2f%%" % (np.mean(train_acc) * 100))

    # 验证阶段
    val_acc = []
    for eval_step in tqdm(range(max(1, len(val_loader) // eval_batches)), desc="验证进度"):
        try:
            x, val_iter = get_meta_batch(
                eval_batches, 
                k_shot, 
                q_query, 
                val_loader, 
                val_iter
            )
            _, acc = Solver(
                meta_model,
                optimizer,
                x,
                n_way,
                k_shot,
                q_query,
                loss_fn,
                inner_train_step=val_inner_train_step,
                train=False,  # 验证模式
            )
            val_acc.append(acc)
            
        except StopIteration:
            print("\n警告: 验证数据迭代器已耗尽，重新初始化...")
            val_iter = iter(val_loader)
            continue
    
    print("  验证准确率: %.2f%%" % (np.mean(val_acc) * 100))

数据集实际长度: 0

第 1 轮训练


训练进度:   0%|          | 0/1 [00:00<?, ?it/s]


警告: 数据迭代器已耗尽，重新初始化...
  平均损失: nan	  准确率: nan%


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


验证进度:   0%|          | 0/1 [00:00<?, ?it/s]


警告: 验证数据迭代器已耗尽，重新初始化...
  验证准确率: nan%

第 2 轮训练


训练进度:   0%|          | 0/1 [00:00<?, ?it/s]


警告: 数据迭代器已耗尽，重新初始化...
  平均损失: nan	  准确率: nan%


验证进度:   0%|          | 0/1 [00:00<?, ?it/s]


警告: 验证数据迭代器已耗尽，重新初始化...
  验证准确率: nan%

第 3 轮训练


训练进度:   0%|          | 0/1 [00:00<?, ?it/s]


警告: 数据迭代器已耗尽，重新初始化...
  平均损失: nan	  准确率: nan%


验证进度:   0%|          | 0/1 [00:00<?, ?it/s]


警告: 验证数据迭代器已耗尽，重新初始化...
  验证准确率: nan%

第 4 轮训练


训练进度:   0%|          | 0/1 [00:00<?, ?it/s]


警告: 数据迭代器已耗尽，重新初始化...
  平均损失: nan	  准确率: nan%


验证进度:   0%|          | 0/1 [00:00<?, ?it/s]


警告: 验证数据迭代器已耗尽，重新初始化...
  验证准确率: nan%

第 5 轮训练


训练进度:   0%|          | 0/1 [00:00<?, ?it/s]


警告: 数据迭代器已耗尽，重新初始化...
  平均损失: nan	  准确率: nan%


验证进度:   0%|          | 0/1 [00:00<?, ?it/s]


警告: 验证数据迭代器已耗尽，重新初始化...
  验证准确率: nan%


In [23]:
import os

# test dataset
class OmniglotTest(Dataset):
    def __init__(self, test_dir):
        self.test_dir = test_dir
        self.n = 5

        self.transform = transforms.Compose([transforms.ToTensor()])

    def __getitem__(self, idx):
        support_files = [
            os.path.join(self.test_dir, "support", f"{idx:>04}", f"image_{i}.png")
            for i in range(self.n)
        ]
        query_files = [
            os.path.join(self.test_dir, "query", f"{idx:>04}", f"image_{i}.png")
            for i in range(self.n)
        ]

        support_imgs = torch.stack(
            [self.transform(Image.open(e)) for e in support_files]
        )
        query_imgs = torch.stack([self.transform(Image.open(e)) for e in query_files])

        return support_imgs, query_imgs

    def __len__(self):
        return len(os.listdir(os.path.join(self.test_dir, "support")))

In [24]:
test_inner_train_step = 10 # you can change this

test_batches = 20
test_dataset = OmniglotTest("Omniglot-test")
test_loader = DataLoader(test_dataset, batch_size=test_batches, shuffle=False)

output = []
for _, batch in enumerate(tqdm(test_loader)):
    support_set, query_set = batch
    x = torch.cat([support_set, query_set], dim=1)
    x = x.to(device)

    labels = Solver(
        meta_model,
        optimizer,
        x,
        n_way,
        k_shot,
        q_query,
        loss_fn,
        inner_train_step=test_inner_train_step,
        train=False,
        return_labels=True,
    )

    output.extend(labels)

  0%|          | 0/32 [00:00<?, ?it/s]

In [25]:
# write to csv
with open("output.csv", "w") as f:
    f.write(f"id,class\n")
    for i, label in enumerate(output):
        f.write(f"{i},{label}\n")