In [None]:
import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np

# visualization
import matplotlib.pyplot as plt
import PIL
import IPython

# environment configuration
import kaggle_datasets
import os

In [None]:
# mute tensorflow logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = str(3)

# Configs

# Distributed Computing

In [None]:
class distribute:
    
    def build_strategy():
        """Creates a tf.distribute.Strategy object.
        
        Assesses the calling compute environment, and creates a Strategy object appropriate
        given the available accelerators. Preference will be granted to TPU accelerators,
        followed by GPU accelerators, before falling back on CPU computation.
        
        Returns:
            A tf.distribute.Strategy object for the calling compute environment.
            This will be of one of the following types
             - TPUStrategy (TPU Accelerator)
             - MirroredStrategy (GPU Accelerator)
             - _DefaultDistributionStrategy (No Accelerator)
        """
        # prefer TPU if available
        try:
            # resolver will throw ValueError over the lack of a tpu address if tpu not found
            tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
            # connect to tpu system
            tf.config.experimental_connect_to_cluster(tpu)
            tf.tpu.experimental.initialize_tpu_system(tpu)
            return tf.distribute.TPUStrategy(tpu)
        except (ValueError):
            # resolver will throw ValueError over the lack of a tpu address if tpu not found
            pass
        
        # connect to GPU if TPU unavailable
        if tf.config.list_physical_devices('GPU'):
            return tf.distribute.MirroredStrategy()
        
        # fall back on CPU
        # TODO: evaluate MirroredStrategy for cpu
        return tf.distribute.get_strategy()
            
    def is_tpu(strategy):
        """Evaluates whether the strategy uses a TPU cluster.
        
        Useful for storage interaction, as the TPU cluster will be in a cloud 
        environment and will not default to local memory.
        
        Args:
            strategy (tf.distribute.Strategy): the strategy to be evaluated
        
        Returns:
            bool: True if the supplied strategy operates on TPU accelerator(s)"""
        return isinstance(strategy, tf.distribute.TPUStrategy)
    
    def loss(strategy, loss_fn):
        """Wraps a loss function with a strategy-aware reduction.
        
        Args:
            strategy (tf.distribute.Strategy): The distribution strategy for the 
                calling compute environment.
            loss_fn (Callable[[tf.Tensor, tf.Tensor], tf.Tensor]): The loss 
                function which takes (y_true, y_pred) as arguments and returns
                the loss.
        
        Returns:
            Callable[[tf.Tensor, tf.Tensor], tf.Tensor]: the loss function wrapped
                with a strategy-aware reduction. Just as the input loss function,
                will take (y_true, y_pred) as arguments and return the loss.
        """
        @tf.function # TODO: tf-ize the function... e.g., check data.compute_global_batch_size effect
        def reduced_loss(y_true, y_pred):
            # flatten instances
            flat_shape = (-1, tf.math.reduce_prod(y_true.shape[1:]))
            y_true = tf.reshape(y_true, flat_shape)
            y_pred = tf.reshape(y_pred, flat_shape)
            
            # calculate reduced loss
            loss_by_instance = loss_fn(y_true, y_pred)
            global_batch_size = data.compute_global_batch_size() # TODO: add reference to strategy arg
            return tf.reduce_sum(loss_by_instance) / global_batch_size
        
        return reduced_loss

# Visualizations

