In [1]:
#@title Import libraries { form-width: "20%" }
import tensorflow as tf
from tqdm import tqdm
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time
from IPython import display
from tensorflow import keras
import matplotlib.pyplot as plt
from math import ceil
import cv2
from tensorflow.python.ops.numpy_ops import np_config

In [None]:
#@title Install packages to generate GIFs { form-width: "20%" }
# To generate GIFs
!pip install imageio
!pip install git+https://github.com/tensorflow/docs

In [None]:
#@title TPU initialization { form-width: "20%" }
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
print("All devices: ", tf.config.list_logical_devices('TPU'))

strategy = tf.distribute.TPUStrategy(resolver)

In [None]:
#@title Mount Google Drive { form-width: "20%" }
from google.colab import drive
drive.mount('/content/drive')

In [None]:
#@title Unzip dataset { form-width: "20%" }
!unzip /content/drive/MyDrive/archive.zip

In [6]:
#@title Functions to read, write and upload tfrecord files { form-width: "20%" }
def _int_feature(list_of_ints): # int64
  return tf.train.Feature(int64_list=tf.train.Int64List(value=list_of_ints))
def _float_feature(list_of_floats): # float32
  return tf.train.Feature(float_list=tf.train.FloatList(value=list_of_floats))
def _bytestring_feature(list_of_bytestrings):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=list_of_bytestrings))
  
def write_tfrecord_file(path, image_paths):
    with tf.io.TFRecordWriter(path) as writer:
        for index in tqdm(range(len(image_paths))):
            with open(image_paths[index], 'rb') as fp:
                img = fp.read()

            feature = {
                "image": _bytestring_feature([img]),
            }
            tf_record = tf.train.Example(features=tf.train.Features(feature=feature))
            writer.write(tf_record.SerializeToString())

def upload_tfrecord_file_to_gcs(bucket, path):
    !gsutil cp {path} gs://{bucket}/

def read_tfrecord(data):
    features = {
        "image": tf.io.FixedLenFeature([], tf.string),
    }
    tf_record = tf.io.parse_single_example(data, features)
    image = tf.io.decode_jpeg(tf_record['image'], channels=3)
    image = tf.image.resize(image, [128,128], antialias=True, method = 'nearest')
    return image

In [None]:
#@title Write and upload tfrecord dataset { form-width: "20%" }
data_path = '/content/train/'
image_paths = []
for root, subdirs, files in os.walk(data_path):
    for f in files:
        image_paths += [os.path.join(root, f)]

project_id = 'bird-gan'

from google.colab import auth
auth.authenticate_user()

!gcloud config set project {project_id}

bucket_name = 'bird_gan_data'
output_file_path = '/content/image_dataset.tfrecords'
num_files = 1
num_shards = 1
dataset_size = int(len(image_paths)/num_shards)
print(dataset_size)

write_tfrecord_file(output_file_path, image_paths)
upload_tfrecord_file_to_gcs(bucket_name, output_file_path)

In [8]:
#@title Create dataset { form-width: "20%" } { form-width: "20%" }
def get_dataset(per_replica_batch_size, normalize):
    gs_paths = ['gs://bird_gan_data/image_dataset.tfrecords']
    AUTOTUNE = tf.data.AUTOTUNE

    ignore_order = tf.data.Options()
    ignore_order.experimental_deterministic = False  # disable order, increase speed

    dataset = tf.data.TFRecordDataset([gs_paths])
    dataset = dataset.with_options(ignore_order)  
    dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTOTUNE)
    # dataset = dataset.shard(num_shards, 0)
    dataset = dataset.cache()
    dataset = dataset.shuffle(dataset_size, reshuffle_each_iteration=True)
    dataset = dataset.repeat()
    dataset = dataset.prefetch(buffer_size=AUTOTUNE)
    dataset = dataset.batch(per_replica_batch_size)

    if normalize:
        dataset = dataset.map(lambda x: x / 255)
        dataset = dataset.map(lambda x: x*2 - 1)

    return dataset

In [None]:
#@title Display dataset { form-width: "20%" }
dataset_demo = get_dataset(16, False)

plt.figure(figsize=(10, 10))
for images in dataset_demo.take(1):
    for i in range(9):
        ax = plt.subplot(3, 3, i+1)
        plt.imshow(images[i].numpy())

