<a href="https://colab.research.google.com/github/NadineML/Prototypical-Networks-for-Few-shot-Learning-PyTorch/blob/master/working-old-Prototypical_Network.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Table of Contents
* [Setup](#scrollTo=T0ay-ybnOqDn)
* [Configuration](#scrollTo=WKwMFW2Mnf_a&uniqifier=37)
  * **interactive** [mount drive](#scrollTo=g9vZGh5SgFmQ)
  * **interactive** [decide if you want to set parameters manually or load config.json](#scrollTo=UyCwYDbpR7Jv) 
  * **interactive** [set parameters manually, if desired](#scrollTo=UyCwYDbpR7Jv) 
  
* [Datasets](#scrollTo=yq7s0Hiy4F_s)
* [Execution for manually set parameters](#scrollTo=y7Tf7K6OyW_i)
* [Execution for automatically set parameters](#scrollTo=Ra1IpSVMOJuW)
* [Sources](#scrollTo=HexvGfNtzwfV)



## Setup

### Imports and Installs

In [None]:
!pip install easyfsl

Collecting easyfsl
  Downloading easyfsl-0.2.0-py3-none-any.whl (24 kB)
Collecting loguru>=0.5.0
  Downloading loguru-0.5.3-py3-none-any.whl (57 kB)
[K     |████████████████████████████████| 57 kB 3.4 MB/s 
Installing collected packages: loguru, easyfsl
Successfully installed easyfsl-0.2.0 loguru-0.5.3


In [None]:
#@title import necessary modules { form-width: "15%", display-mode: "form" }
import torch
from torch import nn, optim
from torch.utils.data import DataLoader,Sampler, Dataset
from torchvision import transforms, datasets
import random
from typing import List, Tuple

from torchvision.models import resnet18, alexnet, squeezenet1_0, googlenet
from tqdm import tqdm

from easyfsl.utils import plot_images, sliding_average

import json
import os
import random
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.colors import ListedColormap
import numpy as np

import subprocess
import sys

import csv
from sklearn.metrics import confusion_matrix, classification_report, ConfusionMatrixDisplay

### Define Classes

In [None]:
#@title StatsTracker Class  { form-width: "15%", display-mode: "form" }
class StatsTracker():
    def __init__(
        self, 
        n_way: int, 
        n_shot: int, 
        n_query: int, 
        n_evaluation_tasks: int, 
        image_size: int, 
        batch_size: int, 
        n_training_episodes: int, 
        n_validation_tasks: int, 
        pretrained_net: str
      ):
      
      super(StatsTracker, self).__init__()
      self.n_way = n_way
      self.n_shot = n_shot
      self.n_query = n_query
      self.n_evaluation_tasks = n_evaluation_tasks
      self.image_size = image_size
      self.batch_size = batch_size
      self.n_training_episodes = n_training_episodes
      self.n_validation_tasks = n_validation_tasks
      self.backbone = pretrained_net
      self.base_performance = 0
      self._loss_list = []
      self.epochs = -1
      self._acc_logging = {}
      
      return


    @property
    def loss_list(self):
        return self._loss_list

    @loss_list.setter
    def loss_list(self, value):
        self._loss_list = value
        self.epochs = len(self._loss_list)-1



    @property
    def acc_logging(self):
        return self._acc_logging

    @acc_logging.setter
    def acc_logging(self, acc):
        self._acc_logging.update({self.epochs : acc})
        





In [None]:
#@title PrototypicalNetworks Class { form-width: "15%", display-mode: "form" }
class PrototypicalNetworks(nn.Module):
    def __init__(self, backbone: nn.Module):
        super(PrototypicalNetworks, self).__init__()
        global tracker
        global VERBOSE
        self.backbone = backbone
        #print(tracker)
        if VERBOSE > 0:
            print("Created Prototypical Network model with pretrained {} as a backbone\n".format(
                tracker.backbone))
        if VERBOSE == 2:
            print("\nNetwork architecture: \n{}".format(backbone))
    def forward(
        self,
        support_images: torch.Tensor,
        support_labels: torch.Tensor,
        query_images: torch.Tensor,
    ) -> torch.Tensor:
        """
        Predict query labels using labeled support images.
        """
        # Extract the features of support and query images
        z_support = self.backbone.forward(support_images)
        z_query = self.backbone.forward(query_images)

        # Infer the number of different classes from the labels of the support set
        n_way = len(torch.unique(support_labels))

        # Prototype i is the mean of all instances of features corresponding to labels == i
        z_proto = torch.cat(
            [
                z_support[torch.nonzero(support_labels == label)].mean(0)
                for label in range(n_way)
            ]
        )

        # Compute the euclidean distance from queries to prototypes
        dists = torch.cdist(z_query, z_proto)

        # And here is the super complicated operation to transform those distances into classification scores!
        scores = -dists
        return scores


In [None]:
#@title ModTaskSampler Class { form-width: "15%", display-mode: "form" }
class ModTaskSampler(Sampler):
    """
    This is a modified version of easyfsl.data_tools.TaskSampler.
    Samples batches in the shape of few-shot classification tasks. At each iteration, it will sample
    n_way classes, and then sample support and query images from these classes.
    """

    def __init__(
        self, dataset: Dataset, n_way: int, n_shot: int, n_query: int, n_tasks: int
    ):
        """
        Args:
            dataset: dataset from which to sample classification tasks. Must have a field 'items_per_label': a
                dict of the structure {label1 : [idx_occurence1, idx_occurence2, ...], label2 : [...]} and 
            n_way: number of classes in one task
            n_shot: number of support images for each class in one task
            n_query: number of query images for each class in one task
            n_tasks: number of tasks to sample
        """
        super().__init__(data_source=None)
        self.n_way = n_way
        self.n_shot = n_shot
        self.n_query = n_query
        self.n_tasks = n_tasks

        self.items_per_label = {}
        assert hasattr(
            dataset, "items_per_label"
        ), "TaskSampler needs a dataset with a field 'items_per_label' containing the labels of all images and their occurrences."
        self.items_per_label = dataset.items_per_label

    def __len__(self):
        return self.n_tasks

    def __iter__(self):
        for _ in range(self.n_tasks):
            yield torch.cat(
                [
                    # pylint: disable=not-callable
                    torch.tensor(
                        random.sample(
                            self.items_per_label[label], self.n_shot + self.n_query
                        )
                    )
                    # pylint: enable=not-callable
                    for label in random.sample(self.items_per_label.keys(), self.n_way)
                ]
            )

    def episodic_collate_fn(
        self, input_data: List[Tuple[torch.Tensor, int]]
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[int]]:
        """
        Collate function to be used as argument for the collate_fn parameter of episodic
            data loaders.
        Args:
            input_data: each element is a tuple containing:
                - an image as a torch Tensor
                - the label of this image
        Returns:
            tuple(Tensor, Tensor, Tensor, Tensor, list[int]): respectively:
                - support images,
                - their labels,
                - query images,
                - their labels,
                - the dataset class ids of the class sampled in the episode
        """

        true_class_ids = list({x[1] for x in input_data})

        all_images = torch.cat([x[0].unsqueeze(0) for x in input_data])
        all_images = all_images.reshape(
            (self.n_way, self.n_shot + self.n_query, *all_images.shape[1:])
        )
        # pylint: disable=not-callable
        all_labels = torch.tensor(
            [true_class_ids.index(x[1]) for x in input_data]
        ).reshape((self.n_way, self.n_shot + self.n_query))
        # pylint: enable=not-callable

        support_images = all_images[:, : self.n_shot].reshape(
            (-1, *all_images.shape[2:])
        )
        query_images = all_images[:, self.n_shot :].reshape((-1, *all_images.shape[2:]))
        support_labels = all_labels[:, : self.n_shot].flatten()
        query_labels = all_labels[:, self.n_shot :].flatten()

        return (
            support_images,
            support_labels,
            query_images,
            query_labels,
            true_class_ids,
        )


In [None]:
#@title CombinedDataset Class  { form-width: "15%", display-mode: "form" }

import bisect
import functools

from torch.utils.data.dataset import Dataset, IterableDataset
from typing import (
    Callable,
    Dict,
    Generic,
    Iterable,
    Iterator,
    List,
    Optional,
    Sequence,
    Tuple,
    TypeVar,
)
T_co = TypeVar('T_co', covariant=True)

class CombinedDataset(Dataset[T_co]):
  # modified version of torch.utils.data.dataset.ConcatDataset
    r"""Dataset as a combination of multiple datasets.

    This class is a modified version of torch.utils.data.dataset.ConcatDataset,
    which is useful to assemble different existing datasets.

    Args:
        datasets (sequence): List of datasets to be concatenated
    """
    datasets: List[Dataset[T_co]]
    cumulative_sizes: List[int]

    @staticmethod
    def cumsum(sequence):
        r, s = [], 0
        for e in sequence:
            l = len(e)
            r.append(l + s)
            s += l
        return r
    
    @staticmethod
    def images_labels_classes(sequence):
        label = 0
        #images = []
        classes = []
        for ds in sequence:
          label += len(ds.classes)
          #images.extend(ds.imgs)
          classes.append(ds.classes)
        #return images, list(range(label)), classes
        return list(range(label)), classes
    


    @staticmethod
    def list_labels_for_all_images(datasets, classes):
      labels = []
      idx = 0
      offset = 0
      for ds_idx in range(len(datasets)):
        l = [instance[1]+offset for instance in datasets[ds_idx].imgs]
        labels.extend(l)
        offset += len(classes[idx])
        idx += 1
      return labels


    @staticmethod
    def mapped_labels(labels, classes):
        map = {}
        cidx = 0
        max_label_idx = len(labels)
        for idx in range(len(classes)):
          for idc in range(0,len(classes[idx])):
              if cidx >= max_label_idx:
                  return map
              map.update({labels[cidx] : (classes[idx][idc], idc)})
              cidx += 1
              
        return map

    
    @staticmethod
    def mapped_items_per_label(datasets, map):
        items_per_label = {}
        max_idx = len(map.keys())
        idx = 0
        offset = 0
        offset2 = 0
        
        for ds in datasets:
          labels = [instance[1] for instance in ds.imgs]
          for i in range(idx, idx+len(ds.classes)):
            items_per_label.update({i : [j+offset2 for j in range(len(labels)) if labels[j] == map.get(i)[1]]})
            offset = len(ds.classes)
          offset2 += len(labels)
          idx += offset
          #offset2 -= 1
        return items_per_label          
        

    def __init__(self, datasets: Iterable[Dataset]) -> None:
        super(CombinedDataset, self).__init__()
        self.datasets = list(datasets)
        assert len(self.datasets) > 0, 'datasets should not be an empty iterable'  # type: ignore[arg-type]
        for d in self.datasets:
            assert not isinstance(d, IterableDataset), "CobinedDataset does not support IterableDataset"
        self.cumulative_sizes = self.cumsum(self.datasets)
        
        #self.imgs, self.possible_labels, self.grouped_classes = self.images_labels_classes(self.datasets)
        self.possible_labels, self.grouped_classes = self.images_labels_classes(self.datasets)
        self.classes = [item for sublist in self.grouped_classes for item in sublist]
        self.mapped_labels = self.mapped_labels(self.possible_labels, self.grouped_classes)
        self.items_per_label = self.mapped_items_per_label(self.datasets, self.mapped_labels)
        self.labels = self.list_labels_for_all_images(self.datasets, self.grouped_classes)

    def __len__(self):
        return self.cumulative_sizes[-1]

    def __getitem__(self, idx):
        if idx < 0:
            if -idx > len(self):
                raise ValueError("absolute value of index should not exceed dataset length")
            idx = len(self) + idx
        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
        if dataset_idx == 0:
            sample_idx = idx
        else:
            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
        return (self.datasets[dataset_idx].imgs[sample_idx][0], self.labels[idx])
        



### Define functions

In [None]:
#@title def create_subfilesystem  { form-width: "15%", display-mode: "form" }
def create_subfilesystem(base_path, config):
  path_dict = {}
  conf_dict = config.pop("params")
  for key in config.keys():
    p1 = os.path.join(base_path, key)
    os.mkdir(p1)
    path_dict.update({key:{}})
    for subcatkey in config[key].keys():
      p2 = os.path.join(p1, subcatkey)
      os.mkdir(p2)
      path_dict[key].update({subcatkey:[]})
      for p_dict in config[key][subcatkey]:
        p3 = os.path.join(p2, str(p_dict[subcatkey]))
        os.mkdir(p3)
        path_dict[key][subcatkey].append(p3)
        p4 = os.path.join(p3, "checkpoints")
        p5 = os.path.join(p3, "figures")
        os.mkdir(p4)
        os.mkdir(p5)
  return path_dict, conf_dict

In [None]:
#@title functions to save to *.csv { form-width: "15%", display-mode: "form" }
def save_performance_as_csv(save_to_path, epochs, train_losses, train_accs, test_losses, test_accs):
  with open(save_to_path, 'w', newline='') as csvfile:
      header_key = ['epoch','train_loss','train_acc','test_loss', 'test_acc']
      
      new_val = csv.DictWriter(csvfile, fieldnames=header_key)
      new_val.writeheader()

      for idx in range(len(epochs)):
        new_val.writerow({'epoch': epochs[idx], 'train_loss': train_losses[idx], 'train_acc' : train_accs[idx] , 'test_loss' : test_losses[idx] , 'test_acc': test_accs[idx]})
  print("Perfomance over time was saved at "+save_to_path+"\n")

  
def save_report_as_csv(save_to_path, report_dict):
  with open(save_to_path, 'w', newline='') as csvfile:
      header_key = ['label','precision','recall','f1-score', 'support']
      sec_key = ['accuracy']
      new_val = csv.DictWriter(csvfile, fieldnames=header_key)
      sec_val = csv.DictWriter(csvfile, fieldnames=sec_key)
      new_val.writeheader()
      for key in report_dict.keys():
          if key != 'accuracy':
            new_val.writerow({'label': key, 'precision': report_dict[key]['precision'], 'recall' : report_dict[key]['recall'],'f1-score' : report_dict[key]['f1-score'], 'support' : report_dict[key]['support']})
      sec_val.writeheader()
      sec_val.writerow({'accuracy' : report_dict['accuracy']})
  print("Classification report was saved at "+save_to_path+"\n")

def save_track_record_as_csv(save_to_path, t):
  with open(save_to_path, 'w', newline='') as csvfile:
    config_key = ['n_way', 'n_shot', 'n_query', 'n_evaluation_tasks', 'image_size', 'batch_size', 'n_validation_tasks', 'pretrained_net']
    acc_stats_key = ['epochs', 'accuracy']
    loss_stats_key = ['epochs', 'loss']
    w1 = csv.DictWriter(csvfile, fieldnames=config_key)
    w2 = csv.DictWriter(csvfile, fieldnames=acc_stats_key)
    w3 = csv.DictWriter(csvfile, fieldnames=loss_stats_key)
    w1.writeheader()
    w1.writerow({'n_way': t.n_way, 'n_shot': t.n_shot, 'n_query': t.n_query, 'n_evaluation_tasks': t.n_evaluation_tasks, 'image_size': t.image_size, 'batch_size': t.batch_size, 'n_validation_tasks': t.n_validation_tasks, 'pretrained_net': t.backbone})
    w2.writeheader()
    a = t.acc_logging
    for i in a.keys():
      w2.writerow({'epochs': i, 'accuracy': a[i]})
    l = t.loss_list
    w3.writeheader()
    for i in range(len(l)):
       w3.writerow({'epochs': i, 'loss': l[i]})

In [None]:
#@title # def fetch_model  { form-width: "15%", display-mode: "form" }
#@markdown This function is used to load model checkpoints and save information about the loaded checkpoint in the StatsTracker object tracker.
def fetch_model(load_path_checkpoint):
  global tracker
  global convolutional_network
  model = PrototypicalNetworks(convolutional_network).to(device)
  checkpoint = torch.load(load_path_checkpoint)
  model.load_state_dict(checkpoint['model_state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  tracker.loss_list = checkpoint['loss']
  print(checkpoint['loss'])
  tracker.epochs = checkpoint['epoch']
  
  if VERBOSE > 0: print("Fetched model {}\n".format(load_path_checkpoint))
  return model, optimizer, checkpoint['loss'], checkpoint['epoch']

In [None]:
#@title # def find_best_model { form-width: "15%", display-mode: "form" }
#@markdown This function builds a path to the model with the smallest loss saved in the StatsTracker object's attribute loss_list.
def find_best_model():
  global tracker
  index_min = np.argmin(tracker.loss_list)
  load_path_checkpoint = os.path.join(cwd, "checkpoints", "model-epoch_{:05}-loss_{:0.3f}-N_way_{}-N_shot_{}-N_query_{}.pt".format(index_min, tracker.loss_list[index_min], tracker.n_way, tracker.n_shot, tracker.n_query))
  return load_path_checkpoint

In [None]:
#@title #functions to evaluate the model { form-width: "15%", display-mode: "form" }
def evaluate_on_one_task_properly(
    support_images: torch.Tensor,
    support_labels: torch.Tensor,
    query_images: torch.Tensor,
    query_labels: torch.Tensor,
) -> [int, int]:
    """
    Returns the number of correct predictions of query labels, and the total number of predictions.
    """
    return (
        torch.max(
            model(support_images.to(device), support_labels.to(device), query_images.to(device))
            .detach()
            .data,
            1,
        )[1]
        == query_labels.to(device)
    ).sum().item(), len(query_labels), list(zip(support_labels, query_labels))


def evaluate_properly(data_loader: DataLoader):
    # We'll count everything and compute the ratio at the end
    global tracker
    total_predictions = 0
    correct_predictions = 0
    total_classifications = []
    
    
    # eval mode affects the behaviour of some layers (such as batch normalization or dropout)
    # no_grad() tells torch not to keep in memory the whole computational graph (it's more lightweight this way)
    model.eval()
    with torch.no_grad():
        for episode_index, (
            support_images,
            support_labels,
            query_images,
            query_labels,
            class_ids,
        ) in tqdm(enumerate(data_loader), total=len(data_loader)):
            
            correct, total, classifications = evaluate_on_one_task_properly(
                support_images, support_labels, query_images, query_labels
            )
            _y_true_, _y_pred_ = zip(*classifications)
            
            y_true_ = [class_ids[i.item()] for i in list(_y_true_)]
            y_pred_ = [class_ids[i.item()] for i in list(_y_pred_)]
            total_predictions += total
            correct_predictions += correct
            total_classifications.extend(zip(y_true_, y_pred_))
    y_true_, y_pred_ = zip(*total_classifications)

    #y_true = [example_class_ids[i] for i in list(y_true_)]
    #y_pred = [example_class_ids[i] for i in list(y_pred_)]
    y_true = list(y_true_)
    y_pred = list(y_pred_)
    accuracy = 100 * correct_predictions/total_predictions
    
    tracker.acc_logging = accuracy
    #if verbose == verbose_levels[2]:
    if VERBOSE == 2:
        print("Ground Truth / Predicted")
        for i in range(total_predictions):
            
            print(
                #f"{total_classifications[i][0]}/{total_classifications[i][1]}"
                f"{faps_data_test.classes[y_true[i]]} / {faps_data_test.classes[y_pred[i]]}"
            )
        print("\n\n")
    print(
        f"Model tested on {len(data_loader)} tasks. Accuracy: {accuracy:.2f}%"
    )
    return y_true, y_pred, accuracy



In [None]:
#@title # def get_metrics { form-width: "15%", display-mode: "form" }
def get_metrics(data_loader: DataLoader):
  global tracker
  labels, predictions, acc = evaluate_properly(data_loader)
  
  if VERBOSE > 0:
      fig, ax = plt.subplots(figsize=(20,20))
      disp = ConfusionMatrixDisplay.from_predictions(y_true=labels, y_pred=predictions, labels=data_loader.dataset.possible_labels, sample_weight=None, normalize='true', display_labels=data_loader.dataset.classes, include_values=True, xticks_rotation='vertical', values_format=None, cmap=faps_colours, ax=ax, colorbar=True)
      save_to_path = "figures"
      confusion_matrix_filename = "confusion-matrix_n-way_{}_n-shot_{}_n-query_{}_training-epochs_{}_validation-tasks_{}.svg".format(tracker.n_way, tracker.n_shot, tracker.n_query, tracker.epochs, tracker.n_validation_tasks)
      plt.savefig(os.path.join(cwd, save_to_path, confusion_matrix_filename))
  classification_report_filename =  "classification-report_n-way_{}_n-shot_{}_n-query_{}_training-epochs_{}_validation-tasks_{}.svg".format(tracker.n_way, tracker.n_shot, tracker.n_query, tracker.epochs, tracker.n_validation_tasks)
  #save_report_as_csv(os.path.join(drive_path, save_to_path, classification_report_filename + '.csv'), classification_report(labels, predictions,labels=data_loader.dataset.possible_labels,target_names=data_loader.dataset.classes, output_dict=True, zero_division=0))

  if VERBOSE > 0:
      print("\n\n")
      print(classification_report(labels, predictions,labels=data_loader.dataset.possible_labels,target_names=data_loader.dataset.classes, output_dict=False, zero_division=0))
  if tracker.epochs == -1:
    tracker.base_performance = acc
  else:
    print("\nModel trained for {} epochs. Accuracy has changed from baseline {:.2f}% to  {:.2f}%.\nThis means improvement by {:.3f}%\n".format(tracker.epochs, tracker.base_performance, acc, acc - tracker.base_performance))


In [None]:
#@title # def fit { form-width: "15%", display-mode: "form" }
def prepare_training():
  criterion = nn.CrossEntropyLoss()
  optimizer = optim.Adam(model.parameters(), lr=0.001)
  return criterion, optimizer

def fit(
    support_images: torch.Tensor,
    support_labels: torch.Tensor,
    query_images: torch.Tensor,
    query_labels: torch.Tensor,
) -> float:
    optimizer.zero_grad()
    classification_scores = model(
        support_images.to(device), support_labels.to(device), query_images.to(device)
    )

    loss = criterion(classification_scores, query_labels.to(device))
    loss.backward()
    optimizer.step()

    return loss.item()


In [None]:
#@title #def training { form-width: "15%", display-mode: "form" }

def training(model, tracker, train_loader, test_loader):
    global criterion
    global optimizer
    log_update_frequency = 5
    save_model_frequency = 150
    get_metrics_frequency = 500
    _print = (VERBOSE > 0)
    all_loss = tracker.loss_list
    offset = max(tracker.epochs, 0)
    min_loss = 0.02
    loss_cutoff = 0.1
    


    model.train()
    with tqdm(enumerate(train_loader), total=len(train_loader)) as tqdm_train:
        tqdm_train.n = offset 
        tqdm_train.last_print_n = offset
        tqdm_train.update()
        for episode_index, (
            support_images,
            support_labels,
            query_images,
            query_labels,
            _,
        ) in tqdm_train:
            loss_value = fit(support_images, support_labels, query_images, query_labels)

            all_loss.append(loss_value)

            if (episode_index) % get_metrics_frequency == 0 and episode_index > 0:
                tracker.loss_list = all_loss
                get_metrics(test_loader)

            if (episode_index) % log_update_frequency == 0 and episode_index > 0:
                tqdm_train.set_postfix(loss=sliding_average(all_loss, log_update_frequency))
                tracker.loss_list = all_loss
            
            if (episode_index + offset) % save_model_frequency == 0 and loss_value < loss_cutoff or loss_value < min_loss or episode_index == tqdm_train.total-1:
                if loss_value < min_loss:
                    min_loss = loss_value
                    if _print: print("\nnew min_loss: "+str(min_loss)+"\n")
                torch.save({
                'epoch': episode_index + offset,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': all_loss,
                }, os.path.join(cwd,"checkpoints","model-epoch_{:05}-loss_{:0.3f}-N_way_{}-N_shot_{}-N_query_{}.pt".format(episode_index + offset, loss_value, tracker.n_way, tracker.n_shot, tracker.n_query)))
                
            tracker.loss_list = all_loss
    get_metrics(test_loader)

In [None]:
#@title the ACTUAL algorithm { form-width: "15%", display-mode: "code" }
def run_for_one_param_setting(save_to_path, param_dict):

  # read parameters from dict
  N_WAY = param_dict["n_way"]
  N_SHOT = param_dict["n_shot"]
  N_QUERY = param_dict["n_query"]
  N_EVALUATION_TASKS = param_dict["n_evaluation_tasks"]
  IMAGE_SIZE = param_dict["image_size"]
  BATCH_SIZE = param_dict["batch_size"]
  N_TRAINING_EPISODES = param_dict["n_training_episodes"]
  #N_TRAINING_EPISODES = 2000
  N_VALIDATION_TASKS = param_dict["n_validation_tasks"]
  PRETRAINED_NET = param_dict["pretrained_net"]
  
  verbose_levels =  ["minimal output", "moderate output", "verbose"]
  global VERBOSE
  VERBOSE = verbose_levels.index(param_dict["verbose"])
  global cwd
  cwd = save_to_path

  # create tracker object
  global tracker
  tracker = StatsTracker(N_WAY, N_SHOT, N_QUERY, N_EVALUATION_TASKS, IMAGE_SIZE, BATCH_SIZE, N_TRAINING_EPISODES, N_VALIDATION_TASKS, PRETRAINED_NET)

  # get pretrained backbone
  nets = ["resnet18", "alexnet", "squeezenet", "googlenet", "densenet161"]
  global convolutional_network
  idx = nets.index(PRETRAINED_NET)
  if idx == 0:
    convolutional_network = resnet18(pretrained=True)
  elif idx == 1:
    convolutional_network = alexnet(pretrained=True)
  elif idx == 2:
    convolutional_network = squeezenet1_0(pretrained=True)
  elif idx == 3:
    convolutional_network = googlenet(pretrained=True)
  elif idx == 4:
    convolutional_network = models.densenet161(pretrained=True)
  convolutional_network.fc = nn.Flatten() 

  # create model with backbone
  global model
  model = PrototypicalNetworks(convolutional_network).to(device)

  global criterion
  global optimizer
  criterion, optimizer = prepare_training()

  # create DataLoader object for testing
  test_sampler=ModTaskSampler(faps_data_test, n_way=N_WAY, n_shot=N_SHOT, n_query=N_QUERY, n_tasks=N_EVALUATION_TASKS)
  test_loader = DataLoader(
      faps_data_test,
      batch_sampler=test_sampler,
      num_workers=2,
      pin_memory=True,
      collate_fn=test_sampler.episodic_collate_fn,
  )

  # create DataLoader object for training
  train_sampler=ModTaskSampler(faps_data_train, n_way=N_WAY, n_shot=N_SHOT, n_query=N_QUERY, n_tasks=N_TRAINING_EPISODES)
  train_loader = DataLoader(
      faps_data_train,
      batch_sampler=train_sampler,
      num_workers=2,
      pin_memory=True,
      collate_fn=train_sampler.episodic_collate_fn,
  )

  # visualize one task
  if VERBOSE > 0:
    (
        example_support_images,
        example_support_labels,
        example_query_images,
        example_query_labels,
        example_class_ids
    ) = next(iter(test_loader))
    print("This task contains the following {} classes {}\n".format(N_WAY, example_class_ids))
    plot_images(example_support_images, "support images", images_per_row=N_SHOT)
    plot_images(example_query_images, "query images", images_per_row=N_QUERY)

  # establish baseline
  get_metrics(test_loader)

  # train the model and save results
  training(model, tracker, train_loader, test_loader)

  # fetch the best model again for evaluation, if loss < 0.1, otherwise it wasn't saved
  if min(tracker.loss_list) < 0.02:
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    index_min = np.argmin(tracker.loss_list)
    load_path_checkpoint = find_best_model()
    model, optimizer, all_loss, epochs = fetch_model(load_path_checkpoint)
    get_metrics(test_loader)

  # save info in tracker object
  record_filename = "record_net_{}_n-way_{}_n-shot_{}_n-query_{}_training-epochs_{}_validation-tasks_{}.csv".format(PRETRAINED_NET, N_WAY, N_SHOT, N_QUERY, tracker.epochs, N_VALIDATION_TASKS)
  s_path = os.path.join(cwd, "figures", record_filename)
  save_track_record_as_csv(s_path, tracker)



### Define constants

In [None]:
#@title colourscheme faps_colours { form-width: "15%", display-mode: "form" }
faps_green=(151/255, 193/255, 57/255)
faps_dark_green=(93/255,119/255,35/255)
faps_light_green=(205/255, 226/255, 158/255)
faps_light_yellow=(1, 234/255, 147/255)
faps_yellow=(1, 204/255, 0)
faps_dark_yellow=(200/255, 162/255, 0)
faps_colours = ListedColormap(["white", faps_light_yellow, faps_yellow, faps_light_green, faps_green, faps_dark_green])

## Configuration: Manual input required

In [None]:
#@title Mount Google Drive { form-width: "15%", display-mode: "form" }
#@markdown This block requires you to go through the login process of your Google Account to access Google Drive where your dataset should be stored.
from google.colab import drive
base_path = '/content/data'
drive.mount(base_path)
wd = os.path.join(base_path, "MyDrive")

Mounted at /content/data


In [None]:
#@title parameter configuration
#@markdown Select this option and provide a config to skip the manual parameter entry of the next cell 
set_parameters_from_config_file = True #@param {type:"boolean"}
#@markdown Please provide a path relative to the "root" level of your Google Drive account including the filename and its filename extension
rel_config_path = "peds3_conf.json" #@param {type:"string"}

if set_parameters_from_config_file:
  if rel_config_path == "":
    print("Please provide rel_config_path or disable the set_parameters_from_config_file option!")
  config_path = os.path.join(wd, rel_config_path)
  with open(config_path) as json_file:
    config_json = json.load(json_file)

  p = os.path.join(wd, "PEDS3/PrototypicalNetwork")
  print(config_json.keys())
  path_dict, conf_dict = create_subfilesystem(p, config_json)
  print(conf_dict.keys())
  data_path = conf_dict["data_path"]
  drive_path = os.path.join(base_path, "MyDrive", data_path)

  preferred_device = conf_dict["preferred_device"]
  if preferred_device == "GPU" and torch.cuda.is_available():
    device = 'cuda'
    gpu_info = !nvidia-smi
    gpu_info = '\n'.join(gpu_info)
    if gpu_info.find('failed') >= 0:
      print('Could not connect to GPU, connected to CPU instead!')
      device = 'cpu'
    else:
      print("Connected to the following GPU:\n")
      print(gpu_info)
  
  image_size = conf_dict["image_size"]
  

dict_keys(['params', 'resnet', 'alexnet', 'squeezenet', 'googlenet'])
dict_keys(['data_path', 'preferred_device', 'image_size'])
Connected to the following GPU:

Sat Jan  8 15:33:06 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.44       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   33C    P0    25W / 250W |      2MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
      

In [None]:
#@title configuration pt 2: manually { form-width: "15%", display-mode: "form" }

if not set_parameters_from_config_file:
  #@markdown #### Number of images per class in the support set
  n_shot = 4 #@param {type:"slider", min:1, max:5, step:1}
  #@markdown #### Number of classes in a task
  n_way = 6 #@param {type:"slider", min:1, max:15, step:1}
  #@markdown #### Number of images per class in the query set
  n_query = 1 #@param {type:"slider", min:1, max:5, step:1}
  n_evaluation_tasks = 100 #@param {type:"slider", min:1, max:100, step:1}
  image_size = 256 #@param {type:"slider", min:16, max:512, step:128}
  batch_size = 8 #@param {type:"slider", min:1, max:16, step:1}
  n_training_episodes = 5000 #@param {type:"slider", min:100, max:7500, step:100}
  n_validation_tasks = 50 #@param {type:"slider", min:10, max:1000, step:10}
  #@markdown ---
  #@markdown ### Choose if you want to use GPU or CPU:
  preferred_device = "GPU" #@param ["CPU", "GPU"]
  #@markdown ---
  #@markdown ### Enter a Google Drive path to your dataset:
  data_path = "modified_data_set" #@param {type:"string"}
  #@markdown ---

  if preferred_device == "GPU" and torch.cuda.is_available():
    device = 'cuda'
    gpu_info = !nvidia-smi
    gpu_info = '\n'.join(gpu_info)
    if gpu_info.find('failed') >= 0:
      print('Could not connect to GPU, connected to CPU instead!')
      device = 'cpu'
    else:
      print("Connected to the following GPU:\n")
      print(gpu_info)
  else:
    device = 'cpu'
    print("Connected to CPU, this might be slow. Consider connecting to GPU and executing this code cell again!\n")

  #@markdown ### Choose a pretrained model to start with:
  pretrained_net = "resnet18" #@param ["resnet18", "alexnet", "squeezenet", "googlenet", "densenet161"]

  #@markdown ---
  #@markdown ### How much output do you want to generate:
  verbose = "verbose" #@param ["minimal output", "moderate output", "verbose"]
  verbose_levels = ["minimal output", "moderate output", "verbose"]
  
  N_WAY = n_way
  N_SHOT = n_shot
  N_QUERY = n_query
  N_EVALUATION_TASKS = n_evaluation_tasks
  N_TRAINING_EPISODES = n_training_episodes
  N_VALIDATION_TASKS = n_validation_tasks
  VERBOSE = verbose_levels.index(verbose)
  drive_path = os.path.join(base_path, "MyDrive", data_path)


## Run Code

In [None]:
#@title Create datasets { form-width: "15%", display-mode: "form" }

train_data_transform = transforms.Compose(
        [
            transforms.Grayscale(num_output_channels=3),
            transforms.RandomResizedCrop(image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ]
)

test_data_transform = transforms.Compose(
        [
            transforms.Grayscale(num_output_channels=3),
            transforms.Resize([int(image_size * 1.15), int(image_size * 1.15)]),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
        ]
)

#create train dataset from all folders in the specified path
folder_list = sorted(os.listdir(os.path.join(drive_path, "train")))
dataset_list = []
for folder in folder_list:
  ds = datasets.ImageFolder(root=os.path.join(drive_path, "train", folder), transform=train_data_transform)
  dataset_list.append(ds)

faps_data_train = CombinedDataset(dataset_list)



#create test dataset from all folders in the specified path
folder_list = sorted(os.listdir(os.path.join(drive_path, "test")))
dataset_test_list = []
for folder in folder_list:
  ds = datasets.ImageFolder(root=os.path.join(drive_path, "test", folder), transform=test_data_transform)
  dataset_test_list.append(ds)

faps_data_test = CombinedDataset(dataset_test_list)


FileNotFoundError: ignored

In [None]:
#@title RUN manually  { form-width: "15%", display-mode: "form" }
if not set_parameters_from_config_file:
  param_dict = {  
      "n_way": N_WAY,
      "n_shot": N_SHOT,
      "n_query": N_QUERY,
      "n_evaluation_tasks": N_EVALUATION_TASKS,
      "image_size": image_size,
      "batch_size": batch_size,
      "n_validation_tasks": N_VALIDATION_TASKS,
      "pretrained_net": pretrained_net,
      "verbose": VERBOSE
  }
  foldername = "{}-n_way_{}-n_shot_{}-n_query-{}".format(pretrained_net, N_WAY, N_SHOT, N_QUERY)
  p1 = os.path.join(cwd, foldername)
  os.mkdir(p1)
  p2 = os.path.join(p1, "figures")
  p3 = os.path.join(p1, "checkpoints")
  os.mkdir(p2)
  os.mkdir(p3)
  run_for_one_param_setting(p1, param_dict)



In [None]:
#@title RUN automatically  { form-width: "15%", display-mode: "form" }
for key in config_json.keys():
    
    for subcatkey in config_json[key].keys():
        
        for param_dict_idx in range(len(config_json[key][subcatkey])):
            run_for_one_param_setting(path_dict[key][subcatkey][param_dict_idx], config_json[key][subcatkey][param_dict_idx])


# Sources

Code skeleton taken from [here](https://github.com/sicara/easy-few-shot-learning/blob/master/notebooks/my_first_few_shot_classifier.ipynb)

The class ModTaskSampler is a modified version of [this](https://github.com/sicara/easy-few-shot-learning/blob/master/easyfsl/data_tools/task_sampler.py)

The class CombinedDataset is a modified version of [this](https://pytorch.org/docs/stable/_modules/torch/utils/data/dataset.html)
