# Train a model with Episodic Training
Episodic training has attracted a lot of interest in the early years of Few-Shot Learning research. Some papers still use it, and refer to it as "meta-learning".

Recent works distinguish the Few-Shot Classifier from the training framework, so as from v1.0 of EasyFSL, methods to episodically train a classifier were taken out of the logic of the FewShotClassifier class. Instead, we provide in this notebook an example of how to perform episodic training on a few-shot classifier.

Use it, copy it, change it, get crazy.

## Getting started
First we're going to do some imports (this is not the interesting part).

In [1]:
try:
    import google.colab
    colab = True
except:
    colab = False

In [2]:
!pip install easyfsl

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting easyfsl
  Downloading easyfsl-1.3.0-py3-none-any.whl (52 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.4/52.4 kB[0m [31m1.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: easyfsl
Successfully installed easyfsl-1.3.0


In [3]:
from pathlib import Path
import random
from statistics import mean

import numpy as np
import torch
from torch import nn
from tqdm import tqdm

import torch
import numpy as np
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import GTSRB
from torchvision.models import resnet18
from tqdm import tqdm

from easyfsl.samplers import TaskSampler
from easyfsl.utils import plot_images, sliding_average

Then we're gonna do the most important thing in Machine Learning research: ensuring reproducibility by setting the random seed. We're going to set the seed for all random packages that we could possibly use, plus some other stuff to make CUDA deterministic (see [here](https://pytorch.org/docs/stable/notes/randomness.html)).

I strongly encourage that you do this in **all your scripts**.

In [4]:
random_seed = 0
np.random.seed(random_seed)
torch.manual_seed(random_seed)
random.seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

Then we're gonna set the shape of our problem.

Also we define our set-up, like the device (change it if you don't have CUDA) or the number of workers for data loading.

In [5]:
n_way = 5
n_shot = 5
n_query = 10

DEVICE = "cuda"
n_workers = 12

## Training

First we define our data loaders for training and validation. You can see that I chose tu use CUB in this notebook, because it's a small dataset, so we can have good results quite quickly. We use `CUB` and `TaskSampler` which are built-in objects from EasyFSL.

In [6]:
from keras.preprocessing.image import ImageDataGenerator
import keras
from urllib.request import urlretrieve
import zipfile
import time
import matplotlib.pyplot as plt
import numpy as np

In [7]:
from torchvision.datasets import GTSRB
from easyfsl.samplers import TaskSampler
from torch.utils.data import DataLoader


n_tasks_per_epoch = 500
n_validation_tasks = 100
image_size = 28

# Instantiate the datasets
train_set = GTSRB(
    root="./data",
    # background=True,
    split="train",
    transform=transforms.Compose(
        [
            transforms.Grayscale(num_output_channels=3),
            transforms.RandomResizedCrop(image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ]
    ),
    download=True,
)

separate = int(np.shape(train_set)[0] * 0.8)
train_set, valid_set = torch.utils.data.random_split(train_set, [separate, np.shape(train_set)[0] - separate])

train_set.get_labels = lambda: [instance[1] for instance in train_set]

valid_set.get_labels = lambda: [instance[1] for instance in valid_set]

# Those are special batch samplers that sample few-shot classification tasks with a pre-defined shape
train_sampler = TaskSampler(
    train_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_tasks_per_epoch
)

val_sampler = TaskSampler(
    valid_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_validation_tasks
)

# Finally, the DataLoader. We customize the collate_fn so that batches are delivered
# in the shape: (support_images, support_labels, query_images, query_labels, class_ids)
train_loader = DataLoader(
    train_set,
    batch_sampler=train_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=train_sampler.episodic_collate_fn,
)

val_loader = DataLoader(
    valid_set,
    batch_sampler=val_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=val_sampler.episodic_collate_fn,
)

Downloading https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB-Training_fixed.zip to data/gtsrb/GTSRB-Training_fixed.zip


100%|██████████| 187490228/187490228 [00:07<00:00, 24493160.00it/s]


Extracting data/gtsrb/GTSRB-Training_fixed.zip to data/gtsrb


  result = asarray(a).shape
  result = asarray(a).shape


In [8]:
!mkdir /content/data/gtsrb/GTSRB/Testing

In [9]:
!mv /content/data/gtsrb/GTSRB/Training/00021 /content/data/gtsrb/GTSRB/Testing
!mv /content/data/gtsrb/GTSRB/Training/00037 /content/data/gtsrb/GTSRB/Testing
!mv /content/data/gtsrb/GTSRB/Training/00042 /content/data/gtsrb/GTSRB/Testing
!mv /content/data/gtsrb/GTSRB/Training/00016 /content/data/gtsrb/GTSRB/Testing
!mv /content/data/gtsrb/GTSRB/Training/00006 /content/data/gtsrb/GTSRB/Testing
!mv /content/data/gtsrb/GTSRB/Training/00005 /content/data/gtsrb/GTSRB/Testing
!mv /content/data/gtsrb/GTSRB/Training/00004 /content/data/gtsrb/GTSRB/Testing
!mv /content/data/gtsrb/GTSRB/Training/00035 /content/data/gtsrb/GTSRB/Testing
!mv /content/data/gtsrb/GTSRB/Training/00000 /content/data/gtsrb/GTSRB/Testing
!mv /content/data/gtsrb/GTSRB/Training/00029 /content/data/gtsrb/GTSRB/Testing
!mv /content/data/gtsrb/GTSRB/Training/00010 /content/data/gtsrb/GTSRB/Testing
!mv /content/data/gtsrb/GTSRB/Training/00012 /content/data/gtsrb/GTSRB/Testing
!mv /content/data/gtsrb/GTSRB/Training/00040 /content/data/gtsrb/GTSRB/Testing
!mv /content/data/gtsrb/GTSRB/Training/00009 /content/data/gtsrb/GTSRB/Testing
!mv /content/data/gtsrb/GTSRB/Training/00036 /content/data/gtsrb/GTSRB/Testing
!mv /content/data/gtsrb/GTSRB/Training/00031 /content/data/gtsrb/GTSRB/Testing
!mv /content/data/gtsrb/GTSRB/Training/00041 /content/data/gtsrb/GTSRB/Testing
!mv /content/data/gtsrb/GTSRB/Training/00018 /content/data/gtsrb/GTSRB/Testing

In [10]:
import os

old_folder_name = '/content/data/gtsrb/GTSRB/Training'
new_folder_name = '/content/data/gtsrb/GTSRB/Trainingg'

# Rename the folder
os.rename(old_folder_name, new_folder_name)

old_folder_name = '/content/data/gtsrb/GTSRB/Testing'
new_folder_name = '/content/data/gtsrb/GTSRB/Training'

# Rename the folder
os.rename(old_folder_name, new_folder_name)

In [11]:
n_test_tasks = 1000

test_set = GTSRB(
    root="./data",
    # background=False,
    split="train",
    transform=transforms.Compose(
        [
            transforms.Grayscale(num_output_channels=3),
            transforms.Resize([int(image_size * 1.15), int(image_size * 1.15)]),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
        ]
    ),
    download=False,
)

test_set.get_labels = lambda: [instance[1] for instance in test_set]

test_sampler = TaskSampler(
    test_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_test_tasks
)

test_loader = DataLoader(
    test_set,
    batch_sampler=test_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=test_sampler.episodic_collate_fn,
)

In [12]:
len(train_set), len(valid_set), len(test_set)

(21312, 5328, 10680)

In [None]:
# !wget "http://btsd.ethz.ch/shareddata/BelgiumTSC/BelgiumTSC_Training.zip"
# !wget "https://btsd.ethz.ch/shareddata/BelgiumTSC/BelgiumTSC_Testing.zip"

--2023-05-20 13:18:24--  http://btsd.ethz.ch/shareddata/BelgiumTSC/BelgiumTSC_Training.zip
Resolving btsd.ethz.ch (btsd.ethz.ch)... 129.132.52.168, 2001:67c:10ec:36c2::168
Connecting to btsd.ethz.ch (btsd.ethz.ch)|129.132.52.168|:80... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://btsd.ethz.ch/shareddata/BelgiumTSC/BelgiumTSC_Training.zip [following]
--2023-05-20 13:18:24--  https://btsd.ethz.ch/shareddata/BelgiumTSC/BelgiumTSC_Training.zip
Connecting to btsd.ethz.ch (btsd.ethz.ch)|129.132.52.168|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 174298785 (166M) [application/zip]
Saving to: ‘BelgiumTSC_Training.zip’


2023-05-20 13:18:45 (8.38 MB/s) - ‘BelgiumTSC_Training.zip’ saved [174298785/174298785]

--2023-05-20 13:18:45--  https://btsd.ethz.ch/shareddata/BelgiumTSC/BelgiumTSC_Testing.zip
Resolving btsd.ethz.ch (btsd.ethz.ch)... 129.132.52.168, 2001:67c:10ec:36c2::168
Connecting to btsd.ethz.ch (btsd.ethz.ch)|129.132.52.168

In [13]:
import shutil
folder_path = "/content/data/gtsrb/GTSRB/Training"
shutil.rmtree(folder_path)
# folder_path = "/content/data/gtsrb/GTSRB/Testing"
# shutil.rmtree(folder_path)

In [None]:
import os

old_folder_name = '/content/data/gtsrb/GTSRB/Training'
new_folder_name = '/content/data/gtsrb/GTSRB/Trainingg'

# Rename the folder
os.rename(old_folder_name, new_folder_name)

In [14]:
extract_path = "./data/gtsrb/GTSRB"
validation_data_dir = './data/gtsrb/GTSRB/Training'

def train_data(train_url):
    zip_dir = "./data/BelgiumTSC_Training.zip"

    print("Downloading Belgium TSC Training Dataset\n")
    urlretrieve(train_url, zip_dir)
    zip_ref = zipfile.ZipFile(zip_dir)

    print("Extracting Zip\n")
    zip_ref.extractall(extract_path)
    zip_ref.close()

def test_data(test_url):
    zip_dir = "./data/BelgiumTSC_Training.zip"

    print("Downloading Belgium TSC Testing Dataset\n")
    urlretrieve(test_url, zip_dir)
    zip_ref = zipfile.ZipFile(zip_dir)

    print("Extracting Zip\n")
    zip_ref.extractall(extract_path)
    zip_ref.close()


def download_data():
    print("Download Datasets")
    start = time.time()
    train_data("http://btsd.ethz.ch/shareddata/BelgiumTSC/BelgiumTSC_Training.zip")
    # test_data("http://btsd.ethz.ch/shareddata/BelgiumTSC/BelgiumTSC_Testing.zip")
    end = time.time()
    print("Downloading Datasets 'BelgiumTSC' took ", end - start, 'seconds')

download_data()

Download Datasets
Downloading Belgium TSC Training Dataset

Extracting Zip

Downloading Datasets 'BelgiumTSC' took  22.946743726730347 seconds


In [None]:
import os
import shutil

cntt =0
summ = 0
folder_A_path = '/content/data/gtsrb/GTSRB/Training'

# Iterate over subdirectories in folder_A
for root, dirs, files in os.walk(folder_A_path):
    for dir in dirs:
        folder_B_path = os.path.join(root, dir)
        ppm_file_count = sum(1 for file in os.listdir(folder_B_path) if file.endswith('.ppm'))
        print(f"Folder {folder_B_path} contains {ppm_file_count} .ppm files.")

        summ += ppm_file_count

        if ppm_file_count < 80:
          cntt += 1
          shutil.rmtree(folder_B_path)

print(summ, "SUM")

In [None]:
len(valid_set) + len(train_set)

4575

In [None]:
import random
from typing import Dict, Iterator, List, Tuple, Union

import torch
from torch import Tensor
from torch.utils.data import Sampler

from easyfsl.datasets import FewShotDataset

GENERIC_TYPING_ERROR_MESSAGE = (
    "Check out the output's type of your dataset's __getitem__() method."
    "It must be a Tuple[Tensor, int] or Tuple[Tensor, 0-dim Tensor]."
)

print("HAHA2")
class TaskSamplerr(Sampler):
    """
    Samples batches in the shape of few-shot classification tasks. At each iteration, it will sample
    n_way classes, and then sample support and query images from these classes.
    """

    def __init__(
        self,
        dataset: FewShotDataset,
        n_way: int,
        n_shot: int,
        n_query: int,
        n_tasks: int,
    ):
        """
        Args:
            dataset: dataset from which to sample classification tasks. Must have implement get_labels() from
                FewShotDataset.
            n_way: number of classes in one task
            n_shot: number of support images for each class in one task
            n_query: number of query images for each class in one task
            n_tasks: number of tasks to sample
        """
        print("HAHAAA: ", dataset.get_labels())
        super().__init__(data_source=None)
        self.n_way = n_way
        self.n_shot = n_shot
        self.n_query = n_query
        self.n_tasks = n_tasks

        self.items_per_label: Dict[int, List[int]] = {}
        
        for item, label in enumerate(dataset.get_labels()):
            print("HAHA3: ", item, label)
            if label in self.items_per_label:
                self.items_per_label[label].append(item)
            else:
                self.items_per_label[label] = [item]

        print("HAHA: ", self.items_per_label) 

        for itemss in self.items_per_label:
          print(itemss, len(self.items_per_label[itemss]))

        self._check_dataset_size_fits_sampler_parameters()

    def __len__(self) -> int:
        return self.n_tasks

    def __iter__(self) -> Iterator[List[int]]:
        """
        Sample n_way labels uniformly at random,
        and then sample n_shot + n_query items for each label, also uniformly at random.
        Yields:
            a list of indices of length (n_way * (n_shot + n_query))
        """
        for _ in range(self.n_tasks):
            yield torch.cat(
                [
                    torch.tensor(
                        random.sample(
                            self.items_per_label[label], self.n_shot + self.n_query
                        )
                    )
                    for label in random.sample(
                        sorted(self.items_per_label.keys()), self.n_way
                    )
                ]
            ).tolist()

    def episodic_collate_fn(
        self, input_data: List[Tuple[Tensor, Union[Tensor, int]]]
    ) -> Tuple[Tensor, Tensor, Tensor, Tensor, List[int]]:
        """
        Collate function to be used as argument for the collate_fn parameter of episodic
            data loaders.
        Args:
            input_data: each element is a tuple containing:
                - an image as a torch Tensor of shape (n_channels, height, width)
                - the label of this image as an int or a 0-dim tensor
        Returns:
            tuple(Tensor, Tensor, Tensor, Tensor, list[int]): respectively:
                - support images of shape (n_way * n_shot, n_channels, height, width),
                - their labels of shape (n_way * n_shot),
                - query images of shape (n_way * n_query, n_channels, height, width)
                - their labels of shape (n_way * n_query),
                - the dataset class ids of the class sampled in the episode
        """
        input_data_with_int_labels = self._cast_input_data_to_tensor_int_tuple(
            input_data
        )
        true_class_ids = list({x[1] for x in input_data_with_int_labels})
        all_images = torch.cat([x[0].unsqueeze(0) for x in input_data_with_int_labels])
        all_images = all_images.reshape(
            (self.n_way, self.n_shot + self.n_query, *all_images.shape[1:])
        )
        all_labels = torch.tensor(
            [true_class_ids.index(x[1]) for x in input_data_with_int_labels]
        ).reshape((self.n_way, self.n_shot + self.n_query))
        support_images = all_images[:, : self.n_shot].reshape(
            (-1, *all_images.shape[2:])
        )
        query_images = all_images[:, self.n_shot :].reshape((-1, *all_images.shape[2:]))
        support_labels = all_labels[:, : self.n_shot].flatten()
        query_labels = all_labels[:, self.n_shot :].flatten()
        return (
            support_images,
            support_labels,
            query_images,
            query_labels,
            true_class_ids,
        )

    @staticmethod
    def _cast_input_data_to_tensor_int_tuple(
        input_data: List[Tuple[Tensor, Union[Tensor, int]]]
    ) -> List[Tuple[Tensor, int]]:
        """
        Check the type of the input for the episodic_collate_fn method, and cast it to the right type if possible.
        Args:
            input_data: each element is a tuple containing:
                - an image as a torch Tensor of shape (n_channels, height, width)
                - the label of this image as an int or a 0-dim tensor
        Returns:
            the input data with the labels cast to int
        Raises:
            TypeError : Wrong type of input images or labels
            ValueError: Input label is not a 0-dim tensor
        """
        for image, label in input_data:
            if not isinstance(image, Tensor):
                raise TypeError(
                    f"Illegal type of input instance: {type(image)}. "
                    + GENERIC_TYPING_ERROR_MESSAGE
                )
            if not isinstance(label, int):
                if not isinstance(label, Tensor):
                    raise TypeError(
                        f"Illegal type of input label: {type(label)}. "
                        + GENERIC_TYPING_ERROR_MESSAGE
                    )
                if label.dtype not in {
                    torch.uint8,
                    torch.int8,
                    torch.int16,
                    torch.int32,
                    torch.int64,
                }:
                    raise TypeError(
                        f"Illegal dtype of input label tensor: {label.dtype}. "
                        + GENERIC_TYPING_ERROR_MESSAGE
                    )
                if label.ndim != 0:
                    raise ValueError(
                        f"Illegal shape for input label tensor: {label.shape}. "
                        + GENERIC_TYPING_ERROR_MESSAGE
                    )

        return [(image, int(label)) for (image, label) in input_data]

    def _check_dataset_size_fits_sampler_parameters(self):
        """
        Check that the dataset size is compatible with the sampler parameters
        """
        self._check_dataset_has_enough_labels()
        self._check_dataset_has_enough_items_per_label()

    def _check_dataset_has_enough_labels(self):
        if self.n_way > len(self.items_per_label):
            raise ValueError(
                f"The number of labels in the dataset ({len(self.items_per_label)} "
                f"must be greater or equal to n_way ({self.n_way})."
            )

    def _check_dataset_has_enough_items_per_label(self):
        print("HI")
        number_of_samples_per_label = [
            len(items_for_label) for items_for_label in self.items_per_label.values()
        ]
        print("SAMPLES: ", number_of_samples_per_label)
        minimum_number_of_samples_per_label = min(number_of_samples_per_label)
        print("MIN: ", minimum_number_of_samples_per_label)
        label_with_minimum_number_of_samples = number_of_samples_per_label.index(
            minimum_number_of_samples_per_label
        )

        print("LABEL: ", label_with_minimum_number_of_samples)
        if self.n_shot + self.n_query > minimum_number_of_samples_per_label:
            raise ValueError(
                f"Label {label_with_minimum_number_of_samples} has only {minimum_number_of_samples_per_label} samples"
                f"but all classes must have at least n_shot + n_query ({self.n_shot + self.n_query}) samples."
            )


HAHA2


In [None]:
import random
from torch.utils.data import Dataset, Subset

def uniform_random_split(TRAIN_SET):
  from torch.utils.data import random_split

  # Assuming you have a dataset called TRAIN_SET
  total_samples = len(TRAIN_SET)

  # Create a list of classes and their respective counts
  class_counts = {c: 0 for _, c in TRAIN_SET}

  # Iterate over the TRAIN_SET and count the samples for each class
  for _, c in TRAIN_SET:
      class_counts[c] += 1

  # Determine the minimum number of samples per class in each set
  min_samples_per_class = 15

  # Calculate the number of samples needed in the validation set for each class
  valid_counts = {c: max(min_samples_per_class - count, 0) for c, count in class_counts.items()}

  # Split the TRAIN_SET into train_set and valid_set
  train_samples = []
  valid_samples = []
  for sample in TRAIN_SET:
      _, c = sample

      # Check if the sample is needed in the validation set for its class
      if valid_counts[c] > 0:
          valid_samples.append(sample)
          valid_counts[c] -= 1
      else:
          train_samples.append(sample)

  # Split the train samples and create train_set and valid_set
  tr_size = int(0.8 * total_samples)
  tst_size = total_samples - tr_size
  train_set, valid_set = random_split(train_samples, [tr_size, tst_size])

  # Concatenate the valid_set with the remaining validation samples
  valid_set = valid_set + valid_samples

  return train_set, valid_set

In [29]:
flag = 0
found = 0

In [31]:
for target_val in range(50, 100, 5):
  import os
  import shutil

  cntt =0
  summ = 0
  folder_A_path = '/content/data/gtsrb/GTSRB/Training'

  # Iterate over subdirectories in folder_A
  for root, dirs, files in os.walk(folder_A_path):
      for dir in dirs:
          folder_B_path = os.path.join(root, dir)
          ppm_file_count = sum(1 for file in os.listdir(folder_B_path) if file.endswith('.ppm'))
          print(f"Folder {folder_B_path} contains {ppm_file_count} .ppm files.")

          summ += ppm_file_count
          print("VAL: ", target_val)
          if ppm_file_count < target_val:
            cntt += 1
            shutil.rmtree(folder_B_path)

  trains_set = GTSRB(
      root="./data",
      # background=True,
      split="train",
      transform=transforms.Compose(
          [
              transforms.Grayscale(num_output_channels=3),
              transforms.RandomResizedCrop(image_size),
              transforms.RandomHorizontalFlip(),
              transforms.ToTensor(),
          ]
      ),
      download=False,
  )

  train_set_size = int(len(trains_set) * 0.8)
  val_set_size = len(trains_set) - train_set_size

  for vall in range(2):
    train_set, valid_set = torch.utils.data.random_split(trains_set, [train_set_size, val_set_size])
    train_set.get_labels = lambda: [instance[1] for instance in train_set]
    valid_set.get_labels = lambda: [instance[1] for instance in valid_set]

    train_dct: dict[int, list[int]] = {}
    for item, label in enumerate(train_set.get_labels()):
      if label in train_dct:
          train_dct[label].append(item)
      else:
          train_dct[label] = [item]
          
    number_of_samples_per_label = [
                len(items_for_label) for items_for_label in train_dct.values()
            ]
    minimum_number_of_samples_per_label = min(number_of_samples_per_label)
    print("MIN: ", minimum_number_of_samples_per_label)
    label_with_minimum_number_of_samples = number_of_samples_per_label.index(
        minimum_number_of_samples_per_label
    )

    print("LABEL: ", label_with_minimum_number_of_samples)
    if n_shot + n_query < minimum_number_of_samples_per_label:
      flag = 1

    valid_dct: dict[int, list[int]] = {}
    for item, label in enumerate(valid_set.get_labels()):
      if label in valid_dct:
          valid_dct[label].append(item)
      else:
          valid_dct[label] = [item]
          
    number_of_samples_per_label = [
                len(items_for_label) for items_for_label in valid_dct.values()
            ]
    minimum_number_of_samples_per_label = min(number_of_samples_per_label)
    print("MIN: ", minimum_number_of_samples_per_label)
    label_with_minimum_number_of_samples = number_of_samples_per_label.index(
        minimum_number_of_samples_per_label
    )

    print("LABEL: ", label_with_minimum_number_of_samples)
    if n_shot + n_query > minimum_number_of_samples_per_label:
      flag = 0
    
    print("Step 1: ", flag, found)
    if flag == 1:
      found = 1
      break
  print("Step 2: ", flag, found)
  if found == 1: break
  found = 0

Folder /content/data/gtsrb/GTSRB/Training/00013 contains 90 .ppm files.
Folder /content/data/gtsrb/GTSRB/Training/00037 contains 98 .ppm files.
Folder /content/data/gtsrb/GTSRB/Training/00045 contains 74 .ppm files.
Folder /content/data/gtsrb/GTSRB/Training/00053 contains 199 .ppm files.
Folder /content/data/gtsrb/GTSRB/Training/00056 contains 95 .ppm files.
Folder /content/data/gtsrb/GTSRB/Training/00022 contains 375 .ppm files.
Folder /content/data/gtsrb/GTSRB/Training/00041 contains 148 .ppm files.
Folder /content/data/gtsrb/GTSRB/Training/00028 contains 125 .ppm files.
Folder /content/data/gtsrb/GTSRB/Training/00040 contains 242 .ppm files.
Folder /content/data/gtsrb/GTSRB/Training/00001 contains 110 .ppm files.
Folder /content/data/gtsrb/GTSRB/Training/00017 contains 79 .ppm files.
Folder /content/data/gtsrb/GTSRB/Training/00019 contains 231 .ppm files.
Folder /content/data/gtsrb/GTSRB/Training/00038 contains 285 .ppm files.
Folder /content/data/gtsrb/GTSRB/Training/00018 contains

In [23]:
target_val

50

In [None]:
train_set.get_labels = lambda: [instance[1] for instance in train_set]

valid_set.get_labels = lambda: [instance[1] for instance in valid_set]

# Those are special batch samplers that sample few-shot classification tasks with a pre-defined shape
train_sampler = TaskSamplerr(
    train_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_tasks_per_epoch
)
print("Done Train")
val_sampler = TaskSamplerr(
    valid_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_validation_tasks
)

In [33]:
from torchvision.datasets import GTSRB
from easyfsl.samplers import TaskSampler
from torch.utils.data import DataLoader


n_tasks_per_epoch = 500
n_validation_tasks = 100
image_size = 28

# # Instantiate the datasets
# train_set = GTSRB(
#     root="./data",
#     # background=True,
#     split="train",
#     transform=transforms.Compose(
#         [
#             transforms.Grayscale(num_output_channels=3),
#             transforms.RandomResizedCrop(image_size),
#             transforms.RandomHorizontalFlip(),
#             transforms.ToTensor(),
#         ]
#     ),
#     download=False,
# )

# train_set_size = int(len(train_set) * 0.8)
# val_set_size = len(train_set) - train_set_size
# train_set, valid_set = torch.utils.data.random_split(train_set, [train_set_size, val_set_size])

# train_set, valid_set = uniform_random_split(train_set)

# train_set.get_labels = lambda: [instance[1] for instance in train_set]

# valid_set.get_labels = lambda: [instance[1] for instance in valid_set]

# # Those are special batch samplers that sample few-shot classification tasks with a pre-defined shape
train_sampler = TaskSampler(
    train_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_tasks_per_epoch
)
print("Done Train")
val_sampler = TaskSampler(
    valid_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_validation_tasks
)

# Finally, the DataLoader. We customize the collate_fn so that batches are delivered
# in the shape: (support_images, support_labels, query_images, query_labels, class_ids)
train_loader = DataLoader(
    train_set,
    batch_sampler=train_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=train_sampler.episodic_collate_fn,
)

val_loader = DataLoader(
    valid_set,
    batch_sampler=val_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=val_sampler.episodic_collate_fn,
)

Done Train




In [35]:
len(valid_set)

643

In [40]:
l_lst = []
for i in range(len(valid_set)):
  l_lst.append(valid_set[i][1])

In [41]:
import numpy as np
print(len(np.unique(l_lst)))

17


And then we define the network. Here I chose Prototypical Networks and the built-in ResNet18 from PyTorch because it's easy.

In [None]:
#Only used this encoder for RelationNetworks few shot algorithm

class CNNEncoder(nn.Module):
    """docstring for ClassName"""
    def __init__(self):
        super(CNNEncoder, self).__init__()
        self.layer1 = nn.Sequential(
                        nn.Conv2d(3,64,kernel_size=3,padding=0),
                        nn.BatchNorm2d(64, momentum=1, affine=True),
                        nn.ReLU(),
                        nn.MaxPool2d(2))
        self.layer2 = nn.Sequential(
                        nn.Conv2d(64,64,kernel_size=3,padding=0),
                        nn.BatchNorm2d(64, momentum=1, affine=True),
                        nn.ReLU(),
                        nn.MaxPool2d(2))
        self.layer3 = nn.Sequential(
                        nn.Conv2d(64,64,kernel_size=3,padding=1),
                        nn.BatchNorm2d(64, momentum=1, affine=True),
                        nn.ReLU())
        self.layer4 = nn.Sequential(
                        nn.Conv2d(64,64,kernel_size=3,padding=1),
                        nn.BatchNorm2d(64, momentum=1, affine=True),
                        nn.ReLU())

    def forward(self,x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        #out = out.view(out.size(0),-1)
        return out # 64

In [48]:
from easyfsl.methods import RelationNetworks, FewShotClassifier
from easyfsl.modules import resnet12


convolutional_network = CNNEncoder()
few_shot_classifier = RelationNetworks(convolutional_network, feature_dimension = 64).to(DEVICE)

Now let's define our training helpers ! I chose to use Stochastic Gradient Descent on 200 epochs with a scheduler that divides the learning rate by 10 after 120 and 160 epochs. The strategy is derived from [this repo](https://github.com/fiveai/on-episodes-fsl).

We're also gonna use a TensorBoard because it's always good to see what your training curves look like.

In [49]:
from torch.optim import SGD, Optimizer
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.tensorboard import SummaryWriter


LOSS_FUNCTION = nn.CrossEntropyLoss()

n_epochs = 200
scheduler_milestones = [120, 160]
scheduler_gamma = 0.1
learning_rate = 1e-2
tb_logs_dir = Path(".")

train_optimizer = SGD(
    few_shot_classifier.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4
)
train_scheduler = MultiStepLR(
    train_optimizer,
    milestones=scheduler_milestones,
    gamma=scheduler_gamma,
)

tb_writer = SummaryWriter(log_dir=str(tb_logs_dir))

And now let's get to it! Here we define the function that performs a training epoch.

We use tqdm to monitor the training in real time in our logs.

In [50]:
def training_epoch(
    model: FewShotClassifier, data_loader: DataLoader, optimizer: Optimizer
):
    all_loss = []
    model.train()
    with tqdm(
        enumerate(data_loader), total=len(data_loader), desc="Training"
    ) as tqdm_train:
        for episode_index, (
            support_images,
            support_labels,
            query_images,
            query_labels,
            _,
        ) in tqdm_train:
            optimizer.zero_grad()
            model.process_support_set(
                support_images.to(DEVICE), support_labels.to(DEVICE)
            )
            classification_scores = model(query_images.to(DEVICE))

            loss = LOSS_FUNCTION(classification_scores, query_labels.to(DEVICE))
            loss.backward()
            optimizer.step()

            all_loss.append(loss.item())

            tqdm_train.set_postfix(loss=mean(all_loss))

    return mean(all_loss)

And we have everything we need! To perform validations we'll just use the built-in `evaluate` function from `easyfsl.methods.utils`.

This is now the time to **start training**.

I added something to log the state of the model that gave the best performance on the validation set.

In [52]:
from easyfsl.utils import evaluate


best_state = few_shot_classifier.state_dict()
best_validation_accuracy = 0.0
for epoch in range(n_epochs):
    print(f"Epoch {epoch}")
    average_loss = training_epoch(few_shot_classifier, train_loader, train_optimizer)
    validation_accuracy = evaluate(
        few_shot_classifier, val_loader, device=DEVICE, tqdm_prefix="Validation"
    )

    if validation_accuracy > best_validation_accuracy:
        best_validation_accuracy = validation_accuracy
        best_state = few_shot_classifier.state_dict()
        print("Ding ding ding! We found a new best model!")

    tb_writer.add_scalar("Train/loss", average_loss, epoch)
    tb_writer.add_scalar("Val/acc", validation_accuracy, epoch)

    # Warn the scheduler that we did an epoch
    # so it knows when to decrease the learning rate
    train_scheduler.step()

Epoch 0


Training: 100%|██████████| 500/500 [00:34<00:00, 14.56it/s, loss=1.22]
Validation: 100%|██████████| 100/100 [00:05<00:00, 17.64it/s, accuracy=0.731]

Ding ding ding! We found a new best model!
Epoch 1



Training: 100%|██████████| 500/500 [00:32<00:00, 15.19it/s, loss=1.13]
Validation: 100%|██████████| 100/100 [00:06<00:00, 15.95it/s, accuracy=0.77]

Ding ding ding! We found a new best model!
Epoch 2



Training: 100%|██████████| 500/500 [00:32<00:00, 15.44it/s, loss=1.1]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.69it/s, accuracy=0.805]

Ding ding ding! We found a new best model!
Epoch 3



Training: 100%|██████████| 500/500 [00:32<00:00, 15.48it/s, loss=1.08]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.99it/s, accuracy=0.774]

Epoch 4



Training: 100%|██████████| 500/500 [00:35<00:00, 14.14it/s, loss=1.06]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.95it/s, accuracy=0.822]

Ding ding ding! We found a new best model!
Epoch 5



Training: 100%|██████████| 500/500 [00:32<00:00, 15.33it/s, loss=1.05]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.93it/s, accuracy=0.853]


Ding ding ding! We found a new best model!
Epoch 6


Training: 100%|██████████| 500/500 [00:32<00:00, 15.22it/s, loss=1.04]
Validation: 100%|██████████| 100/100 [00:06<00:00, 14.29it/s, accuracy=0.847]

Epoch 7



Training: 100%|██████████| 500/500 [00:33<00:00, 14.98it/s, loss=1.03]
Validation: 100%|██████████| 100/100 [00:06<00:00, 15.01it/s, accuracy=0.871]

Ding ding ding! We found a new best model!
Epoch 8



Training: 100%|██████████| 500/500 [00:32<00:00, 15.32it/s, loss=1.03]
Validation: 100%|██████████| 100/100 [00:06<00:00, 16.12it/s, accuracy=0.869]

Epoch 9



Training: 100%|██████████| 500/500 [00:33<00:00, 15.04it/s, loss=1.02]
Validation: 100%|██████████| 100/100 [00:05<00:00, 18.41it/s, accuracy=0.884]


Ding ding ding! We found a new best model!
Epoch 10


Training: 100%|██████████| 500/500 [00:34<00:00, 14.51it/s, loss=1.02]
Validation: 100%|██████████| 100/100 [00:05<00:00, 18.04it/s, accuracy=0.886]

Ding ding ding! We found a new best model!
Epoch 11



Training: 100%|██████████| 500/500 [00:33<00:00, 14.78it/s, loss=1.01]
Validation: 100%|██████████| 100/100 [00:05<00:00, 17.36it/s, accuracy=0.865]

Epoch 12



Training: 100%|██████████| 500/500 [00:33<00:00, 15.02it/s, loss=1.01]
Validation: 100%|██████████| 100/100 [00:06<00:00, 16.63it/s, accuracy=0.885]

Epoch 13



Training: 100%|██████████| 500/500 [00:32<00:00, 15.17it/s, loss=1]
Validation: 100%|██████████| 100/100 [00:06<00:00, 16.15it/s, accuracy=0.882]


Epoch 14


Training: 100%|██████████| 500/500 [00:34<00:00, 14.36it/s, loss=1]
Validation: 100%|██████████| 100/100 [00:06<00:00, 14.72it/s, accuracy=0.892]


Ding ding ding! We found a new best model!
Epoch 15


Training: 100%|██████████| 500/500 [00:35<00:00, 14.14it/s, loss=1]
Validation: 100%|██████████| 100/100 [00:05<00:00, 16.99it/s, accuracy=0.895]


Ding ding ding! We found a new best model!
Epoch 16


Training: 100%|██████████| 500/500 [00:35<00:00, 14.21it/s, loss=0.998]
Validation: 100%|██████████| 100/100 [00:05<00:00, 17.85it/s, accuracy=0.878]

Epoch 17



Training: 100%|██████████| 500/500 [00:33<00:00, 14.72it/s, loss=0.994]
Validation: 100%|██████████| 100/100 [00:06<00:00, 15.61it/s, accuracy=0.903]

Ding ding ding! We found a new best model!
Epoch 18



Training: 100%|██████████| 500/500 [00:35<00:00, 14.08it/s, loss=0.995]
Validation: 100%|██████████| 100/100 [00:06<00:00, 16.15it/s, accuracy=0.891]

Epoch 19



Training: 100%|██████████| 500/500 [00:36<00:00, 13.55it/s, loss=0.987]
Validation: 100%|██████████| 100/100 [00:06<00:00, 16.51it/s, accuracy=0.903]

Ding ding ding! We found a new best model!
Epoch 20



Training: 100%|██████████| 500/500 [00:35<00:00, 13.93it/s, loss=0.991]
Validation: 100%|██████████| 100/100 [00:06<00:00, 15.91it/s, accuracy=0.901]


Epoch 21


Training: 100%|██████████| 500/500 [00:36<00:00, 13.66it/s, loss=0.983]
Validation: 100%|██████████| 100/100 [00:06<00:00, 15.96it/s, accuracy=0.901]

Epoch 22



Training: 100%|██████████| 500/500 [00:34<00:00, 14.43it/s, loss=0.985]
Validation: 100%|██████████| 100/100 [00:05<00:00, 16.96it/s, accuracy=0.908]


Ding ding ding! We found a new best model!
Epoch 23


Training: 100%|██████████| 500/500 [00:34<00:00, 14.29it/s, loss=0.987]
Validation: 100%|██████████| 100/100 [00:05<00:00, 17.39it/s, accuracy=0.917]


Ding ding ding! We found a new best model!
Epoch 24


Training: 100%|██████████| 500/500 [00:34<00:00, 14.39it/s, loss=0.987]
Validation: 100%|██████████| 100/100 [00:05<00:00, 17.85it/s, accuracy=0.917]

Ding ding ding! We found a new best model!
Epoch 25



Training: 100%|██████████| 500/500 [00:33<00:00, 14.97it/s, loss=0.982]
Validation: 100%|██████████| 100/100 [00:06<00:00, 16.58it/s, accuracy=0.905]

Epoch 26



Training: 100%|██████████| 500/500 [00:33<00:00, 14.74it/s, loss=0.98]
Validation: 100%|██████████| 100/100 [00:06<00:00, 16.54it/s, accuracy=0.89]

Epoch 27



Training: 100%|██████████| 500/500 [00:34<00:00, 14.64it/s, loss=0.982]
Validation: 100%|██████████| 100/100 [00:05<00:00, 16.67it/s, accuracy=0.9]


Epoch 28


Training: 100%|██████████| 500/500 [00:34<00:00, 14.67it/s, loss=0.979]
Validation: 100%|██████████| 100/100 [00:06<00:00, 16.33it/s, accuracy=0.921]


Ding ding ding! We found a new best model!
Epoch 29


Training: 100%|██████████| 500/500 [00:35<00:00, 14.23it/s, loss=0.975]
Validation: 100%|██████████| 100/100 [00:06<00:00, 15.30it/s, accuracy=0.921]

Epoch 30



Training: 100%|██████████| 500/500 [00:34<00:00, 14.31it/s, loss=0.977]
Validation: 100%|██████████| 100/100 [00:05<00:00, 16.85it/s, accuracy=0.922]

Ding ding ding! We found a new best model!
Epoch 31



Training: 100%|██████████| 500/500 [00:38<00:00, 12.85it/s, loss=0.977]
Validation: 100%|██████████| 100/100 [00:05<00:00, 17.18it/s, accuracy=0.931]

Ding ding ding! We found a new best model!
Epoch 32



Training: 100%|██████████| 500/500 [00:35<00:00, 14.24it/s, loss=0.973]
Validation: 100%|██████████| 100/100 [00:05<00:00, 17.49it/s, accuracy=0.906]

Epoch 33



Training: 100%|██████████| 500/500 [00:35<00:00, 14.28it/s, loss=0.972]
Validation: 100%|██████████| 100/100 [00:05<00:00, 17.82it/s, accuracy=0.921]

Epoch 34



Training: 100%|██████████| 500/500 [00:35<00:00, 14.28it/s, loss=0.971]
Validation: 100%|██████████| 100/100 [00:05<00:00, 17.51it/s, accuracy=0.911]


Epoch 35


Training: 100%|██████████| 500/500 [00:34<00:00, 14.49it/s, loss=0.976]
Validation: 100%|██████████| 100/100 [00:05<00:00, 17.88it/s, accuracy=0.891]


Epoch 36


Training: 100%|██████████| 500/500 [00:34<00:00, 14.29it/s, loss=0.972]
Validation: 100%|██████████| 100/100 [00:05<00:00, 17.87it/s, accuracy=0.913]


Epoch 37


Training: 100%|██████████| 500/500 [00:34<00:00, 14.67it/s, loss=0.972]
Validation: 100%|██████████| 100/100 [00:05<00:00, 16.91it/s, accuracy=0.929]


Epoch 38


Training: 100%|██████████| 500/500 [00:33<00:00, 15.02it/s, loss=0.97]
Validation: 100%|██████████| 100/100 [00:06<00:00, 15.76it/s, accuracy=0.907]


Epoch 39


Training: 100%|██████████| 500/500 [00:33<00:00, 14.98it/s, loss=0.968]
Validation: 100%|██████████| 100/100 [00:06<00:00, 16.49it/s, accuracy=0.922]

Epoch 40



Training: 100%|██████████| 500/500 [00:33<00:00, 14.95it/s, loss=0.97]
Validation: 100%|██████████| 100/100 [00:06<00:00, 16.37it/s, accuracy=0.883]


Epoch 41


Training: 100%|██████████| 500/500 [00:32<00:00, 15.20it/s, loss=0.971]
Validation: 100%|██████████| 100/100 [00:06<00:00, 15.43it/s, accuracy=0.92]

Epoch 42



Training: 100%|██████████| 500/500 [00:32<00:00, 15.22it/s, loss=0.965]
Validation: 100%|██████████| 100/100 [00:06<00:00, 15.73it/s, accuracy=0.914]

Epoch 43



Training: 100%|██████████| 500/500 [00:32<00:00, 15.35it/s, loss=0.967]
Validation: 100%|██████████| 100/100 [00:07<00:00, 14.24it/s, accuracy=0.925]


Epoch 44


Training: 100%|██████████| 500/500 [00:32<00:00, 15.29it/s, loss=0.967]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.97it/s, accuracy=0.923]


Epoch 45


Training: 100%|██████████| 500/500 [00:32<00:00, 15.30it/s, loss=0.966]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.97it/s, accuracy=0.919]

Epoch 46



Training: 100%|██████████| 500/500 [00:33<00:00, 15.08it/s, loss=0.968]
Validation: 100%|██████████| 100/100 [00:07<00:00, 14.21it/s, accuracy=0.928]

Epoch 47



Training: 100%|██████████| 500/500 [00:32<00:00, 15.34it/s, loss=0.965]
Validation: 100%|██████████| 100/100 [00:06<00:00, 16.25it/s, accuracy=0.928]

Epoch 48



Training: 100%|██████████| 500/500 [00:32<00:00, 15.20it/s, loss=0.963]
Validation: 100%|██████████| 100/100 [00:05<00:00, 18.28it/s, accuracy=0.931]

Ding ding ding! We found a new best model!
Epoch 49



Training: 100%|██████████| 500/500 [00:33<00:00, 14.93it/s, loss=0.963]
Validation: 100%|██████████| 100/100 [00:05<00:00, 17.92it/s, accuracy=0.915]

Epoch 50



Training: 100%|██████████| 500/500 [00:34<00:00, 14.63it/s, loss=0.966]
Validation: 100%|██████████| 100/100 [00:05<00:00, 17.69it/s, accuracy=0.924]


Epoch 51


Training: 100%|██████████| 500/500 [00:33<00:00, 14.72it/s, loss=0.966]
Validation: 100%|██████████| 100/100 [00:05<00:00, 17.50it/s, accuracy=0.941]


Ding ding ding! We found a new best model!
Epoch 52


Training: 100%|██████████| 500/500 [00:33<00:00, 14.82it/s, loss=0.963]
Validation: 100%|██████████| 100/100 [00:05<00:00, 17.76it/s, accuracy=0.925]

Epoch 53



Training: 100%|██████████| 500/500 [00:32<00:00, 15.33it/s, loss=0.964]
Validation: 100%|██████████| 100/100 [00:06<00:00, 16.14it/s, accuracy=0.928]

Epoch 54



Training: 100%|██████████| 500/500 [00:32<00:00, 15.29it/s, loss=0.966]
Validation: 100%|██████████| 100/100 [00:06<00:00, 15.35it/s, accuracy=0.936]

Epoch 55



Training: 100%|██████████| 500/500 [00:32<00:00, 15.32it/s, loss=0.963]
Validation: 100%|██████████| 100/100 [00:07<00:00, 14.06it/s, accuracy=0.939]


Epoch 56


Training: 100%|██████████| 500/500 [00:32<00:00, 15.35it/s, loss=0.966]
Validation: 100%|██████████| 100/100 [00:07<00:00, 14.04it/s, accuracy=0.889]

Epoch 57



Training: 100%|██████████| 500/500 [00:32<00:00, 15.42it/s, loss=0.964]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.76it/s, accuracy=0.926]


Epoch 58


Training: 100%|██████████| 500/500 [00:32<00:00, 15.23it/s, loss=0.964]
Validation: 100%|██████████| 100/100 [00:06<00:00, 16.52it/s, accuracy=0.938]


Epoch 59


Training: 100%|██████████| 500/500 [00:32<00:00, 15.32it/s, loss=0.961]
Validation: 100%|██████████| 100/100 [00:05<00:00, 18.08it/s, accuracy=0.939]

Epoch 60



Training: 100%|██████████| 500/500 [00:32<00:00, 15.22it/s, loss=0.961]
Validation: 100%|██████████| 100/100 [00:05<00:00, 17.14it/s, accuracy=0.93]

Epoch 61



Training: 100%|██████████| 500/500 [00:34<00:00, 14.63it/s, loss=0.963]
Validation: 100%|██████████| 100/100 [00:05<00:00, 18.07it/s, accuracy=0.905]


Epoch 62


Training: 100%|██████████| 500/500 [00:35<00:00, 14.10it/s, loss=0.964]
Validation: 100%|██████████| 100/100 [00:05<00:00, 17.47it/s, accuracy=0.908]

Epoch 63



Training: 100%|██████████| 500/500 [00:34<00:00, 14.56it/s, loss=0.961]
Validation: 100%|██████████| 100/100 [00:06<00:00, 16.63it/s, accuracy=0.934]


Epoch 64


Training: 100%|██████████| 500/500 [00:34<00:00, 14.38it/s, loss=0.959]
Validation: 100%|██████████| 100/100 [00:05<00:00, 18.01it/s, accuracy=0.916]

Epoch 65



Training: 100%|██████████| 500/500 [00:32<00:00, 15.19it/s, loss=0.961]
Validation: 100%|██████████| 100/100 [00:06<00:00, 16.65it/s, accuracy=0.915]


Epoch 66


Training: 100%|██████████| 500/500 [00:32<00:00, 15.33it/s, loss=0.96]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.87it/s, accuracy=0.94]

