In [None]:
# math and ml
import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np

# visualization
import matplotlib.pyplot as plt
import PIL
import IPython

# environment configuration
import kaggle_datasets
import os

# submission handling
import shutil

In [None]:
# mute tensorflow logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = str(3)

# Configs

In [None]:
class CallbackConfig:
    def __init__(self, enabled, **kwargs):
        self.enabled = enabled
        self.kwargs = kwargs

class TrainingConfig:
    EPOCHS = 30
    STEPS = 500
    
    LEARNING_RATE = 2e-4
    FINAL_RATE = 1e-5 # enable LR_SCHEDULE to use
    
    AUGMENTATIONS = [
        'brightness',
        'saturation',
        'contrast',
        'color',
        'translation',
        'cutout',
    ]
    
    # callbacks
    CHECKPOINTS = CallbackConfig(
        enabled=False,
        # prepend appropriate number of zeros to epoch number in filepath
        filepath='./checkpoints/{epoch:0%dd}.ckpt' % int(np.ceil(np.log10(EPOCHS+1))),
    )
    EVOLUTION_VISUAL = CallbackConfig(
        enabled=True,
        classes=['photo', 'monet'],
        frequency=1,
    )
    ALTERNATE = CallbackConfig(
        enabled=False,
    )
    LR_SCHEDULE = CallbackConfig(
        enabled=True,
    )

In [None]:
class UnetConfig:
    UNET_DEPTH = 5
    BASE_DEPTH = 4
    CONV_DEPTH = 3
    
    # residual or inception
    BASE_TYPE = 'residual'
    
    DROPOUT_RATE = 0.5
    
    ACTIVATION = 'leaky_relu'

In [None]:
class PatchConfig:
    DROPOUT_RATE = 0.5

In [None]:
class GANConfig:
    # weights for generator loss (adversarial, id, cycle)
    LOSS_WEIGHTS = tf.constant([1e0, 1e0, 1e0])

In [None]:
class VisConfig:
    # ideal matplotlib figure width for jupyter environment
    CELL_WIDTH = 16.0

In [None]:
class DataConfig:
    DB_NAME = 'gan-getting-started'
    LOCAL_PATH = f'../input/{DB_NAME}'
    
    MONET_DIRNAME = 'monet_tfrec'
    PHOTO_DIRNAME = 'photo_tfrec'
    DATA_DIRNAMES = [MONET_DIRNAME, PHOTO_DIRNAME]
    
    BATCH_SIZE = 8 # 128 # recommended for TPU v3-8 TODO: explore OOM errors
    BUFFER_SIZE = 300 # size of monet dataset should prevent patterns
    
    IMAGE_SHAPE = [256, 256, 3]
    IMAGE_MAX_RGB = 255
    
    SMALLEST_EPOCH = 0
    LARGEST_EPOCH = 1

# Distributed Computing

In [None]:
def distribute_build_strategy():
    """Creates a tf.distribute.Strategy object.

    Assesses the calling compute environment, and creates a Strategy object appropriate
    given the available accelerators. Preference will be granted to TPU accelerators,
    followed by GPU accelerators, before falling back on CPU computation.

    Returns:
        A tf.distribute.Strategy object for the calling compute environment.
        This will be of one of the following types
         - TPUStrategy (TPU Accelerator)
         - MirroredStrategy (GPU Accelerator)
         - _DefaultDistributionStrategy (No Accelerator)
    """
    # prefer TPU if available
    try:
        # resolver will throw ValueError over the lack of a tpu address if tpu not found
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        # connect to tpu system
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        return tf.distribute.TPUStrategy(tpu)
    except (ValueError):
        # resolver will throw ValueError over the lack of a tpu address if tpu not found
        pass

    # connect to GPU if TPU unavailable
    if tf.config.list_physical_devices('GPU'):
        return tf.distribute.MirroredStrategy()

    # fall back on CPU
    # TODO: evaluate MirroredStrategy for cpu
    return tf.distribute.get_strategy()

In [None]:
def distribute_is_tpu(strategy):
    """Evaluates whether the strategy uses a TPU cluster.

    Useful for storage interaction, as the TPU cluster will be in a cloud 
    environment and will not default to local memory.

    Args:
        strategy (tf.distribute.Strategy): the strategy to be evaluated

    Returns:
        bool: True if the supplied strategy operates on TPU accelerator(s)"""
    return isinstance(strategy, tf.distribute.TPUStrategy)

In [None]:
def distribute_loss(loss_fn, strategy=None):
    """Wraps a loss function with a strategy-aware reduction.

    Args:
        loss_fn (Callable[[tf.Tensor, tf.Tensor], tf.Tensor]): The loss 
            function which takes (y_true, y_pred) as arguments and returns
            the loss. Should be compatible with tensorflow's tf.function api
            and does not implement any reduction over the batch. For example,
            a tf.keras.losses.Loss object must be built with the reduction
            argument set to tf.keras.losses.Reduction.NONE.
        strategy (tf.distribute.Strategy): The distribution strategy for the 
            calling compute environment. Only needs to be suppplied when
            function is not called within strategy.scope

    Returns:
        Callable[[tf.Tensor, tf.Tensor], tf.Tensor]: the loss function wrapped
            with a strategy-aware reduction. Just like the input loss function,
            will take (y_true, y_pred) as arguments and return the loss as a
            Tensor scalar.
    """
    # capture current strategy
    if strategy is None:
        strategy = tf.distribute.get_strategy()
        
    # calculate reduction parameters
    global_batch_size = strategy.num_replicas_in_sync * DataConfig.BATCH_SIZE
        
    # build reduction wrapper
    @tf.function 
    def reduced_loss(y_true, y_pred):
        # flatten instances
        flat_shape = (-1, tf.math.reduce_prod(y_true.shape[1:]))
        y_true = tf.reshape(y_true, flat_shape)
        y_pred = tf.reshape(y_pred, flat_shape)

        # calculate reduced loss
        loss_by_instance = loss_fn(y_true, y_pred)
        return tf.reduce_sum(loss_by_instance) / global_batch_size

    return reduced_loss