In [10]:
#@title Define generator model { form-width: "20%" }
def make_generator_model():
    model = tf.keras.Sequential()

    model.add(layers.Dense(4*4*256, input_shape = (256,)))
    model.add(layers.Activation('relu'))
    model.add(layers.Reshape((4,4,256)))
    assert model.output_shape == (None, 4, 4, 256)

    model.add(layers.UpSampling2D())
    model.add(layers.Conv2D(256, kernel_size = 3, padding = 'same'))
    model.add(layers.BatchNormalization(momentum = 0.8))
    model.add(layers.Activation('relu'))
    # shape = (None, 8, 8, 256)

    model.add(layers.UpSampling2D()) 
    model.add(layers.Conv2D(256, kernel_size = 3, padding = 'same'))
    model.add(layers.BatchNormalization(momentum = 0.8))
    model.add(layers.Activation('relu'))
    # shape = (None, 16, 16, 256)

    model.add(layers.UpSampling2D()) 
    model.add(layers.Conv2D(128, kernel_size = 3, padding = 'same'))
    model.add(layers.BatchNormalization(momentum = 0.8))
    model.add(layers.Activation('relu'))
    # shape = (None, 32, 32, 128)

    model.add(layers.UpSampling2D())
    model.add(layers.Conv2D(128, kernel_size = 3, padding = 'same'))
    model.add(layers.BatchNormalization(momentum = 0.8))
    model.add(layers.Activation('relu'))
    # shape = (None, 64, 64, 128)

    model.add(layers.UpSampling2D())
    model.add(layers.Conv2D(128, kernel_size = 3, padding = 'same'))
    model.add(layers.BatchNormalization(momentum = 0.8))
    model.add(layers.Activation('relu'))
    # shape = (None, 128, 128, 128)

    model.add(layers.Conv2D(3,kernel_size = 3, padding = 'same'))
    assert model.output_shape == (None, 128, 128, 3)
    model.add(layers.Activation('tanh'))

    return model

In [12]:
#@title Define discriminator model { form-width: "20%" }
def make_discriminator_model():
    model = tf.keras.Sequential()
    
    model.add(layers.GaussianNoise(0.1))
    model.add(layers.Conv2D(32, kernel_size = 3, strides = 2, input_shape = [128,128,3], padding = 'same'))
    model.add(layers.LeakyReLU(alpha = 0.2))

    model.add(layers.Dropout(0.25))
    model.add(layers.GaussianNoise(0.1))
    model.add(layers.Conv2D(64, kernel_size = 3, strides = 2, padding = 'same'))
    model.add(layers.ZeroPadding2D(padding = ((0,1), (0,1))))
    model.add(layers.BatchNormalization(momentum = 0.8))
    model.add(layers.LeakyReLU(alpha = 0.2))

    model.add(layers.Dropout(0.25))
    model.add(layers.GaussianNoise(0.1))
    model.add(layers.Conv2D(128, kernel_size = 3, strides = 2, padding = 'same'))
    model.add(layers.BatchNormalization(momentum = 0.8))
    model.add(layers.LeakyReLU(alpha = 0.2))

    model.add(layers.Dropout(0.25))
    model.add(layers.GaussianNoise(0.1))
    model.add(layers.Conv2D(256, kernel_size = 3, strides = 1, padding = 'same'))
    model.add(layers.BatchNormalization(momentum = 0.8))
    model.add(layers.LeakyReLU(alpha = 0.2))

    model.add(layers.Dropout(0.25))
    model.add(layers.GaussianNoise(0.1))
    model.add(layers.Conv2D(512, kernel_size = 3, strides = 1, padding = 'same'))
    model.add(layers.BatchNormalization(momentum = 0.8))
    model.add(layers.LeakyReLU(alpha = 0.2))

    model.add(layers.Dropout(0.25))
    model.add(layers.GaussianNoise(0.1))
    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

