# 使用情景训练法训练模型



## Getting started


In [1]:
# 导包
import copy
from pathlib import Path
import random
from statistics import mean
import numpy as np
import torch
from torch import nn
from tqdm import tqdm

In [2]:
#设置随机种子
random_seed = 0
np.random.seed(random_seed)
torch.manual_seed(random_seed)
random.seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
# 设置采样策略,训练设备信息等
n_way = 5
n_shot = 5
n_query = 10

DEVICE = "cuda"
n_workers = 12

In [4]:
from easyfsl.datasets import CUB
from easyfsl.samplers import TaskSampler
from torch.utils.data import DataLoader

import os
# 指定你想切换到的目录路径，这里是为了读数据，因为目前这个notebook在子目录下面，
new_working_directory = "/ML/Mashuai/few-shot-learning"

# 切换到新的工作目录
os.chdir(new_working_directory)

# 打印当前工
n_tasks_per_epoch = 500
n_validation_tasks = 100

# 实例化数据集
train_set = CUB(split="train", training=True)
val_set = CUB(split="val", training=False)

# 这些是特殊的批量采样器，对具有预定义形状的few-shot分类任务进行采样
train_sampler = TaskSampler(
    train_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_tasks_per_epoch
)
val_sampler = TaskSampler(
    val_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_validation_tasks
)

# 最后是DataLoader。我们定制collate_fn以便批量交付
# 格式为: (support_images, support_labels, query_images, query_labels, class_ids)
train_loader = DataLoader(
    train_set,
    batch_sampler=train_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=train_sampler.episodic_collate_fn,
)
val_loader = DataLoader(
    val_set,
    batch_sampler=val_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=val_sampler.episodic_collate_fn,
)

然后我们定义网络。在这里，我选择了Prototypical Networks和PyTorch内置的ResNet18，因为它很容易。

In [5]:
from easyfsl.methods import PrototypicalNetworks, FewShotClassifier,RelationNetworks
from easyfsl.modules import resnet12,default_relation_module
convolutional_network = None

In [6]:
# few_shot_classifier = PrototypicalNetworks(convolutional_network).to(DEVICE)
few_shot_classifier = RelationNetworks(feature_dimension=3).to(DEVICE)

In [7]:
from torch.optim import SGD, Optimizer
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.tensorboard import SummaryWriter


LOSS_FUNCTION = nn.CrossEntropyLoss()

n_epochs = 200
scheduler_milestones = [120, 160]
scheduler_gamma = 0.1
learning_rate = 1e-2
tb_logs_dir = Path("../logs")

train_optimizer = SGD(few_shot_classifier.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)
train_scheduler = MultiStepLR(train_optimizer,milestones=scheduler_milestones,gamma=scheduler_gamma,)

tb_writer = SummaryWriter(log_dir=str(tb_logs_dir))

In [8]:
def training_epoch(model: FewShotClassifier, data_loader: DataLoader, optimizer: Optimizer):
    all_loss = []
    model.train()
    with tqdm(enumerate(data_loader), total=len(data_loader), desc="Training") as tqdm_train:
        for episode_index, (support_images,support_labels,query_images,query_labels,_,) in tqdm_train:
            optimizer.zero_grad()
            model.process_support_set(support_images.to(DEVICE), support_labels.to(DEVICE) )
            classification_scores = model(query_images.to(DEVICE))

            loss = LOSS_FUNCTION(classification_scores, query_labels.to(DEVICE))
            loss.backward()
            optimizer.step()

            all_loss.append(loss.item())

            tqdm_train.set_postfix(loss=mean(all_loss))

    return mean(all_loss)

我们有我们需要的一切!要执行验证，我们将使用内置的来自`easyfsl.methods.utils` 的`evaluate` 函数 .

现在是时候 **start training 原神，启动！**.

我添加了一些东西来记录在验证集上提供最佳性能的模型的状态。

In [9]:
from easyfsl.utils import evaluate


best_state = few_shot_classifier.state_dict()
best_validation_accuracy = 0.0
for epoch in range(n_epochs):
    print(f"Epoch {epoch}")
    average_loss = training_epoch(few_shot_classifier, train_loader, train_optimizer)
    validation_accuracy = evaluate(few_shot_classifier, val_loader, device=DEVICE, tqdm_prefix="Validation")

    if validation_accuracy > best_validation_accuracy:
        best_validation_accuracy = validation_accuracy
        best_state = copy.deepcopy(few_shot_classifier.state_dict())
        print("Ding ding ding! We found a new best model!")

    tb_writer.add_scalar("Train/loss", average_loss, epoch)
    tb_writer.add_scalar("Val/acc", validation_accuracy, epoch)

    # Warn the scheduler that we did an epoch
    # so it knows when to decrease the learning rate
    train_scheduler.step()

