In [1]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
from tqdm.autonotebook import tqdm
%matplotlib inline
from IPython import display
import pandas as pd
from tensorflow_probability.python.distributions import Chi2



In [2]:
TRAIN_BUF = 60000
BATCH_SIZE = 64
TEST_BUF = 10000
DIMS = (28, 28, 1)
N_Z = 128
EPOCH = 50

In [3]:
(train_images, _), (test_images, _) = keras.datasets.fashion_mnist.load_data()

In [4]:
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype("float32") / 255.0
test_images = test_images.reshape(test_images.shape[0], 28, 28, 1).astype("float32") / 255.0

In [5]:
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(TRAIN_BUF).batch(BATCH_SIZE)
test_dataset = tf.data.Dataset.from_tensor_slices(test_images).shuffle(TEST_BUF).batch(BATCH_SIZE)

In [6]:
def unet_convblock_down(_input, channels=16, kernel=(3, 3), activation="relu", pool_size=(2, 2), kernel_initializer="he_normal"):
    conv = keras.layers.Conv2D(channels, kernel, activation=activation, padding="same", kernel_initializer=kernel_initializer)(_input)
    conv = keras.layers.Conv2D(channels, kernel, activation=activation, padding="same", kernel_initializer=kernel_initializer)(conv)
    pool = keras.layers.MaxPooling2D(pool_size=pool_size)(conv)
    return conv, pool

In [7]:
def unet_convblock_up(last_conv, cross_conv, channels=16, kernel=(3, 3), activation="relu", pool_size=(2, 2),kernel_initializer="he_normal"):
    up_conv = keras.layers.UpSampling2D(size=(2, 2))(last_conv)
    merge = keras.layers.concatenate([up_conv, cross_conv], axis=3)
    conv = keras.layers.Conv2D(channels, kernel, activation=activation, padding="same", kernel_initializer=kernel_initializer)(merge)
    conv = keras.layers.Conv2D(channels, kernel, activation=activation, padding="same", kernel_initializer=kernel_initializer)(conv)
    return conv

In [8]:
def unet_mnist():
    inputs = keras.layers.Input(shape=(28, 28, 1))
    up_1, pool_1 = unet_convblock_down(inputs, channels=32)
    up_2, pool_2 = unet_convblock_down(pool_1, channels=64)
    conv_middle = keras.layers.Conv2D(128, (3, 3), activation="relu", kernel_initializer="he_normal", padding="same")(pool_2)
    conv_middle = keras.layers.Conv2D(128, (3, 3), activation="relu", kernel_initializer="he_normal", padding="same")(conv_middle)
    down_2 = unet_convblock_up(conv_middle, up_2, channels=64)
    down_1 = unet_convblock_up(down_2, up_1, channels=32)
    outputs = keras.layers.Conv2D(1, (1, 1), activation="sigmoid")(down_1)
    return inputs, outputs

In [9]:
encoder = [
    keras.layers.InputLayer(input_shape=DIMS),
    keras.layers.Conv2D(filters=32, kernel_size=3, strides=(2, 2), activation="relu"),
    keras.layers.Conv2D(filters=64, kernel_size=3, strides=(2, 2), activation="relu"),
    keras.layers.Flatten(),
    keras.layers.Dense(units=N_Z*2),
]

In [10]:
decoder = [
    keras.layers.Dense(units=7 * 7 * 64, activation="relu"),
    keras.layers.Reshape(target_shape=(7, 7, 64)),
    keras.layers.Conv2DTranspose(filters=64, kernel_size=3, strides=(2, 2), padding="SAME", activation="relu"),
    keras.layers.Conv2DTranspose(filters=32, kernel_size=3, strides=(2, 2), padding="SAME", activation="relu"),
    keras.layers.Conv2DTranspose(filters=1, kernel_size=3, strides=(1, 1), padding="SAME", activation="sigmoid"),
]

In [11]:
def sigmoid(x, shift=0.0, mult=20):
    return tf.constant(1.0) / (tf.constant(1.0) + tf.exp(-tf.constant(1.0) * ((x + tf.constant(shift)) * mult)))

