In [None]:
!nvidia-smi

In [None]:
# you know... don't run this on colab
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ['HTTP_PROXY']='http://proxy:3128/'
os.environ['HTTPS_PROXY']='http://proxy:3128/'

In [None]:
# imports
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import image as mpimage

tfkl = tf.keras.layers


from data.utils import parse_image_example
from modeling.layers import ResidualBlock

In [None]:
# this implements "Minibatch discrimination" from the Improved Techniques for Training GANs paper.
# I don't use this anymore, but kept it in, whatever
class MBD(tf.keras.layers.Layer):
    def __init__(self, size_p, size_q):
        super(MBD, self).__init__()
        self.p = size_p
        self.q = size_q
        
    def build(self, inp_shape):
        self.t = self.add_weight(shape=(inp_shape[-1],) + (self.p, self.q),
                                 initializer="glorot_uniform",
                                 trainable=True)
        
    def call(self, inp):
        #print(inp.shape)  # b x d
        # t is d x p x q
        # we broadcast features over p and q dims, do pointwise multiplication and sum over the d dim :shrug:
        # result is a batch of matrices
        weird_mult = tf.reduce_sum(inp[:, :, tf.newaxis, tf.newaxis] * self.t, axis=1)  # b x p x q
        #print(weird_mult.shape)
        # broadcast to get a b x b x p x q tensor for all (absolute) matrix differences
        # then sum over the columns (q) to get differences between rows
        # result is b x b x p
        weird_diff = tf.exp(-tf.reduce_sum(tf.abs(weird_mult[tf.newaxis] - weird_mult[:, tf.newaxis]), axis=-1))
        #print(weird_diff.shape)
        # finally sum over all examples to arrive at b x p
        return tf.concat([inp, tf.reduce_sum(weird_diff, axis=1)], axis=1)
    

# this is a much simpler alternative to the abow, from Progressive Growing.
# note that the sqrt() may be problematic (gradient is infinite at 0).
# it may be better to remove it and work with variances instead.
class SDLayer(tf.keras.layers.Layer):
    def __init__(self):
        super(SDLayer, self).__init__()
        
    def call(self, inputs):
        _, variances = tf.nn.moments(inputs, axes=[0])
        global_value = tf.reduce_mean(tf.math.sqrt(variances))
        broadcast = global_value * tf.ones_like(inputs)
        return tf.concat([inputs, broadcast[...,:1]], axis=-1)

In [None]:
batch_size = 512

# MNIST is BOOORING so let's at least use fashion
(train_images, _), (test_images, _) = tf.keras.datasets.fashion_mnist.load_data()
train_images = np.pad(train_images[..., None], ((0, 0), (2, 2), (2, 2), (0, 0))).astype(np.float32) / 255.
test_images = np.pad(test_images[..., None], ((0, 0), (2, 2), (2, 2), (0, 0))).astype(np.float32) / 255.

train_data = tf.data.Dataset.from_tensor_slices(train_images).shuffle(60000).batch(batch_size)
test_data = tf.data.Dataset.from_tensor_slices(test_images).batch(batch_size)

In [None]:
test_images = np.concatenate([batch for batch in iter(test_data)], axis=0)

plt.figure(figsize=(15,15))
for ind, img in enumerate(test_images[:64]):
    plt.subplot(8, 8, ind+1)
    plt.imshow(img, vmin=0, vmax=1, cmap="Greys")
    plt.axis("off")
plt.show()