Epoch 0


Training: 100%|█████████████████████| 500/500 [00:20<00:00, 24.11it/s, loss=1.61]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 25.04it/s, accuracy=0.216]

Ding ding ding! We found a new best model!
Epoch 1



Training: 100%|██████████████████████| 500/500 [00:18<00:00, 27.05it/s, loss=1.6]
Validation: 100%|███████████████| 100/100 [00:04<00:00, 24.63it/s, accuracy=0.29]

Ding ding ding! We found a new best model!
Epoch 2



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 27.01it/s, loss=1.58]
Validation: 100%|██████████████| 100/100 [00:04<00:00, 24.19it/s, accuracy=0.363]

Ding ding ding! We found a new best model!
Epoch 3



Training: 100%|█████████████████████| 500/500 [00:17<00:00, 27.83it/s, loss=1.56]
Validation: 100%|██████████████| 100/100 [00:04<00:00, 24.59it/s, accuracy=0.363]

Ding ding ding! We found a new best model!
Epoch 4



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 27.42it/s, loss=1.55]
Validation: 100%|██████████████| 100/100 [00:04<00:00, 24.79it/s, accuracy=0.356]

Epoch 5



Training: 100%|█████████████████████| 500/500 [00:17<00:00, 28.49it/s, loss=1.54]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 25.35it/s, accuracy=0.321]

Epoch 6



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 27.32it/s, loss=1.53]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 26.74it/s, accuracy=0.354]

Epoch 7



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 27.59it/s, loss=1.53]
Validation: 100%|███████████████| 100/100 [00:04<00:00, 24.00it/s, accuracy=0.39]

Ding ding ding! We found a new best model!
Epoch 8



Training: 100%|█████████████████████| 500/500 [00:17<00:00, 28.35it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:04<00:00, 24.45it/s, accuracy=0.373]

Epoch 9



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 27.22it/s, loss=1.53]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 26.89it/s, accuracy=0.372]

Epoch 10



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 26.87it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:04<00:00, 23.99it/s, accuracy=0.392]

Ding ding ding! We found a new best model!
Epoch 11



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 27.14it/s, loss=1.53]
Validation: 100%|███████████████| 100/100 [00:03<00:00, 26.29it/s, accuracy=0.37]

Epoch 12



Training: 100%|█████████████████████| 500/500 [00:17<00:00, 27.98it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 25.21it/s, accuracy=0.373]

Epoch 13



Training: 100%|█████████████████████| 500/500 [00:17<00:00, 28.23it/s, loss=1.53]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 25.70it/s, accuracy=0.391]

Epoch 14



Training: 100%|█████████████████████| 500/500 [00:17<00:00, 28.09it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 25.57it/s, accuracy=0.379]

Epoch 15



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 27.53it/s, loss=1.51]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 26.45it/s, accuracy=0.388]

Epoch 16



Training: 100%|█████████████████████| 500/500 [00:17<00:00, 28.21it/s, loss=1.53]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 27.71it/s, accuracy=0.373]

Epoch 17



Training: 100%|█████████████████████| 500/500 [00:17<00:00, 28.03it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:04<00:00, 24.64it/s, accuracy=0.384]

Epoch 18



Training: 100%|█████████████████████| 500/500 [00:17<00:00, 28.38it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 26.30it/s, accuracy=0.385]

Epoch 19



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 27.68it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 25.95it/s, accuracy=0.377]

Epoch 20



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 26.92it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:04<00:00, 24.38it/s, accuracy=0.367]

Epoch 21



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 27.51it/s, loss=1.51]
Validation: 100%|██████████████| 100/100 [00:04<00:00, 23.71it/s, accuracy=0.381]

Epoch 22



Training: 100%|█████████████████████| 500/500 [00:17<00:00, 28.09it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 26.79it/s, accuracy=0.391]

Epoch 23



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 27.11it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:04<00:00, 24.85it/s, accuracy=0.399]

Ding ding ding! We found a new best model!
Epoch 24



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 27.34it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 25.18it/s, accuracy=0.393]

Epoch 25



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 27.77it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:04<00:00, 24.46it/s, accuracy=0.395]

Epoch 26



Training: 100%|█████████████████████| 500/500 [00:17<00:00, 28.07it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 26.56it/s, accuracy=0.387]

