# 使用情景训练法训练模型



## Getting started


In [1]:
# 导包
import copy
from pathlib import Path
import random
from statistics import mean
import numpy as np
import torch
from torch import nn
from tqdm import tqdm

In [2]:
#设置随机种子
random_seed = 0
np.random.seed(random_seed)
torch.manual_seed(random_seed)
random.seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
# 设置采样策略,训练设备信息等
n_way = 5
n_shot = 5
n_query = 10

DEVICE = "cuda"
n_workers = 12

In [4]:
from easyfsl.datasets import CUB
from easyfsl.samplers import TaskSampler
from torch.utils.data import DataLoader

import os
# 指定你想切换到的目录路径，这里是为了读数据，因为目前这个notebook在子目录下面，
new_working_directory = "/ML/Mashuai/few-shot-learning"

# 切换到新的工作目录
os.chdir(new_working_directory)

# 打印当前工
n_tasks_per_epoch = 500
n_validation_tasks = 100

# 实例化数据集
train_set = CUB(split="train", training=True)
val_set = CUB(split="val", training=False)

# 这些是特殊的批量采样器，对具有预定义形状的few-shot分类任务进行采样
train_sampler = TaskSampler(
    train_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_tasks_per_epoch
)
val_sampler = TaskSampler(
    val_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_validation_tasks
)

# 最后是DataLoader。我们定制collate_fn以便批量交付
# 格式为: (support_images, support_labels, query_images, query_labels, class_ids)
train_loader = DataLoader(
    train_set,
    batch_sampler=train_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=train_sampler.episodic_collate_fn,
)
val_loader = DataLoader(
    val_set,
    batch_sampler=val_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=val_sampler.episodic_collate_fn,
)

然后我们定义网络。在这里，我选择了Prototypical Networks和PyTorch内置的ResNet18，因为它很容易。

In [5]:
from easyfsl.methods import PrototypicalNetworks, FewShotClassifier,RelationNetworks
from easyfsl.modules import resnet12,default_relation_module
convolutional_network = resnet12(use_fc=True,num_classes=len(set(train_set.get_labels()))).to(DEVICE)

In [6]:
convolutional_network = nn.DataParallel(convolutional_network)

In [7]:
few_shot_classifier = PrototypicalNetworks(convolutional_network).to(DEVICE)


In [8]:
from torch.optim import SGD, Optimizer
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.tensorboard import SummaryWriter


LOSS_FUNCTION = nn.CrossEntropyLoss()

n_epochs = 200
scheduler_milestones = [120, 160]
scheduler_gamma = 0.1
learning_rate = 1e-2
tb_logs_dir = Path("../logs")

train_optimizer = SGD(few_shot_classifier.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)
train_scheduler = MultiStepLR(train_optimizer,milestones=scheduler_milestones,gamma=scheduler_gamma,)

tb_writer = SummaryWriter(log_dir=str(tb_logs_dir))

In [9]:
def training_epoch(model: FewShotClassifier, data_loader: DataLoader, optimizer: Optimizer):
    all_loss = []
    model.train()
    with tqdm(enumerate(data_loader), total=len(data_loader), desc="Training") as tqdm_train:
        for episode_index, (support_images,support_labels,query_images,query_labels,_,) in tqdm_train:
            optimizer.zero_grad()
            model.process_support_set(support_images.to(DEVICE), support_labels.to(DEVICE) )
            classification_scores = model(query_images.to(DEVICE))

            loss = LOSS_FUNCTION(classification_scores, query_labels.to(DEVICE))
            loss.backward()
            optimizer.step()

            all_loss.append(loss.item())

            tqdm_train.set_postfix(loss=mean(all_loss))

    return mean(all_loss)

我们有我们需要的一切!要执行验证，我们将使用内置的来自`easyfsl.methods.utils` 的`evaluate` 函数 .

现在是时候 **start training 原神，启动！**.

我添加了一些东西来记录在验证集上提供最佳性能的模型的状态。

In [10]:
from easyfsl.utils import evaluate


best_state = few_shot_classifier.state_dict()
best_validation_accuracy = 0.0
for epoch in range(n_epochs):
    print(f"Epoch {epoch}")
    average_loss = training_epoch(few_shot_classifier, train_loader, train_optimizer)
    validation_accuracy = evaluate(few_shot_classifier, val_loader, device=DEVICE, tqdm_prefix="Validation")

    if validation_accuracy > best_validation_accuracy:
        best_validation_accuracy = validation_accuracy
        best_state = copy.deepcopy(few_shot_classifier.state_dict())
        print("Ding ding ding! We found a new best model!")

    tb_writer.add_scalar("Train/loss", average_loss, epoch)
    tb_writer.add_scalar("Val/acc", validation_accuracy, epoch)

    # Warn the scheduler that we did an epoch
    # so it knows when to decrease the learning rate
    train_scheduler.step()

Epoch 0


Training: 100%|█████████████████████| 500/500 [01:26<00:00,  5.80it/s, loss=1.53]
Validation: 100%|██████████████| 100/100 [00:05<00:00, 19.99it/s, accuracy=0.447]


Ding ding ding! We found a new best model!
Epoch 1


Training: 100%|█████████████████████| 500/500 [01:22<00:00,  6.06it/s, loss=1.47]
Validation: 100%|██████████████| 100/100 [00:05<00:00, 17.73it/s, accuracy=0.482]

Ding ding ding! We found a new best model!
Epoch 2



Training: 100%|█████████████████████| 500/500 [01:22<00:00,  6.07it/s, loss=1.41]
Validation: 100%|██████████████| 100/100 [00:05<00:00, 19.99it/s, accuracy=0.472]

Epoch 3



Training: 100%|█████████████████████| 500/500 [01:21<00:00,  6.16it/s, loss=1.34]
Validation: 100%|██████████████| 100/100 [00:05<00:00, 19.78it/s, accuracy=0.488]

Ding ding ding! We found a new best model!
Epoch 4



Training:  35%|███████▎             | 173/500 [00:28<00:54,  6.05it/s, loss=1.32]


KeyboardInterrupt: 

我们成功完成了情景训练!现在，如果您愿意，您可以检索最佳模型的状态。

In [11]:
few_shot_classifier.load_state_dict(best_state)

<All keys matched successfully>

## Evaluation

现在我们的模型已经训练好了，我们想要测试它。
第一步:获取测试数据。

In [12]:
n_test_tasks = 1000

test_set = CUB(split="test", training=False)
test_sampler = TaskSampler(
    test_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_test_tasks
)
test_loader = DataLoader(
    test_set,
    batch_sampler=test_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=test_sampler.episodic_collate_fn,
)

Second step: we run the few-shot classifier on the test data.

In [13]:
accuracy = evaluate(few_shot_classifier, test_loader, device=DEVICE)
print(f"Average accuracy : {(100 * accuracy):.2f} %")

 13%|████▏                          | 134/1000 [00:08<00:51, 16.66it/s, accuracy=0.4]


KeyboardInterrupt: 

Congrats! You performed Episodic Training using EasyFSL. If you want to compare with a model trained using classical training, look at [this other example notebook](classical_training.ipynb).