# Visualizations

In [None]:
def vis_to_image_grid(images):
    """Transforms tensor of stacked image volumes into a single volume of all images 
    concatenated together.

    This function is not meant to be called directly, as it does not support any variety
    of tensor shapes. It is called by vis_image_grid, which handles more data validation.

    Args:
        images (tf.Tensor): tensor of stacked image volumes with rank 5 and indexed by
            (row, column, height, width, channel).

    Returns:
        tf.Tensor: tensor of shape (rows*height, cols*width, channel) with all images 
            concatenated into a single volume representing a grid of the images
    """
    # validate input
    if len(images.shape) != 5:
        raise ValueError(f'Invalid tensor rank. Expected rank 5, got {len(images.shape)}')

    # transform tensor into shape (row*height, col*width, channel)
    images = tf.transpose(images, [0, 2, 1, 3, 4])
    nshape = (images.shape[0]*images.shape[1], images.shape[2]*images.shape[3], -1)
    images = tf.reshape(images, nshape)

    # return concatenated images
    return images

In [None]:
def vis_image_grid(images):
    """Transforms Tensor of one or more image volumes into a single volume of all images 
    concatenated together.

    This function takes an input Tensor of one of the following shapes:
        (height, width) - 1 image
        (height, width, channel) - 1 image
        (col, height, width, channel) - col images 
        (row, col, height, width, channel) - row*col images
    The input Tensor is then transformed into a single image volume with all input images
    concatenated together as a grid. The output image tensor of rank 3 is indexed by
        (height, width, channel)

    Args:
        images (tf.Tensor): tensor of one or more stacked image volumes. This tensor must 
            be indexed by one of the following shapes:
             - (height, width): 1 image
             - (height, width, channel): 1 image
             - (col, height, width, channel): row of col images 
             - (row, col, height, width, channel): grid of row*col images

    Returns:
        tf.Tensor: tensor of shape (rows*height, cols*width, channel) with all images 
            concatenated into a single volume representing a grid of the images
    """
    # ensure 3d image tensor
    irank = images.shape
    if irank == 5:
        # tensor shape: ()
        # concatenate images
        images = vis_to_image_grid(images)
    elif irank == 4:
        # assume horizontal and add empty rows axis before concatenating
        images = images[None]
        images = vis_to_image_grid(images)
    elif irank == 2:
        # add channels axis
        images = images[..., None]
    elif irank != 3:
        raise ValueError(f'Images tensor has improper rank {irank}')

    return images

In [None]:
def vis_to_pil(img):
    """Transforms input image tensor into a PIL Image object.

    Args:
        img (Union[tf.Tensor, np.ndarray]): A single image volume of float 
            values in the range [0, 1].

    Returns:
        PIL.Image: The same image as a PIL.Image object
    """
    img = tf.cast(img * tf.constant([255.]), tf.uint8)
    img = img.numpy()
    img = PIL.Image.fromarray(img)
    return img

In [None]:
def vis_image_gallery(
    images, 
    img_titles=None, 
    row_titles=None,
    col_titles=None,
    channels=True,
):
    images = np.array(images)
    nrows, ncols = images.shape[:2]
    images = images.reshape(-1, *images.shape[2:])

    image_size = VisConfig.CELL_WIDTH / ncols
    _, axes = plt.subplots(figsize=(image_size*ncols, image_size*nrows), 
                           nrows=nrows, ncols=ncols)
    axes = np.array(axes).ravel()

    # initialize title arrays to Nones
    titles  = np.full_like(axes, None)
    xtitles = np.full_like(axes, None)
    ytitles = np.full_like(axes, None)

    # evaluate titles
    if col_titles is not None:
        titles[:len(col_titles)] = col_titles
        if img_titles is not None:
            xtitles[:] = img_titles
    elif img_titles is not None:
        titles[:] = img_titles
    if row_titles is not None:
        ytitles.reshape(nrows, ncols)[:, 0] = row_titles

    # draw images
    for img, ax, t, xt, yt in zip(images, axes, titles, xtitles, ytitles):
        vis_draw_image(img, ax)
        ax.set_title(t)
        ax.set_xlabel(xt)
        ax.set_ylabel(yt)

In [None]:
def vis_draw_image(img, ax=None):
    if ax is None:
        ax = plt.gca()

    ax.set_xticks([])
    ax.set_yticks([])

    for s in ax.spines:
        ax.spines[s].set_visible(False)

    ax.imshow(img)
    return ax

In [None]:
def vis_generator_evolution(
    images, 
    filepath=None,
    classes=None,
    epochs=None,
    add_alpha=False,
    separate_classes=True,
    vertical=False,
    display=True,
):
    """Visualizes the training evolution of an image generator.

    TODO: Further description necessary

    Args:
        images (tf.Tensor): rank 6 tensor of image volumes indexed by
            (class, image, epoch, height, width, channel)
        filepath (Optional[str]): filepath to save visual to. 
        classes (Optional[Iterable[str]]): list of classes corresponding
            to the first index of the images tensor for annotation
        epochs (Union[None, Iterable[int], Iterable[str]]): list of epochs 
            corresponding to the third index of the images tensor for annotaion
        separate_classes (bool): If True, a gap will be left between the classes
        vertical (bool): If True, timeline will extend downward along the visual 
            while classes and images are spread along horizontal axis.

    Returns:
        tf.Tensor: a single image volume containing the generator evolution
            visual.
    """
    # transform each class into an image grid 
    images = [vis_to_image_grid(img_cls) for img_cls in images]
    # add transparency channel if requested
    if add_alpha:
        images = [tf.concat([image, tf.fill((*image.shape[:-1], 1), 1.)], axis=-1) for image in images]
    # connect all images with gaps between classes
    final_image = images[0]
    for img in images[1:]:
        final_image = tf.concat([
            final_image, 
            tf.fill((DataConfig.IMAGE_SHAPE[0], *img.shape[1:]), 1.), 
            img
        ], axis=0)

    if filepath:
        vis_save_image(final_image, filepath)
    if display:
        vis_display_image(final_image)
    return final_image

