In [None]:
"""Imports, define AE model"""
import tensorflow as tf
tfkl = tf.keras.layers
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import image as mpimage

from data.utils import parse_image_example

In [None]:
# just a skeleton.
# you fill in the autoencoding stuff ;)
# this is mainly the VQ-related things
# if you find it awkard to work with half-finished code written by someone else,
# feel free to roll your own version.
class VQVAE(tf.keras.Model):
    def __init__(self, ..., **kwargs):
        super().__init__(**kwargs)
        #...
        
        # this is to track unused codebook vectros
        self.usage_tracker = tf.keras.metrics.MeanTensor("codebook_usage")
        
        dim_code = encoder.output_shape[-1]
        self.codebook = tf.Variable(tf.random.normal([codebook_size, dim_code]))
        self.codebook_size = codebook_size
        
    def call(self, inputs, training=None):
        return self.apply_ae(inputs, training=training)[0]
    
    def train_step(self, data):
        with tf.GradientTape() as tape:
            reconstruction, encoder_outputs, codes, indices = self.apply_ae(data, training=True)
            
            reconstruction_loss = ...
            codebook_loss = tf.reduce_mean(tf.square(tf.stop_gradient(encoder_outputs) - codes))*code_shape[-1]
            commitment_loss = tf.reduce_mean(tf.square(encoder_outputs - tf.stop_gradient(codes)))*code_shape[-1]
            total_loss = reconstruction_loss + codebook_loss + self.beta*commitment_loss
            
        variables = self.encoder.trainable_variables + self.decoder.trainable_variables + [self.codebook]
        gradients = tape.gradient(total_loss, variables)
        self.optimizer.apply_gradients(zip(gradients, variables))
        
        # track codebook usage
        flat_indices = tf.one_hot(tf.reshape(indices, (-1,)), depth=self.codebook_size)
        usage_count = tf.reduce_sum(flat_indices, axis=0)
    
        self.usage_tracker.update_state(usage_count)
        
        # update with other metrics, i.e. losses
        # see e.g. https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit
        return {"codebook_usage": self.usage_tracker.result()}
    
    def test_step(self, data):
        reconstruction, encoder_outputs, codes, indices = self.apply_ae(data, training=False)
        
        reconstruction_loss = ...
        codebook_loss = tf.reduce_mean(tf.square(tf.stop_gradient(encoder_outputs) - codes))*code_shape[-1]
        commitment_loss = tf.reduce_mean(tf.square(encoder_outputs - tf.stop_gradient(codes)))*code_shape[-1]
        total_loss = reconstruction_loss + codebook_loss + self.beta*commitment_loss
        
        # track codebook usage
        flat_indices = tf.one_hot(tf.reshape(indices, (-1,)), depth=self.codebook_size)
        usage_count = tf.reduce_sum(flat_indices, axis=0)
    
        self.usage_tracker.update_state(usage_count)
        
        return {"codebook_usage": self.usage_tracker.result()}
    
    @property
    def metrics(self):
        return [self.usage_tracker]
    
    # vq-vae-specific stuff
    def read_codebook(self, encoder_outputs, with_indices=False):
        distances = tf.reduce_mean(tf.square(encoder_outputs[:, :, :, None, :] 
                                             - self.codebook[None, None, None, :, :]), axis=-1)
        #print("d", distances.shape)
        min_distance_inds = tf.math.argmin(distances, axis=-1)
        #print("ind", min_inds.shape)
        codes = tf.gather(self.codebook, min_distance_inds)

        # 1st output are the codebooks with "straight-through estimator".
        # this allows gradients to flow from the decoder into the encoder.
        # to learn the codebook, we also return codes, as these allow gradient flow into the codebook.
        if with_indices:
            return encoder_outputs + tf.stop_gradient(codes - encoder_outputs), codes, min_distance_inds
        else:
            return encoder_outputs + tf.stop_gradient(codes - encoder_outputs), codes

    # reconstructions for reconstruction loss (duh)
    # encoder outputs and codes for codebook and commitment loss
    # indices because we may need them to check usage, create training data for AR model etc
    def apply_ae(self, inputs, training=False):
        encoder_outputs = self.encoder(inputs, training=training)
        codes, codes_with_gradients, indices = self.read_codebook(encoder_outputs, with_indices=True)
        return (decoder(codes, training=training),
                encoder_outputs,
                codes_with_gradients,
                indices)

In [None]:
encoder = ...  # model
decoder = ... # another model
vqvae = VQVAE(encoder, decoder, codebook_size)

