# 使用情景训练法训练模型

情景训练法在少样本学习研究的早期吸引了大量关注。一些论文至今仍在使用这种方法，并将其称为"元学习"。

近期的工作中，将少样本分类器与训练框架区分开来，因此从EasyFSL的v1.0版本开始，对分类器进行情景训练的方法已从FewShotClassifier类的逻辑中移出。相反，我们在本笔记本中提供了一个示例，展示如何对一个少样本分类器执行情景训练。

请自由使用、复制、修改，大开脑洞吧。

## Getting started
First we're going to do some imports (this is not the interesting part).

In [1]:
import copy
from pathlib import Path
import random
from statistics import mean

import numpy as np
import torch
from torch import nn
from tqdm import tqdm

Then we're gonna do the most important thing in Machine Learning research: ensuring reproducibility by setting the random seed. We're going to set the seed for all random packages that we could possibly use, plus some other stuff to make CUDA deterministic (see [here](https://pytorch.org/docs/stable/notes/randomness.html)).

I strongly encourage that you do this in **all your scripts**.

In [2]:
random_seed = 0
np.random.seed(random_seed)
torch.manual_seed(random_seed)
random.seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

Then we're gonna set the shape of our problem.

Also we define our set-up, like the device (change it if you don't have CUDA) or the number of workers for data loading.

In [3]:
n_way = 5
n_shot = 5
n_query = 10

DEVICE = "cuda"
n_workers = 12

## Training

First we define our data loaders for training and validation. You can see that I chose tu use CUB in this notebook, because it's a small dataset, so we can have good results quite quickly. We use `CUB` and `TaskSampler` which are built-in objects from EasyFSL.

In [4]:
from easyfsl.datasets import CUB
from easyfsl.samplers import TaskSampler
from torch.utils.data import DataLoader

import os

# 指定你想切换到的目录路径，这里是为了读数据
new_working_directory = "/ML/Mashuai/few-shot-learning"

# 切换到新的工作目录
os.chdir(new_working_directory)

# 打印当前工
n_tasks_per_epoch = 500
n_validation_tasks = 100

# Instantiate the datasets
train_set = CUB(split="train", training=True)
val_set = CUB(split="val", training=False)

# Those are special batch samplers that sample few-shot classification tasks with a pre-defined shape
train_sampler = TaskSampler(
    train_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_tasks_per_epoch
)
val_sampler = TaskSampler(
    val_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_validation_tasks
)

# Finally, the DataLoader. We customize the collate_fn so that batches are delivered
# in the shape: (support_images, support_labels, query_images, query_labels, class_ids)
train_loader = DataLoader(
    train_set,
    batch_sampler=train_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=train_sampler.episodic_collate_fn,
)
val_loader = DataLoader(
    val_set,
    batch_sampler=val_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=val_sampler.episodic_collate_fn,
)

然后我们定义网络。在这里，我选择了Prototypical Networks和PyTorch内置的ResNet18，因为它很容易。

In [20]:
from easyfsl.methods import PrototypicalNetworks, FewShotClassifier,RelationNetworks
from easyfsl.modules import resnet12,default_relation_module


convolutional_network = None

In [21]:
# few_shot_classifier = PrototypicalNetworks(convolutional_network).to(DEVICE)
few_shot_classifier = RelationNetworks(feature_dimension=3).to(DEVICE)

Now let's define our training helpers ! I chose to use Stochastic Gradient Descent on 200 epochs with a scheduler that divides the learning rate by 10 after 120 and 160 epochs. The strategy is derived from [this repo](https://github.com/fiveai/on-episodes-fsl).

We're also gonna use a TensorBoard because it's always good to see what your training curves look like.

In [22]:
from torch.optim import SGD, Optimizer
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.tensorboard import SummaryWriter


LOSS_FUNCTION = nn.CrossEntropyLoss()

n_epochs = 200
scheduler_milestones = [120, 160]
scheduler_gamma = 0.1
learning_rate = 1e-2
tb_logs_dir = Path("../logs")

train_optimizer = SGD(
    few_shot_classifier.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4
)
train_scheduler = MultiStepLR(
    train_optimizer,
    milestones=scheduler_milestones,
    gamma=scheduler_gamma,
)

tb_writer = SummaryWriter(log_dir=str(tb_logs_dir))

And now let's get to it! Here we define the function that performs a training epoch.

We use tqdm to monitor the training in real time in our logs.

In [23]:
def training_epoch(
    model: FewShotClassifier, data_loader: DataLoader, optimizer: Optimizer
):
    all_loss = []
    model.train()
    with tqdm(
        enumerate(data_loader), total=len(data_loader), desc="Training"
    ) as tqdm_train:
        for episode_index, (
            support_images,
            support_labels,
            query_images,
            query_labels,
            _,
        ) in tqdm_train:
            optimizer.zero_grad()
            model.process_support_set(
                support_images.to(DEVICE), support_labels.to(DEVICE)
            )
            classification_scores = model(query_images.to(DEVICE))

            loss = LOSS_FUNCTION(classification_scores, query_labels.to(DEVICE))
            loss.backward()
            optimizer.step()

            all_loss.append(loss.item())

            tqdm_train.set_postfix(loss=mean(all_loss))

    return mean(all_loss)

And we have everything we need! To perform validations we'll just use the built-in `evaluate` function from `easyfsl.methods.utils`.

This is now the time to **start training**.

I added something to log the state of the model that gave the best performance on the validation set.

In [24]:
from easyfsl.utils import evaluate


best_state = few_shot_classifier.state_dict()
best_validation_accuracy = 0.0
for epoch in range(n_epochs):
    print(f"Epoch {epoch}")
    average_loss = training_epoch(few_shot_classifier, train_loader, train_optimizer)
    validation_accuracy = evaluate(
        few_shot_classifier, val_loader, device=DEVICE, tqdm_prefix="Validation"
    )

    if validation_accuracy > best_validation_accuracy:
        best_validation_accuracy = validation_accuracy
        best_state = copy.deepcopy(few_shot_classifier.state_dict())
        # state_dict() returns a reference to the still evolving model's state so we deepcopy
        # https://pytorch.org/tutorials/beginner/saving_loading_models
        print("Ding ding ding! We found a new best model!")

    tb_writer.add_scalar("Train/loss", average_loss, epoch)
    tb_writer.add_scalar("Val/acc", validation_accuracy, epoch)

    # Warn the scheduler that we did an epoch
    # so it knows when to decrease the learning rate
    train_scheduler.step()

Epoch 0


Training: 100%|██████████████████████████| 500/500 [00:20<00:00, 24.66it/s, loss=1.6]
Validation: 100%|████████████████████| 100/100 [00:04<00:00, 23.94it/s, accuracy=0.3]

Ding ding ding! We found a new best model!
Epoch 1



Training: 100%|█████████████████████████| 500/500 [00:18<00:00, 26.32it/s, loss=1.59]
Validation:  72%|█████████████▋     | 72/100 [00:03<00:01, 22.20it/s, accuracy=0.294]


KeyboardInterrupt: 

Yay we successfully performed Episodic Training! Now if you want to you can retrieve the best model's state.

In [11]:
few_shot_classifier.load_state_dict(best_state)

<All keys matched successfully>

## Evaluation

Now that our model is trained, we want to test it.

First step: we fetch the test data.

In [12]:
n_test_tasks = 1000

test_set = CUB(split="test", training=False)
test_sampler = TaskSampler(
    test_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_test_tasks
)
test_loader = DataLoader(
    test_set,
    batch_sampler=test_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=test_sampler.episodic_collate_fn,
)

Second step: we run the few-shot classifier on the test data.

In [13]:
accuracy = evaluate(few_shot_classifier, test_loader, device=DEVICE)
print(f"Average accuracy : {(100 * accuracy):.2f} %")

 13%|████▏                          | 134/1000 [00:08<00:51, 16.66it/s, accuracy=0.4]


KeyboardInterrupt: 

Congrats! You performed Episodic Training using EasyFSL. If you want to compare with a model trained using classical training, look at [this other example notebook](classical_training.ipynb).