In [None]:
def vis_display_image(img):
    IPython.display.display(vis_to_pil(img))

def vis_save_image(img, filepath):
    vis_to_pil(img).save(filepath)

# Data Pipeline

In [None]:
def data_load(
    data_path=DataConfig.LOCAL_PATH,
    data_dirs=DataConfig.DATA_DIRNAMES, 
    batch_size=DataConfig.BATCH_SIZE, 
    buffer_size=DataConfig.BUFFER_SIZE,
    prefetch=tf.data.AUTOTUNE, 
    repeat=True,
    seed=None
):
    """Load competition dataset.
    
    Args:
        data_path (str): The path to the database. If the strategy is using a TPU
            cluster for distribution, the database should be loaded from the gcs
            path. If the strategy is using local computation units (GPU or CPU),
            the database should be loaded from the local path.
        data_dirs (Iterable[str]): The list of directories within the database to
            load. Should all be directories containing tfrec files.
        batch_size (int): The number of instances that should be sampled at a 
            time. Specify batch_size < 1 to disable batching in the dataset.
        buffer_size (int): The number of instances in a dataset to buffer when
            shuffling. See tf.data.Dataset.shuffle for more information.
        prefetch (int): The number of instances to cache ahead of being requested.
            See tf.data.Dataset.prefetch for more information.
        repeat (bool): Whether or not each dataset should be repeated arbitrarily
            many times (e.g., for training purposes).
        seed (int): The seed supplied to tf.data.Dataset.shuffle. 
        
    Returns:
        tf.data.Dataset: A zipped dataset which includes tfrec data read from the 
            supplied directories. Each iteration will yield a tuple of tf.Tensor
            objects. The first tuple item will be the batch yielded from the 
            first supplied directory, the second item yielded from the second
            directory, and so on.
    """
    # append directories to database path
    data_dirs = (os.path.join(data_path, data_dir) for data_dir in data_dirs)

    # initialize empty list of loaded datasets
    datasets = []
    for data_dir in data_dirs:
        # create and configure dataset object
        ds = tf.data.TFRecordDataset(tf.io.gfile.glob(f'{data_dir}/*.tfrec'))
        ds = ds.map(data_tfrec_to_img)
        ds = ds.prefetch(prefetch).shuffle(buffer_size, seed=seed)
        if repeat:
            ds = ds.repeat()
        if batch_size > 0:
            ds = ds.batch(batch_size)
        # append to running list
        datasets.append(ds)

    # zip all loaded datasets
    dataset = tf.data.Dataset.zip(tuple(datasets))

    return dataset

In [None]:
def data_load_sample(
    data_path=DataConfig.LOCAL_PATH,
    data_dirs=DataConfig.DATA_DIRNAMES, 
    sample_size=DataConfig.BATCH_SIZE, 
    buffer_size=DataConfig.BUFFER_SIZE,
    seed=None
):
    """Load a sample from the competition dataset.
    
    Args:
        data_path (str): The path to the database. If the strategy is using a TPU
            cluster for distribution, the database should be loaded from the gcs
            path. If the strategy is using local computation units (GPU or CPU),
            the database should be loaded from the local path.
        data_dirs (Iterable[str]): The list of directories within the database to
            load. Should all be directories containing tfrec files.
        sample_size (int): The number of instances to sample from each directory.
        buffer_size (int): The number of instances in a dataset to buffer when
            shuffling. See tf.data.Dataset.shuffle for more information.
        seed (int): The seed supplied to tf.data.Dataset.shuffle. 
        
    Returns:
        Tuple[tf.Tensor]: The samples from each dataset. The first tuple item 
            will be the batch yielded from the first supplied directory, the 
            second item yielded from the second directory, and so on.
    """
    # append directories to database path
    data_dirs = (os.path.join(data_path, data_dir) for data_dir in data_dirs)

    # initialize empty list of loaded subsets
    samples = []
    for data_dir in data_dirs:
        # create and configure dataset object
        ds = tf.data.TFRecordDataset(tf.io.gfile.glob(f'{data_dir}/*.tfrec'))
        ds = ds.map(data_tfrec_to_img)
        ds = ds.shuffle(buffer_size, seed=seed)
        ds = ds.batch(sample_size)
        # sample and append to running list
        s = next(iter(ds))
        samples.append(s)
    
    return tuple(samples)

In [None]:
def data_tfrec_to_img(tfrec):
    """Translate a tf record containing a jpeg into the image tensor.
    
    Args:
        tfrec (TFRecord): The tensorflow record object to be parsed.
        
    Returns:
        tf.Tensor[tf.float32]: The image contained in the supplied TFRecord as a
            tensor of values in the range [0, 1] and with rank 3 shape indexed by
            (Height, Width, Channels).
    """
    encoded_image = tf.io.parse_single_example(tfrec, {
        'image': tf.io.FixedLenFeature([], tf.string)
    })['image']
    decoded_image = tf.io.decode_jpeg(encoded_image)
    return tf.cast(decoded_image, tf.float32) / DataConfig.IMAGE_MAX_RGB

In [None]:
def data_get_path(strategy=None):
    """Retrieves the database path
    
    Args:
        strategy (tf.distribute.Strategy): The distribution strategy for the 
            calling compute environment. Only needs to be suppplied when
            function is not called within strategy.scope
    
    Returns:
        str: The path to the competition database.
    """
    if strategy is None:
        strategy = tf.distribute.get_strategy()
    if distribute_is_tpu(strategy):
        return KaggleDatasets().get_gcs_path(DataConfig.DB_NAME)
    return DataConfig.LOCAL_PATH

# Image Augmentations