Epoch 27



Training: 100%|█████████████████████| 500/500 [00:17<00:00, 28.15it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:04<00:00, 23.43it/s, accuracy=0.367]

Epoch 28



Training: 100%|█████████████████████| 500/500 [00:17<00:00, 28.37it/s, loss=1.51]
Validation: 100%|██████████████| 100/100 [00:04<00:00, 24.57it/s, accuracy=0.402]

Ding ding ding! We found a new best model!
Epoch 29



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 27.03it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:04<00:00, 24.70it/s, accuracy=0.361]

Epoch 30



Training: 100%|█████████████████████| 500/500 [00:17<00:00, 27.80it/s, loss=1.51]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 25.42it/s, accuracy=0.408]

Ding ding ding! We found a new best model!
Epoch 31



Training: 100%|█████████████████████| 500/500 [00:17<00:00, 28.03it/s, loss=1.52]
Validation: 100%|███████████████| 100/100 [00:03<00:00, 26.21it/s, accuracy=0.34]

Epoch 32



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 26.64it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 25.29it/s, accuracy=0.392]

Epoch 33



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 26.68it/s, loss=1.51]
Validation: 100%|██████████████| 100/100 [00:04<00:00, 24.00it/s, accuracy=0.389]

Epoch 34



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 26.94it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:04<00:00, 23.89it/s, accuracy=0.369]

Epoch 35



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 26.33it/s, loss=1.51]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 25.14it/s, accuracy=0.349]

Epoch 36



Training: 100%|█████████████████████| 500/500 [00:17<00:00, 27.93it/s, loss=1.52]
Validation: 100%|████████████████| 100/100 [00:04<00:00, 22.87it/s, accuracy=0.4]

Epoch 37



Training: 100%|█████████████████████| 500/500 [00:17<00:00, 27.97it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 26.17it/s, accuracy=0.386]

Epoch 38



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 27.55it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:04<00:00, 24.85it/s, accuracy=0.395]

Epoch 39



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 27.01it/s, loss=1.52]
Validation: 100%|███████████████| 100/100 [00:03<00:00, 27.29it/s, accuracy=0.39]

Epoch 40



Training: 100%|█████████████████████| 500/500 [00:17<00:00, 28.48it/s, loss=1.52]
Validation: 100%|████████████████| 100/100 [00:04<00:00, 23.89it/s, accuracy=0.4]

Epoch 41



Training: 100%|█████████████████████| 500/500 [00:17<00:00, 28.50it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:04<00:00, 24.84it/s, accuracy=0.392]

Epoch 42



Training: 100%|█████████████████████| 500/500 [00:17<00:00, 28.21it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 25.16it/s, accuracy=0.399]

Epoch 43



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 26.72it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 26.85it/s, accuracy=0.334]

Epoch 44



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 27.54it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 25.74it/s, accuracy=0.383]

Epoch 45



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 26.44it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 25.61it/s, accuracy=0.399]

Epoch 46



Training: 100%|█████████████████████| 500/500 [00:17<00:00, 28.11it/s, loss=1.51]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 25.96it/s, accuracy=0.385]

Epoch 47



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 27.57it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:04<00:00, 23.50it/s, accuracy=0.415]

Ding ding ding! We found a new best model!
Epoch 48



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 27.27it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 25.17it/s, accuracy=0.386]

Epoch 49



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 27.02it/s, loss=1.51]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 25.26it/s, accuracy=0.378]

Epoch 50



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 26.86it/s, loss=1.51]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 27.07it/s, accuracy=0.403]

Epoch 51



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 27.66it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 26.15it/s, accuracy=0.355]

Epoch 52



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 27.31it/s, loss=1.51]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 25.90it/s, accuracy=0.399]

Epoch 53



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 27.63it/s, loss=1.51]
Validation: 100%|████████████████| 100/100 [00:04<00:00, 24.13it/s, accuracy=0.4]

Epoch 54



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 27.08it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 26.36it/s, accuracy=0.374]

Epoch 55



Training: 100%|█████████████████████| 500/500 [00:17<00:00, 27.92it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 25.24it/s, accuracy=0.364]

Epoch 56



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 27.22it/s, loss=1.51]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 26.18it/s, accuracy=0.416]

Ding ding ding! We found a new best model!
Epoch 57



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 27.15it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:04<00:00, 24.37it/s, accuracy=0.382]

Epoch 58



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 27.59it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:04<00:00, 24.35it/s, accuracy=0.407]

Epoch 59



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 27.77it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 25.52it/s, accuracy=0.384]

