In [1]:
import os
import sys
sys.path.insert(0, os.path.abspath('../tensorflow_fewshot/'))

In [2]:
import tensorflow_fewshot.models.prototypical_network as ptn
import tensorflow_fewshot.models.utils as models_utils

In [3]:
import tensorflow as tf
import numpy as np

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

from skimage.transform import resize
from skimage.color import rgb2gray

import tensorflow_datasets as tfds

from IPython.display import display, clear_output

## Dataset

In [4]:
# Split train/test
train_ds = tfds.load("omniglot", split=tfds.Split.TRAIN, batch_size=-1)
# Downscaling is important, otherwise the network isn't powerful enough to train
omniglot_X = tfds.as_numpy(train_ds)['image']
train_X = np.zeros((4 * omniglot_X.shape[0], 28, 28, 1))
train_Y = tfds.as_numpy(train_ds)['label']
train_Y = np.tile(train_Y, 4) + np.repeat([0, 1, 2, 3], train_Y.shape[0]) * np.max(train_Y)
resize_batch_size = 256
for i in range(omniglot_X.shape[0] // resize_batch_size):
    if i % 1 == 0:
        clear_output(wait=True)
        display('resize + grayscale: ' + str(i * 100 // (omniglot_X.shape[0] // resize_batch_size)) + '%')
    train_X[i * resize_batch_size:(i + 1) * resize_batch_size, :, :, :] = resize(
        rgb2gray(
            omniglot_X[i * resize_batch_size:(i + 1) * resize_batch_size, :, :, :]
        )[:, :, :, None],
        (resize_batch_size, 28, 28, 1))

for i in range(omniglot_X.shape[0] // resize_batch_size):
    clear_output(wait=True)
    display('rotations ' + str(i * 100 // (omniglot_X.shape[0] // resize_batch_size)) + '%')
    for rot in [1, 2, 3]:
        train_X[
        omniglot_X.shape[0] * rot + i * resize_batch_size:
        omniglot_X.shape[0] * rot + (i + 1) * resize_batch_size,
        :, :, :
        ] = np.rot90(train_X[
                     i * resize_batch_size:(i + 1) * resize_batch_size, :, :, :
                     ], rot, (1, 2))

test_ds = tfds.load("omniglot", split=tfds.Split.TEST, batch_size=-1)
omniglot_test_X = tfds.as_numpy(test_ds)['image']
test_X = np.zeros((omniglot_test_X.shape[0], 28, 28, 1))
test_Y = tfds.as_numpy(test_ds)['label']
for i in range(test_X.shape[0] // resize_batch_size):
    if i % 1 == 0:
        clear_output(wait=True)
        display('Test set ' + str(i * 100 // (test_X.shape[0] // resize_batch_size)) + '%')
    test_X[i * resize_batch_size:(i + 1) * resize_batch_size, :, :, :] = resize(
        rgb2gray(
            omniglot_test_X[i * resize_batch_size:(i + 1) * resize_batch_size, :, :, :]
        )[:, :, :, None],
        (resize_batch_size, 28, 28, 1))

# Display the split
print("Train X shape:", train_X.shape)
print("Train Y shape:", train_Y.shape)
print("Test X shape:", test_X.shape)
print("Test Y shape:", test_Y.shape)

'Test set 98%'

Train X shape: (77120, 28, 28, 1)
Train Y shape: (77120,)
Test X shape: (13180, 28, 28, 1)
Test Y shape: (13180,)


## Training

In [5]:
encoder = models_utils.create_imageNetCNN(input_shape=(28, 28, 1))

In [6]:
protonet = ptn.PrototypicalNetwork(encoder=encoder)

In [7]:
protonet.encoder(test_X[200:205,:,:,:]).shape

TensorShape([5, 64])

In [8]:
acc, loss, test_acc = protonet.meta_train(train_X, train_Y, n_episode=200)

Episode 0; lr: 1.00e-03, training loss: 1222.8706, train accuracy: 0.53


## Test

In [15]:
test_classes = np.random.choice(np.unique(test_Y), size=5, replace=False)
print("Test classes are:", test_classes)

fit_indices = np.zeros((25,), dtype=np.int32)
test_indices = np.zeros((25,), dtype=np.int32)
for i in range(5):
    class_indices = np.random.choice(
        np.argwhere(test_Y == test_classes[i]).flatten(),
        size=10,
        replace=False
    )
    fit_indices[i*5:(i+1)*5] = class_indices[:5]
    test_indices[i*5:(i+1)*5] = class_indices[5:]

np.random.shuffle(fit_indices)
np.random.shuffle(test_indices)

fit_data = test_X[fit_indices,:,:,:]
fit_labels = test_Y[fit_indices]

test_data = test_X[test_indices,:,:,:]
test_labels = test_Y[test_indices]

Test classes are: [1570 1241 1284 1326  971]


In [16]:
protonet.fit(fit_data.astype(np.float32), fit_labels)

In [17]:
test_preds = protonet.predict(test_data)

In [18]:
test_preds

array([1284,  971, 1326, 1326, 1326, 1570, 1570,  971, 1284, 1241, 1241,
       1284, 1570, 1570, 1241, 1326,  971, 1326,  971, 1570,  971, 1284,
       1570, 1570, 1241], dtype=int32)

In [19]:
test_labels

array([1284,  971, 1326, 1326, 1326, 1570, 1570,  971, 1284, 1241, 1241,
       1284, 1241, 1284, 1241, 1326,  971, 1326,  971, 1570,  971, 1284,
       1570, 1570, 1241])

In [20]:
acc = np.mean(test_preds == test_labels)
print("Test accuracy is:", acc)

Test accuracy is: 0.92