In [None]:
class Augmentor(tf.keras.layers.Layer):
    def __init__(
        self, 
        augmentations=TrainingConfig.AUGMENTATIONS,
        max_brightness_adjustment=0.5,
        max_saturation_adjustment=0.5,
        max_contrast_adjustment=0.5,
        max_color_adjustment=0.1,
        max_translation_adjustment=0.125,
        max_cutout_size=0.5,
        clip_values=True,
    ):
        super().__init__()
        
        self.max_brightness_adjustment = max_brightness_adjustment
        self.max_saturation_adjustment = max_saturation_adjustment
        self.max_contrast_adjustment = max_contrast_adjustment
        self.max_color_adjustment = max_color_adjustment
        self.max_translation_adjustment = max_translation_adjustment
        self.max_cutout_size = max_cutout_size
        
        self.clip_values = clip_values
        
        if type(augmentations) is str:
            augmentations = map(str.strip, augmentations.split(','))
        self.augmentations = list()
        for augmentation in augmentations:
            if hasattr(self, augmentation):
                self.augmentations.append(getattr(self, augmentation))
    
    def call(self, *batches):
        images = tf.concat(batches, 0)
        
        # randomize function order in tf.function-compatible manner
        for i in tf.random.shuffle(tf.range(len(self.augmentations))):
            for j in range(len(self.augmentations)):
                if i == j:
                    images = self.augmentations[j](images)
        
        # clipping destroys backprop, but prevents fitting to impossible data
        if self.clip_values:
            images = tf.clip_by_value(images, 0., 1.)
            
        return tf.split(images, len(batches))
    
    def brightness(self, images):
        # adjust mean pixel brightness
        num_images = tf.shape(images)[0]
        adjustment = tf.random.uniform([num_images, 1, 1, 1], -1., 1.) 
        adjustment *= self.max_brightness_adjustment
        images = images + adjustment
        return images
    
    def color(self, images):
        # adjust mean of each color
        num_images = tf.shape(images)[0]
        adjustment = tf.random.uniform([num_images, 1, 1, 3], -1., 1.)
        adjustment *= self.max_color_adjustment
        images = images + adjustment
        return images

    def contrast(self, images):
        # adjust variance of pixel brightness
        num_images = tf.shape(images)[0]
        adjustment = tf.random.uniform([num_images, 1, 1, 1], -1., 1.)
        adjustment = 1 + self.max_contrast_adjustment * adjustment
        brightness = tf.math.reduce_mean(images, axis=[-3, -2, -1], keepdims=True)
        images = (images - brightness) * adjustment + brightness
        return images

    def saturation(self, images):
        # adjust variance of each color
        num_images = tf.shape(images)[0]
        adjustment = tf.random.uniform([num_images, 1, 1, 1], -1., 1.)
        adjustment = 1 + self.max_saturation_adjustment * adjustment
        avg_colors = tf.math.reduce_mean(images, axis=-1, keepdims=True)
        images = (images - avg_colors) * adjustment + avg_colors
        return images
    
    # TODO: find a way to sample more than one cutout size per batch
    # would be easy with for loop... detrimental to GPU / TPU computation?
    # assumedly, these functions all use loops under the hood anyways
    def cutout(self, images):
        num_images = tf.shape(images)[0]
        image_size = tf.shape(images)[1:3]
        
        # sample random cutout locations
        centers = tf.random.uniform([num_images, 2]) * tf.cast(image_size, tf.float32)
        centers = tf.cast(centers, tf.int32)
        
        # sample random cutout size
        size = tf.random.uniform([2]) * tf.cast(image_size, tf.float32)
        size *= self.max_cutout_size
        
        # convert to grid indices
        indices = tf.ragged.range(tf.cast(-size / 2, tf.int32), tf.cast(size / 2 + 0.5, tf.int32))
        y_indices, x_indices = indices[0], indices[1]
        y_grid, x_grid = tf.meshgrid(y_indices, x_indices, indexing='ij')
        grid = tf.stack([y_grid, x_grid], axis=-1)
        
        # tile over image index
        indices = tf.tile(grid[None], [num_images, 1, 1, 1])
        
        # offset each image set by centers
        indices = indices + centers[:, None, None]
        indices %= image_size
        
        # prepend image index 
        image_index = tf.reshape(tf.range(num_images), (-1, 1, 1, 1))
        bcast_shape = tf.concat([tf.shape(indices)[:-1], [1]], axis=0)
        image_index = tf.broadcast_to(image_index, bcast_shape)
        indices = tf.concat([image_index, indices], axis=-1)
        
        # reshape to be index list
        indices = tf.reshape(indices, (-1, tf.shape(indices)[-1]))
        
        # mask
        zeros = tf.zeros((tf.shape(indices)[0], tf.shape(images)[-1]))
        images = tf.tensor_scatter_nd_update(images, indices, zeros)
        
        return images

    def translation(self, images):
        original_shape = images.shape
        num_images = tf.shape(images)[0]
        image_size = tf.shape(images)[1:3]
        
        adjustment = tf.random.uniform([num_images, 2], -1., 1.)
        adjustment *= self.max_translation_adjustment * tf.cast(image_size, tf.float32)
        adjustment = tf.cast(adjustment + 0.5, tf.int32)
        adjustment_y, adjustment_x = tf.split(adjustment, 2, axis=1)
        
        rows, cols = tf.range(image_size[0])[None], tf.range(image_size[1])[None]
        rows, cols = tf.tile(rows, [num_images, 1]), tf.tile(cols, [num_images, 1])
        rows, cols = rows - adjustment_y, cols - adjustment_x
        
        # shift indices up by one to separate valid from clipped values
        rows = tf.clip_by_value(rows+1, 0, image_size[0]+1)[..., None]
        cols = tf.clip_by_value(cols+1, 0, image_size[1]+1)[..., None]
        
        # add padding which will be selected by clipped values
        aug_images = tf.pad(images, [[0, 0], [1, 1], [1, 1], [0, 0]])
        aug_images = tf.gather_nd(aug_images, rows, batch_dims=1)
        aug_images = tf.transpose(aug_images, [0, 2, 1, 3])
        aug_images = tf.gather_nd(aug_images, cols, batch_dims=1)
        aug_images = tf.transpose(aug_images, [0, 2, 1, 3])
        images = tf.ensure_shape(aug_images, images.shape)
        return images

