In [None]:
!nvidia-smi  # check GPU usage -- can ignore this

In [None]:
# ignore this cell -- stuff for our server
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ['HTTP_PROXY']='http://proxy:3128/'
os.environ['HTTPS_PROXY']='http://proxy:3128/'

In [None]:
# imports
import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import image as mpimage

from data.utils import parse_image_example
from modeling.layers import ConvNormAct, ResidualBlock

tfkl = tf.keras.layers

In [None]:
batch_size = 128
train_data = tf.data.TFRecordDataset("data/flickr_64_train.TFR").shuffle(60000).map(parse_image_example).batch(batch_size)
test_data = tf.data.TFRecordDataset("data/flickr_64_test.TFR").map(parse_image_example).batch(batch_size)

In [None]:
test_images = np.concatenate([batch for batch in iter(test_data)], axis=0)

plt.figure(figsize=(15,15))
for ind, img in enumerate(test_images[:64]):
    plt.subplot(8, 8, ind+1)
    plt.imshow(img)
    plt.axis("off")
plt.show()

In [None]:
test_images.max()

In [None]:
# a few options for loss functions.
# note: all of these losses return a batch_size vector, i.e. no averaging over the batch dimension at this point.
# also note that we generally SUM over dimensions of the image!

def squared_loss(y_true, y_pred):
    batch_shape = tf.shape(y_true)[0]
    y_true = tf.reshape(y_true, [batch_shape, -1])
    y_pred = tf.reshape(y_pred, [batch_shape, -1])
    
    return tf.reduce_sum((y_true - y_pred)**2, axis=-1)


def logloss(y_true, y_pred, epsilon=0.):
    # this is what we get if assuming a gaussian likelihood and choosing optimal sigma PER IMAGE
    batch_shape = tf.shape(y_true)[0]
    y_true = tf.reshape(y_true, [batch_shape, -1])
    y_pred = tf.reshape(y_pred, [batch_shape, -1])
    
    return tf.cast(tf.shape(y_true)[1], tf.float32) * tf.math.log(tf.norm(y_true - y_pred, axis=-1) + epsilon)


def logloss2(y_true, y_pred, epsilon=0.):
    # this is what we get if assuming a gaussian likelihood and choosing optimal sigma PER PIXEL.
    # very unstable >:(
    batch_shape = tf.shape(y_true)[0]
    y_true = tf.reshape(y_true, [batch_shape, -1])
    y_pred = tf.reshape(y_pred, [batch_shape, -1])
    
    return tf.reduce_sum(tf.math.log(tf.abs(y_true - y_pred) + epsilon), axis=-1)


def bernoulli_loss(y_true, y_pred):
    batch_shape = tf.shape(y_true)[0]
    y_true = tf.reshape(y_true, [batch_shape, -1])
    y_pred = tf.reshape(y_pred, [batch_shape, -1])
    
    xent = tf.nn.sigmoid_cross_entropy_with_logits(labels=y_true, logits=y_pred)
    
    return tf.reduce_sum(xent, axis=-1)


def continuous_bernoulli_loss(y_true, y_pred):
    # this is the loss for the continuous bernoulli distribution
    # it's really just binary cross-entropy plus one more term corresponding to the normalization constant
    batch_shape = tf.shape(y_true)[0]
    y_true = tf.reshape(y_true, [batch_shape, -1])
    y_pred = tf.reshape(y_pred, [batch_shape, -1])
    
    base = tf.nn.sigmoid_cross_entropy_with_logits(labels=y_true, logits=y_pred)
    log_normalizer = continuous_bernoulli_log_normalizer(tf.clip_by_value(tf.nn.sigmoid(y_pred), 1e-4, 1-1e-4))
    
    return tf.reduce_sum(base - log_normalizer, axis=-1)


def continuous_bernoulli_log_normalizer(lam, l_lim=0.49, u_lim=0.51):
    # taken from https://github.com/cunningham-lab/cb_and_cc
    cut_lam = tf.where(tf.logical_or(tf.less(lam, l_lim), tf.greater(lam, u_lim)), lam, l_lim * tf.ones_like(lam))
    log_norm = tf.math.log(tf.abs(2.0 * tf.math.atanh(1 - 2.0 * cut_lam))) - tf.math.log(tf.abs(1 - 2.0 * cut_lam))
    taylor = tf.math.log(2.0) + 4.0 / 3.0 * tf.pow(lam - 0.5, 2) + 104.0 / 45.0 * tf.pow(lam - 0.5, 4)
    return tf.where(tf.logical_or(tf.less(lam, l_lim), tf.greater(lam, u_lim)), log_norm, taylor)


