In [None]:
!git clone https://github.com/Jacobiano/morpholayers.git

In [None]:
!wget -cO - "https://drive.google.com/uc?export=download&id=14AFm92AM5I-oYm9S85AFxzzxM-oOiRK1" > scaleEquivariantTutorialData.tar.gz
!tar -xzf scaleEquivariantTutorialData.tar.gz
!ls

In [None]:
!pip install scikit-image==0.18

In [None]:
import tensorflow as tf
import numpy as np

from morpholayers.layers import QuadraticDilation2D

from tensorflow.keras import layers
from tensorflow.keras.models import Model
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.initializers import Initializer

import matplotlib.pyplot as plt

plt.gray()

In [None]:
import skimage

In [None]:
skimage.__version__

In [None]:
from skimage import data
from skimage.color import rgb2gray

In [None]:
im = rgb2gray(data.cat())
plt.imshow(im)

In [None]:
# define a simple quadratic dilation layer
inputs = layers.Input([None, None, 1])
quad_dilation_layer = QuadraticDilation2D(num_filters=1, kernel_size=[63, 63])
x = quad_dilation_layer(inputs)
quad_dilation = Model(inputs, x)

In [None]:
# show the parabolic structuring element
plt.imshow(quad_dilation_layer.data[:, :, 0, 0])

In [None]:
im_out = quad_dilation(im[np.newaxis, :, :, np.newaxis])[0, :, :, 0]

In [None]:
plt.figure()
plt.imshow(im_out)
plt.figure()
plt.imshow(im_out - im)

In [None]:
class ScaleConsistency(Constraint):
    def __init__(self, scales, vmax=2.0, vmin=0.5, **kwargs):
        super().__init__(**kwargs)
        self.scales = tf.constant(scales)[tf.newaxis, :]
        self.vmin = vmin
        self.vmax = vmax

    def call(self, b):
        bmean = tf.reduce_mean(b * self.scales**2, 1, keepdims=True)
        bmean = tf.clip_by_value(bmean, vmin, vmax)
        b = bmean / self.scales**2
        return b

In [None]:
class ScaleInitializer(Initializer):
    def __init__(self, scales, **kwargs):
        super().__init__(**kwargs)
        self.scales = tf.constant(scales[tf.newaxis, :])

    def __call__(self, shape, dtype=None):
        return tf.ones(shape) / self.scales**2

In [None]:
n_scales = 4
zero_scale = 0.75
scales = zero_scale * (2 ** np.arange(n_scales))
scales = scales.astype(np.float32)
S = scales.max()

In [None]:
# define a lifting layer, which maps inputs to a functions on a set of scalings and trasnlations
inputs = layers.Input([None, None, 1])
qd_layer = QuadraticDilation2D(
    num_filters=n_scales,
    kernel_size=[int(4 * S), int(4 * S)],
    scale=zero_scale,
    bias_initializer=ScaleInitializer(scales),
    bias_constraint=ScaleConsistency(scales),
)
x = qd_layer(inputs)
x = layers.Lambda(lambda v: tf.transpose(v[:, :, :, tf.newaxis, :], (0, 1, 2, 4, 3)))(x)
lifting = Model(inputs, x, name="lifting")

In [None]:
print(qd_layer.bias)

In [None]:
lifting.summary()

In [None]:
im = np.zeros([31, 31], dtype=np.float32)
im[15, 15] = 1.0

In [None]:
im_lifted = lifting(im[np.newaxis, :, :, np.newaxis])

In [None]:
# show the different scales of the lifting
plt.figure()
plt.subplot(1, 5, 1)
plt.imshow(im)
for i in range(n_scales):
    plt.subplot(1, 5, i + 2)
    plt.imshow(im_lifted[0, :, :, i, 0], vmin=0.0, vmax=1.0)

In [None]:
(x_train, y_train), (x_val, y_val) = tf.keras.datasets.mnist.load_data()

In [None]:
x_train = x_train.astype(np.float32) / 255
x_val = x_val.astype(np.float32) / 255

In [None]:
from scale_crosscorrelation import *

In [None]:
inputs = layers.Input((None, None, 1))
x = lifting(inputs)

x = ScaleConv(16, (3, 3, 1), n_scales)(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = ScaleConv(32, (3, 3, 1), n_scales)(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)

x = layers.GlobalMaxPooling3D()(x)
x = layers.Dense(10)(x)

model_invariant = Model(inputs, x)
model_invariant.summary()

model_invariant.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=tf.keras.optimizers.Adam(1e-2),
    metrics=["accuracy"],
)

In [None]:
# model_invariant.fit(x_train, y_train, epochs=10, batch_size=32, validation_data=(x_val, y_val), verbose=1)

In [None]:
# model_invariant.save('scaleEquivariantExample_invariantModel')

In [None]:
model_invariant = tf.keras.models.load_model("scaleEquivariantExample_invariantModel")

In [None]:
from skimage.transform import rescale

In [None]:
# create a new test set in a bigger scale
x_test = np.stack([rescale(x_val[i, ...], 2) for i in range(x_val.shape[0])], 0)
y_test = y_val

In [None]:
# create a new test set in a bigger scale
x_test2 = np.stack([rescale(x_val[i, ...], 1 / 2) for i in range(x_val.shape[0])], 0)
y_test2 = y_val

In [None]:
plt.figure()
plt.imshow(x_val[0, ...])
plt.title("image from the original test set")
plt.figure()
plt.imshow(x_test[0, ...])
plt.title("image from the test set with 2x zoom in")
plt.figure()
plt.imshow(x_test2[0, ...])
plt.title("image from the test set with 2x zoom out")

In [None]:
model_invariant.evaluate(x_val, y_test)

In [None]:
print("testing the equivariant model in the 2x zoomed in test set")
model_invariant.evaluate(x_test, y_test)
print("testing the equivariant model in the 2x zoomed out test set")
model_invariant.evaluate(x_test2, y_test)

In [None]:
inputs = layers.Input((None, None, 1))

x = layers.Conv2D(16, 3, padding="same")(inputs)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.Conv2D(32, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)

x = layers.GlobalMaxPooling2D()(x)
x = layers.Dense(10)(x)

model_noninvariant = Model(inputs, x)
model_noninvariant.summary()

model_noninvariant.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=tf.keras.optimizers.Adam(1e-2),
    metrics=["accuracy"],
)

In [None]:
model_noninvariant = tf.keras.models.load_model("scaleEquivariantExample_baselineModel")

In [None]:
# model_noninvariant.fit(x_train, y_train, epochs=15, batch_size=64, validation_data=(x_val, y_val), verbose=1)

In [None]:
# model_noninvariant.save('scaleEquivariantExample_baselineModel')

In [None]:
model_noninvariant.evaluate(x_val, y_val)

In [None]:
print("testing non-equivariant model in the 2x zoomed in test set")
model_noninvariant.evaluate(x_test, y_val)
print("testing the non-equivariant model in the 2x zoomed out test set")
model_noninvariant.evaluate(x_test2, y_val)

In [None]:
print(model_invariant.layers[1].weights)