## An attempt to bring Neural Architecture Search to Few-shot learning by combing MAML and ENAS

Author: Aleksandr Chebykin

Basic description of MAML (Model-Agnostic Meta-Learning) and ENAS (Efficient Neural Architecture Search):

1) MAML: train a network on a variety of tasks such that it would be able to quickly adapt to a new task. [https://arxiv.org/abs/1703.03400]

2) ENAS: train a controller network to generate a task-specific network architecture. [https://arxiv.org/abs/1802.03268]

The idea of this mini-project was to combine the two algorithms: train a controller network such that it would be able to quickly generate a network architecture for a new task.

This idea didn't work out. Maybe it was a bad idea, maybe I haven't found enough tricks and good hyperparameters to make it work.

The bigger part of the code was taken from the following repostories: https://github.com/oscarknagg/few-shot/ (MAML), https://github.com/TDeVries/enas_pytorch (ENAS). I am thankful to the authors, and can recommend these implementations as both well-written and actually working.

## Description of the approach

For starters, here's a decription of training loops of vanilla MAML and ENAS:

1) MAML: at the start of the outer loop a batch of tasks is sampled; in the inner loop a copy of a network is trained for a small number of steps on each task; at the end of the outer loop the network parameters are updated using the updated parameters from the inner loop.

2) ENAS: in each epoch firstly the controller parameters are frozen while parameters of child networks are trained, and then vice versa.

It is worth noting that there are 2 versions of MAML: for RL and for non-RL. RL version seems more fitting, since ENAS is optimized as a reinforcement learning problem. However, there is not really a chain of state-action-rewards to sample, because there's no real state, and the reward can only be given in the end. [reward is the loss of the generated model on the validation set] Additionaly, I aimed at Few-shot classification on MiniImageNet, a non-RL problem. So in the end I went with a combination of RL and non-RL MAML.

I faced a number of problems when combining MAML and ENAS. For example, ENAS has child networks' parameters trained for a whole epoch before the controller is updated. But in the case of few-shot learning we simply don't have enough samples to do that, and the task keeps changing from batch to batch. The only way to proceed is simultaneous learning of both networks.

Another problem is that to get ENAS reward for the currently generated network, we need to first train it on the training set, and then measure validation accuracy on the validation set. This is not a wasted computation due to weight sharing of child networks, but it does mean that we need to sample two times more data. Why? Because original MAML requires train+val for each iteration. With ENAS this becomes (train_train+train_val)+(val_train+val_val).

In the end, pseudocode of my approach is this:

```
for _ in range(epochs):
    sample batch of tasks {T_i} and data for them (2 train sets + 2 validation sets)
    for T_i in {T_i}:
        clone controller to controller_clone
        # trainining controller-clone:
        sample child_net
        train child_net on train1
        evaluate child_net on val1
        update controller_clone by PG with validation accuracy as reward
        # evaluating controller-clone:
        sample child_net
        train child_net on train2
        evaluate child_net on val2
        backpropagate through controller_clone by PG with validation accuracy as reward
        save gradients of controller_clone
    update controller with all the saved gradients of different controller_clones
    
```



The resulting scheme is pretty complicated, so it is not surprising that it fails to train (networks perform no better than random guessing), particularly taking into account that I don't have hardware to do hyperparameter search, and both algorithms are sensitive to hyperparameters, which are many; for the set of intersecting hyperparameters (e.g. learning rate) the algorithms have different values.

As to the reinforcement learning part, I used the same improvements to the PG as ENAS: adding entropy to the reward; subtracting from the reward the baseline (average reward over the current epoch)

## Loading the MiniImageNet data

In [0]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [0]:
DATA_PATH = '/content/drive/My Drive'
LOG_PATH = '/content/drive/My Drive/mamlenas/logs/'
MODELS_PATH = '/content/drive/My Drive/mamlenas/models/'

In [0]:
import os

try:
    os.mkdir(LOG_PATH)
    os.mkdir(MODELS_PATH)
except:
  pass

In [0]:
# Link to the dataset: https://drive.google.com/file/d/1HkgrkAwukzEZA0TpO7010PkAOREb2Nuk/view

import zipfile
with zipfile.ZipFile(DATA_PATH + "mini-imagenet.zip","r") as zip_ref:
    zip_ref.extractall(DATA_PATH + '/miniImageNet/')

In [0]:
import numpy as np
from torch.distributions.categorical import Categorical
import torch
import shutil
from typing import List, Iterable, Callable, Tuple, Dict, Union

def move_data_to_proper_folder_structure():
    # Find class identities
    classes = []
    for root, _, files in os.walk(DATA_PATH + '/miniImageNet/images/'):
        for f in files:
            if f.endswith('.jpg'):
                classes.append(f[:-12])

    classes = list(set(classes))
    print(classes)

    # Train/test split
    np.random.seed(0)
    np.random.shuffle(classes)
    background_classes, evaluation_classes = classes[:80], classes[80:]

    # Create class folders
    for c in background_classes:
        os.mkdir(DATA_PATH + f'/miniImageNet/images_background/{c}')

    for c in evaluation_classes:
        os.mkdir(DATA_PATH + f'/miniImageNet/images_evaluation/{c}')

    # Move images to correct location
    for root, _, files in os.walk(DATA_PATH + '/miniImageNet/images'):
        for f in tqdm(files, total=600*100):
            if f.endswith('.jpg'):
                class_name = f[:-12]
                image_name = f[-12:]
                # Send to correct folder
                subset_folder = 'images_evaluation' if class_name in evaluation_classes else 'images_background'
                src = f'{root}/{f}'
                dst = DATA_PATH + f'/miniImageNet/{subset_folder}/{class_name}/{image_name}'
                shutil.copy(src, dst)

move_data_to_proper_folder_structure()

[]


In [0]:
from torch.utils.data import DataLoader
from torch import nn
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms
from skimage import io as io_skimage
from tqdm.auto import tqdm
import pandas as pd

