In [23]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="7"
os.environ["NCCL_DEBUG"]="WARN"
#os.environ["NCCL_P2P_LEVEL"]="NODE"

In [24]:
import sys
sys.path.append("../../../")

In [25]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [26]:
import omegaconf
import numpy as np
import matplotlib.pyplot as plt
from ganime.data.experimental import ImageDataset, VideoDataset
from ganime.model.vqgan_clean.vqgan import VQGAN
from ganime.visualization.videos import display_videos
from ganime.visualization.images import display_images
import tensorflow as tf
from datetime import datetime
from tqdm import tqdm
import os
from pyprojroot.pyprojroot import here
#tf.get_logger().setLevel('ERROR')

In [27]:
for device in tf.config.list_physical_devices("GPU"):
    tf.config.experimental.set_memory_growth(device, True)

In [28]:
strategy = tf.distribute.MirroredStrategy(cross_device_ops=tf.distribute.HierarchicalCopyAllReduce())

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


In [29]:
cfg = omegaconf.OmegaConf.load(here("configs/moving_mnist_image.yaml"))

In [30]:
num_workers = len(tf.config.list_physical_devices("GPU"))
batch_size = cfg["trainer"]["batch_size"] 
global_batch_size = batch_size * strategy.num_replicas_in_sync
n_epochs = cfg["trainer"]["n_epochs"] 
sample_batch_size = 8

In [31]:
#dataset_length = 34045 # KNY
#dataset_length = 310153 #KNY full
dataset_length = 20*10000 # MNIST
num_batch = dataset_length / batch_size

In [32]:
# Normalizing the images to [-1, 1]
def normalize(image):
    image = (tf.cast(image, tf.float32) / 127.5) - 1

    return image

In [33]:
def random_crop(image, img_height=64, img_width=128):
    cropped_image = tf.image.random_crop(image, size=[img_height, img_width, 3])

    return cropped_image

In [34]:
@tf.function()
def random_jitter(image):
    # Resizing to 72x142
    image = tf.image.resize(image, [72, 142], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    # Random cropping back to 64x128
    image = random_crop(image)

    if tf.random.uniform(()) > 0.5:
        # Random mirroring
        image = tf.image.flip_left_right(image)

    return image

In [35]:
def preprocess(image):
    image = tf.reshape(image, (tf.shape(image)[0], tf.shape(image)[1], 3))
    #image = random_jitter(image)
    image = normalize(image)
    return image, image

In [36]:
dataset = ImageDataset("../../../data/mnist_tfrecords").load()
dataset = dataset.shuffle(dataset_length, reshuffle_each_iteration=True, seed=10).map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)

In [37]:
train_size = int(dataset_length * 0.8)
validation_size = int(dataset_length * 0.1)
test_size = int(dataset_length * 0.1)

In [38]:
train_ds = dataset.take(train_size)
validation_ds = dataset.skip(train_size).take(validation_size)#.padded_batch(global_batch_size).map(postprocess)
test_ds = dataset.skip(train_size + validation_size).take(validation_size)#.padded_batch(global_batch_size).map(postprocess)

In [39]:
train_sample_data = next(train_ds
                          .batch(sample_batch_size)
                          .prefetch(tf.data.AUTOTUNE).as_numpy_iterator())
validation_sample_data = next(validation_ds.batch(sample_batch_size).as_numpy_iterator())

In [40]:
train_ds = (train_ds.batch(global_batch_size, drop_remainder=True)
            .prefetch(tf.data.AUTOTUNE))
validation_ds = (validation_ds.batch(global_batch_size, drop_remainder=True)
            .prefetch(tf.data.AUTOTUNE))
test_ds = (test_ds.batch(global_batch_size, drop_remainder=True)
            .prefetch(tf.data.AUTOTUNE))

In [41]:
from ganime.visualization.images import display_true_pred

In [42]:
with strategy.scope():
    vqgan = VQGAN(**cfg["model"], num_replicas=strategy.num_replicas_in_sync)