In [None]:
test_images = data_load_sample(data_dirs=[DataConfig.PHOTO_DIRNAME,], sample_size=5)[0]
aug_images = Augmentor()(test_images)[0]

images = tf.stack([test_images, aug_images])
images = tf.transpose(images, [1, 0, 2, 3, 4])

vis_image_gallery(images)
plt.savefig('example-augmentations.png')
plt.show()

# Generator

In [None]:
def Unet(
    input_shape=DataConfig.IMAGE_SHAPE,
    unet_depth=UnetConfig.UNET_DEPTH, 
    base_depth=UnetConfig.BASE_DEPTH, 
    conv_depth=UnetConfig.CONV_DEPTH, 
    dropout_rate=UnetConfig.DROPOUT_RATE, 
    activation=UnetConfig.ACTIVATION, 
    base=UnetConfig.BASE_TYPE,
):
    # Input Layer
    inputs = tf.keras.Input(input_shape)

    # initialize empty list to track unet skip connections
    skips = list()

    # Down Stack
    dn_stack = inputs
    for level in range(unet_depth):
        # skip
        skips.append(dn_stack)
        # downsample and increase filters
        filters_in = dn_stack.shape[-1]
        filters_out = unet_filters_at_level(level+1)
        dn_stack = unet_downsample(filters_in, filters_out)(dn_stack)

    # Base Stack
    base_stack = dn_stack
    base_shape = dn_stack.shape[1:]
    if base == 'residual':
        base_stack = unet_residual_base(base_shape, base_depth)(base_stack)
    # TODO: inception base option
    elif base == 'inception':
        base_stack = unet_inception_base(base_shape, base_depth)(base_stack)
    # TODO: xception base option
    elif base == 'xception':
        base_stack = unet_xception_base(base_shape, base_depth)(base_stack)

    # Up Stack
    up_stack = base_stack
    for level, skip in reversed(list(enumerate(skips))):
        # upsample and decrease filters
        filters_in = up_stack.shape[-1]
        filters_out = skip.shape[-1]
        up_stack = unet_upsample(filters_in, filters_out)(up_stack)
        # concatenate (or add) skip connection along channel axis
        up_stack = tf.keras.layers.Add()([up_stack, skip])

    # Output
    outputs = up_stack
    # convolve reconstructed pixels with originals
    outputs = unet_upsample(
        filters_in=3, 
        filters_out=3, 
        dropout_rate=0, 
        sample_rate=1, 
        activation='tanh'
    )(outputs)
    # rescale pixel values from [-1, 1] of tanh output to [0, 1]
    outputs = tf.keras.layers.Rescaling(scale=0.5, offset=0.5)(outputs)

    # Assemble Model
    unet = tf.keras.Model(inputs=inputs, outputs=outputs)
    return unet

# TODO: put in primary config for easy experimentation
def unet_filters_at_level(level):
    return min(32 * (2**level), 512)

def unet_downsample(
    filters_in,
    filters_out,
    depth=3,
    kernel_size=3,
    sample_rate=2,
    dropout_rate=UnetConfig.DROPOUT_RATE,
    normalize=True,
    activation='leaky_relu',
    kernel_initializer='glorot_uniform',
):
    block = tf.keras.Sequential()
    
    # dropout
    if dropout_rate:
        block.add(tf.keras.layers.SpatialDropout2D(dropout_rate))

    # conv layers
    for _ in range(depth-1):
        block.add(tf.keras.layers.Conv2D(
            filters_in,
            kernel_size,
            padding='same',
            kernel_initializer=kernel_initializer
        ))

    # downsample layer
    block.add(tf.keras.layers.Conv2D(
        filters_out,
        kernel_size,
        strides=sample_rate,
        padding='same',
        kernel_initializer=kernel_initializer,
    ))

    # normalize
    if normalize:
        block.add(tfa.layers.InstanceNormalization())

    # activate
    if activation:
        block.add(tf.keras.layers.Activation(activation))

    return block

def unet_upsample(
    filters_in,
    filters_out,
    depth=3,
    kernel_size=3,
    sample_rate=2,
    dropout_rate=UnetConfig.DROPOUT_RATE,
    normalize=True,
    activation='leaky_relu',
    kernel_initializer='glorot_uniform',
):
    block = tf.keras.Sequential()
    
    # dropout
    if dropout_rate:
        block.add(tf.keras.layers.SpatialDropout2D(dropout_rate))

    # conv layers
    for _ in range(depth-1):
        block.add(tf.keras.layers.Conv2D(
            filters_in,
            kernel_size,
            padding='same',
            kernel_initializer=kernel_initializer,
        ))

    # upsample layer
    block.add(tf.keras.layers.Conv2DTranspose(
        filters_out,
        kernel_size,
        strides=sample_rate,
        padding='same',
        kernel_initializer=kernel_initializer,
    ))

    # normalize
    if normalize:
        block.add(tfa.layers.InstanceNormalization())

    # activate
    if activation:
        block.add(tf.keras.layers.Activation(activation))

    return block

