In [2]:
import os
import io
import tensorflow as tf
import json
import matplotlib.pyplot as plt
import numpy as np
import tensorflow.experimental.numpy as tnp
import pathlib

from datetime import datetime

from tensorflow.data import Dataset
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import normalize
from tensorflow.keras.models import Sequential, model_from_json
from tensorflow.keras.layers import Conv2D, Layer, Conv2DTranspose, Flatten, Dense, Reshape, LeakyReLU, Dropout
from tensorflow.keras.layers import BatchNormalization, Lambda, Input, Concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import Callback, TensorBoard, ModelCheckpoint
from tensorflow.keras.metrics import BinaryAccuracy, Mean
from tensorflow.keras.losses import BinaryCrossentropy, Reduction
from tensorflow.distribute.cluster_resolver import TPUClusterResolver
from tensorflow.distribute import TPUStrategy
from tensorflow.io.gfile import GFile

tf.get_logger().setLevel('ERROR')
tf.config.list_physical_devices('GPU')

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [None]:
# For colab only, the current version of matplotlib (3.2.2) in colab is incompatible with this code, so upgrade
# !pip install matplotlib==3.3.2 --upgrade

## Constants

In [92]:
HP = {
    'batch_size_per_replica': 256,
    'total_samples': 70000,
    'noise_dim': 100,
    'g_optimizer': lambda: Adam(lr=0.0002, beta_1=0.5),
    'd_optimizer': lambda: Adam(lr=0.0002, beta_1=0.5),
    'g_metrics': lambda: [BinaryAccuracy(name='accuracy', threshold=0.5)],
    'd_metrics': lambda: [BinaryAccuracy(name='accuracy', threshold=0.5)],
    'g_loss_fn': lambda: BinaryCrossentropy(reduction=Reduction.NONE),
    'd_loss_fn': lambda: BinaryCrossentropy(reduction=Reduction.NONE),
    'num_images_to_log': 100,
    'use_tpu': False
}

## Load Dataset

In [4]:
def load_dataset(tm):
    # Load data
    (X_train, _), (X_test,_) = mnist.load_data()

    # Preprocess data
    X = tnp.vstack((X_train, X_test))
    X = tf.cast(X, dtype=tf.float32) / 255.0
    X = tf.reshape(X, (-1, 28, 28, 1))
#     y = tf.ones((X.shape[0], 1))
    y = tf.constant(0.9, shape=(X.shape[0], 1), dtype=tf.float32)

    ds = Dataset.from_tensor_slices((X, y))
    ds = ds.shuffle(buffer_size=HP['total_samples']).batch(HP['batch_size_per_replica'] * tm.strategy.num_replicas_in_sync)
    ds = ds.repeat()
    
    return ds

In [None]:
def load_tpu_dataset(is_training=True):
    import tensorflow_datasets as tfds

    split = 'train' if is_training else 'test'
    dataset, info = tfds.load(name='mnist', split=split, with_info=True,
                            as_supervised=True, try_gcs=True)

    def scale(image, label):
        image = tf.cast(image, tf.float32)
        image /= 255.0

        return image, tf.constant([1.0], dtype=tf.float32)

    dataset = dataset.map(scale)

    # Only shuffle and repeat the dataset in training. The advantage to have a
    # infinite dataset for training is to avoid the potential last partial batch
    # in each epoch, so users don't need to think about scaling the gradients
    # based on the actual batch size.
    if is_training:
        dataset = dataset.shuffle(10000)
        dataset = dataset.repeat()

    dataset = dataset.batch(HP['batch_size_per_replica'])

    return dataset

In [5]:
def get_dataset(tm):
    if HP['use_tpu']:
        return tm.strategy.experimental_distribute_datasets_from_function(load_tpu_dataset)
    else:
        return tm.strategy.experimental_distribute_dataset(load_dataset(tm))
    
def get_log_dir():
    return f'gs://{os.environ["GS_BUCKET"]}/mnist/logs' if HP['use_tpu'] else 'logs'

## Train