class MiniImageNet(Dataset):
    def __init__(self, subset):
        """Dataset class representing miniImageNet dataset
        # Arguments:
            subset: Whether the dataset represents the background or evaluation set
        """
        if subset not in ('background', 'evaluation'):
            raise(ValueError, 'subset must be one of (background, evaluation)')
        self.subset = subset

        self.df = pd.DataFrame(self.index_subset(self.subset))

        # Index of dataframe has direct correspondence to item in dataset
        self.df = self.df.assign(id=self.df.index.values)

        # Convert arbitrary class names of dataset to ordered 0-(num_classes - 1) integers
        self.unique_classnames = sorted(self.df['class_name'].unique())
        self.class_name_to_id = {self.unique_classnames[i]: i for i in range(self.num_classes())}
        self.df = self.df.assign(class_id=self.df['class_name'].apply(lambda c: self.class_name_to_id[c]))

        # Create dicts
        self.datasetid_to_filepath = self.df.to_dict()['filepath']
        self.datasetid_to_class_id = self.df.to_dict()['class_id']

        # Setup transforms
        self.transform = transforms.Compose([
            transforms.CenterCrop(224),
            transforms.Resize(84),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

    def __getitem__(self, item):
        instance = Image.open(self.datasetid_to_filepath[item])
        instance = self.transform(instance)
        label = self.datasetid_to_class_id[item]
        return instance, label

    def __len__(self):
        return len(self.df)

    def num_classes(self):
        return len(self.df['class_name'].unique())

    @staticmethod
    def index_subset(subset):
        """Index a subset by looping through all of its files and recording relevant information.
        # Arguments
            subset: Name of the subset
        # Returns
            A list of dicts containing information about all the image files in a particular subset of the
            miniImageNet dataset
        """
        images = []
        print('Indexing {}...'.format(subset))
        # Quick first pass to find total for tqdm bar
        subset_len = 0
        path_my = DATA_PATH
        for root, folders, files in os.walk('path_my'.format(subset)):
            subset_len += len([f for f in files if f.endswith('.png')])

        progress_bar = tqdm(total=subset_len)
        for root, folders, files in os.walk(DATA_PATH + '/miniImageNet/images_{}/'.format(subset)):
            if len(files) == 0:
                continue

            class_name = root.split('/')[-1]

            for f in files:
                progress_bar.update(1)
                images.append({
                    'subset': subset,
                    'class_name': class_name,
                    'filepath': os.path.join(root, f)
                })

        progress_bar.close()
        return images

## Reusable general callbacks for training, logging, and evaluation

The MAML implementation by oscarknagg (https://github.com/oscarknagg/few-shot) has a pretty interesting structure based on decomposing functionality into callbacks. I found it pretty cool and stuck with it, although it turned out to introduce quite a lot of infrastructure code.

In [0]:
from torch.utils.data import Sampler
from collections import OrderedDict, Iterable
from torch.optim import Optimizer
import pickle
import torch.nn.functional as F

def categorical_accuracy(y, y_pred):
    """Calculates categorical accuracy.
    # Arguments:
        y_pred: Prediction probabilities or logits of shape [batch_size, num_categories]
        y: Ground truth categories. Must have shape [batch_size,]
    """
    return torch.eq(y_pred.argmax(dim=-1), y).sum().item() / y_pred.shape[0]

NAMED_METRICS = {
    'categorical_accuracy': categorical_accuracy
}

In [0]:
import numpy as np
import torch
import warnings
import csv
import io

from torch.nn import Module
from typing import Union
def evaluate(model: Module, dataloader: DataLoader, prepare_batch: Callable, metrics: List[Union[str, Callable]],
             loss_fn: Callable = None, prefix: str = 'val_', suffix: str = ''):
    """Evaluate a model on one or more metrics on a particular dataset
    # Arguments
        model: Model to evaluate
        dataloader: Instance of torch.utils.data.DataLoader representing the dataset
        prepare_batch: Callable to perform any desired preprocessing
        metrics: List of metrics to evaluate the model with. Metrics must either be a named metric (see `metrics.py`) or
            a Callable that takes predictions and ground truth labels and returns a scalar value
        loss_fn: Loss function to calculate over the dataset
        prefix: Prefix to prepend to the name of each metric - used to identify the dataset. Defaults to 'val_' as
            it is typical to evaluate on a held-out validation dataset
        suffix: Suffix to append to the name of each metric.
    """
    logs = {}
    seen = 0
    totals = {m: 0 for m in metrics}
    if loss_fn is not None:
        totals['loss'] = 0
    model.eval()
    with torch.no_grad():
        for batch in dataloader:
            x, y = prepare_batch(batch)
            y_pred = model(x)

            seen += x.shape[0]

            if loss_fn is not None:
                totals['loss'] += loss_fn(y_pred, y).item() * x.shape[0]

            for m in metrics:
                if isinstance(m, str):
                    v = NAMED_METRICS[m](y, y_pred)
                else:
                    # Assume metric is a callable function
                    v = m(y, y_pred)

                totals[m] += v * x.shape[0]

    for m in ['loss'] + metrics:
        logs[prefix + m + suffix] = totals[m] / seen

    return logs

class CallbackList(object):
    """Container abstracting a list of callbacks.
    # Arguments
        callbacks: List of `Callback` instances.
    """
    def __init__(self, callbacks):
        self.callbacks = [c for c in callbacks]

    def set_params(self, params):
        for callback in self.callbacks:
            callback.set_params(params)

    def set_model(self, model):
        for callback in self.callbacks:
            callback.set_model(model)

    def on_epoch_begin(self, epoch, logs=None):
        """Called at the start of an epoch.
        # Arguments
            epoch: integer, index of epoch.
            logs: dictionary of logs.
        """
        logs = logs or {}
        for callback in self.callbacks:
            callback.on_epoch_begin(epoch, logs)

    def on_epoch_end(self, epoch, logs=None):
        """Called at the end of an epoch.
        # Arguments
            epoch: integer, index of epoch.
            logs: dictionary of logs.
        """
        logs = logs or {}
        for callback in self.callbacks:
            callback.on_epoch_end(epoch, logs)

    def on_batch_begin(self, batch, logs=None):
        """Called right before processing a batch.
        # Arguments
            batch: integer, index of batch within the current epoch.
            logs: dictionary of logs.
        """
        logs = logs or {}
        for callback in self.callbacks:
            callback.on_batch_begin(batch, logs)

    def on_batch_end(self, batch, logs=None):
        """Called at the end of a batch.
        # Arguments
            batch: integer, index of batch within the current epoch.
            logs: dictionary of logs.
        """
        logs = logs or {}
        for callback in self.callbacks:
            callback.on_batch_end(batch, logs)

    def on_train_begin(self, logs=None):
        """Called at the beginning of training.
        # Arguments
            logs: dictionary of logs.
        """
        logs = logs or {}
        for callback in self.callbacks:
            callback.on_train_begin(logs)

    def on_train_end(self, logs=None):
        """Called at the end of training.
        # Arguments
            logs: dictionary of logs.
        """
        logs = logs or {}
        for callback in self.callbacks:
            callback.on_train_end(logs)


class Callback(object):
    def __init__(self):
        self.model = None

    def set_params(self, params):
        self.params = params

    def set_model(self, model):
        self.model = model

    def on_epoch_begin(self, epoch, logs=None):
        pass

    def on_epoch_end(self, epoch, logs=None):
        pass

    def on_batch_begin(self, batch, logs=None):
        pass

    def on_batch_end(self, batch, logs=None):
        pass

    def on_train_begin(self, logs=None):
        pass

    def on_train_end(self, logs=None):
        pass


class DefaultCallback(Callback):
    """Records metrics over epochs by averaging over each batch.
    """
    def on_epoch_begin(self, batch, logs=None):
        self.seen = 0
        self.totals = {}
        self.metrics = ['loss'] + self.params['metrics']

    def on_batch_end(self, batch, logs=None):
        logs = logs or {}
        batch_size = logs.get('size', 1) or 1
        self.seen += batch_size

        for k, v in logs.items():
            if k in self.totals:
                self.totals[k] += v * batch_size
            else:
                self.totals[k] = v * batch_size

    def on_epoch_end(self, epoch, logs=None):
        if logs is not None:
            for k in self.metrics:
                if k in self.totals:
                    # Make value available to next callbacks.
                    logs[k] = self.totals[k] / self.seen


class ProgressBarLogger(Callback):
    """TQDM progress bar that displays the running average of loss and other metrics."""
    def __init__(self):
        super(ProgressBarLogger, self).__init__()

    def on_train_begin(self, logs=None):
        self.num_batches = self.params['num_batches']
        self.verbose = self.params['verbose']
        self.metrics = ['loss'] + self.params['metrics']

    def on_epoch_begin(self, epoch, logs=None):
        self.target = self.num_batches
        self.pbar = tqdm(total=self.target, desc='Epoch {}'.format(epoch))
        self.seen = 0

    def on_batch_begin(self, batch, logs=None):
        self.log_values = {}

    def on_batch_end(self, batch, logs=None):
        logs = logs or {}
        self.seen += 1

        for k in self.metrics:
            if k in logs:
                self.log_values[k] = logs[k]

        # Skip progbar update for the last batch;
        # will be handled by on_epoch_end.
        if self.verbose and self.seen < self.target:
            self.pbar.update(1)
            self.pbar.set_postfix(self.log_values)

    def on_epoch_end(self, epoch, logs=None):
        # Update log values
        self.log_values = {}
        for k in self.metrics:
            if k in logs:
                self.log_values[k] = logs[k]

        if self.verbose:
            self.pbar.update(1)
            self.pbar.set_postfix(self.log_values)

        self.pbar.close()


class CSVLogger(Callback):
    """Callback that streams epoch results to a csv file.
    Supports all values that can be represented as a string,
    including 1D iterables such as np.ndarray.
    # Arguments
        filename: filename of the csv file, e.g. 'run/log.csv'.
        separator: string used to separate elements in the csv file.
        append: True: append if file exists (useful for continuing
            training). False: overwrite existing file,
    """

    def __init__(self, filename, separator=',', append=False):
        self.sep = separator
        self.filename = filename
        self.append = append
        self.writer = None
        self.keys = None
        self.append_header = True
        self.file_flags = ''
        self._open_args = {'newline': '\n'}
        super(CSVLogger, self).__init__()

    def on_train_begin(self, logs=None):
        if self.append:
            if os.path.exists(self.filename):
                with open(self.filename, 'r' + self.file_flags) as f:
                    self.append_header = not bool(len(f.readline()))
            mode = 'a'
        else:
            mode = 'w'

        self.csv_file = io.open(self.filename,
                                mode + self.file_flags,
                                **self._open_args)

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}

        def handle_value(k):
            is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
            if isinstance(k, str):
                return k
            elif isinstance(k, Iterable) and not is_zero_dim_ndarray:
                return '"[%s]"' % (', '.join(map(str, k)))
            else:
                return k

        if self.keys is None:
            self.keys = sorted(logs.keys())

        if not self.writer:
            class CustomDialect(csv.excel):
                delimiter = self.sep
            fieldnames = ['epoch'] + self.keys
            self.writer = csv.DictWriter(self.csv_file,
                                         fieldnames=fieldnames,
                                         dialect=CustomDialect)
            if self.append_header:
                self.writer.writeheader()

        row_dict = OrderedDict({'epoch': epoch})
        row_dict.update((key, handle_value(logs[key])) for key in self.keys)
        self.writer.writerow(row_dict)
        self.csv_file.flush()

    def on_train_end(self, logs=None):
        self.csv_file.close()
        self.writer = None


class EvaluateMetrics(Callback):
    """Evaluates metrics on a dataset after every epoch.
    # Argments
        dataloader: torch.DataLoader of the dataset on which the model will be evaluated
        prefix: Prefix to prepend to the names of the metrics when they is logged. Defaults to 'val_' but can be changed
        if the model is to be evaluated on many datasets separately.
        suffix: Suffix to append to the names of the metrics when they is logged.
    """
    def __init__(self, dataloader, prefix='val_', suffix=''):
        super(EvaluateMetrics, self).__init__()
        self.dataloader = dataloader
        self.prefix = prefix
        self.suffix = suffix

    def on_train_begin(self, logs=None):
        self.metrics = self.params['metrics']
        self.prepare_batch = self.params['prepare_batch']
        self.loss_fn = self.params['loss_fn']

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        logs.update(
            evaluate(self.model, self.dataloader, self.prepare_batch, self.metrics, self.loss_fn, self.prefix, self.suffix)
        )


class ModelCheckpoint(Callback):
    """Save the model after every epoch.
    `filepath` can contain named formatting options, which will be filled the value of `epoch` and keys in `logs`
    (passed in `on_epoch_end`).
    For example: if `filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`, then the model checkpoints will be saved
    with the epoch number and the validation loss in the filename.
    # Arguments
        filepath: string, path to save the model file.
        monitor: quantity to monitor.
        verbose: verbosity mode, 0 or 1.
        save_best_only: if `save_best_only=True`,
            the latest best model according to
            the quantity monitored will not be overwritten.
        mode: one of {auto, min, max}.
            If `save_best_only=True`, the decision
            to overwrite the current save file is made
            based on either the maximization or the
            minimization of the monitored quantity. For `val_acc`,
            this should be `max`, for `val_loss` this should
            be `min`, etc. In `auto` mode, the direction is
            automatically inferred from the name of the monitored quantity.
        save_weights_only: if True, then only the model's weights will be
            saved (`model.save_weights(filepath)`), else the full model
            is saved (`model.save(filepath)`).
        period: Interval (number of epochs) between checkpoints.
    """

    def __init__(self, filepath, monitor='val_loss', verbose=0, save_best_only=False, mode='auto', period=1):
        super(ModelCheckpoint, self).__init__()
        self.monitor = monitor
        self.verbose = verbose
        self.filepath = filepath
        self.save_best_only = save_best_only
        self.period = period
        self.epochs_since_last_save = 0

        if mode not in ['auto', 'min', 'max']:
            raise ValueError('Mode must be one of (auto, min, max).')

        if mode == 'min':
            self.monitor_op = np.less
            self.best = np.Inf
        elif mode == 'max':
            self.monitor_op = np.greater
            self.best = -np.Inf
        else:
            if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
                self.monitor_op = np.greater
                self.best = -np.Inf
            else:
                self.monitor_op = np.less

        self.best = np.Inf

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        self.epochs_since_last_save += 1
        if self.epochs_since_last_save >= self.period:
            self.epochs_since_last_save = 0
            filepath = self.filepath.format(epoch=epoch + 1, **logs)
            if self.save_best_only:
                current = logs.get(self.monitor)
                if current is None:
                    warnings.warn('Can save best model only with %s available, '
                                  'skipping.' % (self.monitor), RuntimeWarning)
                else:
                    if self.monitor_op(current, self.best):
                        if self.verbose > 0:
                            print('\nEpoch %05d: %s improved from %0.5f to %0.5f,'
                                  ' saving model to %s'
                                  % (epoch + 1, self.monitor, self.best,
                                     current, filepath))
                        self.best = current
                        torch.save(self.model.state_dict(), filepath)
                    else:
                        if self.verbose > 0:
                            print('\nEpoch %05d: %s did not improve from %0.5f' %
                                  (epoch + 1, self.monitor, self.best))
            else:
                if self.verbose > 0:
                    print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
                torch.save(self.model.state_dict(), filepath)
        
class CosineAnnealingLRSchedulerCallback(Callback):
    def __init__(self, scheduler):
        super(CosineAnnealingLRSchedulerCallback, self).__init__()
        self.scheduler = scheduler

    def on_epoch_end(self, epoch, logs=None):
        self.scheduler.step(epoch)

In [0]:
class NShotTaskSampler(Sampler):
    def __init__(self,
                 dataset: torch.utils.data.Dataset,
                 episodes_per_epoch: int = None,
                 n: int = None,
                 k: int = None,
                 q: int = None,
                 num_tasks: int = 1,
                 fixed_tasks = None):
        """PyTorch Sampler subclass that generates batches of n-shot, k-way, q-query tasks.
        Each n-shot task contains a "support set" of `k` sets of `n` samples and a "query set" of `k` sets
        of `q` samples. The support set and the query set are all grouped into one Tensor such that the first n * k
        samples are from the support set while the remaining q * k samples are from the query set.
        The support and query sets are sampled such that they are disjoint i.e. do not contain overlapping samples.
        # Arguments
            dataset: Instance of torch.utils.data.Dataset from which to draw samples
            episodes_per_epoch: Arbitrary number of batches of n-shot tasks to generate in one epoch
            n_shot: int. Number of samples for each class in the n-shot classification tasks.
            k_way: int. Number of classes in the n-shot classification tasks.
            q_queries: int. Number query samples for each class in the n-shot classification tasks.
            num_tasks: Number of n-shot tasks to group into a single batch
            fixed_tasks: If this argument is specified this Sampler will always generate tasks from
                the specified classes
        """
        super(NShotTaskSampler, self).__init__(dataset)
        self.episodes_per_epoch = episodes_per_epoch
        self.dataset = dataset
        if num_tasks < 1:
            raise ValueError('num_tasks must be > 1.')

        self.num_tasks = num_tasks
        self.k = k
        self.n = n
        self.q = q
        self.fixed_tasks = fixed_tasks

        self.i_task = 0

    def __len__(self):
        return self.episodes_per_epoch

    def __iter__(self):
        for _ in range(self.episodes_per_epoch):
            batch = []

            for task in range(self.num_tasks):
                if self.fixed_tasks is None:
                    # Get random classes
                    episode_classes = np.random.choice(self.dataset.df['class_id'].unique(), size=self.k, replace=False)
                else:
                    # Loop through classes in fixed_tasks
                    episode_classes = self.fixed_tasks[self.i_task % len(self.fixed_tasks)]
                    self.i_task += 1

                df = self.dataset.df[self.dataset.df['class_id'].isin(episode_classes)]

                support_k = {k: None for k in episode_classes}
                for k in episode_classes:
                    # Select support examples
                    support = df[df['class_id'] == k].sample(self.n)
                    support_k[k] = support

                    for i, s in support.iterrows():
                        batch.append(s['id'])

                for k in episode_classes:
                    # Select queries that are not in the support set
                    query = df[(df['class_id'] == k) & (~df['id'].isin(support_k[k]['id']))].sample(self.q)
                    for i, q in query.iterrows():
                        batch.append(q['id'])

            yield np.stack(batch)


class EvaluateFewShot(Callback):
    """Evaluate a network on  an n-shot, k-way classification tasks after every epoch.
    # Arguments
        eval_fn: Callable to perform few-shot classification. 
        num_tasks: int. Number of n-shot classification tasks to evaluate the model with.
        n_shot: int. Number of samples for each class in the n-shot classification tasks.
        k_way: int. Number of classes in the n-shot classification tasks.
        q_queries: int. Number query samples for each class in the n-shot classification tasks.
        task_loader: Instance of NShotWrapper class
        prepare_batch: function. The preprocessing function to apply to samples from the dataset.
        prefix: str. Prefix to identify dataset.
    """

    def __init__(self,
                 shared_cnn,
                 eval_fn: Callable,
                 num_tasks: int,
                 n_shot: int,
                 k_way: int,
                 q_queries: int,
                 taskloader: torch.utils.data.DataLoader,
                 prepare_batch: Callable,
                 prefix: str = 'val_',
                 **kwargs):
        super(EvaluateFewShot, self).__init__()
        self.eval_fn = eval_fn
        self.num_tasks = num_tasks
        self.n_shot = n_shot
        self.k_way = k_way
        self.q_queries = q_queries
        self.taskloader = taskloader
        self.prepare_batch = prepare_batch
        self.prefix = prefix
        self.kwargs = kwargs
        self.metric_name = f'{self.prefix}{self.n_shot}-shot_{self.k_way}-way_acc'
        self.shared_cnn = shared_cnn

    def on_train_begin(self, logs=None):
        self.loss_fn = self.params['loss_fn']
        self.optimiser = self.params['optimiser']
        

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        seen = 0
        totals = {'loss': 0, self.metric_name: 0}
        baseline=None
        for batch_index, batch in enumerate(self.taskloader):
            x, y = self.prepare_batch(batch)

            loss, y_pred, baseline = self.eval_fn(self.model, self.shared_cnn, self.optimiser, self.loss_fn, x, y,  n_shot=self.n_shot,
                k_way=self.k_way,
                q_queries=self.q_queries,
                train=False, baseline=baseline,
                **self.kwargs)

            seen += y_pred.shape[0]

            totals['loss'] += loss.item() * y_pred.shape[0]
            totals[self.metric_name] += categorical_accuracy(y, y_pred) * y_pred.shape[0]

        logs[self.prefix + 'loss'] = totals['loss'] / seen
        logs[self.metric_name] = totals[self.metric_name] / seen


def prepare_nshot_task(n: int, k: int, q: int) -> Callable:
    """Typical n-shot task preprocessing.
    # Arguments
        n: Number of samples for each class in the n-shot classification task
        k: Number of classes in the n-shot classification task
        q: Number of query samples for each class in the n-shot classification task
    # Returns
        prepare_nshot_task_: A Callable that processes a few shot tasks with specified n, k and q
    """
    def prepare_nshot_task_(batch: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
        """Create 0-k label and move to GPU.
        TODO: Move to arbitrary device
        """
        x, y = batch
        x = x.double().cuda()
        # Create dummy 0-(num_classes - 1) label
        y = create_nshot_task_label(k, q).cuda()
        return x, y

    return prepare_nshot_task_


def create_nshot_task_label(k: int, q: int) -> torch.Tensor:
    """Creates an n-shot task label.
    Label has the structure:
        [0]*q + [1]*q + ... + [k-1]*q
    # TODO: Test this
    # Arguments
        k: Number of classes in the n-shot classification task
        q: Number of query samples for each class in the n-shot classification task
    # Returns
        y: Label vector for n-shot task of shape [q * k, ]
    """
    y = torch.arange(0, k, 1 / q).long()
    return y

In [0]:
def batch_metrics(model: Module, y_pred: torch.Tensor, y: torch.Tensor, metrics: List[Union[str, Callable]],
                  batch_logs: dict):
    """Calculates metrics for the current training batch
    # Arguments
        model: Model being fit
        y_pred: predictions for a particular batch
        y: labels for a particular batch
        batch_logs: Dictionary of logs for the current batch
    """
    model.eval()
    for m in metrics:
        if isinstance(m, str):
            batch_logs[m] = NAMED_METRICS[m](y, y_pred)
        else:
            # Assume metric is a callable function
            batch_logs = m(y, y_pred)

    return batch_logs


def fit(model: Module, shared_cnn, optimiser: Optimizer, loss_fn: Callable, epochs: int, dataloader: DataLoader,
        prepare_batch: Callable, metrics: List[Union[str, Callable]] = None, callbacks: List[Callback] = None,
        verbose: bool =True, fit_function: Callable = None, fit_function_kwargs: dict = {}):
    """Function to abstract away training loop.
    The benefit of this function is that allows training scripts to be much more readable and allows for easy re-use of
    common training functionality provided they are written as a subclass of voicemap.Callback (following the
    Keras API).
    # Arguments
        model: Model to be fitted.
        optimiser: Optimiser to calculate gradient step from loss
        loss_fn: Loss function to calculate between predictions and outputs
        epochs: Number of epochs of fitting to be performed
        dataloader: `torch.DataLoader` instance to fit the model to
        prepare_batch: Callable to perform any desired preprocessing
        metrics: Optional list of metrics to evaluate the model with
        callbacks: Additional functionality to incorporate into training such as logging metrics to csv, model
            checkpointing, learning rate scheduling etc...
        verbose: All print output is muted if this argument is `False`
        fit_function: Function for calculating gradients. 
        fit_function_kwargs: Keyword arguments to pass to `fit_function`
    """
    # Determine number of samples:
    num_batches = len(dataloader)
    batch_size = dataloader.batch_size

    callbacks = CallbackList([DefaultCallback(), ] + (callbacks or []) + [ProgressBarLogger(), ])
    callbacks.set_model(model)
    callbacks.set_params({
        'num_batches': num_batches,
        'batch_size': batch_size,
        'verbose': verbose,
        'metrics': (metrics or []),
        'prepare_batch': prepare_batch,
        'loss_fn': loss_fn,
        'optimiser': optimiser
    })

    if verbose:
        print('Begin training...')

    callbacks.on_train_begin()
    baseline=None

    for epoch in range(1, epochs+1):
        callbacks.on_epoch_begin(epoch)

        epoch_logs = {}
        for batch_index, batch in enumerate(dataloader):
            batch_logs = dict(batch=batch_index, size=(batch_size or 1))

            callbacks.on_batch_begin(batch_index, batch_logs)

            x, y = prepare_batch(batch)

            loss, y_pred, baseline = fit_function(model, shared_cnn, optimiser, loss_fn, x, y, baseline=baseline, **fit_function_kwargs)
            batch_logs['loss'] = loss.item()

            # Loops through all metrics
            batch_logs = batch_metrics(model, y_pred, y, metrics, batch_logs)

            callbacks.on_batch_end(batch_index, batch_logs)

        # Run on epoch end
        callbacks.on_epoch_end(epoch, epoch_logs)

    # Run on train end
    if verbose:
        print('Finished.')

    callbacks.on_train_end()

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

## Shared CNN of child architectures

In [0]:
'''
Parameters related to ENAS. Mostly keep the default values, 

except for making child network shallower (4 layers instead of 12) 
number of filters equal to 32,
disabling penalty for skip weights
'''

enas_args_search_for = 'macro'
enas_args_resume = ''
enas_args_cutout = 0
enas_args_fixed_arc = False
enas_args_child_num_layers = 4#12
enas_args_child_out_filters = 32
enas_args_child_grad_bound = 5.0
enas_args_child_l2_reg = 0.00025
enas_args_child_num_branches = 6
enas_args_child_keep_prob = 0.9
enas_args_child_lr_max = 0.05
enas_args_child_lr_min = 0.0005
enas_args_child_lr_T = 10
enas_args_controller_lstm_size = 64
enas_args_controller_lstm_num_layers = 1
enas_args_controller_entropy_weight = 0.0001
enas_args_controller_lr = 0.001
enas_args_controller_tanh_constant = 1.5
enas_args_controller_skip_target = 0.4
enas_args_controller_skip_weight = None#0.8
enas_args_controller_bl_dec = 0.99

The implementation of the shared network from ENAS is based on the following ideas: in each layer we have 6 possible operations (3x3, 3x3 separable, 5x5, 5x5 separable, avg pool max pool), i.e. 6 "branches". Only one of them from each layer will be chosen and executed for a given network run.

In [0]:
class ENASLayer(nn.Module):
    '''
    https://github.com/melodyguan/enas/blob/master/src/cifar10/general_child.py#L245
    '''
    def __init__(self, layer_id, in_planes, out_planes):
        super(ENASLayer, self).__init__()

        self.layer_id = layer_id
        self.in_planes = in_planes
        self.out_planes = out_planes

        self.branch_0 = ConvBranch(in_planes, out_planes, kernel_size=3)
        self.branch_1 = ConvBranch(in_planes, out_planes, kernel_size=3, separable=True)
        self.branch_2 = ConvBranch(in_planes, out_planes, kernel_size=5)
        self.branch_3 = ConvBranch(in_planes, out_planes, kernel_size=5, separable=True)
        self.branch_4 = PoolBranch(in_planes, out_planes, 'avg')
        self.branch_5 = PoolBranch(in_planes, out_planes, 'max')

        self.bn = nn.BatchNorm2d(out_planes, track_running_stats=False)

    def forward(self, x, prev_layers, sample_arc):
        layer_type = sample_arc[0]
        if self.layer_id > 0:
            skip_indices = sample_arc[1]
        else:
            skip_indices = []

        if layer_type == 0:
            out = self.branch_0(x)
        elif layer_type == 1:
            out = self.branch_1(x)
        elif layer_type == 2:
            out = self.branch_2(x)
        elif layer_type == 3:
            out = self.branch_3(x)
        elif layer_type == 4:
            out = self.branch_4(x)
        elif layer_type == 5:
            out = self.branch_5(x)
        else:
            raise ValueError("Unknown layer_type {}".format(layer_type))

        for i, skip in enumerate(skip_indices):
            if skip == 1:
                out += prev_layers[i]

        out = self.bn(out)
        return out


class FixedLayer(nn.Module):
    '''
    https://github.com/melodyguan/enas/blob/master/src/cifar10/general_child.py#L245
    '''
    def __init__(self, layer_id, in_planes, out_planes, sample_arc):
        super(FixedLayer, self).__init__()

        self.layer_id = layer_id
        self.in_planes = in_planes
        self.out_planes = out_planes
        self.sample_arc = sample_arc

        self.layer_type = sample_arc[0]
        if self.layer_id > 0:
            self.skip_indices = sample_arc[1]
        else:
            self.skip_indices = torch.zeros(1)

        if self.layer_type == 0:
            self.branch = ConvBranch(in_planes, out_planes, kernel_size=3)
        elif self.layer_type == 1:
            self.branch = ConvBranch(in_planes, out_planes, kernel_size=3, separable=True)
        elif self.layer_type == 2:
            self.branch = ConvBranch(in_planes, out_planes, kernel_size=5)
        elif self.layer_type == 3:
            self.branch = ConvBranch(in_planes, out_planes, kernel_size=5, separable=True)
        elif self.layer_type == 4:
            self.branch = PoolBranch(in_planes, out_planes, 'avg')
        elif self.layer_type == 5:
            self.branch = PoolBranch(in_planes, out_planes, 'max')
        else:
            raise ValueError("Unknown layer_type {}".format(self.layer_type))

        # Use concatentation instead of addition in the fixed layer for some reason
        in_planes = int((torch.sum(self.skip_indices).item() + 1) * in_planes)
        self.dim_reduc = nn.Sequential(
            nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False),
            nn.ReLU(),
            nn.BatchNorm2d(out_planes, track_running_stats=False))

    def forward(self, x, prev_layers, sample_arc):
        out = self.branch(x)

        res_layers = []
        for i, skip in enumerate(self.skip_indices):
            if skip == 1:
                res_layers.append(prev_layers[i])
        prev = res_layers + [out]
        prev = torch.cat(prev, dim=1)

        out = self.dim_reduc(prev)
        return out


class FactorizedReduction(nn.Module):
    '''
    Reduce both spatial dimensions (width and height) by a factor of 2, and 
    potentially to change the number of output filters
    https://github.com/melodyguan/enas/blob/master/src/cifar10/general_child.py#L129
    '''

    def __init__(self, in_planes, out_planes, stride=2):
        super(FactorizedReduction, self).__init__()

        assert out_planes % 2 == 0, ("Need even number of filters when using this factorized reduction.")

        self.in_planes = in_planes
        self.out_planes = out_planes
        self.stride = stride

        if stride == 1:
            self.fr = nn.Sequential(
                nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False),
                nn.BatchNorm2d(out_planes, track_running_stats=False))
        else:
            self.path1 = nn.Sequential(
                nn.AvgPool2d(1, stride=stride),
                nn.Conv2d(in_planes, out_planes // 2, kernel_size=1, bias=False))

            self.path2 = nn.Sequential(
                nn.AvgPool2d(1, stride=stride),
                nn.Conv2d(in_planes, out_planes // 2, kernel_size=1, bias=False))
            self.bn = nn.BatchNorm2d(out_planes, track_running_stats=False)

    def forward(self, x):
        if self.stride == 1:
            return self.fr(x)
        else:
            path1 = self.path1(x)

            # pad the right and the bottom, then crop to include those pixels
            path2 = F.pad(x, pad=(0, 1, 0, 1), mode='constant', value=0.)
            path2 = path2[:, :, 1:, 1:]
            path2 = self.path2(path2)

            out = torch.cat([path1, path2], dim=1)
            out = self.bn(out)
            return out


class SeparableConv(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, bias):
        super(SeparableConv, self).__init__()
        padding = (kernel_size - 1) // 2
        self.depthwise = nn.Conv2d(in_planes, in_planes, kernel_size=kernel_size,
                                   padding=padding, groups=in_planes, bias=bias)
        self.pointwise = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=bias)

    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out


class ConvBranch(nn.Module):
    '''
    https://github.com/melodyguan/enas/blob/master/src/cifar10/general_child.py#L483
    '''
    def __init__(self, in_planes, out_planes, kernel_size, separable=False):
        super(ConvBranch, self).__init__()
        assert kernel_size in [3, 5], "Kernel size must be either 3 or 5"

        self.in_planes = in_planes
        self.out_planes = out_planes
        self.kernel_size = kernel_size
        self.separable = separable

        self.inp_conv1 = nn.Sequential(
            nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_planes, track_running_stats=False),
            nn.ReLU())

        if separable:
            self.out_conv = nn.Sequential(
                SeparableConv(in_planes, out_planes, kernel_size=kernel_size, bias=False),
                nn.BatchNorm2d(out_planes, track_running_stats=False),
                nn.ReLU())
        else:
            padding = (kernel_size - 1) // 2
            self.out_conv = nn.Sequential(
                nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size,
                          padding=padding, bias=False),
                nn.BatchNorm2d(out_planes, track_running_stats=False),
                nn.ReLU())

    def forward(self, x):
        out = self.inp_conv1(x)
        out = self.out_conv(out)
        return out


class PoolBranch(nn.Module):
    '''
    https://github.com/melodyguan/enas/blob/master/src/cifar10/general_child.py#L546
    '''
    def __init__(self, in_planes, out_planes, avg_or_max):
        super(PoolBranch, self).__init__()

        self.in_planes = in_planes
        self.out_planes = out_planes
        self.avg_or_max = avg_or_max

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_planes, track_running_stats=False),
            nn.ReLU())

        if avg_or_max == 'avg':
            self.pool = torch.nn.AvgPool2d(kernel_size=3, stride=1, padding=1)
        elif avg_or_max == 'max':
            self.pool = torch.nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
        else:
            raise ValueError("Unknown pool {}".format(avg_or_max))

    def forward(self, x):
        out = self.conv1(x)
        out = self.pool(out)
        return out