Epoch 67



Training: 100%|██████████| 500/500 [00:32<00:00, 15.18it/s, loss=0.96]
Validation: 100%|██████████| 100/100 [00:07<00:00, 14.22it/s, accuracy=0.911]

Epoch 68



Training: 100%|██████████| 500/500 [00:33<00:00, 15.02it/s, loss=0.959]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.35it/s, accuracy=0.907]


Epoch 69


Training:  22%|██▏       | 109/500 [00:08<00:49,  7.93it/s, loss=0.965]Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f31502e1360><function _MultiProcessingDataLoaderIter.__del__ at 0x7f31502e1360>
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f31502e1360>
Traceback (most recent call last):

Exception ignored in: Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1478, in __del__
Traceback (most recent call last):
<function _MultiProcessingDataLoaderIter.__del__ at 0x7f31502e1360>Exception ignored in:   File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1478, in __del__
<function _MultiProcessingDataLoaderIter.__del__ at 0x7f31502e1360>    self._shutdown_workers()Traceback (most recent call last):

Epoch 70


Training:  17%|█▋        | 84/500 [00:07<00:32, 12.71it/s, loss=0.959]Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f31502e1360>
Exception ignored in: Traceback (most recent call last):
<function _MultiProcessingDataLoaderIter.__del__ at 0x7f31502e1360>  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1478, in __del__

    Traceback (most recent call last):
