## Imports

In [None]:
import os
import sys
import datetime
from enum import Enum
import math
from typing import List, Dict, Optional, Tuple

In [None]:
import tensorflow as tf
from tensorflow.python.keras.utils import conv_utils

In [None]:
import argparse
from create_datetime_str import create_datetime_str
from default_to import default_to
from limit_gpu_memory_usage import limit_gpu_memory_usage
from checkpointer import Checkpointer
import math
import os
from perceptual_difference import create_perceptual_difference_model
import tensorflow as tf
from tensor_ops import pixel_norm, lerp
from train import TrainingState
from typing import List, Optional

In [None]:
import functools
from default_to import default_to
from generator_visualizer_callback import GeneratorVisualizer
from tf_image_records import decode_record_image
from typing import List, Optional, Callable, Tuple

## File IO

In [None]:
def write_text_file(file_path: os.PathLike, content: str, open_mode: str = 'w') -> None:
    ensure_dir_for_file(file_path)
    with open(file_path, open_mode) as f:
        f.write(content)


def write_binary_file(file_path: os.PathLike, content: bytes, open_mode: str = 'wb') -> None:
    ensure_dir_for_file(file_path)
    with open(file_path, open_mode) as f:
        f.write(content)


def write_lines_to_file(file_path: os.PathLike, lines: Iterable[str]):
    ensure_dir_for_file(file_path)
    write_text_file(file_path, '\n'.join(lines))


def read_text_file(file_path: os.PathLike, open_mode: str = 'r') -> str:
    with open(file_path, open_mode) as f:
        return f.read()


def read_binary_file(file_path: os.PathLike, open_mode: str = 'rb') -> bytes:
    with open(file_path, open_mode) as f:
        return f.read()

## Basic Configs

In [None]:
def try_create_tpu_strategy(tpu_name=''):
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect(tpu=tpu_name)
        return tf.distribute.TPUStrategy(tpu)
    except ValueError:
        return None


def create_strategy(tpu_name=''):
    return default_to_fn(
        try_create_tpu_strategy(tpu_name=tpu_name),
        tf.distribute.MirroredStrategy)

## Helpers

In [None]:
def create_datetime_str() -> str:
    return f'{datetime.datetime.now():%Y-%m-%d_%H_%M_%S_%f}'


In [None]:
def ensure_dir(dir_name):
    if not os.path.isdir(dir_name):
        os.makedirs(dir_name)


def ensure_dir_for_file(file_path):
    ensure_dir(os.path.dirname(file_path))

In [None]:
T = TypeVar('T')


def default_to_fn(optional_value: Optional[T], make_default: Callable[[], T]) -> T:
    return make_default() if optional_value is None else optional_value


def default_to(optional_value: Optional[T], default: T) -> T:
    return default_to_fn(optional_value, lambda: default)

In [None]:
def save_image(path: os.PathLike, image: tf.Tensor) -> None:
    image = tf.convert_to_tensor(image)
    if image.dtype != tf.uint8:
        image = tf.image.convert_image_dtype(image, tf.uint8, saturate=True)
    bk_io.write_binary_file(path, tf.io.encode_png(image).numpy())

In [None]:
def encode_record_image(image: tf.Tensor):
    image_shape = image.shape.as_list()
    assert image.dtype == tf.uint8
    assert len(image_shape) == 3
    for x in image_shape:
        assert x is not None
        assert type(x) == int
    image_bytes = tf.io.encode_png(image).numpy()
    return tf.train.Example(
        features=tf.train.Features(
            feature={
                'image_shape': tf.train.Feature(int64_list=tf.train.Int64List(value=image_shape)),
                'image_bytes': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_bytes]))
                })
        ).SerializeToString()


def decode_record_image(record_bytes):
    schema = {
        'image_shape': tf.io.FixedLenFeature([3], dtype=tf.int64),
        'image_bytes': tf.io.FixedLenFeature([], dtype=tf.string)
        }
    example = tf.io.parse_single_example(record_bytes, schema)
    image = tf.io.decode_image(example['image_bytes'])
    image = tf.reshape(image, tf.cast(example['image_shape'], tf.int32))
    return image


In [None]:
def serialize(unserialized: object) -> bytes:
    return pickle.dumps(unserialized)


def deserialize(serialized: bytes):
    return pickle.loads(serialized)


'''
    I'm using np.save() over tf.io.serialize_tensor because parse_tensor
    (https://www.tensorflow.org/api_docs/python/tf/io/parse_tensor) requires the caller to know
    the correct dtype whereas np.save() does not.

    Adapted from:
    https://stackoverflow.com/questions/30698004/how-can-i-serialize-a-numpy-array-while-preserving-matrix-dimensions
'''
def serialize_array(arr: np.ndarray) -> bytes:
    mem_file = io.BytesIO()

    np.save(mem_file, arr)

    mem_file.seek(0)
    return mem_file.read()


def deserialize_array(serialized: bytes) -> np.ndarray:
    mem_file = io.BytesIO()

    mem_file.write(serialized)

    mem_file.seek(0)
    return np.load(mem_file)


def serialize_tensor(t: tf.Tensor) -> bytes:
    return serialize_array(t.numpy())


def serialize_variable(v: tf.Variable) -> bytes:
    return serialize_tensor(v)


def deserialize_tensor(serialized: bytes) -> tf.Tensor:
    return tf.constant(deserialize_array(serialized))


def deserialize_variable(serialized: bytes) -> tf.Variable:
    return tf.Variable(deserialize_array(serialized))


'''
    Annoyingly, keras does not provide a way to serialize a model directly to a bytes object. The
    save()/load_model() API requires a local filesystem, which is commonly not available cloud TPUs.
    The to_json() API does not store weights and stores intermediate dtype information so that a
    model trained on bfloat16 cannot be loaded on a machine that doesn't support bfloat16. For now,
    we'll just stored the weights and assume the deserializer knows how to create the model
    architecture.
'''
def serialize_model(model: tf.keras.Model) -> bytes:
    return serialize({
        'model_weights': list(map(serialize_array, model.get_weights()))
    })


def deserialize_model(serialized: bytes, create_model: Callable[[], tf.keras.Model]) -> tf.keras.Model:
    serializable = deserialize(serialized)
    model = create_model()
    model.set_weights(list(map(deserialize_array, serializable['model_weights'])))
    return model

## Callback

In [None]:
class GeneratorVisualizer(tf.keras.callbacks.Callback):
    def __init__(
            self,
            strategy: tf.distribute.Strategy,
            grid_size: Tuple[int, int],
            generator: tf.keras.Model,
            noise: tf.Tensor,
            on_image_callbacks: List[Callable[[int, tf.Tensor], None]] = list(),
            update_interval: int = 1,
            replica_batch_size: int = 8
            ):
        self.generator = generator
        self.grid_size = grid_size
        self.image_count = grid_size[0] * grid_size[1]
        self.strategy = strategy
        self.replica_batch_size = replica_batch_size
        self.noise = noise
        assert self.noise.shape[0] == grid_size[0] * grid_size[1]
        self.update_interval = update_interval
        self.on_image_callbacks = on_image_callbacks

    @tf.function
    def predict(self, samples):
        return self.generator(samples, training=False)

    def predict_in_batches(self) -> tf.Tensor:
        global_batch_size = self.replica_batch_size * self.strategy.num_replicas_in_sync
        sample_count = self.noise.shape[0]
        batch_count = math.ceil(sample_count / global_batch_size)
        padded_sample_count = batch_count * global_batch_size
        padded_input = tf.concat(
            [self.noise,
            tf.zeros((padded_sample_count - sample_count,) + self.noise.shape[1:], dtype=self.noise.dtype)],
            0)

        dataset = self.strategy.experimental_distribute_dataset(
            tf.data.Dataset.from_tensor_slices(padded_input).batch(global_batch_size))

        batch_outputs = []

        for batch_samples in dataset:
            batch_outputs.append(
                self.strategy.gather(
                    self.strategy.run(self.predict, args=(batch_samples,)),
                    0))

        padded_output = tf.concat(batch_outputs, 0)
        return padded_output[:sample_count]

    def generate_image(self) -> tf.Tensor:
        # Note: calling self.generator.predict(self.noise, batch_size=self.batch_size)
        # caused the next iterations of the training loop to generate NaNs with a
        # high probability, especially on larger resolutions. I didn't find an
        # obvious reason for that. And since it's incredibly tangential to the
        # course topic, I went with a simple work-around.
        images = self.predict_in_batches()
        images = tf.image.convert_image_dtype(images, tf.uint8, saturate=True)

        return tf.reshape(
            images,
            tf.concat(
                [self.grid_size, tf.shape(images)[1:]],
                0))

    def on_epoch_end(self, epoch_i: int, logs=None) -> None:
        if epoch_i % self.update_interval != 0:
            return

        image = self.generate_image()

        for callback in self.on_image_callbacks:
          callback(epoch_i, image)