class SharedCNN(nn.Module):
    def __init__(self,
                k_way=5,
                num_layers=enas_args_child_num_layers,
                num_branches=enas_args_child_num_branches,
                out_filters=enas_args_child_out_filters,
                keep_prob=enas_args_child_keep_prob,
                fixed_arc=None
                # maml_test_arc=None,
                ):
        super(SharedCNN, self).__init__()
        print(num_layers)

        # self.maml_test_arc = maml_test_arc

        self.num_layers = num_layers
        self.num_branches = num_branches
        self.out_filters = out_filters
        self.keep_prob = keep_prob
        self.fixed_arc = fixed_arc

        pool_distance = self.num_layers // 3
        self.pool_layers = [pool_distance - 1, 2 * pool_distance - 1]

        self.stem_conv = nn.Sequential(
            nn.Conv2d(3, out_filters, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_filters, track_running_stats=False))

        self.layers = nn.ModuleList([])
        self.pooled_layers = nn.ModuleList([])

        for layer_id in range(self.num_layers):
            if self.fixed_arc is None:
                layer = ENASLayer(layer_id, self.out_filters, self.out_filters)
            else:
                layer = FixedLayer(layer_id, self.out_filters, self.out_filters, self.fixed_arc[str(layer_id)])
            self.layers.append(layer)

            if layer_id in self.pool_layers:
                for i in range(len(self.layers)):
                    if self.fixed_arc is None:
                        self.pooled_layers.append(FactorizedReduction(self.out_filters, self.out_filters))
                    else:
                        self.pooled_layers.append(FactorizedReduction(self.out_filters, self.out_filters * 2))
                if self.fixed_arc is not None:
                    self.out_filters *= 2

        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(p=1. - self.keep_prob)
        self.classify = nn.Linear(self.out_filters, k_way)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu')

    def forward(self, x, sample_arc):
        x = self.stem_conv(x)

        prev_layers = []
        pool_count = 0
        for layer_id in range(self.num_layers):
            x = self.layers[layer_id](x, prev_layers, sample_arc[str(layer_id)])
            prev_layers.append(x)
            if layer_id in self.pool_layers:
                for i, prev_layer in enumerate(prev_layers):
                    # Go through the outputs of all previous layers and downsample them
                    prev_layers[i] = self.pooled_layers[pool_count](prev_layer)
                    pool_count += 1
                x = prev_layers[-1]

        x = self.global_avg_pool(x)
        x = x.view(x.shape[0], -1)
        # x = self.dropout(x)
        out = self.classify(x)

        return out