self._shutdown_workers()  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1478, in __del__

      File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
self._shutdown_workers()    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive

    assert self._parent_pid == os.getpid(), 'can only test a child process'
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
Exception igno

Epoch 71



Training:   7%|▋         | 35/500 [00:03<00:29, 15.72it/s, loss=0.958]Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f31502e1360>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
    if w.is_alive():
Exception ignored in:   File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
<function _MultiProcessingDataLoaderIter.__del__ at 0x7f31502e1360>    assert self._parent_pid == os.getpid(), 'can only test a child process'Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f31502e1360>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    

self._shutdown_workers()AssertionError
Traceback (most recent

Epoch 72



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f31502e1360>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1478, in __del__
Training:   0%|          | 0/500 [00:00<?, ?it/s]
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f31502e1360>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
    if w.is_aliv

Epoch 73



Training: 100%|██████████| 500/500 [00:36<00:00, 13.71it/s, loss=0.959]
Validation: 100%|██████████| 100/100 [00:06<00:00, 16.22it/s, accuracy=0.93]


Epoch 74


Training: 100%|██████████| 500/500 [00:36<00:00, 13.57it/s, loss=0.956]
Validation: 100%|██████████| 100/100 [00:06<00:00, 15.62it/s, accuracy=0.926]


Epoch 75


Training: 100%|██████████| 500/500 [00:36<00:00, 13.57it/s, loss=0.96]
Validation: 100%|██████████| 100/100 [00:06<00:00, 16.25it/s, accuracy=0.932]


Epoch 76


Training: 100%|██████████| 500/500 [00:37<00:00, 13.49it/s, loss=0.958]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.64it/s, accuracy=0.936]