def continuous_bernoulli_expected_value(lam, l_lim=0.49, u_lim=0.51):
    # if using continuous bernoulli, the expected value is a bit more complicated
    cut_lam = tf.where(tf.logical_or(tf.less(lam, l_lim), tf.greater(lam, u_lim)), lam, l_lim * tf.ones_like(lam))
    expected = cut_lam / (2*cut_lam - 1) + 1 / (2*tf.math.atanh(1 - 2*cut_lam))
    return tf.where(tf.logical_or(tf.less(lam, l_lim), tf.greater(lam, u_lim)), expected, 0.5*tf.ones_like(expected))


def kl_loss_function(means, log_variances):
    batch_shape = tf.shape(means)[0]
    means = tf.reshape(means, [batch_shape, -1])
    log_variances = tf.reshape(log_variances, [batch_shape, -1])
    
    return 0.5 * tf.reduce_sum(means**2 - 1 + tf.exp(log_variances) - log_variances, axis=-1)

In [None]:
class VAE(tf.keras.Model):
    def __init__(self, inputs, encoder, decoder, reconstruction_loss_fn, beta=1., **kwargs):
        super().__init__(inputs, decoder(self.sample_codes(*encoder(inputs))), **kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.reconstruction_loss_fn = reconstruction_loss_fn
        
        self.loss_tracker = tf.keras.metrics.Mean("loss")
        self.kl_tracker = tf.keras.metrics.Mean("kld")
        self.recon_tracker = tf.keras.metrics.Mean("recon_loss")
        
        # to get "standard" VAE, use beta=1
        self.beta = beta
        
    def train_step(self, data):
        with tf.GradientTape() as tape:
            means, log_variances = encoder(data, training=True)
            sampled_codes = self.sample_codes(means, log_variances)
            reconstructions = decoder(sampled_codes, training=True)
            
            recon_loss = self.reconstruction_loss_fn(data, reconstructions)
            
            kl_loss = kl_loss_function(means, log_variances)
            
            total_loss = tf.reduce_mean(recon_loss + self.beta*kl_loss)
            
        variables = self.trainable_variables
        gradients = tape.gradient(total_loss, variables)
        self.optimizer.apply_gradients(zip(gradients, variables))
        
        self.loss_tracker.update_state(total_loss)
        self.kl_tracker.update_state(kl_loss)
        self.recon_tracker.update_state(recon_loss)
        
        return {"loss": self.loss_tracker.result(),
                "kld": self.kl_tracker.result(),
                "recon_loss": self.recon_tracker.result()}
    
    def test_step(self, data):
        means, log_variances = encoder(data, training=False)
        sampled_codes = self.sample_codes(means, log_variances)
        reconstructions = decoder(sampled_codes, training=False)

        recon_loss = self.reconstruction_loss_fn(data, reconstructions)

        kl_loss = kl_loss_function(means, log_variances)

        total_loss = tf.reduce_mean(recon_loss + self.beta*kl_loss)
        
        self.loss_tracker.update_state(total_loss)
        self.kl_tracker.update_state(kl_loss)
        self.recon_tracker.update_state(recon_loss)
        
        return {"loss": self.loss_tracker.result(),
                "kld": self.kl_tracker.result(),
                "recon_loss": self.recon_tracker.result()}
    
    @property
    def metrics(self):
        return [self.loss_tracker, self.kl_tracker, self.recon_tracker]
    
    
    def sample_codes(self, means, log_variances):
        stddevs = tf.exp(0.5*log_variances)
        return stddevs * tf.random.normal(tf.shape(means)) + means

In [None]:
def residual_stack(inputs, filters, strides, blocks_per_level, mode, name):
    outputs = inputs
    for level_ind, (level_filters, level_stride) in enumerate(zip(filters, strides)):
        for block_ind in range(blocks_per_level):
            outputs = ResidualBlock(level_filters,
                                    mode, 
                                    strides=level_stride if block_ind == (blocks_per_level - 1) else 1,
                                    name="_".join([name, str(level_ind+1), str(block_ind+1)]))(outputs)
        
    return outputs

In [None]:
# build the model
likelihood = "continuous_bernoulli"

loss_dictionary = {"bernoulli": bernoulli_loss, 
                   "continuous_bernoulli": continuous_bernoulli_loss, 
                   "gaussian_fixed_sigma": squared_loss,
                   "gaussian_image_sigma": logloss, 
                   "gaussian_pixel_sigma": logloss2}


if likelihood not in loss_dictionary.keys():
    raise ValueError("Invalid likelihood!")

tf.keras.backend.clear_session()

#normalization = lambda **kwargs: tfkl.GroupNormalization(groups=32, **kwargs)
normalization = tfkl.BatchNormalization

blocks_per_level = 4
filters = [64, 128, 256, 512, 512]
strides = [2, 2, 2, 2, 1]

encoder_input = tf.keras.Input((64, 64, 3))
encoder_output = residual_stack(encoder_input, filters, strides, blocks_per_level, "conv", "encoder")
encoder_final = tfkl.Flatten()(encoder_output)
encoder_final = tfkl.Dense(2*512)(encoder_final)

means, log_variances = tf.split(encoder_final, 2, axis=-1)
encoder = tf.keras.Model(encoder_input, [means, log_variances], name="encoder")
code_shape = encoder.output_shape[0][1:]


decoder_input = tf.keras.Input(code_shape)

decoder_front = tfkl.Dense(4*4*512)(decoder_input)
decoder_front = tfkl.Reshape((4, 4, 512))(decoder_front)

decoder_output = residual_stack(decoder_front, reversed(filters), strides, blocks_per_level, "transpose", "decoder")
decoder_final = tfkl.Conv2D(3, 1)(decoder_output)

decoder = tf.keras.Model(decoder_input, decoder_final, name="decoder")

# beta is set here! IMPORTANT PARAMETER!
model = VAE(encoder_input, encoder, decoder, loss_dictionary[likelihood], beta=4., name="vae")
model.summary(expand_nested=True)

# this initializes the final layer weights to 0.
# seems to help with exploding loss in the beginning.
# in some sense this is the "perfect" initialization for the KL divergence,
# as all means will be 0 and all log variances as well(-> variances will be 1)
model.layers[1].layers[-2].weights[0].assign(tf.zeros_like(model.layers[1].layers[-2].weights[0]))

In [None]:
n_steps = 100000
n_data = 60000
n_epochs = n_steps // (n_data // batch_size)

learning_rate = 0.0002
#decay_fn = tf.keras.optimizers.schedules.CosineDecay(learning_rate, n_steps)
optimizer = tf.optimizers.Adam(learning_rate)

model.compile(optimizer=optimizer)

In [None]:
def map_for_likelihood(inputs, likelihood):
    # depending on what likelihood we choose for our data, we may need to model outputs to something else.
    # - for bernoulli likelihood, technically we don't have to do anything.
    #   however, we usually don't want a sigmoid output in our model when we use cross-entropy loss.
    #   so we do the sigmoid here.
    # - for continuous bernoulli, we need to map the sigmoid result further.
    # - for gaussian, we don't need to do anything, but since the output is unconstrained,
    #   it makes sense to clip to [0, 1]
    if "bernoulli" in likelihood:
        outputs = tf.nn.sigmoid(inputs)
        if likelihood == "continuous_bernoulli":
            outputs = continuous_bernoulli_expected_value(
                tf.clip_by_value(outputs, 1e-4, 1-1e-4))
    else:  # gaussian
        outputs = np.clip(inputs, 0, 1)
    return outputs


class ImageGenCallback(tf.keras.callbacks.Callback):
    def __init__(self, frequency, **kwargs):
        super().__init__(**kwargs)
        self.frequency = frequency
        
    def on_epoch_end(self, epoch, logs=None):
        if not epoch % self.frequency:
            noise = tf.random.normal((64,) + code_shape)
            generated_batch = map_for_likelihood(self.model.decoder(noise), likelihood)
        
            plt.figure(figsize=(15,15))
            for ind, image in enumerate(generated_batch):
                plt.subplot(8, 8, ind+1)
                plt.imshow(image)
                plt.axis("off")
            plt.suptitle("Random generations")
            plt.show()


class ReconstructionCallback(tf.keras.callbacks.Callback):
    def __init__(self, frequency, **kwargs):
        super().__init__(**kwargs)
        self.frequency = frequency
    
    def on_epoch_end(self, epoch, logs=None):
        if not epoch % self.frequency:
            cropped_test = test_images[:32]
            generated_batch = map_for_likelihood(self.model(cropped_test), likelihood)
        
            plt.figure(figsize=(15,15))
            for ind, (original, reconstruction) in enumerate(zip(cropped_test, generated_batch)):
                comparison = np.concatenate((original, reconstruction), axis=1)
                plt.subplot(8, 4, ind+1)
                plt.imshow(comparison)
                plt.axis("off")
            plt.suptitle("Test set reconstructions")
            plt.show()

do_train = True

# In general I'm not super happy with the training :(
# early stopping tends to kick in rather... early, so not many steps are done.
# maybe 40 epochs or so, depending on beta.
# I think the architecture is still not otpimal.
# however, tuning it takes quite long.


if do_train:
    lr_schedule = tf.keras.callbacks.ReduceLROnPlateau(factor=0.2, patience=5, verbose=1)
    early_stop = tf.keras.callbacks.EarlyStopping(patience=20, restore_best_weights=True, verbose=1)
    reconstruct = ReconstructionCallback(10)
    image_gen_callback = ImageGenCallback(10)

    history = model.fit(train_data, epochs=n_epochs, validation_data=test_data,
                        callbacks=[lr_schedule, early_stop, reconstruct, image_gen_callback])
    model.save_weights("weights/weights_assignment04_large.hdf5")

    # note: the if using the continuous bernoulli loss, the loss will likely be < 0.
    # this might seem a bit weird, but is actually not an issue.
    # recall that this loss is the negative log likelihood.
    # - if the NLL is negative, that means the log likelihood is positive
    # - if the log likelihood is > 0, that implies that the likelihood is > 1
    # - a p > 1 is nothing unusual for continuous distributions, where we are using *density* functions
else:
    model.load_weights("weights/weights_assignment04_large.hdf5")

In [None]:
model.evaluate(test_data)

In [None]:
# compare some inputs and reconstructions
reconstructions = map_for_likelihood(model.predict(test_data), likelihood)

plt.figure(figsize=(15, 15))
for ind, (original, reconstruction) in enumerate(zip(test_images[:32], reconstructions[:32])):
    plt.subplot(8, 4, ind+1)
    concat = np.concatenate((original, reconstruction), axis=1)
    plt.imshow(concat, vmin=0, vmax=1)
    plt.axis("off")
plt.show()

In [None]:
random_codes = tf.random.normal((64,) + code_shape)
generated = map_for_likelihood(decoder(random_codes), likelihood)

plt.figure(figsize=(15, 15))
for ind, image in enumerate(generated):
    plt.subplot(8, 8, ind+1)
    plt.imshow(image, vmin=0, vmax=1)
    plt.axis("off")
plt.show()

In [None]:
# can check if these are close to 0 (i.e. did the KL divergence do its job)
all_means, all_log_variances = encoder.predict(test_data)

In [None]:
plt.hist(all_means.reshape(-1), bins=100)
plt.title("Distribution of latent means in test set")
plt.show()

plt.hist(all_log_variances.reshape(-1), bins=100)
plt.title("Distribution of latent log variances in test set")
plt.show()

# should be close to 1
plt.hist(np.exp(all_log_variances.reshape(-1)), bins=100)
plt.title("Distribution of latent variances in test set")
plt.show()

In [None]:
randind1 = np.random.randint(len(test_images))
randind2 = np.random.randint(len(test_images))

# annoying indexing since we only want 1 code but need a batch axis
code1_mean, code1_logvar = encoder(test_images[randind1][None])
code1 = model.sample_codes(code1_mean, code1_logvar)

code2_mean, code2_logvar = encoder(test_images[randind2][None])
code2 = model.sample_codes(code2_mean, code2_logvar)

n_interpolation = 64
interpolation_coeffs = np.linspace(0, 1, n_interpolation)


interpolation = "slerp"  # or "linear"
# due to "reasons", spherical interpolation may be more appropriate

# note, the 0 indexing in here is just to get rid of the batch axis
interpolated_codes = []
for coefficient in interpolation_coeffs:
    if interpolation == "linear":
        interpolated_codes.append(coefficient*code2[0] + (1-coefficient)*code1[0])
    elif interpolation == "slerp":
        angle = np.arccos(np.dot(code1[0], code2[0]) / (np.linalg.norm(code1)*np.linalg.norm(code2)))
        interpolated_codes.append(np.sin((1-coefficient)*angle)/np.sin(angle) * code1[0] 
                                  + np.sin(coefficient*angle)/np.sin(angle) * code2[0])
    else:
        raise ValueError("invalid interpolation")
interpolated_codes = np.array(interpolated_codes)


# see the images we are dealing with here
plt.imshow(test_images[randind1])
plt.axis("off")
plt.show()

plt.imshow(test_images[randind2])
plt.axis("off")
plt.show()

In [None]:
# show: first the original, to the right of it the reconstruction from the mean,
# then 14 random code reconstructions
# I only do it for the 1st image here.
repeated_image = tf.tile(test_images[randind1:randind1+1], [14, 1, 1, 1])

# I ran out of names so it's a
a = map_for_likelihood(model(repeated_image), likelihood)
recon_from_mean = map_for_likelihood(decoder(code1_mean), likelihood)
a = np.concatenate([test_images[randind1][None], recon_from_mean, a])

plt.figure(figsize=(8, 8))
for ind, image in enumerate(a):
    plt.subplot(4, 4, ind+1)
    plt.imshow(image, vmin=0, vmax=1)
    plt.axis("off")
plt.show()

In [None]:
# actually look at the interpolations
# NOTE that the 2D grid layout implies there's something 2D happening here.
# it's not. it's just saves space to display the images this way.
# it's a 1d interpolation from top left to bottom right.
# it would be clearer to plot it all in one row or column.
interpolated_images = map_for_likelihood(decoder(interpolated_codes), likelihood)
plt.figure(figsize=(15, 15))
for ind in range(64):
    plt.subplot(8, 8, ind+1)
    plt.imshow(interpolated_images[ind])
    plt.axis("off")
plt.show()

In [None]:
# here's a fun thing we can do to check the usage of the latent space.
# we iterate over the latent dimensions. 
# each time, we take the original code for an image and fix all dimensions except the current one.
# the current dimension is moved in some range (here, -2 to 2), sampling multiple points along the way.
# each resulting code is decoded and displayed.
# I compute the difference in pixel space between successive images.
# for some dimensions, there is next to no difference.
# I conclude that these dimensions are effectively unused.
# for dimensions where the average change is larger than some threshold, I plot the "walk" over that dimension.

# an issue here is that I'm only doing this for a single image. ideally, we would sample images randomly, or
# do the entire process for many images.

diffs = []
for latent_dim in range(code_shape[-1]):

    code_repeated = np.tile(code1, [n_interpolation, 1])
    value_range = np.linspace(-2, 2, n_interpolation)

    for dim in range(len(code_repeated)):
        code_repeated[dim, latent_dim] = value_range[dim]

    interpolated_images = map_for_likelihood(decoder(code_repeated), likelihood)

    avg_abs_diff = np.abs(interpolated_images[0] - interpolated_images[-1]).sum() / np.prod(interpolated_images[0].shape)
    diffs.append(avg_abs_diff)
    # this value is somewhat arbitrary!
    if avg_abs_diff < 0.05:
        print("Dim {} not used".format(latent_dim))
    else:
        print(avg_abs_diff)
        plt.figure(figsize=(15, 15))
        for ind in range(64):
            plt.subplot(8, 8, ind+1)
            plt.imshow(interpolated_images[ind])
            plt.axis("off")
        plt.show()

In [None]:
# here I sort the differences computed above.
# we can see that there is a sharp dropoff in difference at some point.
# this may be seen as a kind of cut-off point which allows us to (very approximately)
# read off the number of dimensions actually used in the latent space.
# this will vary heavily with beta (higher beta -> fewer dimensions used).
plt.loglog(sorted(diffs, reverse=True))
plt.xlabel("Dimension index")
plt.ylabel("Average difference")
plt.title("Average differences between successive images in latent space walk per dimension.")
plt.show()