## LSTM controller network

In [0]:
class Controller(nn.Module):
    '''
    https://github.com/melodyguan/enas/blob/master/src/cifar10/general_controller.py
    '''
    def __init__(self,
                 search_for="macro",
                 search_whole_channels=True,
                 num_layers=enas_args_child_num_layers,
                 num_branches=6,
                 out_filters=32,
                 lstm_size=32,
                 lstm_num_layers=2,
                 tanh_constant=1.5,
                 temperature=None,
                 skip_target=0.4,
                 skip_weight=0.8):
        super(Controller, self).__init__()

        self.search_for = search_for
        self.search_whole_channels = search_whole_channels
        self.num_layers = num_layers
        self.num_branches = num_branches
        self.out_filters = out_filters

        self.lstm_size = lstm_size
        self.lstm_num_layers = lstm_num_layers
        self.tanh_constant = tanh_constant
        self.temperature = temperature

        self.skip_target = skip_target
        self.skip_weight = skip_weight

        self._create_params()

    def _create_params(self):
        '''
        https://github.com/melodyguan/enas/blob/master/src/cifar10/general_controller.py#L83
        '''
        self.w_lstm = nn.LSTM(input_size=self.lstm_size,
                              hidden_size=self.lstm_size,
                              num_layers=self.lstm_num_layers)

        self.g_emb = nn.Embedding(1, self.lstm_size)  # Learn the starting input

        self.w_emb = nn.Embedding(self.num_branches, self.lstm_size) # transforming network output to input for the next timestep
        self.w_soft = nn.Linear(self.lstm_size, self.num_branches, bias=False)

        self.w_attn_1 = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
        self.w_attn_2 = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
        self.v_attn = nn.Linear(self.lstm_size, 1, bias=False)

        self._reset_params()

    def _reset_params(self):
        for m in self.modules():
            if isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
                nn.init.uniform_(m.weight, -0.1, 0.1)

        nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1)
        nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1)

    def forward(self):
        '''
        https://github.com/melodyguan/enas/blob/master/src/cifar10/general_controller.py#L126
        '''
        h0 = None  # setting h0 to None will initialize LSTM state with 0s

        anchors = []
        anchors_w_1 = []

        arc_seq = {}
        entropys = []
        log_probs = []
        skip_count = []
        skip_penaltys = []

        inputs = self.g_emb.weight
        skip_targets = torch.tensor([1.0 - self.skip_target, self.skip_target]).cuda()

        for layer_id in range(self.num_layers):
            #firstly, sample operation to use
            inputs = inputs.unsqueeze(0)
            output, hn = self.w_lstm(inputs, h0)
            output = output.squeeze(0)
            h0 = hn

            logit = self.w_soft(output)
            if self.temperature is not None:
                logit /= self.temperature
            if self.tanh_constant is not None:
                logit = self.tanh_constant * torch.tanh(logit)

            branch_id_dist = Categorical(logits=logit)
            branch_id = branch_id_dist.sample()

            arc_seq[str(layer_id)] = [branch_id]

            log_prob = branch_id_dist.log_prob(branch_id)
            log_probs.append(log_prob.view(-1))
            entropy = branch_id_dist.entropy()
            entropys.append(entropy.view(-1))

            inputs = self.w_emb(branch_id)
            inputs = inputs.unsqueeze(0)

            #secondly, sample previous node to connect to
            output, hn = self.w_lstm(inputs, h0)
            output = output.squeeze(0)

            if layer_id > 0:
                query = torch.cat(anchors_w_1, dim=0)
                query = torch.tanh(query + self.w_attn_2(output))
                query = self.v_attn(query)
                logit = torch.cat([-query, query], dim=1)
                if self.temperature is not None:
                    logit /= self.temperature
                if self.tanh_constant is not None:
                    logit = self.tanh_constant * torch.tanh(logit)

                skip_dist = Categorical(logits=logit)
                skip = skip_dist.sample()
                skip = skip.view(layer_id)

                arc_seq[str(layer_id)].append(skip)

                skip_prob = torch.sigmoid(logit)
                kl = skip_prob * torch.log(skip_prob / skip_targets)
                kl = torch.sum(kl)
                skip_penaltys.append(kl)

                log_prob = skip_dist.log_prob(skip)
                log_prob = torch.sum(log_prob)
                log_probs.append(log_prob.view(-1))

                entropy = skip_dist.entropy()
                entropy = torch.sum(entropy)
                entropys.append(entropy.view(-1))

                # Calculate average hidden state of all nodes that got skips
                # and use it as input for next step
                skip = skip.type(torch.float)
                skip = skip.view(1, layer_id)
                skip_count.append(torch.sum(skip))
                inputs = torch.matmul(skip, torch.cat(anchors, dim=0))
                inputs /= (1.0 + torch.sum(skip))
            else:
                inputs = self.g_emb.weight

            anchors.append(output)
            anchors_w_1.append(self.w_attn_1(output))

        self.sample_arc = arc_seq

        entropys = torch.cat(entropys)
        self.sample_entropy = torch.sum(entropys)

        log_probs = torch.cat(log_probs)
        self.sample_log_prob = torch.sum(log_probs)

        skip_count = torch.stack(skip_count)
        self.skip_count = torch.sum(skip_count)

        skip_penaltys = torch.stack(skip_penaltys)
        self.skip_penaltys = torch.mean(skip_penaltys)