In [12]:
class GAIA(tf.keras.Model):
    def __init__(self, **kwargs):
        super(GAIA, self).__init__()
        self.__dict__.update(kwargs)

        self.enc = tf.keras.Sequential(self.enc)
        self.dec = tf.keras.Sequential(self.dec)

        inputs, outputs = self.unet_function()
        self.disc = tf.keras.Model(inputs=[inputs], outputs=[outputs])

    def encode(self, x):
        return self.enc(x)

    def decode(self, z):
        return self.dec(z)

    def discriminate(self, x):
        return self.disc(x)

    def regularization(self, x1, x2):
        return tf.reduce_mean(tf.square(x1 - x2))

    @tf.function
    def network_pass(self, x):
        z = self.encode(x)
        xg = self.decode(z)
        zi = self._interpolate_z(z)
        xi = self.decode(zi)
        d_xi = self.discriminate(xi)
        d_x = self.discriminate(x)
        d_xg = self.discriminate(xg)
        return z, xg, zi, xi, d_xi, d_x, d_xg

    @tf.function
    def compute_loss(self, x):
        z, xg, zi, xi, d_xi, d_x, d_xg = self.network_pass(x)

        xg_loss = self.regularization(x, xg)
        d_xg_loss = self.regularization(x, d_xg)
        d_xi_loss = self.regularization(xi, d_xi)
        d_x_loss = self.regularization(x, d_x)
        return d_xg_loss, d_xi_loss, d_x_loss, xg_loss

    @tf.function
    def compute_gradients(self, x):
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            d_xg_loss, d_xi_loss, d_x_loss, xg_loss = self.compute_loss(x)

            gen_loss = d_xg_loss + d_xi_loss
            disc_loss = d_xg_loss + d_x_loss - tf.clip_by_value(d_xi_loss, 0, d_x_loss)

        gen_gradients = gen_tape.gradient(gen_loss, self.enc.trainable_variables + self.dec.trainable_variables) 
        disc_gradients = disc_tape.gradient(disc_loss, self.disc.trainable_variables)
        return gen_gradients, disc_gradients

    @tf.function
    def apply_gradients(self, gen_gradients, disc_gradients):
        self.gen_optimizer.apply_gradients(zip(gen_gradients, self.enc.trainable_variables + self.dec.trainable_variables,))
        self.disc_optimizer.apply_gradients(zip(disc_gradients, self.disc.trainable_variables))
    
    @tf.function
    def train(self, x):
        gen_gradients, disc_gradients = self.compute_gradients(x)
        self.apply_gradients(gen_gradients, disc_gradients)

    def _interpolate_z(self, z):
        if self.chsq.df != z.shape[0]:
            self.chsq = Chi2(df=1 / z.shape[0])
        ip = self.chsq.sample((z.shape[0], z.shape[0]))
        ip = ip / tf.reduce_sum(ip, axis=0)
        zi = tf.transpose(tf.tensordot(tf.transpose(z), ip, axes=1))
        return zi

In [13]:
gen_optimizer = keras.optimizers.Adam(1e-3, beta_1=0.5)
disc_optimizer = keras.optimizers.RMSprop(1e-3)

In [14]:
model = GAIA(enc=encoder, dec=decoder, unet_function=unet_mnist, gen_optimizer=gen_optimizer, disc_optimizer=disc_optimizer,
             chsq=Chi2(df=1/BATCH_SIZE))

In [15]:
example_data = next(iter(train_dataset))

def plot_reconstruction(model, example_data, nex=5, zm=3):
    z, xg, zi, xi, d_xi, d_x, d_xg = model.network_pass(example_data)
    fig, axs = plt.subplots(ncols=6, nrows = nex, figsize=(zm*6, zm*nex))
    for axi, (dat, lab) in enumerate(zip(
            [example_data, d_x, xg, d_xg, xi, d_xi],
            ["data", "disc data","gen", "disc gen", "interp", "disc interp"])):
        for ex in range(nex):
            axs[ex, axi].matshow(dat.numpy()[ex].squeeze(), cmap = plt.cm.Greys, vmin=0, vmax=1)
            axs[ex, axi].axis('off')
        axs[0, axi].set_title(lab)     
    plt.show()