Epoch 77



Training: 100%|██████████| 500/500 [00:35<00:00, 14.09it/s, loss=0.96]
Validation: 100%|██████████| 100/100 [00:07<00:00, 12.94it/s, accuracy=0.93]

Epoch 78



Training: 100%|██████████| 500/500 [00:35<00:00, 14.17it/s, loss=0.956]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.52it/s, accuracy=0.925]

Epoch 79



Training: 100%|██████████| 500/500 [00:33<00:00, 15.02it/s, loss=0.958]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.18it/s, accuracy=0.906]


Epoch 80


Training: 100%|██████████| 500/500 [00:34<00:00, 14.36it/s, loss=0.959]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.16it/s, accuracy=0.912]

Epoch 81



Training: 100%|██████████| 500/500 [00:33<00:00, 15.14it/s, loss=0.954]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.52it/s, accuracy=0.93]

Epoch 82



Training: 100%|██████████| 500/500 [00:34<00:00, 14.55it/s, loss=0.96]
Validation: 100%|██████████| 100/100 [00:07<00:00, 12.61it/s, accuracy=0.915]

Epoch 83



Training: 100%|██████████| 500/500 [00:34<00:00, 14.49it/s, loss=0.956]
Validation: 100%|██████████| 100/100 [00:08<00:00, 12.50it/s, accuracy=0.917]