In [None]:
def create_blur_filter(dtype):
    return tf.constant(
        [[0.015625, 0.046875, 0.046875, 0.015625],
         [0.046875, 0.140625, 0.140625, 0.046875],
         [0.046875, 0.140625, 0.140625, 0.046875],
         [0.015625, 0.046875, 0.046875, 0.015625]],
        dtype=dtype)


def blur(x, strides=1):
    channel_count = x.shape[3]
    filter = create_blur_filter(x.dtype)[:, :, tf.newaxis, tf.newaxis]
    filter = tf.tile(filter, [1, 1, channel_count, 1])

    return tf.nn.depthwise_conv2d(x, filter, strides=[1, strides, strides, 1], padding='SAME')


def upsample(x):
    def upsample_with_zeros(x):
        in_height = x.shape[1]
        in_width = x.shape[2]
        channel_count = x.shape[3]

        out_height = in_height * 2
        out_width = in_width * 2
        x = tf.reshape(x, [-1, in_height, 1, in_width, 1, channel_count])
        x = tf.pad(x, [[0, 0], [0, 0], [0, 1], [0, 0], [0, 1], [0, 0]])
        return tf.reshape(x, [-1, out_height, out_width, channel_count])

    return blur(upsample_with_zeros(x*4.))


def downsample(x):
    return blur(x, strides=2)


def pixel_norm(x: tf.Tensor, epsilon: float = 1e-7) -> tf.Tensor:
    original_dtype = x.dtype
    x = tf.cast(x, tf.float32)
    normalized = x / tf.math.sqrt(tf.reduce_mean(tf.square(x), axis=-1, keepdims=True) + epsilon)
    return tf.cast(normalized, original_dtype)


def lerp(start: tf.Tensor, end: tf.Tensor, factor: tf.Tensor) -> tf.Tensor:
    return start + (end - start) * factor


def reduce_std_nan_safe(x, axis=None, keepdims=False, epsilon=1e-7):
    y = tf.cast(x, tf.float32)
    mean = tf.reduce_mean(y, axis=axis, keepdims=True)
    variance = tf.reduce_mean(tf.square(y - mean), axis=axis, keepdims=keepdims)
    sqrt = tf.sqrt(variance + epsilon)
    return tf.cast(sqrt, x.dtype)


def minibatch_stddev(x: tf.Tensor, group_size=4) -> tf.Tensor:
    original_shape = tf.shape(x)
    original_dtype = x.dtype
    global_sample_count = original_shape[0]
    group_size = tf.minimum(group_size, global_sample_count)
    group_count = global_sample_count // group_size
    tf.Assert(
        group_size * group_count == global_sample_count,
        ['Sample count was not divisible by group size'])
    # Shape definitions:
    # N = global sample count
    # G = group count
    # M = sample count within group
    # H = height
    # W = width
    # C = channel count
                                                                # [NHWC] Input shape
    y = tf.reshape(
        x,
        tf.concat([[-1, group_size], original_shape[1:]], 0))   # [GMHWC] Split into groups
    y = tf.cast(y, tf.float32)
    stddevs = reduce_std_nan_safe(y, axis=1, keepdims=True)     # [G1HWC]
    avg = tf.reduce_mean(
        stddevs,
        axis=tf.range(1, tf.rank(stddevs)),
        keepdims=True)                                          # [G1111]
    new_feature_shape = tf.concat([tf.shape(y)[:-1], [1]], 0)   # [GMHW1]
    new_feature = tf.broadcast_to(avg, new_feature_shape)
    y = tf.concat([y, new_feature], axis=-1)
    y = tf.reshape(
		y,
		tf.concat([[-1], tf.shape(y)[2:]], 0))
    y = tf.cast(y, original_dtype)
    return y


In [None]:
class PixelNorm(tf.keras.layers.Layer):
    def call(self, x: tf.Tensor) -> tf.Tensor:
        return pixel_norm(x)


class Upsample(tf.keras.layers.Layer):
    def call(self, x: tf.Tensor) -> tf.Tensor:
        return upsample(x)


class Downsample(tf.keras.layers.Layer):
    def call(self, x: tf.Tensor) -> tf.Tensor:
        return downsample(x)


class MinibatchStddev(tf.keras.layers.Layer):
    def call(self, x: tf.Tensor) -> tf.Tensor:
        return minibatch_stddev(x)


class ScaledLeakyRelu(tf.keras.layers.Layer):
    def __init__(self, alpha: float = 0.2, gain: float = math.sqrt(2.), **kwargs):
        super(ScaledLeakyRelu, self).__init__(**kwargs)
        self.alpha = alpha
        self.gain = gain

    def call(self, x: tf.Tensor) -> tf.Tensor:
        return tf.nn.leaky_relu(x, self.alpha) * self.gain

    def get_config(self):
        config = super(ScaledLeakyRelu, self).get_config()
        config.update({
            'alpha': self.alpha,
            'gain': self.gain
        })
        return config


class ImageConversionMode(Enum):
    TENSORFLOW_TO_MODEL = 0  # [0, 1] to [-1, 1]
    MODEL_TO_TENSORFLOW = 1  # [-1, 1] to [0, 1]


class ImageConversion(tf.keras.layers.Layer):
    def __init__(self, conversion_mode: ImageConversionMode, **kwargs):
        super(ImageConversion, self).__init__(**kwargs)
        self.conversion_mode = ImageConversionMode(conversion_mode)

    def call(self, image: tf.Tensor) -> tf.Tensor:
        if self.conversion_mode == ImageConversionMode.TENSORFLOW_TO_MODEL:
            return image * 2. - 1.  # [0, 1] to [-1, 1]
        elif self.conversion_mode == ImageConversionMode.MODEL_TO_TENSORFLOW:
            return image * 0.5 + 0.5  # [-1, 1] to [0, 1]
        else:
            assert False, f'Unknown conversion mode: {self.conversion_mode}'

    def get_config(self):
        config = super(ImageConversion, self).get_config()
        config.update({
            'conversion_mode': self.conversion_mode.value
        })
        return config


class ScaledAdd(tf.keras.layers.Layer):
    def __init__(self, scale: float = 1. / math.sqrt(2.), **kwargs):
        super(ScaledAdd, self).__init__(**kwargs)
        self.scale_value = scale

    def build(self, input_shapes):
        assert len(input_shapes) == 2
        a_shape, b_shape = input_shapes
        assert a_shape[1:] == b_shape[1:], f'{a_shape} != {b_shape}'
        self.scale = self.add_weight(
            name='scale',
            initializer=tf.keras.initializers.Constant(value=self.scale_value),
            trainable=False)

    def call(self, inputs: List[tf.Tensor]) -> tf.Tensor:
        return (inputs[0] + inputs[1]) * self.scale

    def get_config(self):
        config = super(ScaledAdd, self).get_config()
        config.update({
            'scale': self.scale_value
        })
        return config


class ScaledConv2d(tf.keras.layers.Layer):
    def __init__(
            self,
            channel_count: int,
            kernel_size: int,
            strides: int = 1,
            padding: str = 'valid',
            pre_blur: bool = False,
            **kwargs):
        super(ScaledConv2d, self).__init__(**kwargs)
        self.rank = 2
        self.channel_count = channel_count
        self.kernel_size = conv_utils.normalize_tuple(kernel_size, self.rank, 'kernel_size')
        self.strides = conv_utils.normalize_tuple(strides, self.rank, 'strides')
        self.padding = conv_utils.normalize_padding(padding)
        self.pre_blur = pre_blur

    def build(self, input_shape: List[int]) -> None:
        assert len(input_shape) == self.rank + 2
        in_channel_count = input_shape[-1]
        kernel_shape = self.kernel_size + (in_channel_count, self.channel_count)
        self.kernel = self.add_weight(
            name='kernel',
            shape=kernel_shape,
            initializer=tf.keras.initializers.random_normal(mean=0., stddev=1.),
            trainable=True)
        self.bias = self.add_weight(
            name='bias',
            shape=(self.channel_count,),
            initializer=tf.keras.initializers.zeros(),
            trainable=True)
        self.scale = self.add_weight(
            name='scale',
            shape=(),
            initializer=tf.keras.initializers.constant(
                1. / tf.sqrt(tf.reduce_prod(tf.cast(kernel_shape[:-1], tf.float32)))),
            trainable=False)

    def call(self, x: tf.Tensor) -> tf.Tensor:
        y = x
        if self.pre_blur:
            y = blur(y)
        y = tf.nn.conv2d(y, self.kernel * self.scale, strides=self.strides, padding=self.padding.upper())
        y = tf.nn.bias_add(y, self.bias)
        return y

    def get_config(self):
        config = super(ScaledConv2d, self).get_config()
        config.update({
            'channel_count': self.channel_count,
            'kernel_size': self.kernel_size,
            'strides': self.strides,
            'padding': self.padding,
            'pre_blur': self.pre_blur
        })
        return config