In [None]:
class GAN(tf.keras.Model):
    def __init__(self, generator, discriminator,
                 loss_function, noise_function,
                 generator_optimizer, discriminator_optimizer,
                 label_smoothing=0.9, **kwargs):
        super().__init__(**kwargs)
        self.generator = generator
        self.discriminator = discriminator
        
        self.generator_optimizer = generator_optimizer
        self.discriminator_optimizer = discriminator_optimizer
        
        self.loss_function = loss_function
        self.noise_function = noise_function
        self.label_smoothing = label_smoothing
        
        self.g_loss_tracker = tf.keras.metrics.Mean("generator_loss")
        self.d_loss_tracker = tf.keras.metrics.Mean("discriminator_loss")
        self.feature_loss_tracker = tf.keras.metrics.Mean("feature_matching_loss")
        
    def train_step(self, data):
        data = 2*data - 1  # scale data to [-1. 1]
        
        dequantize_scale = 1/128  # +/- 1 pixel if scaled between -1 and 1
        batch_dim = tf.shape(data)[0]

        # prepare mixed batch for discriminator training.
        # NOTE training=True here is doubtful as we are not actually training the generator here.
        # I would recommend avoiding batchnorm, then this question doesn't matter
        generated_batch = generator(self.noise_function(batch_dim), training=True)
        # one-sided label smoothing applied here
        real_labels = self.label_smoothing*tf.ones([batch_dim, 1])
        generated_labels = tf.zeros([batch_dim, 1])

        # adding noise makes it more difficult and "de-quantizes" the data
        data = data + tf.random.uniform(tf.shape(data), -dequantize_scale, dequantize_scale)
        # tbh I'm not sure if it's necessary to apply it to the generated data s well
        generated_batch = generated_batch + tf.random.uniform(tf.shape(generated_batch), -dequantize_scale, dequantize_scale)
        
        full_batch = tf.concat((data, generated_batch), axis=0)
        full_labels = tf.concat((real_labels, generated_labels), axis=0)

        with tf.GradientTape() as d_tape:
            # index [0] into D output since it returns hidden layers as well (for feature matching)
            d_loss = self.loss_function(full_labels, discriminator(full_batch, training=True)[0]) 
        d_gradients = d_tape.gradient(d_loss, discriminator.trainable_variables)
        self.discriminator_optimizer.apply_gradients(zip(d_gradients, discriminator.trainable_variables))

        # fresh generated batch for generator training
        with tf.GradientTape(watch_accessed_variables=False) as g_tape:
            # tape would automatically watch D variables -> wasteful
            for variable in generator.trainable_variables:
                g_tape.watch(variable)

            gen_only_batch = generator(self.noise_function(2*batch_dim), training=True)
            gen_only_batch = gen_only_batch + tf.random.uniform(tf.shape(gen_only_batch), -dequantize_scale, dequantize_scale)

            d_output_fake = discriminator(gen_only_batch, training=True)
            # since we updated D, we need to re-compute the output for the real data for feature matching
            d_output_real = discriminator(data, training=True)
            # no label smoothing for generator training
            g_loss = loss_fn(tf.ones([2*batch_dim, 1]), d_output_fake[0])

            feature_match_loss = 0
            for fake_feature, real_feature in zip(d_output_fake[1:], d_output_real[1:]):
                feature_difference = tf.reduce_mean(fake_feature, axis=0) - tf.reduce_mean(real_feature, axis=0)
                feature_match_loss += tf.reduce_sum(feature_difference**2)

            g_loss_full = g_loss + feature_match_loss
        g_gradients = g_tape.gradient(g_loss_full, generator.trainable_variables)
        self.generator_optimizer.apply_gradients(zip(g_gradients, generator.trainable_variables))
        
        self.g_loss_tracker.update_state(g_loss)
        self.d_loss_tracker.update_state(d_loss)
        self.feature_loss_tracker.update_state(feature_match_loss)

        return {"generator_loss": self.g_loss_tracker.result(),
                "discriminator_loss": self.d_loss_tracker.result(),
                "feature_matching_loss": self.feature_loss_tracker.result()}
    
    @property
    def metrics(self):
        return [self.g_loss_tracker, self.d_loss_tracker, self.feature_loss_tracker]

In [None]:
def residual_stack(inputs, filters, strides, blocks_per_level, mode, name, normalization):
    all_outputs = []
    outputs = inputs
    for level_ind, (level_filters, level_stride) in enumerate(zip(filters, strides)):
        for block_ind in range(blocks_per_level):
            outputs = ResidualBlock(level_filters,
                                    mode, 
                                    strides=level_stride if block_ind == (blocks_per_level - 1) else 1,
                                    name="_".join([name, str(level_ind+1), str(block_ind+1)]),
                                    normalization=normalization)(outputs)
        all_outputs.append(outputs)
        
    return all_outputs

In [None]:
# build the model

tf.keras.backend.clear_session()

normalization = lambda **kwargs: tfkl.GroupNormalization(groups=16, **kwargs)
#normalization = tfkl.BatchNormalization

blocks_per_level = 2
filters = [32, 64, 128, 256]
strides = [2, 2, 2, 1]


discriminator_input = tf.keras.Input((32, 32, 1))
discriminator_outputs = residual_stack(discriminator_input, filters, strides, blocks_per_level, "conv", "discriminator",
                                     normalization)
discriminator_final = tfkl.Flatten()(discriminator_outputs[-1])
discriminator_final = tfkl.Dense(1)(discriminator_final)

discriminator = tf.keras.Model(discriminator_input, [discriminator_final] + discriminator_outputs, name="discriminator")


# note: the noise is generated directly in "image shape" a 4x4 image with 16 channels.
# but all "pixels" are taken from the noise function.
code_shape = (4, 4, 16)
generator_input = tf.keras.Input(code_shape)
generator_output = residual_stack(generator_input, reversed(filters), strides, blocks_per_level, "transpose", "generator",
                                 normalization)[-1]
# tanh output activation to scale to [-1, 1]
generator_final = tfkl.Conv2D(1, 1, activation=tf.nn.tanh)(generator_output)