Epoch 84


Training: 100%|██████████| 500/500 [00:35<00:00, 13.95it/s, loss=0.957]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.08it/s, accuracy=0.869]


Epoch 85


Training: 100%|██████████| 500/500 [00:35<00:00, 14.26it/s, loss=0.958]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.82it/s, accuracy=0.929]

Epoch 86



Training: 100%|██████████| 500/500 [00:35<00:00, 14.00it/s, loss=0.957]
Validation: 100%|██████████| 100/100 [00:06<00:00, 16.01it/s, accuracy=0.921]

Epoch 87



Training: 100%|██████████| 500/500 [00:35<00:00, 14.13it/s, loss=0.956]
Validation: 100%|██████████| 100/100 [00:05<00:00, 16.83it/s, accuracy=0.902]


Epoch 88


Training: 100%|██████████| 500/500 [00:35<00:00, 13.90it/s, loss=0.958]
Validation: 100%|██████████| 100/100 [00:06<00:00, 16.54it/s, accuracy=0.92]

Epoch 89



Training: 100%|██████████| 500/500 [00:35<00:00, 13.96it/s, loss=0.959]
Validation: 100%|██████████| 100/100 [00:06<00:00, 16.29it/s, accuracy=0.915]


Epoch 90


Training: 100%|██████████| 500/500 [00:35<00:00, 13.92it/s, loss=0.957]
Validation: 100%|██████████| 100/100 [00:06<00:00, 16.56it/s, accuracy=0.921]