'''Replaces an upscale2d + conv2d sequence

Based on progan/stylegan code, this is "Faster and uses less memory than performing the operations
separately."

This is designed to upsample by 2x using a 3x3 convolution and does not accept any parameters other
than the channel count
'''
class UpsampleConv2d(tf.keras.layers.Layer):
    def __init__(
            self,
            channel_count: int,
            **kwargs):
        super(UpsampleConv2d, self).__init__(**kwargs)
        self.rank = 2
        self.channel_count = channel_count
        self.kernel_size = conv_utils.normalize_tuple(3, self.rank, 'kernel_size')
        self.strides = conv_utils.normalize_tuple(2, self.rank, 'strides')
        self.padding = conv_utils.normalize_padding('same')

    def build(self, input_shape: List[int]) -> None:
        assert len(input_shape) == self.rank + 2
        in_channel_count = input_shape[-1]
        kernel_shape = self.kernel_size + (self.channel_count, in_channel_count)
        self.kernel = self.add_weight(
            name='kernel',
            shape=kernel_shape,
            initializer=tf.keras.initializers.random_normal(mean=0., stddev=1.),
            trainable=True)
        self.bias = self.add_weight(
            name='bias',
            shape=(self.channel_count,),
            initializer=tf.keras.initializers.zeros(),
            trainable=True)
        self.scale = self.add_weight(
            name='scale',
            shape=(),
            initializer=tf.keras.initializers.constant(
                1. / tf.sqrt(tf.cast(tf.reduce_prod(kernel_shape) // self.channel_count, tf.float32))),
            trainable=False)

    def compute_output_shape(self, input_shape):
        input_shape = tf.TensorShape(input_shape).as_list()
        return tf.TensorShape([
            input_shape[0],
            input_shape[1] * 2,
            input_shape[2] * 2,
            self.channel_count])

    def call(self, x: tf.Tensor) -> tf.Tensor:
        input_shape = tf.shape(x)
        batch_size, in_height, in_width = input_shape[0], input_shape[1], input_shape[2]
        output_shape = (batch_size, in_height * 2, in_width * 2, self.channel_count)

        y = tf.nn.conv2d_transpose(
            x,
            self.kernel * self.scale,
            output_shape,
            self.strides,
            padding=self.padding.upper())

        if not tf.executing_eagerly():
            y.set_shape(self.compute_output_shape(x.shape))

        y = tf.nn.bias_add(y, self.bias)
        y = blur(y)
        return y

    def get_config(self):
        config = super(UpsampleConv2d, self).get_config()
        config.update({
            'channel_count': self.channel_count
        })
        return config


class ScaledDense(tf.keras.layers.Layer):
    def __init__(
            self,
            output_count: int,
            **kwargs):
        super(ScaledDense, self).__init__(**kwargs)
        self.output_count = output_count

    def build(self, input_shape):
        assert len(input_shape) == 2
        self.kernel = self.add_weight(
            name='kernel',
            shape=[input_shape[-1], self.output_count],
            initializer=tf.keras.initializers.random_normal(mean=0., stddev=1.),
            trainable=True)
        self.bias = self.add_weight(
            name='bias',
            shape=(self.output_count,),
            initializer=tf.keras.initializers.zeros(),
            trainable=True)
        self.scale = self.add_weight(
            name='scale',
            shape=(),
            initializer=tf.keras.initializers.constant(1. / math.sqrt(input_shape[-1])),
            trainable=False)

    def call(self, x: tf.Tensor) -> tf.Tensor:
        y = tf.matmul(x, self.kernel * self.scale)
        return tf.nn.bias_add(y, self.bias)

    def get_config(self):
        config = super(ScaledDense, self).get_config()
        config.update({
            'output_count': self.output_count
        })
        return config


### Model Generation

In [None]:
def validate_resolution(resolution: int) -> None:
    assert resolution in [4, 8, 16, 32, 64, 128, 256, 512, 1024]


def activate(
        x,
        activation = lambda: ScaledLeakyRelu()):
    if activation is not None:
        x = activation()(x)
    return x


def conv_2d(
        x,
        channel_count: int,
        kernel_size: int = 3,
        activation = lambda: ScaledLeakyRelu(),
        padding: str = 'same',
        strides: int = 1,
        pre_blur: bool = False,
        name: Optional[str] = None):
    x = ScaledConv2d(
        channel_count,
        kernel_size,
        padding=padding,
        strides=strides,
        pre_blur=pre_blur,
        name=name)(x)
    return activate(x, activation=activation)


def create_generator_body(resolution, latent_vector):
    def to_rgb(x):
        rgb = conv_2d(x, 3, 1, activation=None, name=f'to_rgb_{resolution}x{resolution}')
        return rgb

    resolution_to_channel_counts = {
        4: 512, 8: 512, 16: 512, 32: 512, 64: 512, 128: 256, 256: 128, 512: 64, 1024: 32}
    channel_count = resolution_to_channel_counts[resolution]
    if resolution == 4:
        block = ScaledDense(
            channel_count*4*4,
            name='latent_to_4x4')(latent_vector)
        block = tf.keras.layers.Reshape((4, 4, channel_count))(block)
        block = activate(block)
        block = conv_2d(block, channel_count, name='conv_4x4')
        rgb = to_rgb(block)
        return rgb, block
    else:
        lower_res_rgb, lower_res_block = create_generator_body(resolution // 2, latent_vector)
        name_base = f'conv_{resolution}x{resolution}'

        block = lower_res_block
        block = activate(UpsampleConv2d(channel_count, name=f'{name_base}_1')(block))
        block = conv_2d(block, channel_count, name=f'{name_base}_2')
        rgb = to_rgb(block)

        lower_res_rgb = Upsample()(lower_res_rgb)
        rgb = tf.keras.layers.Add()([lower_res_rgb, rgb])
        return rgb, block


def create_generator(output_resolution: int, latent_vector_size: int ) -> tf.keras.Model:
    validate_resolution(output_resolution)

    latent_vector = tf.keras.layers.Input((latent_vector_size,))
    normalized_latent_vector = PixelNorm()(latent_vector)

    model_rgb, _ = create_generator_body(output_resolution, normalized_latent_vector)
    tensorflow_rgb = ImageConversion(ImageConversionMode.MODEL_TO_TENSORFLOW)(model_rgb)

    # Cast the output to float32. Needed when using mixed_float16.
    if tensorflow_rgb.dtype != tf.float32:
        tensorflow_rgb = tf.keras.layers.Activation(None, dtype=tf.float32)(tensorflow_rgb)

    return tf.keras.Model(inputs=latent_vector, outputs=tensorflow_rgb)


def make_discriminator_body(tf_format_image, resolution):
    resolution_to_feature_counts: Dict[int, Tuple[int, int]] = {
        1024: (32, 64),
        512: (64, 128),
        256: (128, 256),
        128: (256, 512),
        64: (512, 512),
        32: (512, 512),
        16: (512, 512),
        8: (512, 512),
        4: (512, 512)}
    feature_counts = resolution_to_feature_counts[resolution]

    incoming_block = None
    if tf_format_image.shape[1] > resolution:
        incoming_block = make_discriminator_body(tf_format_image, resolution*2)
    else:
        model_format_image = ImageConversion(ImageConversionMode.TENSORFLOW_TO_MODEL)(tf_format_image)
        incoming_block = conv_2d(
            model_format_image,
            feature_counts[0],
            kernel_size=1,
            name=f'from_rgb_{resolution}x{resolution}')

    if resolution == 4:
        block = incoming_block
        block = MinibatchStddev()(block)
        block = conv_2d(block, feature_counts[0], name=f'conv_4x4_1')
        block = conv_2d(block, feature_counts[1], kernel_size=4, padding='valid', name=f'conv_4x4_2')
        return block
    else:
        name_base = f'conv_{resolution}x{resolution}'
        block = incoming_block
        block = conv_2d(block, feature_counts[0], name=f'{name_base}_1')
        block = conv_2d(block, feature_counts[1], strides=2, pre_blur=True, name=f'{name_base}_2')

        shortcut = Downsample()(incoming_block)
        if shortcut.shape[-1] != block.shape[-1]:
            shortcut = conv_2d(
                shortcut,
                block.shape[-1],
                kernel_size=1,
                activation=None,
                name=f'shortcut_{resolution}x{resolution}')
        block = ScaledAdd()([block, shortcut])
        return block


def create_discriminator(input_resolution: int) -> tf.keras.Model:
    validate_resolution(input_resolution)

    image = tf.keras.layers.Input((input_resolution, input_resolution, 3))

    x = make_discriminator_body(image, 4)
    x = tf.keras.layers.Flatten()(x)
    classification = ScaledDense(
            1,
            name='to_classification')(x)
    if classification.dtype != tf.float32:
        classification = tf.keras.layers.Activation(None, dtype=tf.float32)(classification)

    return tf.keras.Model(inputs=image, outputs=classification)


### Perception differences

In [None]:
def normalize_channels(x):
    return x / tf.sqrt(tf.reduce_sum(tf.square(x), axis=-1, keepdims=True) + 1e-7)


def calculate_feature_pair_difference(a, b):
    a = normalize_channels(a)
    b = normalize_channels(b)

    diffs = tf.linalg.norm(a - b, axis=-1)
    return tf.reduce_mean(diffs, axis=[1, 2])

def calculate_perceptual_difference(feature_pairs):
    difference = None
    for a, b in feature_pairs:
        pair_diff = calculate_feature_pair_difference(a, b)
        difference = pair_diff if difference is None else difference + pair_diff
    difference /= len(feature_pairs)
    return difference


def create_feature_extractor():
    def create_vgg_feature_extractor():
        layer_names = [
            # 'block1_conv2',
            # 'block2_conv2',
            # 'block3_conv3',
            'block4_conv3',
            'block5_conv3',
            ]

        vgg = tf.keras.applications.VGG16(include_top=False, weights='imagenet')
        vgg.trainable = False

        outputs = [vgg.get_layer(name).output for name in layer_names]

        return tf.keras.Model([vgg.input], outputs)

    vgg_features = create_vgg_feature_extractor()
    image = tf.keras.layers.Input((224, 224, 3))

    def preprocess(image):
        return tf.keras.applications.vgg16.preprocess_input(image * 255.)

    preprocessed = tf.keras.layers.Lambda(preprocess)(image)
    features = vgg_features(preprocessed)
    return tf.keras.Model(image, features)


def create_perceptual_difference_model():
    feature_extractor = create_feature_extractor()

    input_shape = feature_extractor.input_shape[1:]

    image_a = tf.keras.layers.Input(input_shape)
    image_b = tf.keras.layers.Input(input_shape)

    features_a = feature_extractor(image_a)
    features_b = feature_extractor(image_b)

    if not isinstance(features_a, list):
        features_a = [features_a]
        features_b = [features_b]

    diff = tf.keras.layers.Lambda(calculate_perceptual_difference)(list(zip(features_a, features_b)))

    return tf.keras.Model([image_a, image_b], diff)


In [None]:
class Checkpointer:
    def __init__(self, file_format: str):
        '''
            Args:
                file_format: format string with variable {checkpoint_i}.
                    Ex. '/my/path/{checkpoint_i}.checkpoint'
        '''
        super().__init__()
        self.file_format = file_format

    def path_for_checkpoint(self, checkpoint_i: Union[int, str]) -> str:
        return self.file_format.format(checkpoint_i=checkpoint_i)

    def list_checkpoints(self) -> List[int]:
        file_names = tf.io.gfile.glob(self.path_for_checkpoint('*'))

        def file_name_to_checkpoint_index(file_name: str) -> Optional[int]:
            try:
                return int(os.path.splitext(os.path.basename(file_name))[0])
            except ValueError:
                return None

        checkpoint_is = list(
            filter(
                lambda i: i is not None,
                map(file_name_to_checkpoint_index, file_names)))
        checkpoint_is.sort()
        return checkpoint_is

    def save_checkpoint(self, checkpoint_i: int, content: bytes) -> None:
        bk_io.write_binary_file(self.path_for_checkpoint(checkpoint_i), content)

    def load_checkpoint(self, checkpoint_i: int) -> bytes:
        return bk_io.read_binary_file(self.path_for_checkpoint(checkpoint_i))

In [None]:



def make_real_image_dataset(
        batch_size: int,
        file_pattern: str = 'gs://bk-ffhq/512x512/*.tfrecord',
        randomly_flip: bool = True
        ) -> tf.data.Dataset:
    file_names = tf.io.gfile.glob(file_pattern)

    def apply_flip(x):
        if randomly_flip and tf.random.uniform((), minval=0., maxval=1., dtype=tf.float32) > 0.5:
            x = x[:, ::-1, :]
        return x

    return tf.data.TFRecordDataset(file_names
        ).map(decode_record_image
        ).map(lambda image: tf.image.convert_image_dtype(image, tf.float32, saturate=True)
        ).shuffle(1000
        ).repeat(
        ).map(apply_flip
        ).batch(batch_size
        ).prefetch(tf.data.AUTOTUNE)


def create_visualizer_noise(
        visualization_grid_size: Tuple[int, int],
        latent_size: int) -> tf.Tensor:
    if latent_size == 3:
        lat_domain = tf.linspace(-math.pi / 2., math.pi / 2., visualization_grid_size[0])
        lon_domain = tf.linspace(-math.pi, math.pi, visualization_grid_size[1])
        lons, lats = tf.meshgrid(lon_domain, lat_domain)
        lons, lats = tf.reshape(lons, [-1]), tf.reshape(lats, [-1])
        xs = tf.cos(lats) * tf.cos(lons)
        ys = tf.cos(lats) * tf.sin(lons)
        zs = tf.sin(lats)
        return tf.stack([xs, ys, zs], axis=-1)

    return tf.random.normal((visualization_grid_size[0] * visualization_grid_size[1], latent_size))


class TrainingOptions:
    def __init__(
            self,
            visualization_grid_size: Tuple[int, int],
            resolution: int,
            replica_batch_size: int,
            epoch_sample_count: int = 1024 * 16,
            total_sample_count: int = 1024 * 800,
            learning_rate: float = 0.002,
            real_images_file_pattern: str = 'gs://bk-ffhq/1024x1024/*.tfrecord',
            latent_size = 512,
            randomly_flip_data: bool = True,
            checkpoint_interval: int = 10,
            visualizer_noise: Optional[tf.Tensor] = None,
            visualization_smoothing_sample_count: int = 10000
            ):
        assert epoch_sample_count % replica_batch_size == 0
        assert total_sample_count % epoch_sample_count == 0

        self.visualization_grid_size = visualization_grid_size
        self.resolution = resolution
        self.replica_batch_size = replica_batch_size
        self.epoch_sample_count = epoch_sample_count
        self.total_sample_count = total_sample_count
        self.learning_rate = learning_rate
        self.real_images_file_pattern = real_images_file_pattern
        self.latent_size = latent_size
        self.randomly_flip_data = randomly_flip_data
        self.checkpoint_interval = checkpoint_interval
        self.visualizer_noise = default_to(
            visualizer_noise,
            create_visualizer_noise(visualization_grid_size, latent_size))
        expected_noise_shape = (visualization_grid_size[0] * visualization_grid_size[1], latent_size)
        assert self.visualizer_noise.shape == expected_noise_shape
        self.visualization_smoothing_sample_count = visualization_smoothing_sample_count

    @property
    def epoch_count(self):
        return self.total_sample_count // self.epoch_sample_count


class TrainingState:
    def __init__(
            self,
            options: TrainingOptions,
            generator: Optional[tf.keras.Model] = None,
            visualization_generator: Optional[tf.keras.Model] = None,
            discriminator: Optional[tf.keras.Model] = None,
            epoch_i: int = 0):
        self.options = options
        self.generator = generator
        self.visualization_generator = visualization_generator
        self.discriminator = discriminator
        self.epoch_i = epoch_i

    def training_is_done(self) -> bool:
        return self.epoch_i * self.options.epoch_sample_count >= self.options.total_sample_count

    def __getstate__(self):
        state = self.__dict__.copy()
        state['generator'] = serialize_model(self.generator)
        state['visualization_generator'] = serialize_model(self.visualization_generator)
        state['discriminator'] = serialize_model(self.discriminator)
        return state

    def __setstate__(self, state):
        self.__dict__ = state.copy()
        self.generator = deserialize_model(
            self.generator,
            functools.partial(create_generator, self.options.resolution, self.options.latent_size))
        self.visualization_generator = deserialize_model(
            self.visualization_generator,
            functools.partial(create_generator, self.options.resolution, self.options.latent_size))
        self.discriminator = deserialize_model(
            self.discriminator,
            functools.partial(create_discriminator, self.options.resolution))


class CheckpointStateCallback(tf.keras.callbacks.Callback):
    def __init__(
            self,
            state: TrainingState,
            checkpointer: Checkpointer):
        self.state = state
        self.checkpointer = checkpointer
        super().__init__()

    def on_epoch_end(self, epoch_i: int, logs=None) -> None:
        self.state.epoch_i = epoch_i + 1
        if self.state.epoch_i % self.state.options.checkpoint_interval == 0:
            self.checkpointer.save_checkpoint(self.state.epoch_i, serialize(self.state))


def train(
        strategy: tf.distribute.Strategy,
        checkpointer: Checkpointer,
        state: TrainingState,
        on_visualization_callbacks: List[Callable[[int, tf.Tensor], None]] = list()
        ) -> None:
    options = state.options
    def create_visualizer() -> GeneratorVisualizer:
        replica_batch_size = min(options.replica_batch_size, 8)
        visualization_image_count = options.visualization_grid_size[0] * options.visualization_grid_size[1]
        assert visualization_image_count % strategy.num_replicas_in_sync == 0
        max_replica_batch_size = visualization_image_count // strategy.num_replicas_in_sync
        if (replica_batch_size > max_replica_batch_size or
            visualization_image_count % (replica_batch_size * strategy.num_replicas_in_sync) != 0):
            replica_batch_size = max_replica_batch_size

        return GeneratorVisualizer(
            strategy,
            options.visualization_grid_size,
            state.visualization_generator,
            options.visualizer_noise,
            update_interval=1,
            on_image_callbacks=on_visualization_callbacks,
            replica_batch_size=replica_batch_size)

    checkpoint_callback = CheckpointStateCallback(state, checkpointer)

    if state.generator is None:
        global_batch_size = options.replica_batch_size * strategy.num_replicas_in_sync
        with strategy.scope():
            state.generator = create_generator(options.resolution, options.latent_size)
            state.visualization_generator = create_generator(
                options.resolution,
                options.latent_size)
            state.discriminator = create_discriminator(options.resolution)

        @tf.function
        def zero_vars(var):
            for v in var:
                v.assign(tf.zeros_like(v))
        strategy.run(zero_vars, args=(state.visualization_generator.trainable_variables,))

    global_batch_size = options.replica_batch_size * strategy.num_replicas_in_sync

    visualization_weight_decay = (
        0.5 ** (global_batch_size / options.visualization_smoothing_sample_count)
        if options.visualization_smoothing_sample_count > 0 else
        0.0)
    update_visualization_generator_callback = UpdateVisualizationGeneratorCallback(
        strategy,
        state.generator,
        state.visualization_generator,
        visualization_weight_decay)

    visualizer = create_visualizer()


    image_dataset = strategy.experimental_distribute_dataset(
        make_real_image_dataset(
            global_batch_size,
            file_pattern=options.real_images_file_pattern,
            randomly_flip=options.randomly_flip_data))

    state.epoch_i = training_loop(
        strategy,
        state.generator,
        state.discriminator,
        image_dataset,
        state.epoch_i,
        options.epoch_count,
        options.replica_batch_size,
        options.epoch_sample_count,
        learning_rate=options.learning_rate,
        callbacks=[update_visualization_generator_callback, visualizer, checkpoint_callback])
    checkpointer.save_checkpoint(state.epoch_i, serialize(state))


In [None]:
class DummyStrategyScope:
    def __enter__(self):
        pass
    def __exit__(self, *args):
        pass


class DummyStrategy:
    def __init__(self):
        self.num_replicas_in_sync = 1
    def scope(self):
        return DummyStrategyScope()
    def experimental_distribute_dataset(self, dataset, *args):
        return dataset
    def run(self, f, args=()):
        return f(*args)
    def reduce(self, mode, value, axis=None):
        return value
    def gather(self, t, axis):
        return t

In [None]:
### Generate Images

In [None]:
def load_generator(checkpoint_folder_path: os.PathLike, checkpoint_i: Optional[int] = None):
    checkpointer = Checkpointer(
        os.path.join(checkpoint_folder_path, '{checkpoint_i}.checkpoint'))
    checkpoint_i = default_to(checkpoint_i, max(checkpointer.list_checkpoints()))
    training_state: TrainingState = deserialize(checkpointer.load_checkpoint(checkpoint_i))
    return training_state.visualization_generator


def calc_magnitude(v):
    return tf.sqrt(tf.reduce_sum(tf.square(v), axis=-1, keepdims=True))


# Normalize batch of vectors.
def normalize(v, magnitude=1.0):
    return v * magnitude / tf.sqrt(tf.reduce_sum(tf.square(v), axis=-1, keepdims=True))


# Spherical interpolation of a batch of vectors.
def slerp(a, b, t):
    norm_mag = lerp(calc_magnitude(a), calc_magnitude(b), t)
    a = normalize(a)
    b = normalize(b)
    d = tf.reduce_sum(a * b, axis=-1, keepdims=True)
    p = t * tf.math.acos(d)
    c = normalize(b - d * a)
    d = a * tf.math.cos(p) + c * tf.math.sin(p)
    return normalize(d, norm_mag)


@tf.function
def sample_path_lengths_from(
        generator: tf.keras.Model,
        perceptual_differencer: tf.keras.Model,
        origin_noise: tf.Tensor,
        origin_images: tf.Tensor) -> tf.Tensor:
    def resize_and_crop(images: tf.Tensor) -> tf.Tensor:
        differ_image_size = perceptual_differencer.input_shape[0][1:3]
        # Resize the images to be slightly larger than the differ's input size then crop the center
        # to reduce the amount of background in the image and focus on the face.
        images = tf.image.resize(
            images,
            (256 * differ_image_size[0] // 224, 256 * differ_image_size[1] // 224))
        images = tf.image.resize_with_crop_or_pad(images, differ_image_size[0], differ_image_size[1])
        return images

    target_noises = slerp(origin_noise, tf.random.normal(origin_noise.shape), 0.1)
    target_images = generator(target_noises)
    origin_images, target_images = map(resize_and_crop, (origin_images, target_images))
    return perceptual_differencer([origin_images, target_images])


def calculate_perceptual_path_length(
        generator: tf.keras.Model,
        perceptual_differencer: tf.keras.Model,
        noise: tf.Tensor,
        images: tf.Tensor,
        sample_count: int = 10):
    diffs = [
        sample_path_lengths_from(generator, perceptual_differencer, noise, images)
        for _ in range(sample_count)]
    avg_diff = tf.reduce_max(tf.square(tf.stack(diffs, axis=-1)), axis=-1)
    return avg_diff


def generate_random_images(
        generator: tf.keras.Model,
        args: argparse.Namespace):
    perceptual_differencer = create_perceptual_difference_model()
    batch_size = 4
    noise_shape = (batch_size, generator.input_shape[-1])

    prog_bar = tf.keras.utils.Progbar(args.sample_count)
    generated_sample_count = 0
    while generated_sample_count < args.sample_count:
        noise = pixel_norm(tf.random.normal(noise_shape))
        images = generator(noise)
        ppls = (
            calculate_perceptual_path_length(generator, perceptual_differencer, noise, images)
            if args.threshold > 0 else
            tf.zeros((batch_size,)))

        for (image, score) in zip(images, ppls):
            if args.threshold > 0 and score > args.threshold:
                continue
            file_name = os.path.join(args.out, f'{create_datetime_str()}_{score}.png')
            save_image(file_name, image)
            generated_sample_count += 1
            prog_bar.add(1)
            if generated_sample_count == args.sample_count:
                break


def generate_interpolation(
        generator: tf.keras.Model,
        start_noise: tf.Tensor,
        end_noise: tf.Tensor) -> List[tf.Tensor]:
    def generate_image(factor):
        noise = slerp(start_noise, end_noise, factor)
        return generator(noise)[0]
    return list(map(generate_image, tf.linspace(0.0, 1.0, 100)))


def generate_random_interpolations(
        generator: tf.keras.Model,
        args: argparse.Namespace):
    def generate_random_interpolation():
        noise_shape = (1, generator.input_shape[-1])
        start_noise = pixel_norm(tf.random.normal(noise_shape))
        end_noise = pixel_norm(tf.random.normal(noise_shape))
        return generate_interpolation(generator, start_noise, end_noise)

    prog_bar = tf.keras.utils.Progbar(args.sample_count)
    for _ in range(args.sample_count):
        interpolation_folder = os.path.join(args.out, create_datetime_str())
        images = generate_random_interpolation()
        for image_i, image in enumerate(images):
            file_path = interpolation_folder / f'{image_i}.png'
            save_image(file_path, image)
        prog_bar.add(1)


def generate_circular_interpolation(
        generator: tf.keras.Model,
        noise_origin: tf.Tensor,
        noise_pass_through_point: tf.Tensor,
        step_count: int = 500) -> List[tf.Tensor]:
    assert noise_origin.shape[0] == 1
    assert noise_pass_through_point.shape[0] == 1

    dot_product = tf.reduce_sum(noise_origin * noise_pass_through_point, axis=-1)[0]
    angular_distance = tf.math.acos(dot_product)
    assert tf.abs(angular_distance) > 1e-7

    full_revolution_factor = 2. * math.pi / angular_distance

    images = []
    for factor in tf.linspace(0.0, full_revolution_factor, step_count):
        noise = slerp(noise_origin, noise_pass_through_point, factor)
        images.append(generator(noise)[0])
    return images


def generate_random_circular_interpolations(
        generator: tf.keras.Model,
        args: argparse.Namespace):
    def generate_random_circular_interpolation():
        noise_shape = (1, generator.input_shape[-1])
        noise_origin = normalize(tf.random.normal(noise_shape))
        noise_pass_through_point = normalize(tf.random.normal(noise_shape))
        return generate_circular_interpolation(generator, noise_origin, noise_pass_through_point)

    prog_bar = tf.keras.utils.Progbar(args.sample_count)
    for _ in range(args.sample_count):
        interpolation_folder = os.path.join(args.out, create_datetime_str())
        images = generate_random_circular_interpolation()
        for image_i, image in enumerate(images):
            file_path = interpolation_folder / f'{image_i}.png'
            save_image(file_path, image)
        prog_bar.add(1)


# Callable function
def generate_images():
    limit_gpu_memory_usage(4*1024)

    parser = argparse.ArgumentParser(
        description='Sample images from a trained generator',
        fromfile_prefix_chars='@')

    subparsers = parser.add_subparsers()

    images_parser = subparsers.add_parser('images', help='Sample random images from the generator')
    images_parser.add_argument(
        '--threshold',
        type=float,
        default=0.3,
        help='Threshold for max perceptual path length filter. 0 for unlimited.')
    images_parser.set_defaults(action=generate_random_images)

    interpolation_parser = subparsers.add_parser(
        'interpolations',
        help='Sample interpolations from randomly chosen start and end points in the the latent ' +
            'space')
    interpolation_parser.set_defaults(action=generate_random_interpolations)

    circular_parser = subparsers.add_parser(
        'circular-interpolations',
        help='Sample interpolations from a randomly chosen start point in the latent space that ' +
            'travel in a random direction in a circle around the latent space back to the start ' +
            'point.')
    circular_parser.set_defaults(action=generate_random_circular_interpolations)

    for subparser in [images_parser, interpolation_parser, circular_parser]:
        subparser.add_argument('--sample_count', help='number of samples to generator', type=int, default=1)
        subparser.add_argument('--checkpoint', help='checkpoint to load. Defaults highest checkpoint number.', type=int)
        subparser.add_argument('checkpoint_folder', help='folder containing checkpoints')
        subparser.add_argument('out', help='root output folder path')

    args = parser.parse_args()

    generator = load_generator(args.checkpoint_folder, args.checkpoint)
    args.action(generator, args)

In [None]:
generate_images()

### Create TF Records

In [None]:
DEFAULT_RECORD_COUNT = 8

def slice_dataset(dataset: tf.data.Dataset, start_index: int, increment: int):
    def is_in_increment_group(index, value):
        return index % increment == start_index

    def unenumerate(index, value):
        return value

    return dataset.enumerate(
        ).filter(is_in_increment_group
        ).map(unenumerate)


def resize_image(image: tf.Tensor, size: Tuple[int, int]) -> tf.Tensor:
    original_dtype = image.dtype
    image = tf.image.convert_image_dtype(image, tf.float32, saturate=True)
    # Computer Color is Broken - minutephysics
    # https://youtu.be/LKnqECcg6Gw
    image = tf.image.adjust_gamma(image, 2.2) # Gamma decode the image assuming a gamma of 2.2
    image = tf.image.resize(image, size)
    image = tf.image.adjust_gamma(image, 1./2.2) # Re-gamma-encode the image
    image = tf.image.convert_image_dtype(image, original_dtype, saturate=True)
    return image


def create_record(dataset: tf.data.Dataset, output_file_name: os.PathLike):
    ensure_dir_for_file(output_file_name)
    with tf.io.TFRecordWriter(str(output_file_name)) as file_writer:
        for image in dataset:
            file_writer.write(encode_record_image(image))


def create_records_from_dataset(
        dataset: tf.data.Dataset,
        prepare,
        output_resolution: Optional[int],
        output_dir: os.PathLike,
        record_count: int):
    prog_bar = tf.keras.utils.Progbar(record_count)
    for record_i in range(record_count):
        record_dataset = slice_dataset(dataset, record_i, record_count)
        if prepare is not None:
            record_dataset = record_dataset.map(prepare)

        if output_resolution is not None:
            record_dataset = record_dataset.map(
                lambda image: resize_image(image, (output_resolution, output_resolution)))

        record_dataset = record_dataset.prefetch(tf.data.AUTOTUNE)

        record_file_path = os.path.join(output_dir, f'{record_i}.tfrecord')
        create_record(record_dataset, record_file_path)
        prog_bar.add(1)


def dataset_from_image_files(input_file_pattern: str):
    def prepare(file_name: os.PathLike):
        return tf.io.decode_image(tf.io.read_file(file_name))
    input_file_names = tf.io.gfile.glob(input_file_pattern)
    dataset = tf.data.Dataset.from_tensor_slices(input_file_names)
    return dataset, prepare


def dataset_from_tfrecords(record_file_pattern: str):
    record_file_names = tf.io.gfile.glob(record_file_pattern)
    dataset = tf.data.TFRecordDataset(record_file_names)
    return dataset, decode_record_image


def extract_record(record_name, output_dir, group_i: int = 0, group_count: int = 1):
    dataset = tf.data.TFRecordDataset([record_name]
        ).map(decode_record_image
        ).prefetch(tf.data.AUTOTUNE)
    for i, image in enumerate(dataset):
        file_name = os.path.join(output_dir, f'{i*group_count+group_i}.png')
        save_image(file_name, image)


def load_mnist_data() -> tf.data.Dataset:
    def prepare(sample):
        image = sample['image']
        image = tf.broadcast_to(image, image.shape[:2] + (3,))
        return image
    return tfds.load('mnist', split='train').map(prepare)


def create_records_from_args(args: argparse.Namespace):
    def gather_args():
        if args.source.lower() == 'mnist':
            print('Creating dataset from MNIST')
            return (
                load_mnist_data(),
                None,
                default_to(args.resolution, 32),
                default_to(args.record_count, DEFAULT_RECORD_COUNT))

        if args.source.lower().endswith('tfrecord'):
            print(f'Creating dataset from records: {args.source}')
            dataset, prepare = dataset_from_tfrecords(args.source)
            return (
                dataset,
                prepare,
                args.resolution,
                default_to(args.record_count, len(tf.io.gfile.glob(args.source))))

        print(f'Creating dataset from images: {args.source}')
        dataset, prepare = dataset_from_image_files(args.source)
        return (
            dataset,
            prepare,
            args.resolution,
            default_to(args.record_count, DEFAULT_RECORD_COUNT))

    source_dataset, prepare, resolution, record_count = gather_args()
    create_records_from_dataset(
        source_dataset,
        prepare,
        resolution,
        args.output_folder_name,
        record_count)


def extract_records(args: argparse.Namespace):
    print(f'Extracting records from {args.source}')
    record_paths = tf.io.gfile.glob(args.source)
    record_count = len(record_paths)

    prog_bar = tf.keras.utils.Progbar(record_count)
    for record_i, record_path in enumerate(record_paths):
        extract_record(
            record_path,
            args.output_folder_name,
            group_i=record_i,
            group_count=record_count)
        prog_bar.add(1)


def main():
    limit_gpu_memory_usage()

    parser = argparse.ArgumentParser(
        description='Create, resize, and extract tfrecords',
        fromfile_prefix_chars='@')
    subparsers = parser.add_subparsers()

    create_parser = subparsers.add_parser(
        'create',
        help='create a set of tf records')
    create_parser.set_defaults(action=create_records_from_args)

    create_parser.add_argument(
        '--resolution',
        help='resolution to resize the images to. By default the images will remain their ' +
            'pre-existing size with the exception of "mnist", which will be resized to 32 by '
            'default.',
        type=int)

    create_parser.add_argument(
        '--record-count',
        help='number of tf records to split the data into. If the source is a set of tfrecords, ' +
            'this will default to the number of records in the source. Otherwise, it will ' +
            f'default to {DEFAULT_RECORD_COUNT}.',
        type=int)

    create_parser.add_argument(
        'source',
        help='a file glob pattern for images or records (*.tfrecord) to pack into the ' +
            'destination records. You can pass the special value of "mnist" to download ' +
            'the mnist dataset and pack it into the records.')

    create_parser.add_argument(
        'output_folder_name',
        help='output directory to write the tf records')

    extract_parser = subparsers.add_parser(
        'extract',
        help='extract the images from a set of tf records',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    extract_parser.set_defaults(action=extract_records)

    extract_parser.add_argument(
        'source',
        help='a file glob pattern for the tf records to be extracted')

    extract_parser.add_argument(
        'output_folder_name',
        help='output directory to write the images')

    args = parser.parse_args()
    args.action(args)

### Training

In [None]:
def save_visualization(
        visualization_folder_path: os.PathLike,
        should_spherically_project: bool,
        epoch_i: int,
        images: tf.Tensor,
        visualization_callback: Optional[Callable[[tf.Tensor], None]] = None):
    image_file_name = f'{epoch_i+1}.png'

    unprojected_image = images_to_gridded_image(images)
    unprojected_folder_path = (
        os.path.join(visualization_folder_path, 'unprojected')
        if should_spherically_project else
        visualization_folder_path)
    save_image(os.path.join(unprojected_folder_path, image_file_name), unprojected_image)

    if should_spherically_project:
        projected_image = spherically_project_images_to_grid(images)
        save_image(os.path.join(visualization_folder_path, 'projected', image_file_name), projected_image)

    if visualization_callback:
        visualization_callback(projected_image if should_spherically_project else unprojected_image)


def create_visualization_callback(
        root_output_path: os.PathLike,
        latent_size: int,
        visualization_callback: Optional[Callable[[tf.Tensor], None]]):
    visualization_path = os.path.join(root_output_path, 'visualizations')
    should_spherically_project = latent_size == 3
    return functools.partial(
        save_visualization,
        visualization_path,
        should_spherically_project,
        visualization_callback=visualization_callback)


def create_checkpointer(output_root: os.PathLike):
    checkpointer_path = os.path.join(output_root, 'checkpoints', '{checkpoint_i}.checkpoint')
    return Checkpointer(checkpointer_path)


def init_training(
        args: argparse.Namespace,
        visualization_callback: Optional[Callable[[tf.Tensor], None]],
        strategy: Optional[tf.distribute.Strategy]):
    output_root = args.out
    if args.create_unique_id:
        output_root = os.path.join(output_root, create_datetime_str())

    strategy = default_to(strategy, DummyStrategy())

    checkpointer = create_checkpointer(output_root)
    training_options = TrainingOptions(
        tuple(args.visualization_grid_size),
        args.resolution,
        args.replica_batch_size,
        epoch_sample_count=args.epoch_sample_count,
        total_sample_count=args.total_sample_count,
        real_images_file_pattern=args.dataset_file_pattern,
        latent_size=args.latent_size,
        checkpoint_interval=args.checkpoint_interval,
        visualization_smoothing_sample_count=args.visualization_smoothing_sample_count,
        randomly_flip_data=(not args.disable_horizontal_flip_data_augmentation))
    training_state = TrainingState(training_options)

    visualization_callback = create_visualization_callback(
        output_root,
        training_state.options.latent_size,
        visualization_callback)

    train(
        strategy,
        checkpointer,
        training_state,
        on_visualization_callbacks=[visualization_callback])

def resume_training(
        args: argparse.Namespace,
        visualization_callback: Optional[Callable[[tf.Tensor], None]],
        strategy: Optional[tf.distribute.Strategy]):
    strategy = default_to(strategy, DummyStrategy())
    checkpointer = create_checkpointer(args.out)
    checkpoint_i = default_to(args.checkpoint, max(checkpointer.list_checkpoints()))

    print(f'Resuming training from checkpoint {checkpoint_i}')

    with strategy.scope():
        training_state = deserialize(checkpointer.load_checkpoint(checkpoint_i))

    visualization_callback = create_visualization_callback(
        args.out,
        training_state.options.latent_size,
        visualization_callback)

    train(
        strategy,
        checkpointer,
        training_state,
        on_visualization_callbacks=[visualization_callback])


def main(
        raw_arguments: Optional[List[str]] = None,
        visualization_callback: Optional[Callable[[tf.Tensor], None]] = None,
        strategy: Optional[tf.distribute.Strategy] = None):
    parser = argparse.ArgumentParser(
        description='Train a GAN',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        fromfile_prefix_chars='@')
    subparsers = parser.add_subparsers()

    init_parser = subparsers.add_parser(
        'init',
        help='initialize and begin training',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    init_parser.set_defaults(func=init_training)

    init_parser.add_argument(
        '--create-unique-id',
        action='store_true',
        help='if true, a unique id will be created from the date/time and appended to the output path')

    init_parser.add_argument(
        '--visualization-grid-size',
        type=int,
        nargs=2,
        help='height and width of the visualization grid. Ex: ... --visualization_grid_size 30 60 ...',
        default=[4, 8])

    init_parser.add_argument(
        '--resolution',
        type=int,
        help='resolution of generated images. Must match the resolution of the dataset',
        default=512)

    init_parser.add_argument(
        '--replica-batch-size',
        type=int,
        help='size of batch per replica',
        default=8)

    init_parser.add_argument(
        '--epoch-sample-count',
        type=int,
        help='number of samples per epoch. Must divide evenly into total-sample-count',
        default=16*1024)

    init_parser.add_argument(
        '--total-sample-count',
        type=int,
        help='number of total samples to train on. Must be divisible by epoch-sample-count',
        default=25*1024*1024)

    init_parser.add_argument(
        '--latent-size',
        type=int,
        help='size of the generator\'s latent vector',
        default=512)

    init_parser.add_argument(
        '--checkpoint-interval',
        type=int,
        help='interval of epochs to save a checkpoint',
        default=10)

    init_parser.add_argument(
        '--visualization-smoothing-sample-count',
        type=float,
        help='the factor by which to decay the visualization weights. If 0, no smoothing will be applied.',
        default=10000)

    init_parser.add_argument(
        '--disable-horizontal-flip-data-augmentation',
        action='store_true',
        help='including this option will disable horizontal flip data augmentation. Use when ' +
            'horizontally flipping an image changes its semantics, ex., mnist digits.')

    init_parser.add_argument(
        'dataset_file_pattern',
        help='GLOB pattern for the dataset files. Ex: \'D:/datasets/ffhq/1024x1024/*.tfrecord\'')

    resume_parser = subparsers.add_parser('resume', help='resume training from a checkpoint')
    resume_parser.set_defaults(func=resume_training)

    for subparser in [init_parser, resume_parser]:
        subparser.add_argument(
            '--gpu-mem-limit',
            help='maximum amount of memory to consume on the gpu in GB. 0 for unlimited',
            type=int,
            default=0)
        subparser.add_argument('out', help='root output folder')

    resume_parser.add_argument(
        '--checkpoint',
        help='checkpoint epoch to resume from. Defaults to the largest checkpointed epoch',
        type=int)

    args = parser.parse_args(args=raw_arguments)

    if args.gpu_mem_limit != 0:
        limit_gpu_memory_usage(args.gpu_mem_limit * 1024)

    args.func(args, visualization_callback, strategy)


In [None]:
### Training Loop

In [None]:
def training_loop(
        strategy: tf.distribute.Strategy,
        generator: tf.keras.Model,
        discriminator: tf.keras.Model,
        real_image_dataset: tf.distribute.DistributedDataset,
        epoch_i: int,
        end_epoch_i: int,
        replica_batch_size: int,
        epoch_sample_count: int,
        learning_rate: float = 0.002,
        beta_1: float = 0.0,
        beta_2: float = 0.99,
        d_regularization_interval: int = 16,
        callbacks: List[tf.keras.callbacks.Callback] = []
        ) -> int:
    global_batch_size = replica_batch_size * strategy.num_replicas_in_sync
    assert epoch_sample_count % global_batch_size == 0
    epoch_batch_count = epoch_sample_count // global_batch_size

    noise_size = generator.inputs[0].shape[-1]

    d_stat_names = ['d_loss', 'd_real', 'd_fake']
    g_stat_names = ['g_loss']

    progbar_callback = tf.keras.callbacks.ProgbarLogger(count_mode='steps')
    progbar_callback.target = epoch_batch_count
    callback_list = tf.keras.callbacks.CallbackList(
        callbacks=[progbar_callback] + callbacks,
        model=generator)

    with strategy.scope():
        generator.optimizer = tf.keras.optimizers.Adam(
            learning_rate=learning_rate,
            beta_1=beta_1,
            beta_2=beta_2)

        lazy_ratio = d_regularization_interval / (d_regularization_interval + 1)
        discriminator.optimizer = tf.keras.optimizers.Adam(
            learning_rate=learning_rate*lazy_ratio,
            beta_1=beta_1**lazy_ratio,
            beta_2=beta_2**lazy_ratio)

    def reduce_across_batch(x: tf.Tensor) -> tf.Tensor:
        return tf.reduce_sum(x) / global_batch_size

    @tf.function
    def take_g_step() -> Dict[str, tf.Tensor]:
        noise = tf.random.normal(shape=(replica_batch_size, noise_size))
        fake_images = generator(noise, training=True)
        fake_classifications = discriminator(fake_images, training=False)
        loss = reduce_across_batch(tf.nn.softplus(-fake_classifications))

        grads = tf.gradients(loss, generator.trainable_variables)

        assert len(generator.trainable_variables) == len(grads)
        generator.optimizer.apply_gradients(zip(grads, generator.trainable_variables))

        return {'g_loss': loss}

    @tf.function
    def take_d_classification_step(real_images) -> Dict[str, tf.Tensor]:
        noise = tf.random.normal(shape=(replica_batch_size, noise_size))
        fake_images = generator(noise, training=False)
        real_classifications = discriminator(real_images, training=True)
        fake_classifications = discriminator(fake_images, training=True)

        real_loss = reduce_across_batch(tf.nn.softplus(-real_classifications))
        fake_loss = reduce_across_batch(tf.nn.softplus(fake_classifications))
        d_loss = real_loss + fake_loss
        d_grads = tf.gradients(d_loss, discriminator.trainable_variables)
        assert len(d_grads) == len(discriminator.trainable_variables)
        discriminator.optimizer.apply_gradients(zip(d_grads, discriminator.trainable_variables))

        stats = [d_loss, real_loss, fake_loss]
        assert len(stats) == len(d_stat_names)
        stat_dict = dict(zip(d_stat_names, stats))

        return stat_dict

    @tf.function
    def take_d_reg_step(real_images) -> Dict[str, tf.Tensor]:
        real_classifications = discriminator(real_images, training=True)
        real_grads = tf.gradients(tf.reduce_sum(real_classifications), real_images)
        gradient_loss = reduce_across_batch(tf.reduce_sum(tf.square(real_grads), axis=[1, 2, 3]))
        gradient_penalty_strength = 10. * 0.5 * d_regularization_interval
        gradient_penalty = gradient_loss * gradient_penalty_strength
        reg_grads = tf.gradients(gradient_penalty, discriminator.trainable_variables)
        assert len(reg_grads) == len(discriminator.trainable_variables)

        # The final bias addition has a second derivative of 0 which tf.gradients reports as
        # None. To prevent apply_gradients() from warning about this, we just insert a tensor
        # full of zeros.
        assert reg_grads[-1] is None
        reg_grads[-1] = tf.zeros_like(discriminator.trainable_variables[-1])
        discriminator.optimizer.apply_gradients(zip(reg_grads, discriminator.trainable_variables))

        return {'d_grad_reg': gradient_penalty}

    def tensor_dict_to_numpy(tensor_dict: Dict[str, tf.Tensor]) -> Dict[str, float]:
        return {key: strategy.reduce('sum', value, axis=None).numpy() for key, value in tensor_dict.items()}

    real_image_iter = iter(real_image_dataset)
    callback_list.on_train_begin()
    while epoch_i < end_epoch_i:
        callback_list.on_epoch_begin(epoch_i)
        all_stat_names = g_stat_names + d_stat_names + ['d_grad_reg', 'epoch']
        epoch_stats = {key: 0. for key in all_stat_names}
        epoch_stat_counts = epoch_stats.copy()
        for batch_i in range(0, epoch_batch_count):
            callback_list.on_train_batch_begin(batch_i)

            batch_stats = {}
            d_stats = strategy.run(take_d_classification_step, args=(next(real_image_iter),))
            batch_stats.update(tensor_dict_to_numpy(d_stats))

            global_batch_i = epoch_i * epoch_batch_count + batch_i
            if global_batch_i % d_regularization_interval == 0:
                d_reg_stats = strategy.run(take_d_reg_step, args=(next(real_image_iter),))
                batch_stats.update(tensor_dict_to_numpy(d_reg_stats))

            g_stats = strategy.run(take_g_step)
            batch_stats.update(tensor_dict_to_numpy(g_stats))

            batch_stats['epoch'] = epoch_i

            for name, stat in batch_stats.items():
                epoch_stats[name] = ((epoch_stats[name] * epoch_stat_counts[name] + stat) /
                    (epoch_stat_counts[name] + 1))
                epoch_stat_counts[name] += 1

            callback_list.on_train_batch_end(batch_i, logs=batch_stats)

        callback_list.on_epoch_end(epoch_i, logs=epoch_stats)
        try:
            mem_info = tf.config.experimental.get_memory_info('GPU:0')
            print(f'Memory usage: {mem_info}')
        except:
            pass
        epoch_i += 1
    callback_list.on_train_end()
    return epoch_i

# Plotting and Image generation.

In [None]:
def images_to_gridded_image(images: tf.Tensor) -> tf.Tensor:
    '''
        images should have shape [grid_height, grid_width, image_height, image_width, channel_count]
    '''
    grid_height, grid_width, image_height, image_width, channel_count = images.shape
    grid = tf.transpose(images, perm=[0, 2, 1, 3, 4])
    grid = tf.reshape(grid, (grid_height * image_height, grid_width * image_width, channel_count))
    return grid


def spherically_project_images_to_grid(
        images: tf.Tensor,
        background_color: Optional[tf.Tensor] = None) -> tf.Tensor:
    '''
        images should have shape [lat, lon, height, width, channels]
    '''
    grid_height, grid_width, _, _, _ = images.shape

    background_color = default_to(background_color, tf.constant([128, 128, 128], dtype=tf.uint8))

    assert background_color.dtype == images.dtype
    projected = tf.Variable(tf.broadcast_to(background_color, images.shape))
    cur_dist = tf.Variable(tf.fill((grid_height, grid_width), math.inf))

    for cell_y, unproj_cell_x in [(y, x) for y in range(grid_height) for x in range(grid_width)]:
        x_range = math.sin(cell_y / (grid_height - 1) * math.pi)
        proj_cell_x = unproj_cell_x * x_range + grid_width * 0.5 * (1 - x_range)
        int_proj_cell_x = round(proj_cell_x)
        cell_center_dist = tf.abs(int_proj_cell_x - proj_cell_x)
        if cell_center_dist < cur_dist[cell_y, int_proj_cell_x]:
            cur_dist[cell_y, int_proj_cell_x].assign(cell_center_dist)
            projected[cell_y, int_proj_cell_x].assign(images[cell_y, unproj_cell_x])

    return images_to_gridded_image(projected)


In [None]:
def plot_models(create_model: Callable[[int], tf.keras.Model], model_name: str):
    for resolution in [4, 8, 16, 32, 64, 128, 256, 512, 1024]:
        print(resolution)
        model = create_model(resolution)
        file_path = os.path.join('output', 'model_plots', f'{model_name}_{resolution}x{resolution}.png')
        ensure_dir_for_file(file_path)
        tf.keras.utils.plot_model(model, file_path, show_shapes=True, show_dtype=True)

In [None]:
# Example model build and plot. 
limit_gpu_memory_usage(1024)

plot_models(lambda resolution: create_generator(resolution, 512), 'generator')
plot_models(create_discriminator, 'discriminator')