In [16]:
losses = pd.DataFrame(columns = ['d_xg_loss', 'd_xi_loss', 'd_x_loss', 'xg_loss'])

In [17]:
for epoch in range(EPOCH):
    for batch, train_x in enumerate(train_dataset):
        model.train(train_x)
    loss = []
    for batch, test_x in enumerate(test_dataset):
        loss.append(model.compute_loss(train_x))
    losses.loc[len(losses)] = np.mean(loss, axis=0)
    display.clear_output()
    print("Epoch: {}".format(epoch))
    plot_reconstruction(model, example_data)

TypeError: in converted code:

    <ipython-input-12-f5e9049f6859>:64 train  *
        gen_gradients, disc_gradients = self.compute_gradients(x)
    C:\ProgramData\Anaconda3\lib\site-packages\tensorflow_core\python\eager\def_function.py:568 __call__
        result = self._call(*args, **kwds)
    <ipython-input-12-f5e9049f6859>:48 compute_gradients  *
        d_xg_loss, d_xi_loss, d_x_loss, xg_loss = self.compute_loss(x)
    C:\ProgramData\Anaconda3\lib\site-packages\tensorflow_core\python\eager\def_function.py:568 __call__
        result = self._call(*args, **kwds)
    <ipython-input-12-f5e9049f6859>:37 compute_loss  *
        z, xg, zi, xi, d_xi, d_x, d_xg = self.network_pass(x)
    C:\ProgramData\Anaconda3\lib\site-packages\tensorflow_core\python\eager\def_function.py:568 __call__
        result = self._call(*args, **kwds)
    <ipython-input-12-f5e9049f6859>:28 network_pass  *
        zi = self._interpolate_z(z)
    <ipython-input-12-f5e9049f6859>:68 _interpolate_z  *
        if self.chsq.df != z.shape[0]:
    C:\ProgramData\Anaconda3\lib\site-packages\tensorflow_core\python\autograph\operators\control_flow.py:918 if_stmt
        basic_symbol_names, composite_symbol_names)
    C:\ProgramData\Anaconda3\lib\site-packages\tensorflow_core\python\autograph\operators\control_flow.py:956 tf_if_stmt
        error_checking_orelse)
    C:\ProgramData\Anaconda3\lib\site-packages\tensorflow_core\python\util\deprecation.py:507 new_func
        return func(*args, **kwargs)
    C:\ProgramData\Anaconda3\lib\site-packages\tensorflow_core\python\ops\control_flow_ops.py:1174 cond
        return cond_v2.cond_v2(pred, true_fn, false_fn, name)
    C:\ProgramData\Anaconda3\lib\site-packages\tensorflow_core\python\ops\cond_v2.py:83 cond_v2
        op_return_value=pred)
    C:\ProgramData\Anaconda3\lib\site-packages\tensorflow_core\python\framework\func_graph.py:983 func_graph_from_py_func
        expand_composites=True)
    C:\ProgramData\Anaconda3\lib\site-packages\tensorflow_core\python\util\nest.py:568 map_structure
        structure[0], [func(*x) for x in entries],
    C:\ProgramData\Anaconda3\lib\site-packages\tensorflow_core\python\util\nest.py:568 <listcomp>
        structure[0], [func(*x) for x in entries],
    C:\ProgramData\Anaconda3\lib\site-packages\tensorflow_core\python\framework\func_graph.py:943 convert
        (str(python_func), type(x)))

    TypeError: To be compatible with tf.contrib.eager.defun, Python functions must return zero or more Tensors; in compilation of <function tf_if_stmt.<locals>.error_checking_body at 0x0000011F9F23F1F8>, found return value of type <class 'tensorflow_probability.python.distributions.chi2.Chi2'>, which is not a Tensor.