Epoch 91


Training: 100%|██████████| 500/500 [00:37<00:00, 13.47it/s, loss=0.955]
Validation: 100%|██████████| 100/100 [00:05<00:00, 16.74it/s, accuracy=0.904]


Epoch 92


Training: 100%|██████████| 500/500 [00:35<00:00, 13.99it/s, loss=0.958]
Validation: 100%|██████████| 100/100 [00:05<00:00, 16.71it/s, accuracy=0.909]

Epoch 93



Training: 100%|██████████| 500/500 [00:35<00:00, 13.91it/s, loss=0.956]
Validation: 100%|██████████| 100/100 [00:06<00:00, 16.39it/s, accuracy=0.924]

Epoch 94



Training: 100%|██████████| 500/500 [00:35<00:00, 13.99it/s, loss=0.955]
Validation: 100%|██████████| 100/100 [00:06<00:00, 15.68it/s, accuracy=0.916]

Epoch 95



Training: 100%|██████████| 500/500 [00:34<00:00, 14.33it/s, loss=0.958]
Validation: 100%|██████████| 100/100 [00:06<00:00, 14.97it/s, accuracy=0.919]

Epoch 96



Training: 100%|██████████| 500/500 [00:34<00:00, 14.49it/s, loss=0.956]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.59it/s, accuracy=0.892]

Epoch 97



Training: 100%|██████████| 500/500 [00:35<00:00, 14.24it/s, loss=0.958]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.44it/s, accuracy=0.935]

Epoch 98



Training: 100%|██████████| 500/500 [00:34<00:00, 14.51it/s, loss=0.959]
Validation: 100%|██████████| 100/100 [00:07<00:00, 12.94it/s, accuracy=0.922]

Epoch 99



Training: 100%|██████████| 500/500 [00:34<00:00, 14.32it/s, loss=0.955]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.26it/s, accuracy=0.857]

Epoch 100



Training: 100%|██████████| 500/500 [00:34<00:00, 14.44it/s, loss=0.956]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.01it/s, accuracy=0.926]

Epoch 101



Training: 100%|██████████| 500/500 [00:34<00:00, 14.50it/s, loss=0.956]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.30it/s, accuracy=0.91]

Epoch 102



Training: 100%|██████████| 500/500 [00:34<00:00, 14.33it/s, loss=0.956]
Validation: 100%|██████████| 100/100 [00:06<00:00, 15.75it/s, accuracy=0.916]

Epoch 103



Training: 100%|██████████| 500/500 [00:35<00:00, 14.17it/s, loss=0.956]
Validation: 100%|██████████| 100/100 [00:06<00:00, 16.19it/s, accuracy=0.911]

Epoch 104



Training: 100%|██████████| 500/500 [00:34<00:00, 14.29it/s, loss=0.953]
Validation: 100%|██████████| 100/100 [00:06<00:00, 16.61it/s, accuracy=0.931]

Epoch 105



Training: 100%|██████████| 500/500 [00:35<00:00, 14.15it/s, loss=0.954]
Validation: 100%|██████████| 100/100 [00:06<00:00, 16.29it/s, accuracy=0.924]

Epoch 106



Training: 100%|██████████| 500/500 [00:35<00:00, 13.89it/s, loss=0.958]
Validation: 100%|██████████| 100/100 [00:06<00:00, 16.64it/s, accuracy=0.92]


Epoch 107


Training: 100%|██████████| 500/500 [00:37<00:00, 13.37it/s, loss=0.959]
Validation: 100%|██████████| 100/100 [00:06<00:00, 15.88it/s, accuracy=0.901]

Epoch 108



Training: 100%|██████████| 500/500 [00:38<00:00, 13.13it/s, loss=0.953]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.15it/s, accuracy=0.914]

Epoch 109



Training: 100%|██████████| 500/500 [00:36<00:00, 13.63it/s, loss=0.955]
Validation: 100%|██████████| 100/100 [00:07<00:00, 12.52it/s, accuracy=0.918]

Epoch 110



Training: 100%|██████████| 500/500 [00:37<00:00, 13.38it/s, loss=0.953]
Validation: 100%|██████████| 100/100 [00:06<00:00, 15.64it/s, accuracy=0.935]

Epoch 111



Training: 100%|██████████| 500/500 [00:36<00:00, 13.56it/s, loss=0.957]
Validation: 100%|██████████| 100/100 [00:06<00:00, 16.45it/s, accuracy=0.916]


Epoch 112


Training: 100%|██████████| 500/500 [00:36<00:00, 13.81it/s, loss=0.952]
Validation: 100%|██████████| 100/100 [00:05<00:00, 16.74it/s, accuracy=0.889]


Epoch 113


Training: 100%|██████████| 500/500 [00:35<00:00, 13.91it/s, loss=0.954]
Validation: 100%|██████████| 100/100 [00:05<00:00, 16.77it/s, accuracy=0.915]


Epoch 114


Training: 100%|██████████| 500/500 [00:37<00:00, 13.18it/s, loss=0.955]
Validation: 100%|██████████| 100/100 [00:07<00:00, 14.25it/s, accuracy=0.872]

Epoch 115



Training: 100%|██████████| 500/500 [00:37<00:00, 13.40it/s, loss=0.956]
Validation: 100%|██████████| 100/100 [00:08<00:00, 12.44it/s, accuracy=0.898]

Epoch 116



Training: 100%|██████████| 500/500 [00:37<00:00, 13.43it/s, loss=0.955]
Validation: 100%|██████████| 100/100 [00:06<00:00, 14.42it/s, accuracy=0.913]


Epoch 117


Training: 100%|██████████| 500/500 [00:37<00:00, 13.39it/s, loss=0.953]
Validation: 100%|██████████| 100/100 [00:06<00:00, 16.03it/s, accuracy=0.874]

Epoch 118



Training: 100%|██████████| 500/500 [00:37<00:00, 13.32it/s, loss=0.954]
Validation: 100%|██████████| 100/100 [00:07<00:00, 14.28it/s, accuracy=0.908]

Epoch 119



Training: 100%|██████████| 500/500 [00:36<00:00, 13.71it/s, loss=0.958]
Validation: 100%|██████████| 100/100 [00:07<00:00, 12.88it/s, accuracy=0.914]

Epoch 120



Training: 100%|██████████| 500/500 [00:34<00:00, 14.29it/s, loss=0.943]
Validation: 100%|██████████| 100/100 [00:07<00:00, 12.62it/s, accuracy=0.928]

Epoch 121



Training: 100%|██████████| 500/500 [00:35<00:00, 14.14it/s, loss=0.938]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.31it/s, accuracy=0.933]

Epoch 122



Training: 100%|██████████| 500/500 [00:38<00:00, 12.99it/s, loss=0.937]
Validation: 100%|██████████| 100/100 [00:06<00:00, 14.98it/s, accuracy=0.931]


Epoch 123


Training: 100%|██████████| 500/500 [00:38<00:00, 12.94it/s, loss=0.936]
Validation: 100%|██████████| 100/100 [00:08<00:00, 12.37it/s, accuracy=0.934]

Epoch 124



Training: 100%|██████████| 500/500 [00:37<00:00, 13.37it/s, loss=0.936]
Validation: 100%|██████████| 100/100 [00:08<00:00, 11.88it/s, accuracy=0.942]


Ding ding ding! We found a new best model!
Epoch 125


Training: 100%|██████████| 500/500 [00:38<00:00, 12.94it/s, loss=0.936]
Validation: 100%|██████████| 100/100 [00:06<00:00, 15.37it/s, accuracy=0.937]

Epoch 126



Training: 100%|██████████| 500/500 [00:38<00:00, 12.91it/s, loss=0.935]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.45it/s, accuracy=0.948]

Ding ding ding! We found a new best model!
Epoch 127



Training: 100%|██████████| 500/500 [00:36<00:00, 13.52it/s, loss=0.933]
Validation: 100%|██████████| 100/100 [00:08<00:00, 12.38it/s, accuracy=0.945]


Epoch 128


Training: 100%|██████████| 500/500 [00:37<00:00, 13.45it/s, loss=0.934]
Validation: 100%|██████████| 100/100 [00:06<00:00, 14.97it/s, accuracy=0.953]

Ding ding ding! We found a new best model!
Epoch 129



Training: 100%|██████████| 500/500 [00:38<00:00, 13.13it/s, loss=0.932]
Validation: 100%|██████████| 100/100 [00:06<00:00, 16.23it/s, accuracy=0.953]


Ding ding ding! We found a new best model!
Epoch 130


Training: 100%|██████████| 500/500 [00:36<00:00, 13.54it/s, loss=0.933]
Validation: 100%|██████████| 100/100 [00:06<00:00, 14.71it/s, accuracy=0.923]