def unet_residual_base(
    base_shape, 
    stack_depth, 
    kernel_size=3,
    dropout_rate=UnetConfig.DROPOUT_RATE,
    activation=UnetConfig.ACTIVATION, 
    compression_factor=0.5,
    preactivation=True, 
    bottleneck=True, 
    normalize=True,
):
    stack = stack_input = tf.keras.Input(base_shape)

    normalized = (lambda stack: tfa.layers.InstanceNormalization()(stack))
    activated = (lambda stack: tf.keras.layers.Activation(activation)(stack))
    actnorm = (lambda stack: activated(normalized(stack)))

    base_filters = base_shape[-1]
    neck_filters = int(base_filters * compression_factor)        
    for _ in range(stack_depth):
        block = block_input = stack

        if dropout_rate:
            block = tf.keras.layers.SpatialDropout2D(dropout_rate)(block)

        if preactivation:
            block = activated(block)

        if bottleneck:
            # compression conv
            block = tf.keras.layers.Conv2D(neck_filters, 1)(block)
            block = actnorm(block)
            # bottleneck conv
            block = tf.keras.layers.Conv2D(neck_filters, kernel_size, padding='same')(block)
            block = actnorm(block)
            # expansion conv
            block = tf.keras.layers.Conv2D(base_filters, 1)(block)
            block = normalized(block)
        else:
            block = tf.keras.layers.Conv2D(base_filters, kernel_size, padding='same')(block)
            block = actnorm(block)
            block = tf.keras.layers.Conv2D(base_filters, kernel_size, padding='same')(block)
            block = normalized(block)

        stack = tf.keras.layers.Add()([block, block_input])

        if not preactivation:
            stack = activated(stack)
    
    return tf.keras.Model(inputs=stack_input, outputs=stack)

# Discriminator 

In [None]:
class PatchDiscriminator:
    def build(
        dropout_rate=PatchConfig.DROPOUT_RATE,
    ):
        """
        :param depth: int. The number of convolutional layers to stack.
        :param kernel_widths: int, Iterable[int]. The size(s) of the filters.
        If int, the model will contain depth convolutional layers, each with
        identical kernel widths as specified. If Iterable[int], their will be 
        len(kernel_widths) convolutional layers with associated kernel widths.
        :param patch_size: int. The receptive field of each of the neurons in
        the final output layer.
        :param init_filters: int. The number of filters in the first convolutional
        layer. After that, filters will be halved in each layer.
        :param min_filters: int. Lower bound on number of filters in a convolutional
        layer. All layers will output at least min_filters channels except for the 
        final layer (which will output one channel as the patch prediction).
        """
#         # cast kernel_widths to tensor for consistency
#         if type(kernel_widths) is int:
#             kernel_widths = tf.fill((depth), kernel_widths)
#         else:
#             kernel_widths = tf.constant(kernel_widths)
#             depth = kernel_widths.shape[0]
            
#         # calculate stride for given patch_size, kernel_widths, and depth
#         strides = tf.constant(kernel_widths)
        
#         # reduce number of filters at each layer by factor of 2
#         channel_depths = init_filters // (2**tf.range(depth))
#         # apply lower bound
#         channel_depths = tf.where(channel_depths < min_filters, min_filters, channel_depths)

#         channel_depths=[ 64, 128, 256, 528,   1]
#         kernel_widths= [  5,   5,   3,   3,   3]
#         strides=       [  2,   2,   2,   2,   2]
        # final receptive field of 69x69
    
        channel_depths = [ 64, 256, 528,   1]
        kernel_widths  = [  5,   5,   3,   3]
        strides        = [  3,   3,   2,   2]
        # final receptive field of 71x71
        
        # build convolutional model
        layers = [tf.keras.Input(DataConfig.IMAGE_SHAPE)]
        for filters, width, stride in zip(channel_depths, kernel_widths, strides):
            layers.extend([
                tf.keras.layers.SpatialDropout2D(dropout_rate),
                tf.keras.layers.Conv2D(filters, width, stride),
                tfa.layers.InstanceNormalization(),
                tf.keras.layers.Activation('leaky_relu')
            ])
        layers[-1] = tf.keras.layers.Activation('sigmoid')
        
        model = tf.keras.Sequential(layers)
        return model

# Optimizers

In [None]:
class AdamsFamily:
    def __init__(self,
        family_size=4,
        learning_rate=0.001,
        beta_1=0.9,
        beta_2=0.999,
        epsilon=1e-07,
        amsgrad=False,
        name='Adam',
        **kwargs,
    ):
        learning_rate = tf.Variable(learning_rate)
        self.adams = [tf.keras.optimizers.Adam(
            learning_rate,
            beta_1,
            beta_2,
            epsilon,
            amsgrad,
            name,
            **kwargs,
        ) for _ in range(family_size)]
        self.learning_rate = learning_rate
        
    @property
    def lr(self):
        return self.learning_rate
    
    @lr.setter
    def lr(self, learning_rate):
        self.learning_rate.assign(learning_rate)
        
    def __iter__(self):
        self._index = -1
        return self
    
    def __next__(self):
        self._index += 1
        return self.adams[self._index]

# CycleGAN

