In [None]:
# NOTE this large version has several differences from the small one.
# I tried to annotate all of them.

In [None]:
!nvidia-smi

In [None]:
# blabla don't run this on colab
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
os.environ['HTTP_PROXY']='http://proxy:3128/'
os.environ['HTTPS_PROXY']='http://proxy:3128/'

In [None]:
# imports
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import image as mpimage

tfkl = tf.keras.layers


from data.utils import parse_image_example
from modeling.layers import ResidualBlock

In [None]:
class SDLayer(tf.keras.layers.Layer):
    def __init__(self):
        super(SDLayer, self).__init__()
        
    def call(self, inputs):
        _, variances = tf.nn.moments(inputs, axes=[0])
        global_value = tf.reduce_mean(variances)
        broadcast = global_value * tf.ones_like(inputs)
        return tf.concat([inputs, broadcast[...,:1]], axis=-1)

In [None]:
# I trained this on multiple GPUs, this is a hold-over from that
strategy = tf.distribute.MirroredStrategy()

In [None]:
# prints number of GPUs
strategy.num_replicas_in_sync

In [None]:
# needs preprocessed flickr faces data
batch_size = 256 * strategy.num_replicas_in_sync
train_data = tf.data.TFRecordDataset("data/flickr_64_train.TFR").shuffle(60000).map(parse_image_example).batch(batch_size, drop_remainder=True)
test_data = tf.data.TFRecordDataset("data/flickr_64_test.TFR").map(parse_image_example).batch(batch_size)

train_data = train_data.map(tf.image.random_flip_left_right)

In [None]:
test_images = np.concatenate([batch for batch in iter(test_data)], axis=0)

plt.figure(figsize=(15,15))
for ind, img in enumerate(test_images[:64]):
    plt.subplot(8, 8, ind+1)
    plt.imshow(img)
    plt.axis("off")
plt.show()

In [None]:
# same as the small one (I think xd)
class GAN(tf.keras.Model):
    def __init__(self, generator, discriminator, noise_function,
                 generator_optimizer, discriminator_optimizer,
                 label_smoothing=0.9, **kwargs):
        super().__init__(**kwargs)
        self.generator = generator
        self.discriminator = discriminator
        
        self.generator_optimizer = generator_optimizer
        self.discriminator_optimizer = discriminator_optimizer
        
        self.noise_function = noise_function
        self.label_smoothing = label_smoothing
        
        self.g_loss_tracker = tf.keras.metrics.Mean("generator_loss")
        self.d_loss_tracker = tf.keras.metrics.Mean("discriminator_loss")
        self.feature_loss_tracker = tf.keras.metrics.Mean("feature_matching_loss")
        
    def call(self, noise_input):
        # dummy call, needed to be able to load_weights()
        return self.discriminator(self.generator(noise_input))
        
    def train_step(self, data):
        data = 2*data - 1
        
        dequantize_scale = 1/128  # +/- 1 pixel if scaled between -1 and 1
        batch_dim = tf.shape(data)[0]

        # prepare mixed batch for discriminator training.
        # NOTE training=True here is doubtful as we are not actually training the generator here.
        # I would recommend avoiding batchnorm, then this question doesn't matter
        generated_batch = generator(self.noise_function(batch_dim), training=True)
        # one-sided label smoothing applied here
        real_labels = self.label_smoothing*tf.ones([batch_dim, 1])
        generated_labels = tf.zeros([batch_dim, 1])

        # adding noise makes it more difficult and "de-quantizes" the data
        data = data + tf.random.uniform(tf.shape(data), -dequantize_scale, dequantize_scale)
        # tbh I'm not sure if it's necessary to apply it to the generated data s well
        generated_batch = generated_batch + tf.random.uniform(tf.shape(generated_batch), -dequantize_scale, dequantize_scale)
        
        full_batch = tf.concat((data, generated_batch), axis=0)
        full_labels = tf.concat((real_labels, generated_labels), axis=0)

        with tf.GradientTape() as d_tape:
            # index [0] into D output since it returns hidden layers as well (for feature matching)
            d_loss = self.compiled_loss(full_labels, discriminator(full_batch, training=True)[0])
        d_gradients = d_tape.gradient(d_loss, discriminator.trainable_variables)
        self.discriminator_optimizer.apply_gradients(zip(d_gradients, discriminator.trainable_variables))

        # fresh generated batch for generator training
        with tf.GradientTape(watch_accessed_variables=False) as g_tape:
            # tape would automatically watch D variables -> wasteful
            for variable in generator.trainable_variables:
                g_tape.watch(variable)

            gen_only_batch = generator(self.noise_function(2*batch_dim), training=True)
            gen_only_batch = gen_only_batch + tf.random.uniform(tf.shape(gen_only_batch), -dequantize_scale, dequantize_scale)

            d_output_fake = discriminator(gen_only_batch, training=True)
            # since we updated D, we need to re-compute the output for the real data for feature matching
            d_output_real = discriminator(data, training=True)
            # no label smoothing for generator training
            g_loss = self.compiled_loss(tf.ones([2*batch_dim, 1]), d_output_fake[0])

            feature_match_loss = 0
            for fake_feature, real_feature in zip(d_output_fake[1:], d_output_real[1:]):
                feature_difference = tf.reduce_mean(fake_feature, axis=0) - tf.reduce_mean(real_feature, axis=0)
                feature_match_loss += tf.reduce_sum(feature_difference**2)

            g_loss_full = g_loss + feature_match_loss
        g_gradients = g_tape.gradient(g_loss_full, generator.trainable_variables)
        self.generator_optimizer.apply_gradients(zip(g_gradients, generator.trainable_variables))
        
        self.g_loss_tracker.update_state(g_loss)
        self.d_loss_tracker.update_state(d_loss)
        self.feature_loss_tracker.update_state(feature_match_loss)

        return {"generator_loss": self.g_loss_tracker.result(),
                "discriminator_loss": self.d_loss_tracker.result(),
                "feature_matching_loss": self.feature_loss_tracker.result()}
    
    @property
    def metrics(self):
        return [self.g_loss_tracker, self.d_loss_tracker, self.feature_loss_tracker]