Epoch 131



Training: 100%|██████████| 500/500 [00:35<00:00, 14.14it/s, loss=0.932]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.89it/s, accuracy=0.95]

Epoch 132



Training: 100%|██████████| 500/500 [00:34<00:00, 14.43it/s, loss=0.933]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.55it/s, accuracy=0.946]

Epoch 133



Training: 100%|██████████| 500/500 [00:34<00:00, 14.43it/s, loss=0.932]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.43it/s, accuracy=0.947]

Epoch 134



Training: 100%|██████████| 500/500 [00:34<00:00, 14.40it/s, loss=0.931]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.30it/s, accuracy=0.933]

Epoch 135



Training: 100%|██████████| 500/500 [00:34<00:00, 14.35it/s, loss=0.933]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.14it/s, accuracy=0.956]


Ding ding ding! We found a new best model!
Epoch 136


Training: 100%|██████████| 500/500 [00:34<00:00, 14.61it/s, loss=0.932]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.47it/s, accuracy=0.952]


Epoch 137


Training: 100%|██████████| 500/500 [00:34<00:00, 14.48it/s, loss=0.932]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.13it/s, accuracy=0.928]


Epoch 138


Training: 100%|██████████| 500/500 [00:34<00:00, 14.53it/s, loss=0.931]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.46it/s, accuracy=0.948]

Epoch 139



Training: 100%|██████████| 500/500 [00:34<00:00, 14.43it/s, loss=0.931]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.18it/s, accuracy=0.947]


Epoch 140


Training: 100%|██████████| 500/500 [00:35<00:00, 14.19it/s, loss=0.93]
Validation: 100%|██████████| 100/100 [00:08<00:00, 12.37it/s, accuracy=0.922]


Epoch 141


Training: 100%|██████████| 500/500 [00:38<00:00, 12.90it/s, loss=0.93]
Validation: 100%|██████████| 100/100 [00:06<00:00, 15.21it/s, accuracy=0.952]


Epoch 142


Training: 100%|██████████| 500/500 [00:38<00:00, 12.88it/s, loss=0.931]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.97it/s, accuracy=0.945]

Epoch 143



Training: 100%|██████████| 500/500 [00:36<00:00, 13.57it/s, loss=0.929]
Validation: 100%|██████████| 100/100 [00:08<00:00, 12.20it/s, accuracy=0.946]


Epoch 144


Training: 100%|██████████| 500/500 [00:37<00:00, 13.17it/s, loss=0.93]
Validation: 100%|██████████| 100/100 [00:06<00:00, 15.17it/s, accuracy=0.942]


Epoch 145


Training: 100%|██████████| 500/500 [00:38<00:00, 12.90it/s, loss=0.93]
Validation: 100%|██████████| 100/100 [00:06<00:00, 15.05it/s, accuracy=0.947]

Epoch 146



Training: 100%|██████████| 500/500 [00:38<00:00, 13.06it/s, loss=0.93]
Validation: 100%|██████████| 100/100 [00:08<00:00, 12.45it/s, accuracy=0.957]

Ding ding ding! We found a new best model!
Epoch 147



Training: 100%|██████████| 500/500 [00:36<00:00, 13.56it/s, loss=0.93]
Validation: 100%|██████████| 100/100 [00:08<00:00, 11.90it/s, accuracy=0.952]

Epoch 148



Training: 100%|██████████| 500/500 [00:37<00:00, 13.33it/s, loss=0.931]
Validation: 100%|██████████| 100/100 [00:06<00:00, 16.27it/s, accuracy=0.937]

Epoch 149



Training: 100%|██████████| 500/500 [00:36<00:00, 13.59it/s, loss=0.93]
Validation: 100%|██████████| 100/100 [00:06<00:00, 16.08it/s, accuracy=0.939]


Epoch 150


Training: 100%|██████████| 500/500 [00:36<00:00, 13.56it/s, loss=0.929]
Validation: 100%|██████████| 100/100 [00:05<00:00, 16.94it/s, accuracy=0.951]

Epoch 151



Training: 100%|██████████| 500/500 [00:37<00:00, 13.45it/s, loss=0.928]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.29it/s, accuracy=0.916]

Epoch 152



Training: 100%|██████████| 500/500 [00:37<00:00, 13.47it/s, loss=0.928]
Validation: 100%|██████████| 100/100 [00:08<00:00, 12.10it/s, accuracy=0.95]


Epoch 153


Training: 100%|██████████| 500/500 [00:35<00:00, 13.89it/s, loss=0.929]
Validation: 100%|██████████| 100/100 [00:06<00:00, 14.47it/s, accuracy=0.953]


Epoch 154


Training: 100%|██████████| 500/500 [00:37<00:00, 13.39it/s, loss=0.929]
Validation: 100%|██████████| 100/100 [00:06<00:00, 16.07it/s, accuracy=0.959]

Ding ding ding! We found a new best model!
Epoch 155



Training: 100%|██████████| 500/500 [00:36<00:00, 13.57it/s, loss=0.929]
Validation: 100%|██████████| 100/100 [00:06<00:00, 16.16it/s, accuracy=0.946]

Epoch 156



Training: 100%|██████████| 500/500 [00:36<00:00, 13.66it/s, loss=0.928]
Validation: 100%|██████████| 100/100 [00:06<00:00, 16.03it/s, accuracy=0.937]

Epoch 157



Training: 100%|██████████| 500/500 [00:37<00:00, 13.41it/s, loss=0.928]
Validation: 100%|██████████| 100/100 [00:06<00:00, 14.79it/s, accuracy=0.951]

Epoch 158



Training: 100%|██████████| 500/500 [00:37<00:00, 13.35it/s, loss=0.927]
Validation: 100%|██████████| 100/100 [00:08<00:00, 12.24it/s, accuracy=0.934]

Epoch 159



Training: 100%|██████████| 500/500 [00:37<00:00, 13.17it/s, loss=0.927]
Validation: 100%|██████████| 100/100 [00:06<00:00, 14.85it/s, accuracy=0.947]

Epoch 160



Training: 100%|██████████| 500/500 [00:38<00:00, 12.85it/s, loss=0.928]
Validation: 100%|██████████| 100/100 [00:06<00:00, 15.16it/s, accuracy=0.95]


Epoch 161


Training: 100%|██████████| 500/500 [00:37<00:00, 13.39it/s, loss=0.927]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.05it/s, accuracy=0.936]

Epoch 162



Training: 100%|██████████| 500/500 [00:35<00:00, 14.04it/s, loss=0.926]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.10it/s, accuracy=0.945]


Epoch 163


Training: 100%|██████████| 500/500 [00:35<00:00, 14.21it/s, loss=0.928]
Validation: 100%|██████████| 100/100 [00:07<00:00, 12.99it/s, accuracy=0.935]

Epoch 164



Training: 100%|██████████| 500/500 [00:34<00:00, 14.33it/s, loss=0.928]
Validation: 100%|██████████| 100/100 [00:07<00:00, 12.60it/s, accuracy=0.951]

Epoch 165



Training: 100%|██████████| 500/500 [00:37<00:00, 13.34it/s, loss=0.926]
Validation: 100%|██████████| 100/100 [00:06<00:00, 15.04it/s, accuracy=0.954]


Epoch 166


Training: 100%|██████████| 500/500 [00:38<00:00, 12.95it/s, loss=0.926]
Validation: 100%|██████████| 100/100 [00:06<00:00, 15.05it/s, accuracy=0.956]

Epoch 167



Training: 100%|██████████| 500/500 [00:37<00:00, 13.18it/s, loss=0.928]
Validation: 100%|██████████| 100/100 [00:08<00:00, 12.22it/s, accuracy=0.937]

Epoch 168



Training: 100%|██████████| 500/500 [00:37<00:00, 13.50it/s, loss=0.928]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.69it/s, accuracy=0.954]


Epoch 169


Training: 100%|██████████| 500/500 [00:37<00:00, 13.49it/s, loss=0.926]
Validation: 100%|██████████| 100/100 [00:06<00:00, 15.95it/s, accuracy=0.953]


Epoch 170


Training: 100%|██████████| 500/500 [00:37<00:00, 13.36it/s, loss=0.928]
Validation: 100%|██████████| 100/100 [00:06<00:00, 15.74it/s, accuracy=0.956]


Epoch 171


Training: 100%|██████████| 500/500 [00:37<00:00, 13.50it/s, loss=0.926]
Validation: 100%|██████████| 100/100 [00:06<00:00, 16.64it/s, accuracy=0.953]


Epoch 172


Training: 100%|██████████| 500/500 [00:36<00:00, 13.64it/s, loss=0.926]
Validation: 100%|██████████| 100/100 [00:06<00:00, 15.52it/s, accuracy=0.935]

Epoch 173



Training: 100%|██████████| 500/500 [00:36<00:00, 13.67it/s, loss=0.927]
Validation: 100%|██████████| 100/100 [00:07<00:00, 12.88it/s, accuracy=0.951]


Epoch 174


Training: 100%|██████████| 500/500 [00:36<00:00, 13.54it/s, loss=0.926]
Validation: 100%|██████████| 100/100 [00:08<00:00, 12.06it/s, accuracy=0.941]

Epoch 175



Training: 100%|██████████| 500/500 [00:38<00:00, 12.92it/s, loss=0.926]
Validation: 100%|██████████| 100/100 [00:06<00:00, 15.25it/s, accuracy=0.952]


Epoch 176


Training: 100%|██████████| 500/500 [00:38<00:00, 12.83it/s, loss=0.927]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.09it/s, accuracy=0.952]

Epoch 177



Training: 100%|██████████| 500/500 [00:37<00:00, 13.38it/s, loss=0.927]
Validation: 100%|██████████| 100/100 [00:08<00:00, 12.15it/s, accuracy=0.96]


Ding ding ding! We found a new best model!
Epoch 178


Training: 100%|██████████| 500/500 [00:39<00:00, 12.75it/s, loss=0.925]
Validation: 100%|██████████| 100/100 [00:06<00:00, 15.43it/s, accuracy=0.948]


Epoch 179


Training: 100%|██████████| 500/500 [00:38<00:00, 12.83it/s, loss=0.926]
Validation: 100%|██████████| 100/100 [00:06<00:00, 14.85it/s, accuracy=0.946]

