# Train a model with Episodic Training
Episodic training has attracted a lot of interest in the early years of Few-Shot Learning research. Some papers still use it, and refer to it as "meta-learning".

Recent works distinguish the Few-Shot Classifier from the training framework, so as from v1.0 of EasyFSL, methods to episodically train a classifier were taken out of the logic of the FewShotClassifier class. Instead, we provide in this notebook an example of how to perform episodic training on a few-shot classifier.

Use it, copy it, change it, get crazy.

## Getting started
First we're going to do some imports (this is not the interesting part).

In [None]:
try:
    import google.colab
    colab = True
except:
    colab = False

In [None]:
if colab is True:
    # Running in Google Colab
    # Clone the repo
    !git clone https://github.com/sicara/easy-few-shot-learning
    %cd easy-few-shot-learning
    !pip install .
else:
    # Run locally
    # Ensure working directory is the project's root
    # Make sure easyfsl is installed!
    %cd ..

In [None]:
from pathlib import Path
import random
from statistics import mean

import numpy as np
import torch
from torch import nn
from tqdm import tqdm

Then we're gonna do the most important thing in Machine Learning research: ensuring reproducibility by setting the random seed. We're going to set the seed for all random packages that we could possibly use, plus some other stuff to make CUDA deterministic (see [here](https://pytorch.org/docs/stable/notes/randomness.html)).

I strongly encourage that you do this in **all your scripts**.

In [None]:
random_seed = 0
np.random.seed(random_seed)
torch.manual_seed(random_seed)
random.seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

Then we're gonna set the shape of our problem.

Also we define our set-up, like the device (change it if you don't have CUDA) or the number of workers for data loading.

In [None]:
n_way = 5
n_shot = 5
n_query = 10

DEVICE = "cuda"
n_workers = 12

## Training

First we define our data loaders for training and validation. You can see that I chose tu use CUB in this notebook, because it's a small dataset, so we can have good results quite quickly. We use `CUB` and `TaskSampler` which are built-in objects from EasyFSL.

In [None]:
# Download the CUB dataset
!make download-cub

In [None]:
from easyfsl.datasets import CUB
from easyfsl.samplers import TaskSampler
from torch.utils.data import DataLoader


n_tasks_per_epoch = 500
n_validation_tasks = 100

# Instantiate the datasets
train_set = CUB(split="train", training=True)
val_set = CUB(split="val", training=False)

# Those are special batch samplers that sample few-shot classification tasks with a pre-defined shape
train_sampler = TaskSampler(
    train_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_tasks_per_epoch
)
val_sampler = TaskSampler(
    val_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_validation_tasks
)

# Finally, the DataLoader. We customize the collate_fn so that batches are delivered
# in the shape: (support_images, support_labels, query_images, query_labels, class_ids)
train_loader = DataLoader(
    train_set,
    batch_sampler=train_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=train_sampler.episodic_collate_fn,
)
val_loader = DataLoader(
    val_set,
    batch_sampler=val_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=val_sampler.episodic_collate_fn,
)

And then we define the network. Here I chose Prototypical Networks and the built-in ResNet18 from PyTorch because it's easy.

In [None]:
from easyfsl.methods import PrototypicalNetworks, FewShotClassifier
from easyfsl.modules import resnet12


convolutional_network = resnet12()
few_shot_classifier = PrototypicalNetworks(convolutional_network).to(DEVICE)

Now let's define our training helpers ! I chose to use Stochastic Gradient Descent on 200 epochs with a scheduler that divides the learning rate by 10 after 120 and 160 epochs. The strategy is derived from [this repo](https://github.com/fiveai/on-episodes-fsl).

We're also gonna use a TensorBoard because it's always good to see what your training curves look like.

In [None]:
from torch.optim import SGD, Optimizer
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.tensorboard import SummaryWriter


LOSS_FUNCTION = nn.CrossEntropyLoss()

n_epochs = 200
scheduler_milestones = [120, 160]
scheduler_gamma = 0.1
learning_rate = 1e-2
tb_logs_dir = Path(".")

train_optimizer = SGD(
    few_shot_classifier.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4
)
train_scheduler = MultiStepLR(
    train_optimizer,
    milestones=scheduler_milestones,
    gamma=scheduler_gamma,
)

tb_writer = SummaryWriter(log_dir=str(tb_logs_dir))

And now let's get to it! Here we define the function that performs a training epoch.

We use tqdm to monitor the training in real time in our logs.

In [None]:
def training_epoch(
    model: FewShotClassifier, data_loader: DataLoader, optimizer: Optimizer
):
    all_loss = []
    model.train()
    with tqdm(
        enumerate(data_loader), total=len(data_loader), desc="Training"
    ) as tqdm_train:
        for episode_index, (
            support_images,
            support_labels,
            query_images,
            query_labels,
            _,
        ) in tqdm_train:
            optimizer.zero_grad()
            model.process_support_set(
                support_images.to(DEVICE), support_labels.to(DEVICE)
            )
            classification_scores = model(query_images.to(DEVICE))

            loss = LOSS_FUNCTION(classification_scores, query_labels.to(DEVICE))
            loss.backward()
            optimizer.step()

            all_loss.append(loss.item())

            tqdm_train.set_postfix(loss=mean(all_loss))

    return mean(all_loss)

And we have everything we need! To perform validations we'll just use the built-in `evaluate` function from `easyfsl.methods.utils`.

This is now the time to **start training**.

I added something to log the state of the model that gave the best performance on the validation set.

In [None]:
from easyfsl.utils import evaluate


best_state = few_shot_classifier.state_dict()
best_validation_accuracy = 0.0
for epoch in range(n_epochs):
    print(f"Epoch {epoch}")
    average_loss = training_epoch(few_shot_classifier, train_loader, train_optimizer)
    validation_accuracy = evaluate(
        few_shot_classifier, val_loader, device=DEVICE, tqdm_prefix="Validation"
    )

    if validation_accuracy > best_validation_accuracy:
        best_validation_accuracy = validation_accuracy
        best_state = few_shot_classifier.state_dict()
        print("Ding ding ding! We found a new best model!")

    tb_writer.add_scalar("Train/loss", average_loss, epoch)
    tb_writer.add_scalar("Val/acc", validation_accuracy, epoch)

    # Warn the scheduler that we did an epoch
    # so it knows when to decrease the learning rate
    train_scheduler.step()

Yay we successfully performed Episodic Training! Now if you want to you can retrieve the best model's state.

In [None]:
few_shot_classifier.load_state_dict(best_state)

## Evaluation

Now that our model is trained, we want to test it.

First step: we fetch the test data.

In [None]:
n_test_tasks = 1000

test_set = CUB(split="test", training=False)
test_sampler = TaskSampler(
    test_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_test_tasks
)
test_loader = DataLoader(
    test_set,
    batch_sampler=test_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=test_sampler.episodic_collate_fn,
)

Second step: we run the few-shot classifier on the test data.

In [None]:
accuracy = evaluate(few_shot_classifier, test_loader, device=DEVICE)
print(f"Average accuracy : {(100 * accuracy):.2f} %")

Congrats! You performed Episodic Training using EasyFSL. If you want to compare with a model trained using classical training, look at [this other example notebook](classical_training.ipynb).