In [None]:
def residual_stack(inputs, filters, strides, blocks_per_level, mode, name, normalization):
    all_outputs = []
    outputs = inputs
    for level_ind, (level_filters, level_stride) in enumerate(zip(filters, strides)):
        for block_ind in range(blocks_per_level):
            outputs = ResidualBlock(level_filters,
                                    mode, 
                                    strides=level_stride if block_ind == (blocks_per_level - 1) else 1,
                                    name="_".join([name, str(level_ind+1), str(block_ind+1)]),
                                    normalization=normalization)(outputs)
        all_outputs.append(outputs)
        
    return all_outputs


# this version of the generator "passes through" the early noise (here called code)
# to all later blocks. this way the blocks have more direct access to the original noise.
# maybe this helps =) maybe not
def residual_stackD(inputs, code, filters, strides, blocks_per_level, mode, name, normalization):
    outputs = inputs
    for level_ind, (level_filters, level_stride) in enumerate(zip(filters, strides)):
        if level_ind > 0:
            code = tfkl.UpSampling2D(interpolation="bilinear")(code)
            
        for block_ind in range(blocks_per_level):
            if level_ind > 0 or block_ind > 0:
                combined = tf.concat((outputs, code), axis=-1)
            else:
                combined = outputs
            outputs = ResidualBlock(level_filters,
                                    mode, 
                                    strides=level_stride if block_ind == (blocks_per_level - 1) else 1,
                                    name="_".join([name, str(level_ind+1), str(block_ind+1)]),
                                    normalization=normalization)(combined)
        
    return outputs

In [None]:
# build the model

tf.keras.backend.clear_session()


normalization = lambda **kwargs: tfkl.GroupNormalization(groups=32, **kwargs)
#normalization = tfkl.BatchNormalization

# so big wow
blocks_per_level = 4
filters = [64, 128, 256, 512, 768]
strides = [2, 2, 2, 2, 1]

with strategy.scope():
    discriminator_input = tf.keras.Input((64, 64, 3))
    discriminator_outputs = residual_stack(discriminator_input, filters, strides, blocks_per_level, "conv", "discriminator",
                                         normalization)
    discriminator_final = tfkl.Flatten()(discriminator_outputs[-1])
    discriminator_final = SDLayer()(discriminator_final)
    discriminator_final = tfkl.Dense(1)(discriminator_final)

    discriminator = tf.keras.Model(discriminator_input, [discriminator_final] + discriminator_outputs, name="discriminator")


    # NOTE this time I start with 1D code vector and reshape into the image shape.
    # I found this to work a bit better maybe?
    # I also start the image already at 8x8 pixels instead of 4x4.
    # I really have no clue what is optimal here...
    code_shape = (512,)
    generator_input = tf.keras.Input(code_shape)
    generator_front = tfkl.Dense(8*8*64)(generator_input)
    generator_front = tfkl.Reshape((8, 8, 64))(generator_front)
    
    generator_output = residual_stackD(generator_front, generator_front, reversed(filters[:-1]), strides[1:], blocks_per_level, "upconv", "generator",
                                     normalization)
    generator_final = tfkl.Conv2D(3, 1, activation=tf.nn.tanh)(generator_output)

    generator = tf.keras.Model(generator_input, generator_final, name="generator")


# not happy with the discriminator having 2x as many parameters as the generator.
# this may lead to suboptimal performance.
discriminator.summary()
generator.summary()

In [None]:
def noise_fn(n_samples):
    # this is binary noise!
    # bernoulli with p=0.5.
    return tf.cast(tf.random.uniform((n_samples,) + code_shape, maxval=2, dtype=tf.int32), tf.float32)

