In [None]:
from dl_stash import image

In [None]:
import tensorflow as tf
from tensorflow import keras
from keras import layers

import tensorflow_datasets as tfds
import matplotlib.pyplot as plt

In [None]:
scaling = layers.Rescaling(1/255)

augmentation = keras.Sequential(
    [
        layers.Input(shape=(28, 28, 1), dtype=tf.uint8),
        scaling,
        layers.RandomRotation (1),
        layers.RandomPerspective(),
    ]
)


In [None]:
mnist = tfds.image_classification.MNIST()
# mnist.download_and_prepare()
mnist_ds = mnist.as_dataset()

In [None]:
mnist_ds.keys()

In [None]:
im = next(iter( mnist_ds["test"].take(1) ))["image"]

In [None]:
plt.imshow(im.numpy())

In [None]:
BATCH_SIZE = 128
LR = 0.001

In [None]:
train_ds = mnist_ds["train"].map(
        lambda x: x["image"]
    ).batch(
        BATCH_SIZE
    ).map(
    lambda b: (augmentation(b), scaling(b))
)
val_ds = mnist_ds["test"].map(
        lambda x: x["image"]
    ).batch(
        BATCH_SIZE
    ).map(
    lambda b: (augmentation(b), scaling(b))
)

In [None]:
class AffinedTransform(layers.Layer):
    def __init__(self, normalize_displacement=True, **kwargs):
        super().__init__(**kwargs)
        self.normalize_displacement = normalize_displacement
        # Identity transformation
        self.default_transform = tf.constant([[1.0, 0.0, 0.0, 0.0, 1.0, 0.0]], dtype=tf.float32)

    def call(self, inputs):
        # im is (batch_size, height, width, channels)
        # params is (batch_size, 6)
        im, params = inputs
        im_shape = tf.shape(im)
        batch_size = im_shape[0]
        height = tf.cast(im_shape[1], dtype=tf.float32)
        width = tf.cast(im_shape[2], dtype=tf.float32)
        if self.normalize_displacement:
            scaling_factor = tf.convert_to_tensor([[1.0, 1.0, width, 1.0, 1.0, height]], dtype=tf.float32)
            params = params * scaling_factor
        params = self.default_transform + params
        transformation_matrix = tf.reshape(params, [batch_size, 2, 3])
        return image.affine_transform(im, transformation_matrix)

In [None]:
# affine = AffinedTransform()


In [None]:
# outa = affine(
#     [
#         tf.expand_dims(tf.cast(im, tf.float32), axis=0),
#         tf.convert_to_tensor([[-1, 1, 0, 1, -1, 0]], dtype=tf.float32)
#     ]
# )[0]

# plt.imshow(outa.numpy())


In [None]:
encoder = keras.Sequential(
    [
        layers.Input(shape=(28, 28, 1), dtype=tf.float32),
        layers.Conv2D(16, 3, activation="relu", padding="same"),
        layers.Conv2D(32, 3, activation="relu", padding="same"),
        layers.MaxPooling2D(),
        layers.Conv2D(64, 3, activation="relu", padding="same"),
        layers.Conv2D(64, 3, activation="relu", padding="same"),
        layers.MaxPooling2D(),
        layers.Conv2D(128, 3, activation="relu", padding="same"),
        layers.Conv2D(256, 3, activation="relu", padding="same"),
        layers.GlobalAveragePooling2D(),
        layers.Dense(6, activation=None),
    ]
)

im_input = layers.Input(shape=(28, 28, 1), dtype=tf.float32)
z_params = encoder(im_input)

affine = AffinedTransform()
im_output = affine([im_input, z_params])

model = keras.Model(im_input, im_output)

model.compile(
    optimizer=keras.optimizers.Adam(LR),
    metrics=["mse"],
    loss="mse",
)


In [None]:
history = model.fit(train_ds, epochs=50, validation_data=val_ds)

In [None]:
x, y = next(iter( val_ds.take(1)))

In [None]:
y_pred = model(x)

In [None]:
list(range(2, 4))

In [None]:
# Visualize predicted
n = len(y_pred)
offset = 16
n = 16
for i in range(offset, offset + n):
    plt.figure(figsize=(12, 6 ))
    plt.subplot(1, 3, 1)
    plt.imshow(
        x[i]
    )
    plt.subplot(1, 3, 2)
    plt.imshow(
        y[i]
    )
    plt.subplot(1, 3, 3)
    plt.imshow(
        y_pred[i]
    )