In [None]:
class vis:
    max_width = 16.0
        
    def _to_image_grid(images):
        """Transforms tensor of stacked image volumes into a single volume of all images 
        concatenated together.
        
        This function is not meant to be called directly, as it does not support any variety
        of tensor shapes. It is called by vis.image_grid, which handles more data validation.
        
        Args:
            images (tf.Tensor): tensor of stacked image volumes with rank 5 and indexed by
                (row, column, height, width, channel).
                
        Returns:
            tf.Tensor: tensor of shape (rows*height, cols*width, channel) with all images 
                concatenated into a single volume representing a grid of the images
        """
        # validate input
        if len(images.shape) != 5:
            raise ValueError(f'Invalid tensor rank. Expected rank 5, got {len(images.shape)}')
        
        # transform tensor into shape (row*height, col*width, channel)
        images = tf.transpose(images, [0, 2, 1, 3, 4])
        nshape = (images.shape[0]*images.shape[1], images.shape[2]*images.shape[3], -1)
        images = tf.reshape(images, nshape)
        
        # return concatenated images
        return images
    
    def image_grid(images):
        """Transforms Tensor of one or more image volumes into a single volume of all images 
        concatenated together.
        
        This function takes an input Tensor of one of the following shapes:
            (height, width) - 1 image
            (height, width, channel) - 1 image
            (col, height, width, channel) - col images 
            (row, col, height, width, channel) - row*col images
        The input Tensor is then transformed into a single image volume with all input images
        concatenated together as a grid. The output image tensor of rank 3 is indexed by
            (height, width, channel)
            
        Args:
            images (tf.Tensor): tensor of one or more stacked image volumes. This tensor must 
                be indexed by one of the following shapes:
                 - (height, width): 1 image
                 - (height, width, channel): 1 image
                 - (col, height, width, channel): row of col images 
                 - (row, col, height, width, channel): grid of row*col images
        
        Returns:
            tf.Tensor: tensor of shape (rows*height, cols*width, channel) with all images 
                concatenated into a single volume representing a grid of the images
        """
        # ensure 3d image tensor
        irank = images.shape
        if irank == 5:
            # tensor shape: ()
            # concatenate images
            images = vis._to_image_grid(images)
        elif irank == 4:
            # assume horizontal and add empty rows axis before concatenating
            images = images[None]
            images = vis._to_image_grid(images)
        elif irank == 2:
            # add channels axis
            images = images[..., None]
        elif irank != 3:
            raise ValueError(f'Images tensor has improper rank {irank}')
            
        return images
    
    def to_pil(img):
        """Transforms input image tensor into a PIL Image object.
        
        Args:
            img (Union[tf.Tensor, np.ndarray]): A single image volume of float 
                values in the range [0, 1].
        
        Returns:
            PIL.Image: The same image as a PIL.Image object
        """
        img = tf.cast(img * tf.constant([255.]), tf.uint8)
        img = img.numpy()
        img = PIL.Image.fromarray(img)
        return img
    
    def image_gallery(
        images, 
        img_titles=None, 
        row_titles=None,
        col_titles=None,
        channels=True,
    ):
        images = np.array(images)
        nrows, ncols = images.shape[:2]
        images = images.reshape(-1, *images.shape[2:])
        
        image_size = vis.max_width / ncols
        _, axes = plt.subplots(figsize=(image_size*ncols, image_size*nrows), 
                               nrows=nrows, ncols=ncols)
        axes = np.array(axes).ravel()
        
        # initialize title arrays to Nones
        titles  = np.full_like(axes, None)
        xtitles = np.full_like(axes, None)
        ytitles = np.full_like(axes, None)
        
        # evaluate titles
        if col_titles is not None:
            titles[:len(col_titles)] = col_titles
            if img_titles is not None:
                xtitles[:] = img_titles
        elif img_titles is not None:
            titles[:] = img_titles
        if row_titles is not None:
            ytitles.reshape(nrows, ncols)[:, 0] = row_titles
        
        # draw images
        for img, ax, t, xt, yt in zip(images, axes, titles, xtitles, ytitles):
            vis.draw_image(img, ax)
            ax.set_title(t)
            ax.set_xlabel(xt)
            ax.set_ylabel(yt)
        
    def draw_image(img, ax=None):
        if ax is None:
            ax = plt.gca()
        
        ax.set_xticks([])
        ax.set_yticks([])
        
        for s in ax.spines:
            ax.spines[s].set_visible(False)
            
        ax.imshow(img)
        return ax
    
    def generator_evolution(
        images, 
        filepath=None,
        classes=None,
        epochs=None,
        add_alpha=True,
        separate_classes=True,
        vertical=False,
        display=True,
    ):
        """Visualizes the training evolution of an image generator.
        
        TODO: Further description necessary
        
        Args:
            images (tf.Tensor): rank 6 tensor of image volumes indexed by
                (class, image, epoch, height, width, channel)
            filepath (Optional[str]): filepath to save visual to. 
            classes (Optional[Iterable[str]]): list of classes corresponding
                to the first index of the images tensor for annotation
            epochs (Union[None, Iterable[int], Iterable[str]]): list of epochs 
                corresponding to the third index of the images tensor for annotaion
            separate_classes (bool): If True, a gap will be left between the classes
            vertical (bool): If True, timeline will extend downward along the visual 
                while classes and images are spread along horizontal axis.
                
        Returns:
            tf.Tensor: a single image volume containing the generator evolution
                visual.
        """
        # transform each class into an image grid 
        images = [vis._to_image_grid(img_cls) for img_cls in images]
        # add transparency channel if requested
        if add_alpha:
            images = [tf.concat([image, tf.fill((*image.shape[:-1], 1), 1.)], axis=-1) for image in images]
        # connect all images with gaps between classes
        final_image = images[0]
        for img in images[1:]:
            final_image = tf.concat([
                final_image, 
                tf.fill((data.image_shape[0], *img.shape[1:]), 1.), 
                img
            ], axis=0)
            
        if filepath:
            vis.save_image(final_image, filepath)
        if display:
            vis.display_image(final_image)
        return final_image
        
    
    def display_image(img):
        IPython.display.display(vis.to_pil(img))
        
    def save_image(img, filepath):
        vis.to_pil(img).save(filepath)