Working with z of shape (1, 512, 16, 16) = 131072 dimensions.
VQLPIPSWithDiscriminator running with hinge loss.


In [43]:
with strategy.scope():
    gen_optimizer = tf.keras.optimizers.Adam(
                learning_rate=cfg["trainer"]["gen_lr"],
                beta_1=cfg["trainer"]["gen_beta_1"],
                beta_2=cfg["trainer"]["gen_beta_2"],
                clipnorm=cfg["trainer"]["gen_clip_norm"],
    )
    disc_optimizer = tf.keras.optimizers.Adam(
        learning_rate=cfg["trainer"]["disc_lr"],
        beta_1=cfg["trainer"]["disc_beta_1"],
        beta_2=cfg["trainer"]["disc_beta_2"],
        clipnorm=cfg["trainer"]["disc_clip_norm"],
    )
    vqgan.compile(gen_optimizer=gen_optimizer, disc_optimizer=disc_optimizer)

In [46]:
from ganime.metrics.image import calculate_images_metrics

In [47]:
dataset = ImageDataset("../../../data/mnist_tfrecords").load()
dataset = dataset.shuffle(dataset_length, reshuffle_each_iteration=True).map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)

batch_size = 128

train_ds = dataset.take(train_size)
validation_ds = dataset.skip(train_size).take(validation_size)#.padded_batch(global_batch_size).map(postprocess)
test_ds = dataset.skip(train_size + validation_size).take(validation_size)#.padded_batch(global_batch_size).map(postprocess)

train_ds = (train_ds.batch(batch_size, drop_remainder=True)
            .prefetch(tf.data.AUTOTUNE))
validation_ds = (validation_ds.batch(batch_size, drop_remainder=True)
            .prefetch(tf.data.AUTOTUNE))
test_ds = (test_ds.batch(batch_size, drop_remainder=True)
            .prefetch(tf.data.AUTOTUNE))

In [48]:
metrics_test = [calculate_images_metrics(test_ds, vqgan, validation_size // batch_size) for i in range(5)]
metrics_test

  0%|          | 0/156 [00:00<?, ?it/s]

  0%|          | 0/156 [00:00<?, ?it/s]

  0%|          | 0/156 [00:00<?, ?it/s]

  0%|          | 0/156 [00:00<?, ?it/s]

  0%|          | 0/156 [00:00<?, ?it/s]

[{'fid': 18.115365820305104, 'ssim': 0.83215195, 'psnr': 23.687765},
 {'fid': 18.127795343861955, 'ssim': 0.8321791, 'psnr': 23.685871},
 {'fid': 17.983672100023036, 'ssim': 0.83247787, 'psnr': 23.699646},
 {'fid': 18.04111774504951, 'ssim': 0.8317932, 'psnr': 23.68214},
 {'fid': 18.159041522908865, 'ssim': 0.83215183, 'psnr': 23.69139}]

In [49]:
metrics_train = [calculate_images_metrics(train_ds, vqgan, train_size // batch_size) for i in range(5)]
metrics_train

  0%|          | 0/1250 [00:00<?, ?it/s]

  0%|          | 0/1250 [00:00<?, ?it/s]

  0%|          | 0/1250 [00:00<?, ?it/s]

  0%|          | 0/1250 [00:00<?, ?it/s]

  0%|          | 0/1250 [00:00<?, ?it/s]

[{'fid': 17.798508110873158, 'ssim': 0.8321389, 'psnr': 23.689665},
 {'fid': 17.82400717736511, 'ssim': 0.83215195, 'psnr': 23.690237},
 {'fid': 17.834273862523, 'ssim': 0.83214396, 'psnr': 23.689932},
 {'fid': 17.853948306492335, 'ssim': 0.83206, 'psnr': 23.688257},
 {'fid': 17.84624170945504, 'ssim': 0.832055, 'psnr': 23.68829}]