In [None]:
# this can be used to reset unused codebook entries after each epoch
# ...or less frequently (frequency argument)
class CodebookResetter(tf.keras.callbacks.Callback):
    def __init__(self, frequency, **kwargs):
        super().__init__(**kwargs)
        self.frequency = frequency
        self.reference_batch = next(iter(train_data))
    
    def on_epoch_end(self, epoch, logs=None):
        if not epoch % self.frequency:
            average_usage = logs["codebook_usage"]
            # entries that were never used this epoch
            unused_code_indices = np.where(average_usage == 0)[0]
            print("\nDEBUG UNUSED INDICES", unused_code_indices)
            
            # we take random encoder outputs for a reference batch
            reference_encodings = self.model.encoder(self.reference_batch)
            
            new_codebook_entries = tf.stack([reference_encodings[np.random.choice(batch_size),
                                               np.random.choice(code_image_w),
                                               np.random.choice(code_image_h)] for _ in unused_code_indices],
                                            axis=0)
            print("DEBUG NEW ENTRIES", new_codebook_entries.shape, "\n")
            
            # unused entries are replaced by reference encodings
            if len(unused_code_indices):
                sparse_update = tf.IndexedSlices(new_codebook_entries,
                                                 tf.convert_to_tensor(unused_code_indices, dtype=tf.int32))
                self.model.codebook.scatter_update(sparse_update)


# fit...

In [None]:
# create a dataset of codes
# NOTE this stores both the code vectors AND the indices.
# you may also store only the indices (takes less space)

train_code_inds = []
train_codes = []
for step, img_batch in enumerate(train_data):
    encodings = encoder(img_batch)
    codes, _, inds = vqvae.read_codebook(encodings, with_indices=True)
    train_code_inds.append(inds)
    train_codes.append(codes)
train_code_inds = np.concatenate(train_code_inds)
train_codes = np.concatenate(train_codes)

code_data = tf.data.Dataset.from_tensor_slices((train_code_inds, train_codes))

code_data_train = code_data.shuffle(60000).batch(batch_size)

In [None]:
# create a dataset of codes (test set)

test_code_inds = []
test_codes = []
for step, img_batch in enumerate(test_data):
    if not step % 50:
        print(step)
    encodings = encoder(img_batch)
    codes, _, inds = vqvae.read_codebook(encodings, with_indices=True)
    test_code_inds.append(inds)
    test_codes.append(codes)
test_code_inds = np.concatenate(test_code_inds)
test_codes = np.concatenate(test_codes)

code_data_test = tf.data.Dataset.from_tensor_slices((test_code_inds, test_codes)).batch(batch_size)

In [None]:
# bonus: check codebook usage. are all 256 vectors used, and how often?
from collections import Counter

usage_count = Counter(train_code_inds.reshape((-1,)))

descending = [thing[1] for thing in usage_count.most_common()]
plt.bar(range(len(descending)), descending)

In [None]:
# now train a pixelcnn on the codes

In [None]:
# tested for TF 2.11!
# it may not work in other version :(
from tensorflow.python.keras.layers.convolutional import Conv

