# Few-Shot learning with Reptile

**Author:** [ADMoreau](https://github.com/ADMoreau)<br>
**Date created:** 2020/05/21<br>
**Last modified:** 2020/05/30<br>
**Description:** Few-shot classification of the Omniglot dataset using Reptile.

## Introduction

The [Reptile](https://arxiv.org/abs/1803.02999) algorithm was developed by OpenAI to
perform model agnostic meta-learning. Specifically, this algorithm was designed to
quickly learn to perform new tasks with minimal training (few-shot learning).
The algorithm works by performing Stochastic Gradient Descent using the
difference between weights trained on a mini-batch of never before seen data and the
model weights prior to training over a fixed number of meta-iterations.


Code Sources 

Few-shot learning with reptile: https://keras.io/examples/vision/rept...
On First-Order Meta Learning: https://arxiv.org/pdf/1803.02999.pdf
MAML: https://arxiv.org/pdf/1703.03400.pdf
Generative Teaching Networks: https://arxiv.org/pdf/1912.07768.pdf
Teaching with Commentaries: https://arxiv.org/pdf/2011.03037.pdf
Meta Pseudo Labels: https://arxiv.org/pdf/2003.10580.pdf

*El código también mantiene sus comentarios originales en ingles

In [None]:



import matplotlib.pyplot as plt
import numpy as np
import random
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_datasets as tfds


## Define the Hyperparameters


In [None]:

# These hyperparameters define de numer of meta iterations

learning_rate = 0.003
meta_step_size = 0.25

inner_batch_size = 25
eval_batch_size = 25

meta_iters = 2000
eval_iters = 5

# Steps inside sampled tasks  
inner_iters = 4


eval_interval = 1
# Hyperparameter for the number of inner sampling 
train_shots = 20
shots = 5
classes = 5


## Prepare the data

The [Omniglot dataset](https://github.com/brendenlake/omniglot/) is a dataset of 1,623
characters taken from 50 different alphabets, with 20 examples for each character.
The 20 samples for each character were drawn online via Amazon's Mechanical Turk. For the
few-shot learning task, `k` samples (or "shots") are drawn randomly from `n` randomly-chosen
classes. These `n` numerical values are used to create a new set of temporary labels to use
to test the model's ability to learn a new task given few examples. In other words, if you
are training on 5 classes, your new class labels will be either 0, 1, 2, 3, or 4.
Omniglot is a great dataset for this task since there are many different classes to draw
from, with a reasonable number of samples for each class.


In [None]:

# Define our data loader. So we have a class data set that´s going to be sampling tasks from the omniglot data set 
class Dataset:
# This class will facilitate the creation of a few-shot dataset
    # from the Omniglot dataset that can be sampled from quickly while also
    # allowing to create new labels at the same time.


# We have our initialization
    def __init__(self, training):
        # We set the splits
        split = "train" if training else "test"
        # We load the omniglot dataset from our tensorflow datasets library
        ds = tfds.load("omniglot", split=split, as_supervised=True, shuffle_files=False)
        # Iterate over the dataset to get each individual image and its class,
        # and put that data into a dictionary.
        self.data = {}

# Extraction, where we do the convert the image to the float32 data points
        def extraction(image, label):
            # This function will shrink the Omniglot images to the desired size,
            # scale pixel values and convert the RGB image to grayscale
            image = tf.image.convert_image_dtype(image, tf.float32)

            # Move it into grayscale
            image = tf.image.rgb_to_grayscale(image)

            # Resize
            image = tf.image.resize(image, [28, 28])
            return image, label

# We use ds.map this extraction function. This is how we loop through the data and then map it by applying 
# this function to each instance in the data set
        for image, label in ds.map(extraction):

          # We are going to convert our image to a numpy array
            image = image.numpy()
          # The label to a string in the numpy array because we are not using the original label from omniglot
            label = str(label.numpy())

            # we are using this to index the data set
            if label not in self.data:
                self.data[label] = []
            self.data[label].append(image)
        self.labels = list(self.data.keys())


# Now we are in the meat of our dataset loading class. 

# Sampling this new mini data set as we sample a random sets of the characters and then assign these new labels
# based on the arbitrary configuration of these new classes in a new classification problem
    def get_mini_dataset(
        self, batch_size, repetitions, shots, num_classes, split=False
    ):

# We assign this placeholder numpy array for the labels and for the images
        temp_labels = np.zeros(shape=(num_classes * shots))
        temp_images = np.zeros(shape=(num_classes * shots, 28, 28, 1))
        if split:
            test_labels = np.zeros(shape=(num_classes))
            test_images = np.zeros(shape=(num_classes, 28, 28, 1))



        # Get a random subset of labels from the entire label set.
        label_subset = random.choices(self.labels, k=num_classes)
# We loop through the class index class object in enumerate label subsets so this is our label subset we are looping
# through the five different randomly selected labels
        for class_idx, class_obj in enumerate(label_subset):




            # Use enumerated index value as a temporary label for mini-batch in
            # few shot learning.
            temp_labels[class_idx * shots : (class_idx + 1) * shots] = class_idx
            # If creating a split dataset for testing, select an extra sample from each
            # label to create the test dataset.



# This is how we overwrite the data with out selection of the new images for our selected class index 

            if split:
                test_labels[class_idx] = class_idx
                images_to_split = random.choices(
                    self.data[label_subset[class_idx]], k=shots + 1
                )
                test_images[class_idx] = images_to_split[-1]
                temp_images[
                    class_idx * shots : (class_idx + 1) * shots
                ] = images_to_split[:-1]
# If we dont do the split, then we just overwrite our temporary images by just indexing into the original set with our
# index of the label subset as we are looping through the class index 
            else:
                # For each index in the randomly selected label_subset, sample the
                # necessary number of images.
                temp_images[
                    class_idx * shots : (class_idx + 1) * shots
                ] = random.choices(self.data[label_subset[class_idx]], k=shots)



# Overall this is going to return these nested data sets that we are going to be using to fit these classification 
# problems and update our global parameters by looking at the direction each of these new subtasks pull it towards

# We take this new nested data set and we cast our labels to integers for assigning the zero, one, two, three and
# four labels
        dataset = tf.data.Dataset.from_tensor_slices(
            (temp_images.astype(np.float32), temp_labels.astype(np.int32))
        )
        dataset = dataset.shuffle(100).batch(batch_size).repeat(repetitions)
        if split:
            return dataset, test_images, test_labels
        return dataset


import urllib3

urllib3.disable_warnings()  # Disable SSL warnings that may happen during download.
train_dataset = Dataset(training=True)
test_dataset = Dataset(training=False)


## Visualize some examples from the dataset


In [None]:
# Visualizing some images from the omniglot data set

# Define a 5x5 matrix of visualizing these images
_, axarr = plt.subplots(nrows=5, ncols=5, figsize=(20, 20))

# We overwrite each position and sampled the next set from the trained data set loader
sample_keys = list(train_dataset.data.keys())

for a in range(5):
    for b in range(5):
        temp_image = train_dataset.data[sample_keys[a]][b]
        temp_image = np.stack((temp_image[:, :, 0],) * 3, axis=2)
        temp_image *= 255
        temp_image = np.clip(temp_image, 0, 255).astype("uint8")
        if b == 2:
            axarr[a, b].set_title("Class : " + sample_keys[a])
        axarr[a, b].imshow(temp_image, cmap="gray")
        axarr[a, b].xaxis.set_visible(False)
        axarr[a, b].yaxis.set_visible(False)
plt.show()


## Build the model


In [None]:
# We are stacking these convolution batch norm ReLU blocks and we do this four times and then 
def conv_bn(x):
    x = layers.Conv2D(filters=64, kernel_size=3, strides=2, padding="same")(x)
    x = layers.BatchNormalization()(x)
    return layers.ReLU()(x)


inputs = layers.Input(shape=(28, 28, 1))
x = conv_bn(inputs)
x = conv_bn(x)
x = conv_bn(x)
x = conv_bn(x)

# we flatten it into a vector
x = layers.Flatten()(x)

# connect it to our five class classification problem
outputs = layers.Dense(classes, activation="softmax")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
model.compile()
optimizer = keras.optimizers.SGD(learning_rate=learning_rate)


## Train the model


In [None]:
# We define these two arrays to store the losses 
training = []
testing = []

# We are looping through the number of meta iterations
for meta_iter in range(meta_iters):
    frac_done = meta_iter / meta_iters
    # We derived this meta step size 
    cur_meta_step_size = (1 - frac_done) * meta_step_size
    # Temporarily save the weights from the model.
    old_vars = model.get_weights()
    # Get a sample from the full dataset.
    # One mini data set per update
    mini_dataset = train_dataset.get_mini_dataset(
        inner_batch_size, inner_iters, train_shots, classes
    )

    for images, labels in mini_dataset:
        with tf.GradientTape() as tape:
            preds = model(images)
            loss = keras.losses.sparse_categorical_crossentropy(labels, preds)
        grads = tape.gradient(loss, model.trainable_weights)
        optimizer.apply_gradients(zip(grads, model.trainable_weights))
    new_vars = model.get_weights()
    # Perform SGD for the meta step.
# When we are doing this global update as we are updating our global parameters and direction of the new parameters
# we are going to have this update rule 

# gradient descent where its the old parameters plus the gradient times the learning rate
    for var in range(len(new_vars)):
        new_vars[var] = old_vars[var] + (
            (new_vars[var] - old_vars[var]) * cur_meta_step_size
        )

    # Once we have finished our inner update we are now going to the evaluation step

    # After the meta-learning step, reload the newly-trained weights into the model.
    model.set_weights(new_vars)
    # Evaluation loop
    if meta_iter % eval_interval == 0:
        accuracies = []
        for dataset in (train_dataset, test_dataset):
            # Sample a mini dataset from the full dataset.
            train_set, test_images, test_labels = dataset.get_mini_dataset(
                eval_batch_size, eval_iters, shots, classes, split=True
            )
            old_vars = model.get_weights()
            # Train on the samples and get the resulting accuracies.
            for images, labels in train_set:
                with tf.GradientTape() as tape:
                    preds = model(images)
                    loss = keras.losses.sparse_categorical_crossentropy(labels, preds)
                grads = tape.gradient(loss, model.trainable_weights)
                optimizer.apply_gradients(zip(grads, model.trainable_weights))
# Task predictions
            test_preds = model.predict(test_images)
            test_preds = tf.argmax(test_preds).numpy()
            num_correct = (test_preds == test_labels).sum()
            # Reset the weights after getting the evaluation accuracies.
            model.set_weights(old_vars)
            accuracies.append(num_correct / classes)
        training.append(accuracies[0])
        testing.append(accuracies[1])
        if meta_iter % 100 == 0:
            print(
                "batch %d: train=%f test=%f" % (meta_iter, accuracies[0], accuracies[1])
            )

            # That´s how we define our training loop updating the global inialization parameters and the how we
            # have a evaluation loop that does not actually update the parameters.


## Visualize Results


In [None]:
# First, some preprocessing to smooth the training and testing arrays for display.
window_length = 100
train_s = np.r_[
    training[window_length - 1 : 0 : -1], training, training[-1:-window_length:-1]
]
test_s = np.r_[
    testing[window_length - 1 : 0 : -1], testing, testing[-1:-window_length:-1]
]
w = np.hamming(window_length)
train_y = np.convolve(w / w.sum(), train_s, mode="valid")
test_y = np.convolve(w / w.sum(), test_s, mode="valid")

# Display the training accuracies.
x = np.arange(0, len(test_y), 1)
plt.plot(x, test_y, x, train_y)
plt.legend(["test", "train"])
plt.grid()

train_set, test_images, test_labels = dataset.get_mini_dataset(
    eval_batch_size, eval_iters, shots, classes, split=True
)
for images, labels in train_set:
    with tf.GradientTape() as tape:
        preds = model(images)
        loss = keras.losses.sparse_categorical_crossentropy(labels, preds)
    grads = tape.gradient(loss, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
test_preds = model.predict(test_images)
test_preds = tf.argmax(test_preds).numpy()

_, axarr = plt.subplots(nrows=1, ncols=5, figsize=(20, 20))

sample_keys = list(train_dataset.data.keys())

for i, ax in zip(range(5), axarr):
    temp_image = np.stack((test_images[i, :, :, 0],) * 3, axis=2)
    temp_image *= 255
    temp_image = np.clip(temp_image, 0, 255).astype("uint8")
    ax.set_title(
        "Label : {}, Prediction : {}".format(int(test_labels[i]), test_preds[i])
    )
    ax.imshow(temp_image, cmap="gray")
    ax.xaxis.set_visible(False)
    ax.yaxis.set_visible(False)
plt.show()