In [None]:
class CycleGAN(tf.keras.Model):
    def __init__(self, m_gen, p_gen, m_dis, p_dis, aug):
        super().__init__()
        self.m_gen = m_gen
        self.p_gen = p_gen
        self.m_dis = m_dis
        self.p_dis = p_dis
        self.aug = aug
        
    def compile(self, optimizer, id_loss, cycle_loss, dis_loss, gen_loss_weights):
        super().compile()
        self.optimizer = optimizer
        self.id_loss = id_loss
        self.cycle_loss = cycle_loss
        self.dis_loss = dis_loss
        self.gen_loss_weights = gen_loss_weights
    
    def train_step(self, data):
        # read tupled batch data = (monet_batch, photo_batch)
        m_real, p_real = data
        
        # Progress batch data through CycleGAN process
        with tf.GradientTape(persistent=True) as g_tape:
            # identity outputs
            m_id = self.m_gen(m_real, training=True)
            p_id = self.p_gen(p_real, training=True)
            
            # identity loss
            m_id_loss = self.id_loss(m_real, m_id)
            p_id_loss = self.id_loss(p_real, p_id)
            
            # transfer outputs
            m_fake = self.m_gen(p_real, training=True)
            p_fake = self.p_gen(m_real, training=True)
            
            # cycle outputs
            m_cycle = self.m_gen(p_fake, training=True)
            p_cycle = self.p_gen(m_fake, training=True)
            
            # cycle loss
            m_cycle_loss = self.cycle_loss(m_real, m_cycle)
            p_cycle_loss = self.cycle_loss(p_real, p_cycle)
            cycle_loss = m_cycle_loss + p_cycle_loss
            
            # differentiable augmentations
            m_real, m_fake = self.aug(m_real, m_fake)
            p_real, p_fake = self.aug(p_real, p_fake)
            
            # discriminator outputs
            m_dis_real = self.m_dis(m_real, training=True)
            m_dis_fake = self.m_dis(m_fake, training=True)
            p_dis_real = self.p_dis(p_real, training=True)
            p_dis_fake = self.p_dis(p_fake, training=True)
            
            # discriminator loss            
            m_dis = tf.concat([m_dis_real, m_dis_fake], 0)
            p_dis = tf.concat([p_dis_real, p_dis_fake], 0)
            
            labels_real =  tf.ones_like(m_dis_real)
            labels_fake = tf.zeros_like(m_dis_fake)
            labels = tf.concat([labels_real, labels_fake], 0)
            
            m_dis_loss = self.dis_loss(labels, m_dis)
            p_dis_loss = self.dis_loss(labels, p_dis)
            
            # generator loss
            m_gen_loss = self.dis_loss(labels_real, m_dis_fake)
            p_gen_loss = self.dis_loss(labels_real, p_dis_fake)
            
            m_gen_loss = tf.tensordot(
                tf.stack([m_gen_loss, m_id_loss, cycle_loss]),
                self.gen_loss_weights, 1
            )
            p_gen_loss = tf.tensordot(
                tf.stack([p_gen_loss, p_id_loss, cycle_loss]),
                self.gen_loss_weights, 1
            )
            
        # collect model losses and variables
        models = [self.m_gen, self.p_gen, self.m_dis, self.p_dis]
        losses = [m_gen_loss, p_gen_loss, m_dis_loss, p_dis_loss]
        variables = [model.trainable_variables for model in models]
        
        # apply backpropagation
        for model_loss, model_vars, adam in zip(losses, variables, self.optimizer):
            grads = g_tape.gradient(model_loss, model_vars)
            adam.apply_gradients(zip(grads, model_vars))
        
        # return losses and metrics
        return {
            'monet_id_loss': m_id_loss,
            'photo_id_loss': p_id_loss,
            'monet_cycle_loss': m_cycle_loss,
            'photo_cycle_loss': p_cycle_loss,
            'monet_discriminator_loss': m_dis_loss,
            'photo_discriminator_loss': p_dis_loss
        }
    
    def call(self, x, output_class='monet'):
        if output_class == 'monet':
            return self.m_gen(x)
        if output_class == 'photo':
            return self.p_gen(x)

# Callbacks

In [None]:
class VisualizeCycleGanEvolution(tf.keras.callbacks.Callback):
    DEFAULT_FILEPATH = './cycle-gan-evolution.png'
    def __init__(self, test_images, classes=None, frequency=1, filepath=DEFAULT_FILEPATH, 
                 separate_classes=True, show_initial=True):
        """
        :param test_images: tensor containing the batch of images to test on. If multiple 
        classes are being visualized, test_images should be an iterable containing a batch 
        for each class. Index of batches should match the index of the class in classes 
        for which each batch is to be transformed into.
        :param classes: None, str, or Iterable[str]. The name(s) of the classes explored by
        the CycleGAN model. Will each be used as an argument to the __call__ method
        of the CycleGAN. If None, length of classes is assumed to be 1, and the model will
        be called with no other arguments.
        :param frequency: int or Iterable[int]. If single int, test will be run at 
        the end of every epoch such that 'epoch % frequency == 0' evaluates to True. If 
        Iterable, test will be run whenever 'epoch in frequency' evaluates to True. Epoch
        in this consideration will begin at one - not zero.
        :param filepath: str. The location at which to save the resulting image.
        :param separate_classes: bool. If true, each class will be saved as a separate 
        image with the class prepended to the file name.
        :param show_initial: bool. If true, will include initial predictions of the gan 
        model (before any training occurs).
        """
        super().__init__()
        
        # ensure classes is Iterable[str]
        if classes is None or type(classes) is str:
            classes = [classes,]
            
        # images tensor should be of shape (epoch, class, image, height, width[, channels])
        if len(classes) == 1:
            self.images = test_images[None, None]
        else:
            self.images = tf.stack(test_images)[None]
                    
        # process separate_classes and filepath
        if separate_classes and len(classes) > 1:
            name_index = max(0, filepath.rfind('/')+1)
            self.filepaths = [filepath[:name_index] + class_name + '-' + filepath[name_index:] 
                              for class_name in classes]
        else:
            self.filepaths = [filepath,]
            
        # assign remaining args to attributes
        self.classes = classes
        self.frequency = frequency
        
    def on_train_begin(self, logs=None):
        # collect initial transformations
        self.images = self._collect_images(self.images)
    
    def on_epoch_end(self, epoch, logs=None):
        # check if frequency dictates this epoch to be detailed
        epoch += 1
        if (
            type(self.frequency) is int and epoch % self.frequency == 0 or
            hasattr(self.frequency, '__iter__') and epoch in self.frequency
        ):
            self.images = self._collect_images(self.images)
    
#     @tf.function # TODO: either finish tensorflowizing this, or reformat to need no arguments or returns
    def _collect_images(self, images):
        # initialize new tensor with shape (0, height, width[, channels])
        new_images = tf.zeros([0, *images.shape[3:]], dtype=images.dtype)
        
        # iterate over classes and images
        for c, cla in enumerate(self.classes):
            # extract original images
            oimgs = images[0, c]
            # transform image batch (with class name as argument if available)
            nimgs = self.model(oimgs, cla) if cla else self.model(oimgs)
            # concatenate along image axis
            new_images = tf.concat((new_images, nimgs), axis=0)
            
        # add epoch and class axes to tensor
        new_images = tf.reshape(new_images, (len(self.classes), -1, *new_images.shape[1:]))[None]
        # concatenate existing epoch data with new
        return tf.concat((images, new_images), axis=0)
    
    def on_train_end(self, logs=None):
        rank = len(self.images.shape)
        # ensure channels axis exists
        if rank == 5:
            self.images = self.images[..., None]
            
        # reshape images from (epoch, class, image, height, width, channels)
        #                  to (class, image, epoch, height, width, channels)
        images = tf.transpose(self.images, [1, 2, 0, 3, 4, 5])
        
        if len(self.filepaths) > 1:
            [vis_generator_evolution(img[None], fp) for img, fp in zip(images, self.filepaths)]
        else:
            vis_generator_evolution(images, self.filepaths[0])