In [None]:
#@title Loss and optimiser { form-width: "20%" }
with strategy.scope():
    cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)

    def discriminator_loss(real_output, fake_output):
        # real_loss = cross_entropy(tf.ones_like(real_output), real_output)
        # fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
        real_loss = cross_entropy(tf.random.uniform((32,1), minval=0.9, maxval=1), real_output)
        fake_loss = cross_entropy(tf.random.uniform((32,1), minval=0, maxval=0.1), fake_output)
        total_loss = real_loss + fake_loss
        return total_loss

    def generator_loss(fake_output):
        return cross_entropy(tf.ones_like(fake_output), fake_output)

In [None]:
G_lr = 0.0002
D_lr = 0.0002
with strategy.scope():
    generator_optimizer = tf.keras.optimizers.Adam(G_lr)
    discriminator_optimizer = tf.keras.optimizers.Adam(D_lr)

In [None]:
#@title Create model objects { form-width: "20%" }
with strategy.scope():
    generator = make_generator_model()
    discriminator = make_discriminator_model()

In [None]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
with strategy.scope():
    checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                    discriminator_optimizer=discriminator_optimizer,
                                    generator=generator,
                                    discriminator=discriminator)

In [None]:
EPOCHS = 1000
batch_size = 256
noise_dim = 256
num_examples_to_generate = 16
seed = tf.random.normal([num_examples_to_generate, noise_dim])

In [None]:
#@title Define training step { form-width: "20%" }
per_replica_batch_size = batch_size // strategy.num_replicas_in_sync # Each worker in the TPU will train on this batch size
print(per_replica_batch_size)

train_dataset = strategy.distribute_datasets_from_function(
    lambda _: get_dataset(per_replica_batch_size, True))

@tf.function
def train_step(iterator, steps): 
    def step_fn(images): # This function is distributed across the workers of the TPU
        noise = tf.random.normal([per_replica_batch_size, noise_dim])

        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            generated_images = generator(noise, training=True)

            real_output = discriminator(images, training=True)
            fake_output = discriminator(generated_images, training=True)

            gen_loss = generator_loss(fake_output)
            disc_loss = discriminator_loss(real_output, fake_output)

        gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
        gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

        generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
        discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

    for _ in tf.range(steps):
        strategy.run(step_fn, args=(next(iterator),)) # steps_per_epoch number of steps are run in one call of training_step

In [None]:
#@title Define training loop { form-width: "20%" }
def train_TPU(train_dataset, EPOCHS):
    steps_per_epoch = (dataset_size // batch_size)
    print(steps_per_epoch)
    train_iterator = iter(train_dataset)

    for epoch in range(EPOCHS):
        start = time.time()
        print(f'Epoch: {epoch+1}/{EPOCHS}')

        train_step(train_iterator, tf.convert_to_tensor(steps_per_epoch))

        # Produce images for the GIF as you go
        display.clear_output(wait=True)
        generate_and_save_images(generator,
                                epoch + 1,
                                seed)
        
        # Save the model every 15 epochs
        # if (epoch + 1) % 15 == 0:
        # checkpoint.save(file_prefix = checkpoint_prefix)
        
        print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
    
    # Generate after the final epoch
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                            EPOCHS,
                            seed)

In [None]:
#@title Define function to display generator outputs { form-width: "20%" }
def generate_and_save_images(model, epoch, test_input):
    with strategy.scope():
        predictions = model(test_input, training=False)

    fig = plt.figure(figsize=(12, 12))
    
    np_config.enable_numpy_behavior()

    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i+1)
        plt.imshow((predictions[i]*127.5 + 127.5).astype(np.uint8))
        plt.axis('off')

    plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
    plt.show()

In [None]:
train_TPU(train_dataset, EPOCHS)

In [None]:
#@title Display a single image using the epoch number { form-width: "20%" }
def display_image(epoch_no):
    return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))

In [None]:
display_image(697)

In [None]:
#@title Create GIF of training progress { form-width: "20%" }
anim_file = 'bird_dcgan.mp4'

with imageio.get_writer(anim_file, mode='I') as writer:
    filenames = glob.glob('image*.png')
    filenames = sorted(filenames)
    for filename in filenames:
        image = imageio.imread(filename)
        writer.append_data(image)
    image = imageio.imread(filename)
    writer.append_data(image)

In [None]:
#@title Save model { form-width: "20%" }
checkpoint.save(file_prefix = '/content/drive/')