Epoch 180



Training: 100%|██████████| 500/500 [00:37<00:00, 13.51it/s, loss=0.927]
Validation: 100%|██████████| 100/100 [00:07<00:00, 12.65it/s, accuracy=0.954]

Epoch 181



Training: 100%|██████████| 500/500 [00:35<00:00, 13.93it/s, loss=0.927]
Validation: 100%|██████████| 100/100 [00:08<00:00, 12.46it/s, accuracy=0.94]


Epoch 182


Training: 100%|██████████| 500/500 [00:35<00:00, 13.90it/s, loss=0.926]
Validation: 100%|██████████| 100/100 [00:06<00:00, 15.26it/s, accuracy=0.956]


Epoch 183


Training: 100%|██████████| 500/500 [00:36<00:00, 13.57it/s, loss=0.927]
Validation: 100%|██████████| 100/100 [00:06<00:00, 16.37it/s, accuracy=0.941]


Epoch 184


Training: 100%|██████████| 500/500 [00:37<00:00, 13.43it/s, loss=0.927]
Validation: 100%|██████████| 100/100 [00:06<00:00, 15.77it/s, accuracy=0.946]


Epoch 185


Training: 100%|██████████| 500/500 [00:38<00:00, 12.98it/s, loss=0.926]
Validation: 100%|██████████| 100/100 [00:06<00:00, 15.01it/s, accuracy=0.94]

Epoch 186



Training: 100%|██████████| 500/500 [00:37<00:00, 13.31it/s, loss=0.926]
Validation: 100%|██████████| 100/100 [00:08<00:00, 12.38it/s, accuracy=0.946]

Epoch 187



Training: 100%|██████████| 500/500 [00:36<00:00, 13.79it/s, loss=0.926]
Validation: 100%|██████████| 100/100 [00:07<00:00, 12.50it/s, accuracy=0.943]


Epoch 188


Training: 100%|██████████| 500/500 [00:39<00:00, 12.72it/s, loss=0.927]
Validation: 100%|██████████| 100/100 [00:06<00:00, 15.28it/s, accuracy=0.953]

Epoch 189



Training: 100%|██████████| 500/500 [00:39<00:00, 12.81it/s, loss=0.926]
Validation: 100%|██████████| 100/100 [00:07<00:00, 14.10it/s, accuracy=0.943]

Epoch 190



Training: 100%|██████████| 500/500 [00:38<00:00, 13.05it/s, loss=0.928]
Validation: 100%|██████████| 100/100 [00:08<00:00, 12.15it/s, accuracy=0.917]

Epoch 191



Training: 100%|██████████| 500/500 [00:39<00:00, 12.53it/s, loss=0.925]
Validation: 100%|██████████| 100/100 [00:06<00:00, 14.45it/s, accuracy=0.946]

Epoch 192



Training: 100%|██████████| 500/500 [00:39<00:00, 12.64it/s, loss=0.925]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.68it/s, accuracy=0.94]

Epoch 193



Training: 100%|██████████| 500/500 [00:38<00:00, 13.02it/s, loss=0.924]
Validation: 100%|██████████| 100/100 [00:08<00:00, 11.85it/s, accuracy=0.953]

Epoch 194



Training: 100%|██████████| 500/500 [00:38<00:00, 13.04it/s, loss=0.927]
Validation: 100%|██████████| 100/100 [00:06<00:00, 14.54it/s, accuracy=0.956]

Epoch 195



Training: 100%|██████████| 500/500 [00:39<00:00, 12.64it/s, loss=0.926]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.63it/s, accuracy=0.951]

Epoch 196



Training: 100%|██████████| 500/500 [00:38<00:00, 13.04it/s, loss=0.925]
Validation: 100%|██████████| 100/100 [00:08<00:00, 11.93it/s, accuracy=0.952]


Epoch 197


Training: 100%|██████████| 500/500 [00:39<00:00, 12.60it/s, loss=0.926]
Validation: 100%|██████████| 100/100 [00:06<00:00, 14.91it/s, accuracy=0.958]

Epoch 198



Training: 100%|██████████| 500/500 [00:39<00:00, 12.51it/s, loss=0.926]
Validation: 100%|██████████| 100/100 [00:07<00:00, 13.06it/s, accuracy=0.934]

Epoch 199



Training: 100%|██████████| 500/500 [00:37<00:00, 13.44it/s, loss=0.927]
Validation: 100%|██████████| 100/100 [00:08<00:00, 12.27it/s, accuracy=0.941]


Yay we successfully performed Episodic Training! Now if you want to you can retrieve the best model's state.

In [53]:
few_shot_classifier.load_state_dict(best_state)

<All keys matched successfully>

## Evaluation

Now that our model is trained, we want to test it.

First step: we fetch the test data.

In [54]:
import os

old_folder_name = '/content/data/gtsrb/GTSRB/Training'
new_folder_name = '/content/data/gtsrb/GTSRB/Trainingggg'

# Rename the folder
os.rename(old_folder_name, new_folder_name)

old_folder_name = '/content/data/gtsrb/GTSRB/Trainingg'
new_folder_name = '/content/data/gtsrb/GTSRB/Training'

# Rename the folder
os.rename(old_folder_name, new_folder_name)

# For n_test_tasks = 1000

In [55]:
n_test_tasks = 1000

test_set = GTSRB(
    root="./data",
    # background=False,
    split="train",
    transform=transforms.Compose(
        [
            transforms.Grayscale(num_output_channels=3),
            transforms.Resize([int(image_size * 1.15), int(image_size * 1.15)]),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
        ]
    ),
    download=False,
)

test_set.get_labels = lambda: [instance[1] for instance in test_set]

test_sampler = TaskSampler(
    test_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_test_tasks
)

test_loader = DataLoader(
    test_set,
    batch_sampler=test_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=test_sampler.episodic_collate_fn,
)

Second step: we run the few-shot classifier on the test data.

In [56]:
accuracy = evaluate(few_shot_classifier, test_loader, device=DEVICE)
print(f"Average accuracy : {(100 * accuracy):.2f} %")

100%|██████████| 1000/1000 [00:44<00:00, 22.60it/s, accuracy=0.761]

Average accuracy : 76.09 %





# For n_test_tasks = 10000

In [57]:
n_test_tasks = 10000

test_set = GTSRB(
    root="./data",
    # background=False,
    split="train",
    transform=transforms.Compose(
        [
            transforms.Grayscale(num_output_channels=3),
            transforms.Resize([int(image_size * 1.15), int(image_size * 1.15)]),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
        ]
    ),
    download=False,
)

test_set.get_labels = lambda: [instance[1] for instance in test_set]

test_sampler = TaskSampler(
    test_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_test_tasks
)

test_loader = DataLoader(
    test_set,
    batch_sampler=test_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=test_sampler.episodic_collate_fn,
)

In [58]:
accuracy = evaluate(few_shot_classifier, test_loader, device=DEVICE)
print(f"Average accuracy : {(100 * accuracy):.2f} %")

100%|██████████| 10000/10000 [07:13<00:00, 23.08it/s, accuracy=0.759]

Average accuracy : 75.93 %





In [59]:
torch.save(few_shot_classifier, 'model.pth')

In [60]:
nahian = torch.load('model.pth')

In [61]:
few_shot_classifier.state_dict()

OrderedDict([('backbone.layer1.0.weight',
              tensor([[[[-6.3456e-02, -3.2317e-02, -1.0411e-01],
                        [ 8.6533e-02, -2.8199e-02, -1.3228e-01],
                        [ 1.9502e-01,  7.3431e-02, -7.6206e-02]],
              
                       [[-6.4705e-02, -2.8337e-02, -1.0290e-01],
                        [ 9.2724e-02, -3.2615e-02, -1.3101e-01],
                        [ 2.0591e-01,  8.1902e-02, -6.7469e-02]],
              
                       [[-6.6839e-02, -2.2988e-02, -9.2954e-02],
                        [ 9.5857e-02, -2.9953e-02, -1.2360e-01],
                        [ 2.0846e-01,  8.1817e-02, -6.5237e-02]]],
              
              
                      [[[-7.3603e-02, -1.1134e-01, -1.1691e-01],
                        [ 3.4300e-02,  1.6304e-02, -3.4418e-02],
                        [ 1.1333e-01,  1.1498e-01,  5.2232e-02]],
              
                       [[-6.5942e-02, -1.0008e-01, -1.2895e-01],
                        [ 3.3578e

In [62]:
few_shot_classifier.load_state_dict(few_shot_classifier.state_dict())

<All keys matched successfully>

In [63]:
train_optimizer.state_dict()

{'state': {0: {'momentum_buffer': tensor([[[[-0.0277, -0.0242, -0.0221],
             [-0.0255, -0.0236, -0.0190],
             [-0.0232, -0.0248, -0.0231]],
   
            [[-0.0277, -0.0242, -0.0221],
             [-0.0255, -0.0237, -0.0190],
             [-0.0231, -0.0248, -0.0231]],
   
            [[-0.0277, -0.0241, -0.0221],
             [-0.0255, -0.0237, -0.0190],
             [-0.0231, -0.0248, -0.0231]]],
   
   
           [[[-0.0530, -0.0573, -0.0631],
             [-0.0533, -0.0593, -0.0662],
             [-0.0587, -0.0630, -0.0691]],
   
            [[-0.0529, -0.0573, -0.0631],
             [-0.0533, -0.0593, -0.0662],
             [-0.0587, -0.0630, -0.0691]],
   
            [[-0.0529, -0.0573, -0.0631],
             [-0.0532, -0.0592, -0.0662],
             [-0.0587, -0.0630, -0.0692]]],
   
   
           [[[-0.0026, -0.0012, -0.0024],
             [-0.0048, -0.0025, -0.0029],
             [-0.0076, -0.0059, -0.0047]],
   
            [[-0.0027, -0.0012, -0.0025],


In [64]:
torch.save({
            'model_state_dict': few_shot_classifier.state_dict(),
            'optimizer_state_dict': train_optimizer.state_dict(),
            }, 'final.pth')

Congrats! You performed Episodic Training using EasyFSL. If you want to compare with a model trained using classical training, look at [this other example notebook](classical_training.ipynb).
