In [1]:
import tensorflow as tf
import matplotlib.pyplot as plt

from deep_topology import layers, data

In [3]:
generator, steps_per_epoch = data.get_ordered_mnist_generator()

## Numpy version

In [5]:
num_classes = 10
inner_batch_size = 8
lambda_ = 1

inp = tf.keras.Input(shape=(28, 28, 1))
y_true = tf.keras.Input(shape=(1,))

feats = inp
for n_filters in [8, 32, 64, 128]:
    feats = tf.keras.layers.Conv2D(n_filters, 3, padding="same")(feats)
    feats = tf.keras.layers.LeakyReLU(alpha=0.1)(feats)
    feats = tf.keras.layers.MaxPool2D((2, 2), 2)(feats)
feats = tf.keras.layers.Flatten()(feats)

topo_reg = layers.TopologicallyDenseRegularization(
    beta=0.2, inner_batch_size=inner_batch_size, num_classes=num_classes
)(feats)

out = tf.keras.layers.Dense(num_classes)(feats)


def regularized_mse(y_true, out, reg, lambda_):
    return tf.losses.mean_squared_error(y_true, out) + lambda_ * reg


loss = tf.keras.layers.Lambda(
    lambda x, lambda_: regularized_mse(x[0], x[1], x[2], lambda_=lambda_), arguments={"lambda_": lambda_}
)([y_true, out, topo_reg])

model = tf.keras.Model([inp, y_true], loss)
model.compile(optimizer="adam", loss="mse")

In [6]:
model.fit(generator, epochs=1, steps_per_epoch=steps_per_epoch, shuffle=False)



<tensorflow.python.keras.callbacks.History at 0x13a9aecd0>

## Pure tensorflow version

In [7]:
num_classes = 10
inner_batch_size = 8
lambda_ = 1

inp = tf.keras.Input(shape=(28, 28, 1), batch_size=80)
y_true = tf.keras.Input(shape=(1,), batch_size=80)

feats = inp
for n_filters in [8, 32, 64, 128]:
    feats = tf.keras.layers.Conv2D(n_filters, 3, padding="same")(feats)
    feats = tf.keras.layers.LeakyReLU(alpha=0.1)(feats)
    feats = tf.keras.layers.MaxPool2D((2, 2), 2)(feats)
feats = tf.keras.layers.Flatten()(feats)

topo_reg = layers.TopologicallyDenseRegularization(
    beta=0.2, inner_batch_size=inner_batch_size, num_classes=num_classes, numpy=False
)(feats)

out = tf.keras.layers.Dense(num_classes)(feats)


def regularized_mse(y_true, out, reg, lambda_):
    return tf.losses.mean_squared_error(y_true, out) + lambda_ * reg


loss = tf.keras.layers.Lambda(
    lambda x, lambda_: regularized_mse(x[0], x[1], x[2], lambda_=lambda_), arguments={"lambda_": lambda_}
)([y_true, out, topo_reg])

model = tf.keras.Model([inp, y_true], loss)
model.compile(optimizer="adam", loss="mse")

In [8]:
model.fit(generator, epochs=1, steps_per_epoch=steps_per_epoch, shuffle=False)



<tensorflow.python.keras.callbacks.History at 0x12c86f5d0>