generator = tf.keras.Model(generator_input, generator_final, name="generator")


discriminator.summary()
generator.summary()



In [None]:
def noise_fn(n_samples):
    return tf.random.normal((n_samples,) + code_shape)
    #return tf.random.uniform([n_samples, np.prod(code_shape)], -1., 1.)  # should perform similarly

label_smoothing = 0.9

loss_fn = tf.losses.BinaryCrossentropy(from_logits=True)

n_steps = 50000
n_data = 60000
n_epochs = n_steps // (n_data // batch_size)
lr = tf.optimizers.schedules.CosineDecay(0.001, n_steps)

# note the unusual beta, default is 0.9
gen_opt = tf.optimizers.Adam(lr, beta_1=0.5)
disc_opt = tf.optimizers.Adam(lr, beta_1=0.5)

gan = GAN(generator, discriminator, loss_fn, noise_fn, gen_opt, disc_opt)

gan.compile(jit_compile=True)

In [None]:
class ImageGenCallback(tf.keras.callbacks.Callback):
    def __init__(self, frequency, **kwargs):
        super().__init__(**kwargs)
        self.frequency = frequency
        
    def on_epoch_end(self, epoch, logs=None):
        if not epoch % self.frequency:
            noise = noise_fn(64)
            generated_batch = 0.5 * (self.model.generator(noise) + 1)
        
            plt.figure(figsize=(15,15))
            for ind, image in enumerate(generated_batch):
                plt.subplot(8, 8, ind+1)
                plt.imshow(image, vmin=0, vmax=1, cmap="Greys")
                plt.axis("off")
            plt.suptitle("Random generations")
            plt.show()
            
            
do_train = True

if do_train:
    image_gen_callback = ImageGenCallback(10)

    history = gan.fit(train_data, epochs=n_epochs, callbacks=[image_gen_callback])
    gan.save_weights("weights/weights_assignment05_small.hdf5")

else:
    gan.load_weights("weights/weights_assignment05_small.hdf5")

In [None]:
random_codes = noise_fn(64)
generated = (generator(random_codes).numpy() + 1) * 0.5

plt.figure(figsize=(15, 15))
for ind, image in enumerate(generated):
    plt.subplot(8, 8, ind+1)
    plt.imshow(image, vmin=0, vmax=1, cmap="Greys")
    plt.axis("off")
plt.show()

In [None]:
# potentially better, but less variable, sampling with truncated noise
random_codes = noise_fn(64)
lim = 1.  # as you make this smaller, you will get less variety, but possibly better outputs (less noisy, artifact-y)
random_codes = tf.clip_by_value(random_codes, -lim, lim)
generated = (generator(random_codes).numpy() + 1) * 0.5

plt.figure(figsize=(15, 15))
for ind, image in enumerate(generated):
    plt.subplot(8, 8, ind+1)
    plt.imshow(image, vmin=0, vmax=1, cmap="Greys")
    plt.axis("off")
plt.show()

In [None]:
# TODO compare FID scores of truncated vs not

In [None]:
# following: "inference" for GANs via optimization
some_imgs = test_images[:32]  # some target images

plt.figure(figsize=(15, 15))
for ind, image in enumerate(some_imgs):
    plt.subplot(8, 4, ind+1)
    plt.imshow(image, cmap="Greys", vmin=0, vmax=1)
    plt.axis("off")
plt.show()

In [None]:
# we start with random noise, and optimize the noise such that the generated images become closer to the target ones.
# afterwards, we could take the optimized noise and claim those as "latent variables" for the generated/target images.
# problem: the optimized noise may not be anything like samples from our noise distribution (here, standard normal)!

@tf.function(jit_compile=True)
def opt_noise(noise, targets):
    with tf.GradientTape(watch_accessed_variables=False) as tape:
        tape.watch(noise)
        candidate_gen = generator(noise)
        gen_error = tf.reduce_mean((candidate_gen - targets)**2)
    noise_grad = tape.gradient(gen_error, noise)
    noise_opt.apply_gradients(zip([noise_grad], [noise]))
    
    return gen_error

    
candidate_noise = tf.Variable(noise_fn(32))
noise_opt = tf.optimizers.Adam(0.1)
n_steps = 25001

for step in range(n_steps):
    if not step % 2500:
        print(step)
        current_state = (generator(candidate_noise).numpy() + 1) / 2.

        plt.figure(figsize=(15, 15))
        for ind, image in enumerate(current_state):
            plt.subplot(8, 4, ind+1)
            plt.imshow(np.concatenate([image, some_imgs[ind]], axis=1), cmap="Greys", vmin=0, vmax=1)
            plt.axis("off")
        plt.show()

    gen_error = opt_noise(candidate_noise, 2*some_imgs-1)
    
    if not step % 2500:
        print("error", gen_error.numpy())