In [93]:
tm = TrainManager(log_dir=get_log_dir(), experiment='test')

tm.init_tpu()
tm.init_auth()
tm.init_strategy()

ds = get_dataset(tm)

steps_per_epoch = HP['total_samples'] // tm.global_batch_size

with tm.strategy.scope():
    gan, last_epoch = tm.init_gan(restore_latest=False, save_architecture=False)
    
callbacks = [
#     LogGeneratedResults(num_images=HP['num_images_to_log'], log_dir=tm.generated_results_dir),
#     TensorBoard(log_dir=tm.tensorboard_dir, update_freq='epoch'),
#     SaveModel(log_dir=tm.model_dir)
]

epochs = 1
gan.fit(ds, epochs=epochs+last_epoch, steps_per_epoch=steps_per_epoch, callbacks=callbacks, initial_epoch=last_epoch)



<tensorflow.python.keras.callbacks.History at 0x7f464fb970a0>

## Train Manager

In [23]:
class TrainManager:
    def __init__(self, log_dir, experiment='test'):
        self.log_dir = log_dir
        self.experiment = experiment
        self._tpu_built = False
        self._authenticated = False
    
    def init_timestamp(self, latest=False):
        if latest:
            all_timestamps = tf.io.gfile.listdir(self.experiment_dir)
            if len(all_timestamps) == 0:
                raise ValueError(f'Cannot restore latest as there are no timestamps in dir: "{self.experiment_dir}"')
            timestamp = max(all_timestamps)
        else:
            timestamp = datetime.now().strftime('%Y-%m-%dT%H-%M-%S')
            
        self._curr_timestamp = timestamp
        
    def init_gan(self, restore_latest=False, save_architecture=True):
        if not hasattr(self, '_curr_timestamp'):
            self.init_timestamp(restore_latest)

        if restore_latest:
            gan = GAN.load(self.model_dir, latest=True)
            return gan, self.last_epoch
        else:
            gan = build_gan()
            if save_architecture:
                gan.save_architecture(self.model_dir)
            last_epoch = 0
            return gan, last_epoch
    
    def init_strategy(self):
        if HP['use_tpu'] and not hasattr(self, '_resolver'):
            self.init_tpu()

        self._strategy = TPUStrategy(self._resolver) if HP['use_tpu'] else tf.distribute.get_strategy()
    
    def init_tpu(self):
        # Do nothing if we already initialized or if we're not using TPUs
        if self._tpu_built or not HP['use_tpu']:
            return

        self._resolver = TPUClusterResolver(tpu=f'grpc://{os.environ["COLAB_TPU_ADDR"]}')
        tf.config.experimental_connect_to_cluster(self._resolver)
        tf.tpu.experimental.initialize_tpu_system(self._resolver)
        self._tpu_built = True

        tf.config.list_logical_devices()
        
    def init_auth(self):
        # Do nothing if we already authenticated or if we're not using TPUs
        if self._authenticated or not HP['use_tpu']:
            return

        if self.is_colab:
            from google.colab import auth
            # Authenticates the Colab machine and also the TPU using your
            # credentials so that they can access your private GCS buckets.
            auth.authenticate_user()
            self._authenticated = True

    @property
    def experiment_dir(self):
        return os.path.join(self.log_dir, self.experiment)
        
    @property
    def timestamp_dir(self):
        return os.path.join(self.experiment_dir, self._curr_timestamp)
        
    @property
    def model_dir(self):
        return os.path.join(self.timestamp_dir, 'models')
    
    @property
    def tensorboard_dir(self):
        return os.path.join(self.timestamp_dir, 'tensorboard')
    
    @property
    def generated_results_dir(self):
        return os.path.join(self.timestamp_dir, 'generated_results')
    
    @property
    def last_epoch(self):
        # Both the discrimator and generator have the same # of epochs so we can get either
        latest_ckpt_path = tf.train.latest_checkpoint(os.path.join(self.model_dir, 'discriminator'))
        path = pathlib.PurePath(latest_ckpt_path)
        # Checkpoint prefixes are stored as epoch_x so we split to get the epoch number
        epoch = path.name.split('_')[-1]
        return int(epoch)
    
    @property
    def strategy(self):
        return self._strategy
    
    @property
    def global_batch_size(self):
        if not hasattr(self, 'strategy'):
            raise ValueError("'strategy' hasn't been initialized yet.")

        return HP['batch_size_per_replica'] * self.strategy.num_replicas_in_sync
    
    @property
    def is_colab(self):
        # This is always set on Colab, the value is 0 or 1 depending on GPU presence
        return 'COLAB_GPU' in os.environ

