In [1]:
import tensorflow as tf
import numpy as np

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

from skimage.transform import resize
from skimage.color import rgb2gray

import tensorflow_datasets as tfds

from IPython.display import display, clear_output

In [3]:
# Split train/test
train_ds = tfds.load("omniglot", split=tfds.Split.TRAIN, batch_size=-1)
# Downscaling is important, otherwise the network isn't powerful enough to train
omniglot_X = tfds.as_numpy(train_ds)['image']
train_X = np.zeros((4 * omniglot_X.shape[0], 28, 28, 1))
train_Y = tfds.as_numpy(train_ds)['label']
train_Y = np.tile(train_Y, 4) + np.repeat([0, 1, 2, 3], train_Y.shape[0]) * np.max(train_Y)
resize_batch_size = 256
for i in range(omniglot_X.shape[0] // resize_batch_size):
    if i % 1 == 0:
        clear_output(wait=True)
        display('resize + grayscale: ' + str(i * 100 // (omniglot_X.shape[0] // resize_batch_size)) + '%')
    train_X[i * resize_batch_size:(i + 1) * resize_batch_size, :, :, :] = resize(
        rgb2gray(
            omniglot_X[i * resize_batch_size:(i + 1) * resize_batch_size, :, :, :]
        )[:, :, :, None],
        (resize_batch_size, 28, 28, 1))

for i in range(omniglot_X.shape[0] // resize_batch_size):
    clear_output(wait=True)
    display('rotations ' + str(i * 100 // (omniglot_X.shape[0] // resize_batch_size)) + '%')
    for rot in [1, 2, 3]:
        train_X[
        omniglot_X.shape[0] * rot + i * resize_batch_size:
        omniglot_X.shape[0] * rot + (i + 1) * resize_batch_size,
        :, :, :
        ] = np.rot90(train_X[
                     i * resize_batch_size:(i + 1) * resize_batch_size, :, :, :
                     ], rot, (1, 2))

test_ds = tfds.load("omniglot", split=tfds.Split.TEST, batch_size=-1)
omniglot_test_X = tfds.as_numpy(test_ds)['image']
test_X = np.zeros((omniglot_test_X.shape[0], 28, 28, 1))
test_Y = tfds.as_numpy(test_ds)['label']
for i in range(test_X.shape[0] // resize_batch_size):
    if i % 1 == 0:
        clear_output(wait=True)
        display('Test set ' + str(i * 100 // (test_X.shape[0] // resize_batch_size)) + '%')
    test_X[i * resize_batch_size:(i + 1) * resize_batch_size, :, :, :] = resize(
        rgb2gray(
            omniglot_test_X[i * resize_batch_size:(i + 1) * resize_batch_size, :, :, :]
        )[:, :, :, None],
        (resize_batch_size, 28, 28, 1))

# Display the split
print("Train X shape:", train_X.shape)
print("Train Y shape:", train_Y.shape)
print("Test X shape:", test_X.shape)
print("Test Y shape:", test_Y.shape)

'Test set 98%'

Train X shape: (77120, 28, 28, 1)
Train Y shape: (77120,)
Test X shape: (13180, 28, 28, 1)
Test Y shape: (13180,)


In [191]:
import tensorflow_fewshot.models.prototypical_network as ptn

import importlib
importlib.reload(ptn)

<module 'tensorflow_fewshot.models.prototypical_network' from '../tensorflow_fewshot/models/prototypical_network.py'>

In [192]:
protonet = ptn.PrototypicalNetwork()

In [193]:
protonet.encoder(test_X[200:205,:,:,:]).shape



To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.





To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.



TensorShape([5, 64])

In [125]:
protonet.meta_train(train_X, train_Y, n_episode=2)



To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.





To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.



Episode 0; lr: 1.00e-03, training loss: 1136.0100, train accuracy: 0.65


In [194]:
protonet.train(test_X[:200,:,:,:].astype(np.float32), test_Y[:200])

In [196]:
protonet.predict(test_X[200:205,:,:,:])

array([60, 60, 60, 60, 60])