In [1]:
%load_ext autoreload
%autoreload 2

In [6]:
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

from temporal_basis_transformation_network.keras import TemporalBasisTrafo
import dlop_ldn_function_bases as bases

In [3]:
def normalize_img(image, label):
    return tf.cast(image, tf.float32) / 255.0, label

# Load the original dataset
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

In [4]:
# Permute the images and reshape
rng = np.random.RandomState(49291)
idcs = np.random.permutation(np.arange(28 * 28))

train_imgs_perm = np.zeros((len(ds_train), 28 * 28))
test_imgs_perm = np.zeros((len(ds_test), 28 * 28))
train_lbls = np.zeros(len(ds_train), dtype=np.int)
test_lbls = np.zeros(len(ds_test), dtype=np.int)

for i, (image, label) in enumerate(tfds.as_numpy(ds_train)):
    train_imgs_perm[i] = 2.0 * (image.flatten()[idcs] / 255.0) - 1.0
    train_lbls[i] = label

for i, (image, label) in enumerate(tfds.as_numpy(ds_test)):
    test_imgs_perm[i] = 2.0 * (image.flatten()[idcs] / 255.0) - 1.0
    test_lbls[i] = label

ds_train_perm = tf.data.Dataset.from_tensor_slices((train_imgs_perm, train_lbls))
ds_train_perm = ds_train_perm.shuffle(ds_info.splits['train'].num_examples)
ds_train_perm = ds_train_perm.batch(128)
ds_train_perm = ds_train_perm.prefetch(tf.data.experimental.AUTOTUNE)

ds_test_perm = tf.data.Dataset.from_tensor_slices((test_imgs_perm, test_lbls))
ds_test_perm = ds_test_perm.batch(128)
ds_test_perm = ds_test_perm.prefetch(tf.data.experimental.AUTOTUNE)

In [8]:
H = bases.mk_dlop_basis(256, 784)

model = tf.keras.models.Sequential([
  tf.keras.layers.Reshape((784, 1)),              # (784, 1)
  tf.keras.layers.Dense(256, activation='relu'),  # (784, 256)
  tf.keras.layers.Dense(16, use_bias=False),      # (784, 16)
  TemporalBasisTrafo(H, n_units=16, pad=False),   # (1, 16 * 256)
  tf.keras.layers.Dense(256, activation='relu'),  # (1, 256)
  tf.keras.layers.Dense(10, use_bias=False),      # (1, 10)
  tf.keras.layers.Reshape((10,))                  # (10)
])

model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.fit(
    ds_train_perm,
    epochs=20,
    validation_data=ds_test_perm,
)

Epoch 1/20
Epoch 2/20
Epoch 3/20
101/469 [=====>........................] - ETA: 34s - loss: 0.0610 - sparse_categorical_accuracy: 0.9827

KeyboardInterrupt: 