##MAMLENAS: meta-gradient step

In [0]:
def meta_gradient_step_mamlenas(controller: Module,
                       shared_cnn: Module,
                       controller_optimiser: Optimizer,
                       loss_fn: Callable,
                       x: torch.Tensor,
                       y: torch.Tensor,
                       n_shot: int,
                       k_way: int,
                       q_queries: int,
                       order: int,
                       inner_train_steps: int,
                       inner_lr: float,
                       train: bool,
                       device: Union[str, torch.device],
                       shared_cnn_optimizer, enas_args_child_grad_bound, 
                       enas_args_controller_skip_weight, enas_args_controller_entropy_weight,
                       enas_args_controller_bl_dec, baseline):
    """
    Perform a gradient step on a MAMLENAS meta-learner.
    # Arguments
        x: Input samples for all few shot tasks
        y: Input labels of all few shot tasks
    """
    data_shape = x.shape[2:]
    create_graph = False # because order == 1

    task_gradients = []
    task_losses = []
    task_predictions = []
    shared_cnn_steps = 1
    baseline_meter = AverageMeter()

    '''
    Training shared_cnn on all meta-batches before starting the proper MAML: didn't help.
    '''

    # controller.eval()
    # shared_cnn.train()
    # for meta_batch in x1:
    #     x_task_train = meta_batch[:n_shot * k_way]
    #     y = create_nshot_task_label(k_way, n_shot).to(device)
    #     for i in range(shared_cnn_steps):
    #         with torch.no_grad():
    #             # sample architecture:
    #             controller()
    #         sample_arc = controller.sample_arc
    #         shared_cnn.zero_grad()
    #         shared_cnn_optimizer.zero_grad()
    #         logits = shared_cnn(x_task_train, sample_arc)
    #         loss = loss_fn(logits, y)
    #         loss.backward()
    #         # grad_norm = torch.nn.utils.clip_grad_norm_(shared_cnn.parameters(), enas_args_child_grad_bound)
    #         shared_cnn_optimizer.step()


    for task_batch in x:
        # By construction x is a 5D tensor of shape: (meta_batch_size, n*k*2 + q*k*2, channels, width, height)
        # Hence when we iterate over the first  dimension we are iterating through the meta batches

        #these are train_train and train_val:
        x_task_train = task_batch[:n_shot * k_way]
        x_task_val = task_batch[n_shot * k_way * 2: n_shot * k_way * 2 + q_queries * k_way]

        # Clone a controller to fast_controller using the current meta model weights
        fast_controller = pickle.loads(pickle.dumps(controller)).to(device, dtype=torch.float)
        fast_opt = torch.optim.SGD(fast_controller.parameters(), lr=inner_lr)

        # Take `inner_train_steps` of the fast_controller (these are MAML steps; for RL tasks they use 1, and so do I)
        for inner_batch in range(inner_train_steps):
            # train shared_cnn
            y = create_nshot_task_label(k_way, n_shot).to(device)
            fast_controller.eval()
            shared_cnn.train()
            #IMPORTANT difference to the reference ENAS: only sample child architecture once and train it for multiple steps
            with torch.no_grad():
                # sample architecture:
                fast_controller()
            for i in range(shared_cnn_steps):
                sample_arc = fast_controller.sample_arc
                shared_cnn.zero_grad()
                logits = shared_cnn(x_task_train, sample_arc)
                loss = loss_fn(logits, y)
                loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(shared_cnn.parameters(), enas_args_child_grad_bound)
                shared_cnn_optimizer.step()
            # train fast_controller (just 1 step)
            y = create_nshot_task_label(k_way, q_queries).to(device)
            fast_controller.train()
            shared_cnn.eval()
            fast_opt.zero_grad()

            fast_controller()  # perform forward pass to generate a new architecture
            sample_arc = fast_controller.sample_arc

            with torch.no_grad():
                pred = shared_cnn(x_task_val, sample_arc)

            val_acc = torch.mean((torch.max(pred, 1)[1] == y).type(torch.float))
            reward = val_acc.clone().detach()
            reward += enas_args_controller_entropy_weight * fast_controller.sample_entropy

            if baseline is None:
                baseline = val_acc
            else:
                baseline += (1 - enas_args_controller_bl_dec) * (reward - baseline)
                # detach to make sure that gradients are not backpropped through the baseline
                baseline = baseline.detach()

            baseline_meter.update(baseline.item())

            loss = -1 * fast_controller.sample_log_prob * (reward - baseline)

            if enas_args_controller_skip_weight is not None:
                loss += enas_args_controller_skip_weight * fast_controller.skip_penaltys

            #ignore gradient aggregation from reference ENAS

            loss.backward()
            fast_opt.step()
            fast_controller.zero_grad()

        # Evaluate the fast_contoller that has taken some steps

        #these are val_train and val_val
        x_task_train = task_batch[n_shot * k_way:n_shot * k_way * 2]
        x_task_val = task_batch[n_shot * k_way * 2 + q_queries * k_way:]
        
        y = create_nshot_task_label(k_way, n_shot).to(device)
        fast_controller.eval()
        shared_cnn.train()

        with torch.no_grad():
            fast_controller()
        for i in range(shared_cnn_steps):
            sample_arc = fast_controller.sample_arc
            shared_cnn_optimizer.zero_grad()
            logits = shared_cnn(x_task_train, sample_arc)
            loss = loss_fn(logits, y)
            loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(shared_cnn.parameters(), enas_args_child_grad_bound)
            shared_cnn_optimizer.step()

        # train fast_controller (just 1 step)
        y = create_nshot_task_label(k_way, q_queries).to(device)
        fast_controller.train()
        shared_cnn.eval()
        fast_opt.zero_grad()

        fast_controller()  # perform forward pass to generate a new architecture
        sample_arc = fast_controller.sample_arc

        with torch.no_grad():
            pred = shared_cnn(x_task_val, sample_arc)
        val_acc = torch.mean((torch.max(pred, 1)[1] == y).type(torch.float))

        reward = val_acc.clone().detach()
        reward += enas_args_controller_entropy_weight * fast_controller.sample_entropy
        baseline += (1 - enas_args_controller_bl_dec) * (reward - baseline)
        # detach to make sure that gradients are not backpropped through the baseline
        baseline = baseline.detach()
            
        baseline_meter.update(baseline.item())

        loss = -1 * fast_controller.sample_log_prob * (reward - baseline)

        if enas_args_controller_skip_weight is not None:
            loss += enas_args_controller_skip_weight * fast_controller.skip_penaltys

        loss.backward(retain_graph=True)

        y_pred = pred.softmax(dim=1)
        task_predictions.append(y_pred)

        # Accumulate losses and gradients
        task_losses.append(loss)

        grads = torch.autograd.grad(loss, fast_controller.parameters(), create_graph=True)
        named_grads = {name: g for ((name, _), g) in zip(fast_controller.named_parameters(), grads)}
        # named_grads = {name: parameter.grad.clone().detach() for (name, parameter) in fast_model.named_parameters()}
        task_gradients.append(named_grads)
                
    # Finally update the controller
    assert order == 1
    if train:
        sum_task_gradients = {k: torch.stack([grad[k] for grad in task_gradients]).mean(dim=0)
                                for k in task_gradients[0].keys()}
        controller.train()
        controller_optimiser.zero_grad()

        for name, param in controller.named_parameters():
            param.grad = sum_task_gradients[name]

        controller_optimiser.step()

    return torch.stack(task_losses).mean(), torch.cat(task_predictions), baseline