## Build GAN

In [35]:
def build_gan():
    generator = build_generator(HP['noise_dim'])
    discriminator = build_discriminator(input_shape=(28, 28, 1))

    gan = GAN(generator, discriminator, HP['noise_dim'])
    gan.compile(
        g_optimizer=HP['g_optimizer'](), 
        d_optimizer=HP['d_optimizer'](),
        g_metrics=HP['g_metrics'](),
        d_metrics=HP['d_metrics'](),
        loss_fn=HP['g_loss_fn'](),
        run_eagerly=True
    )

    return gan

## Build Generator

In [8]:
def build_generator(noise_dim):
    model = Sequential()
    
    num_features = 128
    
    # 7x7
    model.add(Dense(num_features * 7 * 7, input_dim=noise_dim))
    model.add(LeakyReLU(alpha=0.02))
    model.add(Reshape((7, 7, num_features)))
    
    # 14x14
    model.add(Conv2DTranspose(num_features, (4, 4), strides=(2, 2), padding='same'))
    model.add(LeakyReLU(alpha=0.02))
    
    # 28x28
    model.add(Conv2DTranspose(num_features, (4, 4), strides=(2, 2), padding='same'))
    model.add(LeakyReLU(alpha=0.02))
    
    model.add(Conv2D(1, (7, 7), activation='sigmoid', padding='same'))

    return model

## Build Discriminator