class MaskedConv2D(tfkl.Conv2D):
  def __init__(self,
               filters,
               kernel_size,
               strides=1,
               padding='valid',
               data_format=None,
               dilation_rate=1,
               activation=None,
               use_bias=True,
               kernel_initializer='glorot_uniform',
               bias_initializer='zeros',
               kernel_regularizer=None,
               bias_regularizer=None,
               activity_regularizer=None,
               kernel_constraint=None,
               bias_constraint=None,
               trainable=True,
               name=None,
               mask_type="a",
               **kwargs):
    super(MaskedConv2D, self).__init__(
               filters,
               kernel_size,
               strides=strides,
               padding=padding,
               data_format=data_format,
               dilation_rate=dilation_rate,
               activation=activation,
               use_bias=use_bias,
               kernel_initializer=kernel_initializer,
               bias_initializer='zeros',
               kernel_regularizer=None,
               bias_regularizer=None,
               activity_regularizer=None,
               kernel_constraint=None,
               bias_constraint=None,
               trainable=True,
               name=None,
        **kwargs)
    self.mask = np.zeros([kernel_size, kernel_size, 1, 1], dtype=np.float32)
    self.mask[:kernel_size//2, :, :, :] = 1
    self.mask[kernel_size//2, :kernel_size//2, :, : ] = 1
    if mask_type == "b":
        self.mask[kernel_size//2, kernel_size//2] = 1
    self.mask = tf.convert_to_tensor(self.mask)
    
    
  def call(self, inputs):
    masked_kernel = self.mask * self.kernel
    outputs = self.convolution_op(inputs, masked_kernel)

    if self.use_bias:
      if self.data_format == 'channels_first':
        if self.rank == 1:
          # nn.bias_add does not accept a 1D input tensor.
          bias = tf.python.ops.array_ops.reshape(self.bias, (1, self.filters, 1))
          outputs += bias
        else:
          outputs = tf.nn.bias_add(outputs, self.bias, data_format='NCHW')
      else:
        outputs = tf.nn.bias_add(outputs, self.bias, data_format='NHWC')

    if self.activation is not None:
      return self.activation(outputs)
    return outputs

In [None]:
# proposal for a pixelcnn class.
# training is catually quite basic. the masked convolutions are set up such that
# the output at a given position is the prediction for that pixel.
# since the 1st layer mask also removes the center pixel, this is fine.
# we don't have to shift inputs vs targets by one step, like in an autoregressive RNN for example.

# note that this models takes the code VECTORS as input, but predicts the indices.
# it would also be possible to change this: take the indices as vectors.
# in that case, the first model layer should probably be an embedding.
# this effectively allows the AR model to learn its own representations for the indices,
# instead of re-using the code vectors from the VQ-VAE.
class PixelCNN(tf.keras.Model):
    def __init__(self, inputs, outputs, **kwargs):
        super().__init__(inputs, outputs, **kwargs)
        
        self.cross_entropy_tracker = tf.keras.metrics.Mean("cross_entropy")
        
    def train_step(self, data):
        indices, codes = data
        with tf.GradientTape() as tape:
            outputs = self(codes, training=True)
            loss = self.compiled_loss(indices, outputs)

        variables = self.trainable_variables
        gradients = tape.gradient(loss, variables)
        self.optimizer.apply_gradients(zip(gradients, variables))
        
        self.cross_entropy_tracker.update_state(loss)
        return {"cross_entropy": self.cross_entropy_tracker.result()}
    
    def test_step(self, data):
        indices, codes = data
        outputs = self(codes, training=False)
        loss = self.compiled_loss(indices, outputs)
        
        self.cross_entropy_tracker.update_state(loss)
        return {"cross_entropy": self.cross_entropy_tracker.result()}
    
    # generation proceeds pixel by pixel
    def generate(self, num_samples):
        image = np.zeros([num_samples, code_image_w, code_image_h, dim_code], dtype=np.float32)
        for row in range(code_image_w):
            for col in range(code_image_h):
                index_logits = self(image)[:, row, col, :]
                index_sample = tfd.Categorical(logits=index_logits, dtype=tf.int32).sample()
                image[:, row, col] = tf.gather(vqvae.codebook, index_sample)

        return map_for_likelihood(decoder(image), likelihood)
    
    @property
    def metrics(self):
        return [self.cross_entropy_tracker]

In [None]:
inp = tf.keras.Input(code_shape)
x = inp
x = MaskedConv2D(32, 3, padding="same", mask_type="a")(x)
# you can stack more MaskedConv2D layers, build residual blocks with 1x1 convolutions, etc.
# note that you can use mask_type="b" for all layers except the first one.
x = ...
# output layer: predict codebook index
x = tfkl.Conv2D(codebook_size, 1, padding="same")(x)
 
pixel_cnn = PixelCNN(inp, x)


In [None]:
# indices are classes, so use cross-entropy as loss
pixel_loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)

In [None]:
class ImageGenCallback(tf.keras.callbacks.Callback):
    def __init__(self, frequency, **kwargs):
        super().__init__(**kwargs)
        self.frequency = frequency
        
    def on_epoch_end(self, epoch, logs=None):
        if not epoch % self.frequency:
            generated_batch = self.model.generate(64)
        
            plt.figure(figsize=(15,15))
            for ind, image in enumerate(generated_batch):
                plt.subplot(8, 8, ind+1)
                plt.imshow(image)
                plt.axis("off")
            plt.suptitle("Random generations")
            plt.show()
            
image_gen_callback = ImageGenCallback(10)
            
pixel_cnn.fit(code_data_train, validation_data=code_data_test, epochs=n_epochs, callbacks=[image_gen_callback])

In [None]:
samples = pixel_cnn.generate(64).numpy()
plt.figure(figsize=(15,15))
for ind, img in enumerate(samples):
    plt.subplot(8, 8, ind+1)
    plt.imshow(img)
    plt.axis("off")
plt.show()