## Running the experiment

In [0]:
from torch.optim.lr_scheduler import CosineAnnealingLR
assert torch.cuda.is_available()
device = torch.device('cuda')
torch.backends.cudnn.benchmark = True


##############
# Parameters #
##############

args_dataset = 'MiniImageNet'
args_n = 5
args_k = 2
args_q = 5
args_inner_train_steps = 1 
args_inner_val_steps = 1
args_inner_lr = enas_args_controller_lr # 'alpha' from the original paper
args_meta_lr = enas_args_controller_lr # 'beta' from the original paper
args_meta_batch_size = 5
args_order = 1 
args_epochs = 100 
args_epoch_len = 50
args_eval_batches = 10

dataset_class = MiniImageNet
fc_layer_size = None#1600 # I ignore this parameter
num_input_channels = 3

param_str = f'{args_dataset}_order={args_order}_n={args_n}_k={args_k}_metabatch={args_meta_batch_size}_' \
            f'train_steps={args_inner_train_steps}_val_steps={args_inner_val_steps}'
print(param_str)


###################
# Create datasets #

# n and q get multiplied by 2 because we need 2 pairs of (train + val) 
###################
background = dataset_class('background')
background_taskloader = DataLoader(
    background,
    batch_sampler=NShotTaskSampler(background, args_epoch_len, n=args_n * 2, k=args_k, q=args_q * 2,
                                   num_tasks=args_meta_batch_size),
    num_workers=8
)
evaluation = dataset_class('evaluation')
evaluation_taskloader = DataLoader(
    evaluation,
    batch_sampler=NShotTaskSampler(evaluation, args_eval_batches, n=args_n * 2, k=args_k, q=args_q * 2,
                                   num_tasks=args_meta_batch_size),
    num_workers=8
)