In [9]:
def build_discriminator(input_shape):
    model = Sequential()
    model.add(Conv2D(64, (3, 3), strides=(2, 2), padding='same', input_shape=input_shape))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization())
    model.add(Dropout(0.2))

    model.add(Conv2D(64, (3, 3), strides=(2, 2), padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization())
    model.add(Dropout(0.2))
    
    model.add(Flatten())
#     model.add(MinibatchDiscrimination(1))
    model.add(Dense(1, activation='sigmoid'))
        
    return model

## GAN Model

In [90]:
class GAN(Model):
    def __init__(self, generator, discriminator, noise_dim, **kwargs):
        super(GAN, self).__init__(**kwargs)
        self.generator = generator
        self.discriminator = discriminator
        self.noise_dim = noise_dim

    def compile(self, g_optimizer, d_optimizer, g_metrics, d_metrics, loss_fn, **kwargs):
        super(GAN, self).compile(**kwargs)
        
        self.generator.compile(optimizer=g_optimizer, metrics=g_metrics, loss=loss_fn)
        self.discriminator.compile(optimizer=d_optimizer, metrics=d_metrics, loss=loss_fn)
        
        # Save the compiled info so we can deserialize the model later
        self.compiled_config = {
            'g_optimizer': self.generator.optimizer,
            'd_optimizer': self.discriminator.optimizer,
            'g_metrics': g_metrics,
            'd_metrics': d_metrics,
            'loss_fn': loss_fn,
            **kwargs
        }

#     @tf.function
    def train_step(self, data):
        X_real, y_real = data
        global_batch_size = HP['batch_size_per_replica'] * self.distribute_strategy.num_replicas_in_sync

        X_fake, y_fake = self.generate_fake_data(HP['batch_size_per_replica'], self.noise_dim)

        # Should shuffle? so that it's not all real then all fake
        X, y = tf.concat((X_real, X_fake), axis=0), tf.concat((y_real, y_fake), axis=0)
    
        # Train discriminator
        with tf.GradientTape() as tape:
            y_pred = self.discriminator(X)
            d_loss_per_replica = self.discriminator.compiled_loss(y, y_pred)
            d_loss = tf.nn.compute_average_loss(d_loss_per_replica, global_batch_size=global_batch_size)
        d_gradients = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.discriminator.optimizer.apply_gradients(zip(d_gradients, self.discriminator.trainable_weights))
        
#         print(y_pred[502:512, :], y[502:512, :])

        # Update discriminator metrics
        self.discriminator.compiled_metrics.update_state(y, y_pred)
    
#         y_gan = tf.ones((HP['batch_size_per_replica'], 1))
        y_gan = tf.constant(0.9, shape=(HP['batch_size_per_replica'], 1), dtype=tf.float32)

        # Train generator
        with tf.GradientTape() as tape:
            X_gan, _ = self.generate_fake_data(HP['batch_size_per_replica'], self.noise_dim)
            y_pred = self.discriminator(X_gan)
            g_loss_per_replica = self.generator.compiled_loss(y_gan, y_pred)
            g_loss = tf.nn.compute_average_loss(g_loss_per_replica, global_batch_size=global_batch_size)
        g_gradients = tape.gradient(g_loss, self.generator.trainable_weights)
        self.generator.optimizer.apply_gradients(zip(g_gradients, self.generator.trainable_weights))
            
        # Update generator metrics
        self.generator.compiled_metrics.update_state(y_gan, y_pred)
        
        # Rename the metrics to g_ or d_
        g_metrics = { f'g_{m.name}': m.result() for m in self.generator.metrics }
        d_metrics = { f'd_{m.name}': m.result() for m in self.discriminator.metrics }

        return { **g_metrics, **d_metrics }

#     @tf.function
    def generate_fake_data(self, batch_size, noise_dim):
        noise_data = tf.random.normal(shape=(batch_size, noise_dim))
        X_fake = self.generator(noise_data)
        y_fake = tf.zeros((batch_size, 1))
        
        return X_fake, y_fake
    
    @property
    def metrics(self):
        # Resets these metrics after each epoch
        return [*self.generator.metrics, *self.discriminator.metrics]
    
    def save_architecture(self, log_dir):
        # If it's a local path we need to make sure it exists before writing
        if not log_dir.startswith('gs://'):
            pathlib.Path(log_dir).mkdir(parents=True, exist_ok=True)

        with GFile(os.path.join(log_dir, 'gan_architecture.json'), 'w') as json_file:
            json_file.write(self.to_json())
    
    def save_weights(self, log_dir, prefix):
        self.generator.save_weights(os.path.join(log_dir, 'generator', prefix))
        self.discriminator.save_weights(os.path.join(log_dir, 'discriminator', prefix))
        
    @classmethod
    def load(cls, log_dir, prefix=None, latest=False):
        if not latest and not prefix:
            raise ValueError('Either prefix or latest should be used.')
            
        with tf.keras.utils.custom_object_scope({'GAN': cls}):
            saved_json = GFile(os.path.join(log_dir, 'gan_architecture.json'), 'r').read()
            gan = model_from_json(saved_json)
            
        if latest:
            generator_ckpt = tf.train.latest_checkpoint(os.path.join(log_dir, 'generator'))
            discriminator_ckpt = tf.train.latest_checkpoint(os.path.join(log_dir, 'discriminator'))
        else:
            generator_ckpt = os.path.join(log_dir, 'generator', prefix)
            discriminator_ckpt = os.path.join(log_dir, 'discriminator', prefix)

        gan.generator.load_weights(generator_ckpt)
        gan.discriminator.load_weights(discriminator_ckpt)
        
        return gan
    
    def get_config(self):
        return {
            'generator': self.generator,
            'discriminator': self.discriminator,
            'noise_dim': self.noise_dim,
            'compiled_config': self.compiled_config
        }
    
    @classmethod
    def from_config(cls, config):
        generator = model_from_json(json.dumps(config.pop('generator')))
        discriminator = model_from_json(json.dumps(config.pop('discriminator')))
        
        compiled_config = config.pop('compiled_config')
        compiled_config['g_optimizer'] = tf.keras.optimizers.deserialize(compiled_config['g_optimizer'])
        compiled_config['d_optimizer'] = tf.keras.optimizers.deserialize(compiled_config['d_optimizer'])
        compiled_config['g_metrics'] = [tf.keras.metrics.deserialize(c) for c in compiled_config['g_metrics']]
        compiled_config['d_metrics'] = [tf.keras.metrics.deserialize(c) for c in compiled_config['d_metrics']]
        compiled_config['loss_fn'] = tf.keras.losses.deserialize(compiled_config['loss_fn'])

        gan = cls(generator, discriminator, **config)
        gan.compile(**compiled_config)
        
        return gan

## Log Generated Results Callback

In [11]:
class LogGeneratedResults(Callback):
    def __init__(self, num_images, log_dir, **kwargs):
        super(LogGeneratedResults, self).__init__(**kwargs)
        self.num_images = num_images
        self.log_dir = log_dir
        
    def on_epoch_end(self, epoch, logs):
        if not hasattr(self.model, 'generate_fake_data'):
            raise ValueError('The model should have a generate_fake_data function.')
            
        if not hasattr(self.model, 'noise_dim'):
            raise ValueError('The model should have a noise_dim property.')

        images, _ = self.model.generate_fake_data(self.num_images, self.model.noise_dim)
        title = f'Total Images: {self.num_images} | Noise Dim: {self.model.noise_dim}'

        self.log_images(images, title, self.log_dir, 'gray', epoch)

    def log_images(self, images, title, log_dir, cmap, epoch):
        assert len(images.shape) == 4

        num_images = len(images)
        figsize = int(np.ceil(np.sqrt(num_images)))

        figure = plt.figure(figsize=(figsize, figsize))

        for i in range(num_images):
            plt.subplot(figsize, figsize, i + 1)
            plt.axis('off')
            plt.imshow(images[i, :, :, :], cmap=cmap)

        # Save the plot to a PNG in memory.
        buf = io.BytesIO()
        plt.savefig(buf, format='png')
        # Closing the figure prevents it from being displayed directly inside
        # the notebook.
        plt.close(figure)
        buf.seek(0)
        # Convert PNG buffer to TF image
        image = tf.image.decode_png(buf.getvalue(), channels=4)
        # Add the batch dimension
        image = tf.expand_dims(image, 0)

        file_writer = tf.summary.create_file_writer(log_dir)

        with file_writer.as_default():
            tf.summary.image(title, image, step=epoch)

## SaveModel

In [12]:
class SaveModel(Callback):
    def __init__(self, log_dir, **kwargs):
        super(SaveModel, self).__init__(**kwargs)
        self.log_dir = log_dir
        
    def on_epoch_end(self, epoch, logs):
        self.model.save_weights(log_dir=self.log_dir, prefix=f'epoch_{epoch+1}')

## Minibatch Discrimination Layer

In [None]:
class MinibatchDiscrimination(Layer):
    def __init__(self, kernel_dims, **kwargs):
        super(MinibatchDiscrimination, self).__init__(**kwargs)
        self.kernel_dims = kernel_dims

    def build(self, input_shape):
        self.in_features = input_shape[-1]
    
    def call(self, X):
        features = tf.reshape(X, (-1, self.in_features, self.kernel_dims)) # NxBxC

        Mi = tf.expand_dims(features, axis=0) # 1xNxBxC

        Mj = tf.expand_dims(features, axis=1) # Nx1xBxC

        abs_diff = tnp.abs(Mi - Mj) # NxNxBxC

        norm = tnp.sum(abs_diff, axis=3) # NxNxB
    
        outputs = tnp.sum(tnp.exp(-norm), axis=0) # NxB

        return Concatenate(axis=1)((X, outputs)) # Nx(B+X.shape[-1])
    
    def get_config(self):
        config = super(MinibatchDiscrimination, self).get_config()
        config.update({ 'kernel_dims': self.kernel_dims })
        return config

## Other