In [None]:
class AlternateTraining(tf.keras.callbacks.Callback):
    def on_train_batch_begin(self, batch, logs=None):
        gen_batch = (batch % 2 == 0)
        self.model.m_gen.trainable = self.model.p_gen.trainable = gen_batch
        self.model.m_dis.trainable = self.model.p_dis.trainable = not gen_batch

# Implementation

In [None]:
strategy = distribute_build_strategy()

In [None]:
# ensure checkpoints directory exists
if TrainingConfig.CHECKPOINTS.enabled:
    checkpoint_fp = TrainingConfig.CHECKPOINTS.kwargs['filepath']
    checkpoint_dir = os.path.join(os.path.split(checkpoint_fp)[:-1])
    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)

callbacks = []
with strategy.scope(): # for callbacks in need of distribute knowledge (*cough cough* checkpoints *cough*)
    if TrainingConfig.CHECKPOINTS.enabled:
        optkwargs = dict(experimental_io_device='/job:localhost')
        if TrainingConfig.CHECKPOINTS.kwargs.get('save_weights_only', False):
            options = tf.train.CheckpointOptions(optkwargs)
        else:
            options = tf.saved_model.SaveOptions(optkwargs)
        callbacks.append(tf.keras.callbacks.ModelCheckpoint(
            **TrainingConfig.CHECKPOINTS.kwargs,
            options=options,
        ))
    if TrainingConfig.EVOLUTION_VISUAL.enabled:
        callbacks.append(VisualizeCycleGanEvolution(
            **TrainingConfig.EVOLUTION_VISUAL.kwargs,
            test_images=data_load_sample(
                data_get_path(),
                sample_size=5
            ),
        ))
    if TrainingConfig.ALTERNATE.enabled:
        callbacks.append(AlternateTraining())
    if TrainingConfig.LR_SCHEDULE.enabled:
        def schedule(epoch, lr):
            """Decreases learning rate exponentially according to an initial and final rate"""
            LR0 = TrainingConfig.LEARNING_RATE
            LRF = TrainingConfig.FINAL_RATE
            EPOCHS = TrainingConfig.EPOCHS
            attenuation = tf.math.pow(LR0 / LRF, epoch / EPOCHS)
            lr = LR0 / attenuation
            return lr
        callbacks.append(tf.keras.callbacks.LearningRateScheduler(
            **TrainingConfig.LR_SCHEDULE.kwargs,
            schedule=schedule
        ))

In [None]:
with strategy.scope():
    # instantiate models
    gan = CycleGAN(
        m_gen = Unet(),
        p_gen = Unet(),
        m_dis = PatchDiscriminator.build(),
        p_dis = PatchDiscriminator.build(),
        aug   = Augmentor()
    )
    
    # reduction within distribute strategy restricted only to NONE or SUM
    NONE=tf.keras.losses.Reduction.NONE
    
    # instantiate optimizer and losses
    gan.compile(
        optimizer = AdamsFamily(learning_rate=TrainingConfig.LEARNING_RATE),
        id_loss = distribute_loss(tf.keras.losses.MeanAbsoluteError(reduction=NONE)),
        cycle_loss = distribute_loss(tf.keras.losses.MeanAbsoluteError(reduction=NONE)),
        dis_loss = distribute_loss(tf.keras.losses.BinaryCrossentropy(reduction=NONE)),
        gen_loss_weights = GANConfig.LOSS_WEIGHTS,
    )

In [None]:
# document model for journal
# gan.m_gen.summary()
# gan.m_dis.summary()

tf.keras.utils.plot_model(gan.m_gen, 'generator_model.png', show_shapes=True, expand_nested=True)
tf.keras.utils.plot_model(gan.m_dis, 'discriminator_model.png', show_shapes=True, expand_nested=True);

In [None]:
# train model
history = gan.fit(
    x=data_load(data_get_path(strategy)),
    epochs=TrainingConfig.EPOCHS,
    steps_per_epoch=TrainingConfig.STEPS,#tf.math.ceil(6686 / data.compute_global_batch_size()),
    callbacks=callbacks
)

# Generate Images

In [None]:
photos = data_load(
    data_get_path(strategy), 
    data_dirs=[DataConfig.PHOTO_DIRNAME,], 
    buffer_size=1,
    repeat=False,
)

os.mkdir('../fresh-monets')
filepath='../fresh-monets/{:04d}.png'

for batch_num, photo_batch in enumerate(photos):
    monet_batch = gan(photo_batch)
    image_offset = batch_num * DataConfig.BATCH_SIZE
    for image_num, monet in enumerate(monet_batch):
        vis_save_image(monet[0], filepath.format(image_offset + image_num))
    
shutil.make_archive('./images', 'zip', '../fresh-monets');

In [None]:
photos = data_load_sample(
    data_get_path(strategy), 
    data_dirs=[DataConfig.PHOTO_DIRNAME,], 
    sample_size=5,
    seed=0,
)[0]

monets = gan(photos)

images = np.array([photos, monets]).swapaxes(0, 1)

vis_image_gallery(images, col_titles=['Photo', 'Monet'])
plt.savefig('final-test.png')
plt.show()

In [None]:
os.mkdir('./models')
gan.m_gen.save('./models/monet-generator')
gan.p_gen.save('./models/photo-generator')
gan.m_dis.save('./models/monet-discriminator')
gan.p_dis.save('./models/photo-discriminator')