############
# Training #
############
print(f'Training MAML on {args_dataset}...')
shared_cnn = SharedCNN(args_k).to(device, dtype=torch.float32)
loss_fn = nn.CrossEntropyLoss().to(device)
controller = Controller(search_for=enas_args_search_for,
                    search_whole_channels=True,
                    num_layers=enas_args_child_num_layers,
                    num_branches=enas_args_child_num_branches,
                    out_filters=enas_args_child_out_filters,
                    lstm_size=enas_args_controller_lstm_size,
                    lstm_num_layers=enas_args_controller_lstm_num_layers,
                    tanh_constant=enas_args_controller_tanh_constant,
                    temperature=None,
                    skip_target=enas_args_controller_skip_target,
                    skip_weight=enas_args_controller_skip_weight).to(device, dtype=torch.float32)

controller_optimizer = torch.optim.Adam(params=controller.parameters(),
                                    lr=enas_args_controller_lr,
                                    betas=(0.0, 0.999),
                                    eps=1e-3)

shared_cnn_optimizer = torch.optim.AdamW(params=shared_cnn.parameters(),
                            lr=enas_args_child_lr_max, weight_decay=enas_args_child_l2_reg)

shared_cnn_scheduler = CosineAnnealingLR(optimizer=shared_cnn_optimizer,
                                     T_max=enas_args_child_lr_T,
                                     eta_min=enas_args_child_lr_min)