Epoch 60



Training: 100%|█████████████████████| 500/500 [00:17<00:00, 28.70it/s, loss=1.51]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 25.47it/s, accuracy=0.369]

Epoch 61



Training: 100%|█████████████████████| 500/500 [00:17<00:00, 27.95it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 25.26it/s, accuracy=0.367]

Epoch 62



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 26.46it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 25.73it/s, accuracy=0.346]

Epoch 63



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 27.27it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 25.61it/s, accuracy=0.401]

Epoch 64



Training: 100%|█████████████████████| 500/500 [00:17<00:00, 27.82it/s, loss=1.51]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 25.69it/s, accuracy=0.404]

Epoch 65



Training: 100%|█████████████████████| 500/500 [00:17<00:00, 27.98it/s, loss=1.51]
Validation: 100%|██████████████| 100/100 [00:04<00:00, 24.87it/s, accuracy=0.377]

Epoch 66



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 27.51it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 25.16it/s, accuracy=0.335]

Epoch 67



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 27.37it/s, loss=1.51]
Validation: 100%|██████████████| 100/100 [00:04<00:00, 23.85it/s, accuracy=0.397]

Epoch 68



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 27.13it/s, loss=1.51]
Validation: 100%|██████████████| 100/100 [00:04<00:00, 24.60it/s, accuracy=0.397]

Epoch 69



Training: 100%|█████████████████████| 500/500 [00:17<00:00, 28.31it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 25.06it/s, accuracy=0.382]

Epoch 70



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 27.65it/s, loss=1.51]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 25.58it/s, accuracy=0.382]

Epoch 71



Training: 100%|█████████████████████| 500/500 [00:17<00:00, 28.28it/s, loss=1.51]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 26.76it/s, accuracy=0.387]

Epoch 72



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 27.08it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 26.12it/s, accuracy=0.367]

Epoch 73



Training: 100%|█████████████████████| 500/500 [00:17<00:00, 27.92it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 26.62it/s, accuracy=0.385]

Epoch 74



Training: 100%|█████████████████████| 500/500 [00:17<00:00, 27.88it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 26.84it/s, accuracy=0.381]

Epoch 75



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 27.74it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 27.80it/s, accuracy=0.384]

Epoch 76



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 27.73it/s, loss=1.51]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 26.04it/s, accuracy=0.354]

Epoch 77



Training: 100%|█████████████████████| 500/500 [00:17<00:00, 28.73it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 26.85it/s, accuracy=0.368]

Epoch 78



Training: 100%|█████████████████████| 500/500 [00:18<00:00, 27.16it/s, loss=1.51]
Validation: 100%|██████████████| 100/100 [00:04<00:00, 24.19it/s, accuracy=0.404]

Epoch 79



Training: 100%|█████████████████████| 500/500 [00:20<00:00, 24.50it/s, loss=1.51]
Validation: 100%|██████████████| 100/100 [00:04<00:00, 23.87it/s, accuracy=0.359]

Epoch 80



Training: 100%|█████████████████████| 500/500 [00:23<00:00, 21.24it/s, loss=1.52]
Validation: 100%|██████████████| 100/100 [00:03<00:00, 25.66it/s, accuracy=0.388]

Epoch 81



Training:  24%|█████                | 120/500 [00:06<00:19, 19.85it/s, loss=1.53]
Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7f9ab990f2e0>>
Traceback (most recent call last):
  File "/root/anaconda3/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 770, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 

KeyboardInterrupt



我们成功完成了情景训练!现在，如果您愿意，您可以检索最佳模型的状态。

In [11]:
few_shot_classifier.load_state_dict(best_state)

<All keys matched successfully>

## Evaluation

现在我们的模型已经训练好了，我们想要测试它。
第一步:获取测试数据。

In [12]:
n_test_tasks = 1000

test_set = CUB(split="test", training=False)
test_sampler = TaskSampler(
    test_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_test_tasks
)
test_loader = DataLoader(
    test_set,
    batch_sampler=test_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=test_sampler.episodic_collate_fn,
)

Second step: we run the few-shot classifier on the test data.

In [13]:
accuracy = evaluate(few_shot_classifier, test_loader, device=DEVICE)
print(f"Average accuracy : {(100 * accuracy):.2f} %")

 13%|████▏                          | 134/1000 [00:08<00:51, 16.66it/s, accuracy=0.4]


KeyboardInterrupt: 

Congrats! You performed Episodic Training using EasyFSL. If you want to compare with a model trained using classical training, look at [this other example notebook](classical_training.ipynb).