label_smoothing = 0.9

with strategy.scope():
    loss_fn = tf.losses.BinaryCrossentropy(from_logits=True)

    n_steps = 100000
    n_data = 60000
    n_epochs = n_steps // (n_data // batch_size)
    lr = tf.optimizers.schedules.CosineDecay(0.0002, n_steps)

    gen_opt = tf.optimizers.Adam(lr, beta_1=0.5)
    disc_opt = tf.optimizers.Adam(lr, beta_1=0.5)

    gan = GAN(generator, discriminator, noise_fn, gen_opt, disc_opt)

    gan.compile(jit_compile=True, loss=loss_fn)

In [None]:
class ImageGenCallback(tf.keras.callbacks.Callback):
    def __init__(self, frequency, **kwargs):
        super().__init__(**kwargs)
        self.frequency = frequency
        
    def on_epoch_end(self, epoch, logs=None):
        if not epoch % self.frequency:
            noise = noise_fn(64)
            generated_batch = 0.5 * (self.model.generator(noise) + 1)
        
            plt.figure(figsize=(15,15))
            for ind, image in enumerate(generated_batch):
                plt.subplot(8, 8, ind+1)
                plt.imshow(image)
                plt.axis("off")
            plt.suptitle("Random generations")
            plt.show()
            
            
do_train = True

if do_train:
    image_gen_callback = ImageGenCallback(5)

    history = gan.fit(train_data, epochs=n_epochs, callbacks=[image_gen_callback])
    gan.save_weights("weights/weights_assignment05_large.hdf5")

else:
    # loading weights doesn't work if the model hasn't been called
    gan(noise_fn(1))
    gan.load_weights("weights/weights_assignment05_large.hdf5")

In [None]:
random_codes = noise_fn(64)
generated = (generator(random_codes).numpy() + 1) * 0.5

plt.figure(figsize=(15, 15))
for ind, image in enumerate(generated):
    plt.subplot(8, 8, ind+1)
    plt.imshow(image)
    plt.axis("off")
plt.show()

In [None]:
# following: "inference" for GANs via optimization
# this works terribly because we can't optimize binary noise with gradient descent
some_imgs = test_images[:32]

plt.figure(figsize=(15, 15))
for ind, image in enumerate(some_imgs):
    plt.subplot(8, 4, ind+1)
    plt.imshow(image, cmap="Greys", vmin=0, vmax=1)
    plt.axis("off")
plt.show()

In [None]:
@tf.function(jit_compile=True)
def opt_noise(noise, targets):
    with tf.GradientTape(watch_accessed_variables=False) as tape:
        tape.watch(noise)
        candidate_gen = generator(noise)
        gen_error = tf.reduce_mean((candidate_gen - targets)**2)
    noise_grad = tape.gradient(gen_error, noise)
    noise_opt.apply_gradients(zip([noise_grad], [noise]))
    # I at least clip noise to the RANGE [0,1], which is still not binary.
    noise.assign(tf.clip_by_value(noise, 0, 1))
    
    return gen_error

    
candidate_noise = tf.Variable(noise_fn(32))
n_steps = 25001
noise_opt = tf.optimizers.Adam(tf.optimizers.schedules.CosineDecay(0.1, n_steps))

for step in range(n_steps):
    if not step % 2500:
        print(step)
        current_state = (generator(candidate_noise).numpy() + 1) / 2.

        plt.figure(figsize=(15, 15))
        for ind, image in enumerate(current_state):
            plt.subplot(8, 4, ind+1)
            plt.imshow(np.concatenate([image, some_imgs[ind]], axis=1), cmap="Greys", vmin=0, vmax=1)
            plt.axis("off")
        plt.show()

    gen_error = opt_noise(candidate_noise, 2*some_imgs-1)
    
    if not step % 2500:
        print("error", gen_error.numpy())

In [None]:
# here I generated 32 images and set all codes to 0 in one specific dimension. generate images.
# then also set all codes to 1 in that same dimension, generate images.
# -> the images only differ by one variable flipped.
# this could in theory tell us about what that dimension represents.
# in practice, results don't tell me anything...
random_codes = noise_fn(32)
random_codes = np.tile(random_codes, [2, 1])

dim = 0
if True:
    random_codes[:32, dim] = 0
    random_codes[32:, dim] = 1

#for index in range(4):
#    random_codes[8*2*index:8*(2*index+1), index] = 0
#    random_codes[8*index:8*(index+1), index + 1] = 1

generated = (generator(random_codes).numpy() + 1) * 0.5

In [None]:
# first 4 rows: dimn set to 0; rows 5-8: dim set to 1.
plt.figure(figsize=(15, 15))
for ind, image in enumerate(generated):
    plt.subplot(8, 8, ind+1)
    plt.imshow(image)
    plt.axis("off")
plt.show()