# Data Pipeline

In [None]:
class data:
    db_name    = 'gan-getting-started'
    local_path = f'../input/{db_name}'
    monet_dir  = '/monet_tfrec'
    photo_dir  = '/photo_tfrec'
    data_dirs  = [monet_dir, photo_dir]
    
    batch_size = 8 # 128 # recommended for TPU v3-8... but what about GPU and CPU???
    prefetch = batch_size # will not be scaled with num_replicas_in_sync... problem???
    buffer_size = 300 # size of monet dataset should prevent patterns
    
    image_shape = [256, 256, 3]
    image_max_rgb = 255
    
    SMALLEST_EPOCH = 0
    LARGEST_EPOCH = 1
    
    def load(
        data_path=local_path,
        data_dirs=data_dirs, 
        batch_size=batch_size, 
        prefetch=prefetch, 
        buffer_size=buffer_size,
        zip_method=LARGEST_EPOCH,
        seed=None
    ):
        # append directories to database path
        data_dirs = map(lambda data_dir: data_path + data_dir, data_dirs)
        
        # initialize list of loaded datasets
        datasets = []
        for data_dir in data_dirs:
            # read in tf records
            tfds = tf.data.TFRecordDataset(tf.io.gfile.glob(f'{data_dir}/*.tfrec'))
            # map to actual images
            ds = tfds.map(data._tfrec_to_img)
            # configure dataset options
            ds = ds.prefetch(prefetch).shuffle(buffer_size, seed=seed)
            if zip_method == data.LARGEST_EPOCH:
                ds = ds.repeat()
            if batch_size > 0:
                ds = ds.batch(batch_size)
            # append to list of loaded datasets
            datasets.append(ds)
        
        dataset = tf.data.Dataset.zip(tuple(datasets))
        
        return dataset
    
    def load_subset(
        data_path=local_path,
        data_dirs=data_dirs, 
        subset_size=1,
        batch_size=batch_size, 
        buffer_size=buffer_size,
        seed=None
    ):
        # append directories to database path
        data_dirs = map(lambda data_dir: data_path + data_dir, data_dirs)
        
        # initialize list of loaded subsets
        subsets = []
        for data_dir in data_dirs:
            # read in tf records
            tfds = tf.data.TFRecordDataset(tf.io.gfile.glob(f'{data_dir}/*.tfrec'))
            # map to actual images
            ds = tfds.map(data._tfrec_to_img)
            # configure dataset options
            ds = ds.shuffle(buffer_size, seed=seed)
            if batch_size > 0:
                ds = ds.batch(batch_size)
            # sample subset_size batches
            ss = tf.data.Dataset.from_tensor_slices([batch for _, batch in zip(range(subset_size), ds)])
            # append sampled batches to list of loaded subsets
            subsets.append(ss)
        
        # zip loaded subsets into one
        if len(subsets) > 1:
            subset = tf.data.Dataset.zip(tuple(subsets))
        else:
            subset = subsets[0]
        
        return subset
            
    def _tfrec_to_img(tfrec):
        encoded_image = tf.io.parse_single_example(tfrec, {
            'image': tf.io.FixedLenFeature([], tf.string)
        })['image']
        decoded_image = tf.io.decode_jpeg(encoded_image)
        return tf.cast(decoded_image, tf.float32) / data.image_max_rgb
    
    def compute_global_batch_size(strategy=None, local_batch_size=batch_size):
        if strategy is None:
            return local_batch_size
        return strategy.num_replicas_in_sync * local_batch_size
    
    def get_data_path(tpu_strategy=False):
        if tpu_strategy:
            return KaggleDatasets().get_gcs_path(data.db_name)
        return data.local_path