def prepare_meta_batch(n, k, q, meta_batch_size):
    def prepare_meta_batch_(batch):
        x, y = batch
        # Reshape to `meta_batch_size` number of tasks. Each task contains
        # n*k support samples to train the fast model on and q*k query samples to
        # evaluate the fast model on and generate meta-gradients
        x = x.reshape(meta_batch_size, n*k*2 + q*k*2, num_input_channels, x.shape[-2], x.shape[-1])
        # Move to device
        x = x.float().to(device)
        # Create label
        y = create_nshot_task_label(k, q).cuda().repeat(meta_batch_size)
        return x, y

    return prepare_meta_batch_

#callbacks to run after avery training epoch
callbacks = [
    EvaluateFewShot(
        shared_cnn=shared_cnn,
        eval_fn=meta_gradient_step_mamlenas,
        num_tasks=args_eval_batches,
        n_shot=args_n,
        k_way=args_k,
        q_queries=args_q,
        taskloader=evaluation_taskloader,
        prepare_batch=prepare_meta_batch(args_n, args_k, args_q, args_meta_batch_size),
        # MAML kwargs
        inner_train_steps=args_inner_val_steps,
        inner_lr=args_inner_lr,
        device=device,
        order=args_order,
        # MAMLENAS kwargs 
        shared_cnn_optimizer=shared_cnn_optimizer, 
        enas_args_child_grad_bound=enas_args_child_grad_bound, 
        enas_args_controller_skip_weight=enas_args_controller_skip_weight, 
        enas_args_controller_entropy_weight=enas_args_controller_entropy_weight,
        enas_args_controller_bl_dec=enas_args_controller_bl_dec
    ),
    ModelCheckpoint(
        filepath=MODELS_PATH + f'{param_str}.pth',
        monitor=f'val_{args_n}-shot_{args_k}-way_acc'
    ),
    # CosineAnnealingLRSchedulerCallback(shared_cnn_scheduler),
    CSVLogger(LOG_PATH + f'{param_str}.csv'),
]

fit(
    controller,
    shared_cnn,
    controller_optimizer,
    loss_fn,
    epochs=args_epochs,
    dataloader=background_taskloader,
    prepare_batch=prepare_meta_batch(args_n, args_k, args_q, args_meta_batch_size),
    callbacks=callbacks,
    metrics=['categorical_accuracy'],
    fit_function=meta_gradient_step_mamlenas,
    fit_function_kwargs={'n_shot': args_n, 'k_way': args_k, 'q_queries': args_q,
                         'train': True,
                         'order': args_order, 'device': device, 'inner_train_steps': args_inner_train_steps,
                         'inner_lr': args_inner_lr, 'shared_cnn_optimizer': shared_cnn_optimizer, 
                        'enas_args_child_grad_bound': enas_args_child_grad_bound, 
                        'enas_args_controller_skip_weight': enas_args_controller_skip_weight, 
                        'enas_args_controller_entropy_weight': enas_args_controller_entropy_weight,
                         'enas_args_controller_bl_dec': enas_args_controller_bl_dec},
)

MiniImageNet_order=1_n=5_k=2_metabatch=5_train_steps=1_val_steps=1
Indexing background...


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Indexing evaluation...


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Training MAML on MiniImageNet...
4
Begin training...


HBox(children=(IntProgress(value=0, description='Epoch 1', max=50, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Epoch 2', max=50, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Epoch 3', max=50, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Epoch 4', max=50, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Epoch 5', max=50, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Epoch 6', max=50, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Epoch 7', max=50, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Epoch 8', max=50, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Epoch 9', max=50, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Epoch 10', max=50, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 11', max=50, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 12', max=50, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 13', max=50, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 14', max=50, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 15', max=50, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 16', max=50, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 17', max=50, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 18', max=50, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 19', max=50, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 20', max=50, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 21', max=50, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 22', max=50, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 23', max=50, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 24', max=50, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 25', max=50, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 26', max=50, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 27', max=50, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 28', max=50, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 29', max=50, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 30', max=50, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 31', max=50, style=ProgressStyle(description_width='ini…