# Augmentor

In [None]:
class Augmentor(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()
        self.augmentations = []
    
    def call(self, batch1, batch2):
        images = tf.concat([batch1, batch2], 0)
        
        for aug in self.augmentations:
            images = aug(images)
        
        return tf.split(images, 2)
    

# Generator

In [None]:
class UnetGenerator:
    def build(
        unet_depth=5, 
        base_depth=4, 
        conv_depth=2, 
        dropout_rate=0.5,
        activation='leaky_relu', 
        base='residual', 
        name=None,
    ):
        # Input Layer
        inputs = tf.keras.Input(data.image_shape, name='input_image')
        
        # Down Stack
        dn_stack = inputs
        skips = []
        for level in range(unet_depth):
            # convolve 
            filters = UnetGenerator.filters_at_level(level)
            dn_stack = UnetGenerator.conv(dn_stack, filters, conv_depth)
            # skip
            skips.append(dn_stack)
            # downsample
            dn_stack = UnetGenerator.downsample(dn_stack)
            
        # Base Stack
        base_stack = dn_stack
        if base == 'residual':
            base_stack = UnetGenerator.residual_base(base_stack, base_depth, dropout_rate=dropout_rate)
        elif base == 'inception':
            base_stack = UnetGenerator.inception_base(base_stack, base_depth, dropout_rate=dropout_rate)

        # Up Stack
        up_stack = base_stack
        for level, skip in reversed(list(enumerate(skips))):
            # upsample
            up_stack = UnetGenerator.upsample(up_stack)
            # concatenate skip connection along channel axis
            up_stack = tf.keras.layers.Concatenate(axis=-1)([up_stack, skip])
            # convolve skip and up_stack tensors
            filters = UnetGenerator.filters_at_level(level)
            up_stack = UnetGenerator.conv(up_stack, filters, conv_depth, dropout_rate=dropout_rate)
            
        # Output
        outputs = up_stack
        # convolve back to three channels
        outputs = UnetGenerator.conv(outputs, 3, conv_depth)
        # concatenate original image
        outputs = tf.keras.layers.Concatenate(axis=-1)([outputs, inputs])
        # convolve reconstructed pixels with original for coloration
        outputs = tf.keras.layers.Conv2D(
            filters=3, 
            kernel_size=1, 
            padding='same',            
            activation='tanh'
        )(outputs)
        # rescale pixel values from [-1, 1] of tanh output to [0, 1]
        outputs = tf.keras.layers.Rescaling(scale=0.5, offset=0.5)(outputs)
        
        # Assemble Model
        unet = tf.keras.Model(inputs=inputs, outputs=outputs, name=name)
        return unet
    
    def filters_at_level(level):
        return min(32 * (2**level), 512)
    
    def downsample(
        input_tensor, 
        size=3, 
        strides=2, 
        activation=None, 
        normalize=False,
        initializer='glorot_uniform'
    ):
        # downsampling layer
        ds = tf.keras.layers.Conv2D(
            input_tensor.shape[-1] * 2,
            size, 
            strides, 
            padding='same', 
            kernel_initializer=initializer
        )(input_tensor)
            
        # normalization
        if normalize:
            ds = tfa.layers.InstanceNormalization()(ds)
        
        # activation
        if activation:
            ds = tf.keras.layers.Activation(activation)(ds)
            
        return ds
    
    def upsample(
        input_tensor, 
        size=3, 
        strides=2, 
        activation=None, 
        normalize=False,
        initializer='glorot_uniform'
    ):
        # upsampling layer
        us = tf.keras.layers.Conv2DTranspose(
            input_tensor.shape[-1] // 2,
            size, 
            strides, 
            padding='same', 
            kernel_initializer=initializer
        )(input_tensor)
            
        # normalization
        if normalize:
            us = tfa.layers.InstanceNormalization()(us)
        
        # activation
        if activation:
            us = tf.keras.layers.Activation(activation)(us)
            
        return us
    
    def conv(
        input_tensor, 
        filters=-1,
        depth=2, 
        size=3, 
        dropout_rate=0,
        activation='leaky_relu',
        normalize=True
    ):
        """
        :param level
        :param size
        :param depth
        :param activation: str or callable used as argument to keras Activation layer.
        Supply activation=None to remove activation layer.
        :param normalize
        
        Builds several consecutive Conv2D blocks, and adds optional activation 
        and normalization layers.
        
        Based on the premise of increasing effective kernel size with fewer parameters
        as introduced by VGGNets.
        """
        block = input_tensor
        
        # dropout
        if dropout_rate:
            block = tf.keras.layers.SpatialDropout2D(dropout_rate)(block)
        
        # convolution layers
        if filters < 0:
            # maintain input channels when in doubt
            filters = input_tensor.shape[-1]
        for _ in range(depth):
            block = tf.keras.layers.Conv2D(filters, size, padding='same')(block)
            
        # normalization
        if normalize:
            # use LayerNormalization to mimic instance normalization (available in tensorflow addons)
            block = tfa.layers.InstanceNormalization()(block)
        
        # activation
        if activation:
            block = tf.keras.layers.Activation(activation)(block)
            
        return block
    
    def residual_base(
        input_tensor, 
        depth, 
        size=3,
        dropout_rate=0,
        activation='relu', 
        preactivation=True, 
        bottleneck=True, 
        compression_factor=0.5,
        normalize=True
    ):
        stack = input_tensor
        
        # use LayerNormalization to mimic instance normalization (available in tensorflow addons)
        normalized = (lambda stack: tfa.layers.InstanceNormalization()(stack))
        activated = (lambda stack: tf.keras.layers.Activation(activation)(stack))
        actnorm = (lambda stack: activated(normalized(stack)))
        
        filters = input_tensor.shape[-1]
        neck_filters = int(filters * compression_factor)        
        for _ in range(depth):
            stack_input = stack
        
            if dropout_rate:
                stack = tf.keras.layers.SpatialDropout2D(dropout_rate)(stack)
                
            if preactivation:
                stack = activated(stack)
                
            if bottleneck:
                # compression conv
                stack = tf.keras.layers.Conv2D(neck_filters, 1)(stack)
                stack = actnorm(stack)
                # bottleneck conv
                stack = tf.keras.layers.Conv2D(neck_filters, size, padding='same')(stack)
                stack = actnorm(stack)
                # expansion conv
                stack = tf.keras.layers.Conv2D(filters, 1)(stack)
                stack = normalized(stack)
            else:
                stack = tf.keras.layers.Conv2D(filters, size, padding='same')(stack)
                stack = actnorm(stack)
                stack = tf.keras.layers.Conv2D(filters, size, padding='same')(stack)
                stack = normalized(stack)
            
            stack = tf.keras.layers.Add()([stack, stack_input])
            
            if not preactivation:
                stack = activated(stack)
                
        return stack

# Discriminator 

In [None]:
class PatchDiscriminator:
    def build(depth=3, kernel_widths=3, patch_size=70, init_filters=128, min_filters=8, dropout_rate=0.5):
        """
        :param depth: int. The number of convolutional layers to stack.
        :param kernel_widths: int, Iterable[int]. The size(s) of the filters.
        If int, the model will contain depth convolutional layers, each with
        identical kernel widths as specified. If Iterable[int], their will be 
        len(kernel_widths) convolutional layers with associated kernel widths.
        :param patch_size: int. The receptive field of each of the neurons in
        the final output layer.
        :param init_filters: int. The number of filters in the first convolutional
        layer. After that, filters will be halved in each layer.
        :param min_filters: int. Lower bound on number of filters in a convolutional
        layer. All layers will output at least min_filters channels except for the 
        final layer (which will output one channel as the patch prediction).
        """
#         # cast kernel_widths to tensor for consistency
#         if type(kernel_widths) is int:
#             kernel_widths = tf.fill((depth), kernel_widths)
#         else:
#             kernel_widths = tf.constant(kernel_widths)
#             depth = kernel_widths.shape[0]
            
#         # calculate stride for given patch_size, kernel_widths, and depth
#         strides = tf.constant(kernel_widths)
        
#         # reduce number of filters at each layer by factor of 2
#         channel_depths = init_filters // (2**tf.range(depth))
#         # apply lower bound
#         channel_depths = tf.where(channel_depths < min_filters, min_filters, channel_depths)

        channel_depths=[ 64, 128, 256, 528,   1]
        kernel_widths= [  5,   5,   3,   3,   3]
        strides=       [  2,   2,   2,   2,   2]
        # final receptive field of 69x69
        
        # build convolutional model
        layers = [tf.keras.Input(data.image_shape)]
        for filters, width, stride in zip(channel_depths, kernel_widths, strides):
            layers.extend([
                tf.keras.layers.SpatialDropout2D(dropout_rate),
                tf.keras.layers.Conv2D(filters, width, stride),
                tfa.layers.InstanceNormalization(),
                tf.keras.layers.Activation('leaky_relu')
            ])
        layers[-1] = tf.keras.layers.Activation('sigmoid')
        
        model = tf.keras.Sequential(layers)
        return model

# CycleGAN

In [None]:
class CycleGAN(tf.keras.Model):
    def __init__(self, m_gen, p_gen, m_dis, p_dis, aug):
        super().__init__()
        self.m_gen = m_gen
        self.p_gen = p_gen
        self.m_dis = m_dis
        self.p_dis = p_dis
        self.aug = aug
        
    def compile(self, optimizer, id_loss, cycle_loss, dis_loss):
        super().compile()
        self.optimizer = optimizer
        self.id_loss = id_loss
        self.cycle_loss = cycle_loss
        self.dis_loss = dis_loss
    
    def train_step(self, data):
        # read tupled batch data = (monet_batch, photo_batch)
        m_real, p_real = data
        
        # Progress batch data through CycleGAN process
        with tf.GradientTape(persistent=True) as g_tape:
            # identity outputs
            m_id = self.m_gen(m_real, training=True)
            p_id = self.p_gen(p_real, training=True)
            
            # identity loss
            m_id_loss = self.id_loss(m_real, m_id)
            p_id_loss = self.id_loss(p_real, p_id)
            
            # transfer outputs
            m_fake = self.m_gen(p_real, training=True)
            p_fake = self.p_gen(m_real, training=True)
            
            # cycle outputs
            m_cycle = self.m_gen(p_fake, training=True)
            p_cycle = self.p_gen(m_fake, training=True)
            
            # cycle loss
            m_cycle_loss = self.cycle_loss(m_real, m_cycle)
            p_cycle_loss = self.cycle_loss(p_real, p_cycle)
            cycle_loss = m_cycle_loss + p_cycle_loss
            
            # differentiable augmentations
            m_real, m_fake = self.aug(m_real, m_fake)
            p_real, p_fake = self.aug(p_real, p_fake)
            
            # discriminator outputs
            m_dis_real = self.m_dis(m_real, training=True)
            m_dis_fake = self.m_dis(m_fake, training=True)
            p_dis_real = self.p_dis(p_real, training=True)
            p_dis_fake = self.p_dis(p_fake, training=True)
            
            # discriminator loss            
            m_dis = tf.concat([m_dis_real, m_dis_fake], 0)
            p_dis = tf.concat([p_dis_real, p_dis_fake], 0)
            
            labels_real =  tf.ones_like(m_dis_real)
            labels_fake = tf.zeros_like(m_dis_fake)
            labels = tf.concat([labels_real, labels_fake], 0)
            
            m_dis_loss = self.dis_loss(labels, m_dis)
            p_dis_loss = self.dis_loss(labels, p_dis)
            
            # generator loss
            labels = tf.concat([labels_fake, labels_real], 0)
            m_gen_loss = self.dis_loss(labels, m_dis)
            p_gen_loss = self.dis_loss(labels, p_dis)
            
            m_gen_loss += m_id_loss + cycle_loss 
            p_gen_loss += p_id_loss + cycle_loss 
            
        # collect model losses and variables
        models = [self.m_gen, self.p_gen, self.m_dis, self.p_dis]
        losses = [m_gen_loss, p_gen_loss, m_dis_loss, p_dis_loss]
        variables = [model.trainable_variables for model in models]
        
        # apply backpropagation
        for model_loss, model_vars in zip(losses, variables):
            grads = g_tape.gradient(model_loss, model_vars)
            self.optimizer.apply_gradients(zip(grads, model_vars))
        
        # return losses and metrics
        return {
            'monet_id_loss': m_id_loss,
            'photo_id_loss': p_id_loss,
            'monet_cycle_loss': m_cycle_loss,
            'photo_cycle_loss': p_cycle_loss,
            'monet_discriminator_loss': m_dis_loss,
            'photo_discriminator_loss': p_dis_loss
        }
    
    def call(self, x, output_class='monet'):
        if output_class == 'monet':
            return self.m_gen(x)
        if output_class == 'photo':
            return self.p_gen(x)

In [None]:
class VisualizeCycleGanEvolution(tf.keras.callbacks.Callback):
    default_filepath='./cycle-gan-evolution.png'
    def __init__(self, test_images, classes=None, frequency=1, filepath=default_filepath, 
                 separate_classes=True, show_initial=True):
        """
        :param test_images: tensor containing the batch of images to test on. If multiple 
        classes are being visualized, test_images should be an iterable containing a batch 
        for each class. Index of batches should match the index of the class in classes 
        for which each batch is to be transformed into.
        :param classes: None, str, or Iterable[str]. The name(s) of the classes explored by
        the CycleGAN model. Will each be used as an argument to the __call__ method
        of the CycleGAN. If None, length of classes is assumed to be 1, and the model will
        be called with no other arguments.
        :param frequency: int or Iterable[int]. If single int, test will be run at 
        the end of every epoch such that 'epoch % frequency == 0' evaluates to True. If 
        Iterable, test will be run whenever 'epoch in frequency' evaluates to True. Epoch
        in this consideration will begin at one - not zero.
        :param filepath: str. The location at which to save the resulting image.
        :param separate_classes: bool. If true, each class will be saved as a separate 
        image with the class prepended to the file name.
        :param show_initial: bool. If true, will include initial predictions of the gan 
        model (before any training occurs).
        """
        super().__init__()
        
        # ensure classes is Iterable[str]
        if classes is None or type(classes) is str:
            classes = [classes,]
            
        # images tensor should be of shape (epoch, class, image, height, width[, channels])
        if len(classes) == 1:
            self.images = test_images[None, None]
        else:
            self.images = tf.stack(test_images)[None]
                    
        # process separate_classes and filepath
        if separate_classes and len(classes) > 1:
            name_index = max(0, filepath.rfind('/')+1)
            self.filepaths = [filepath[:name_index] + class_name + '-' + filepath[name_index:] 
                              for class_name in classes]
        else:
            self.filepaths = [filepath,]
            
        # assign remaining args to attributes
        self.classes = classes
        self.frequency = frequency
        
    def on_train_begin(self, logs=None):
        # collect initial transformations
        self.images = self._collect_images(self.images)
    
    def on_epoch_end(self, epoch, logs=None):
        # check if frequency dictates this epoch to be detailed
        epoch += 1
        if (
            type(self.frequency) is int and epoch % self.frequency == 0 or
            hasattr(self.frequency, '__iter__') and epoch in self.frequency
        ):
            self.images = self._collect_images(self.images)
    
#     @tf.function # TODO: either finish tensorflowizing this, or reformat to need no arguments or returns
    def _collect_images(self, images):
        # initialize new tensor with shape (0, height, width[, channels])
        new_images = tf.zeros([0, *images.shape[3:]], dtype=images.dtype)
        
        # iterate over classes and images
        for c, cla in enumerate(self.classes):
            # extract original images
            oimgs = images[0, c]
            # transform image batch (with class name as argument if available)
            nimgs = self.model(oimgs, cla) if cla else self.model(oimgs)
            # concatenate along image axis
            new_images = tf.concat((new_images, nimgs), axis=0)
            
        # add epoch and class axes to tensor
        new_images = tf.reshape(new_images, (len(self.classes), -1, *new_images.shape[1:]))[None]
        # concatenate existing epoch data with new
        return tf.concat((images, new_images), axis=0)
    
    def on_train_end(self, logs=None):
        rank = len(self.images.shape)
        # ensure channels axis exists
        if rank == 5:
            self.images = self.images[..., None]
            
        # reshape images from (epoch, class, image, height, width, channels)
        #                  to (class, image, epoch, height, width, channels)
        images = tf.transpose(self.images, [1, 2, 0, 3, 4, 5])
        
        if len(self.filepaths) > 1:
            [vis.generator_evolution(img[None], fp) for img, fp in zip(images, self.filepaths)]
        else:
            vis.generator_evolution(images, self.filepaths[0])

In [None]:
class AlternateTraining(tf.keras.callbacks.Callback):
    def on_train_batch_begin(self, batch, logs=None):
        gen_batch = (batch % 2 == 0)
        self.model.m_gen.trainable = self.model.p_gen.trainable = gen_batch
        self.model.m_dis.trainable = self.model.p_dis.trainable = not gen_batch

# Implementation

In [None]:
strategy = distribute.build_strategy()

In [None]:
with strategy.scope():
    # instantiate models
    gan = CycleGAN(
        m_gen = UnetGenerator.build(),
        p_gen = UnetGenerator.build(),
        m_dis = PatchDiscriminator.build(),
        p_dis = PatchDiscriminator.build(),
        aug   = Augmentor()
    )
    
    # reduction within distribute strategy restricted only to NONE or SUM
    NONE=tf.keras.losses.Reduction.NONE
    
    # instantiate optimizer and losses
    gan.compile(
        optimizer = tf.keras.optimizers.Adam(2e-4),
        id_loss = distribute.loss(strategy, tf.keras.losses.MeanAbsoluteError(reduction=NONE)),
        cycle_loss = distribute.loss(strategy, tf.keras.losses.MeanAbsoluteError(reduction=NONE)),
        dis_loss = distribute.loss(strategy, tf.keras.losses.BinaryCrossentropy(reduction=NONE))
    )

In [None]:
# ensure checkpoints directory exists
# if not os.path.exists('./checkpoints'):
#     os.mkdir('./checkpoints')

with strategy.scope(): # for callbacks in need of distribute knowledge (*cough cough* checkpoints *cough*)
    callbacks = [
#         tf.keras.callbacks.ModelCheckpoint(
#             filepath='./checkpoints/{epoch:02d}.ckpt',
#             save_weights_only=True,
#             options=tf.train.CheckpointOptions(experimental_io_device='/job:localhost')
#         ),
        VisualizeCycleGanEvolution(
            test_images=next(iter(data.load_subset(
                data.get_data_path(distribute.is_tpu(strategy)),
                batch_size=5
            ))),
            classes=['photo', 'monet'],
            frequency=3,
        ),
#         AlternateTraining(),
    ]

In [None]:
# document model for journal
# gan.m_gen.summary()
# gan.m_dis.summary()

tf.keras.utils.plot_model(gan.m_gen, 'generator_model.png', show_shapes=True)
tf.keras.utils.plot_model(gan.m_dis, 'discriminator_model.png', show_shapes=True);

In [None]:
# train model
history = gan.fit(
    x=data.load(data.get_data_path(distribute.is_tpu(strategy))),
    epochs=24,
    initial_epoch=0,
    steps_per_epoch=300,#tf.math.ceil(6686 / data.compute_global_batch_size()),
    validation_data=None,
    validation_steps=None,
    validation_freq=1,
    callbacks=callbacks
)

# Generate Images

In [None]:
test_set = data.load_subset(
    data.get_data_path(distribute.is_tpu(strategy)),
    data_dirs=[data.photo_dir,],
    batch_size=5,
    seed=0
)

photos = next(iter(test_set))
monets = gan(photos)

images = np.array([photos, monets]).swapaxes(0, 1)

col_titles = ['Photo', 'Monet']
vis.image_gallery(images, col_titles=col_titles)
plt.savefig('final-test.png')
plt.show()