In [None]:
import tensorflow as tf
import numpy as np
import cv2
import matplotlib.pyplot as plt
import csv
import os
import sys
import time
import logging
import re
from commons import *
from gan_arch import *
from datetime import datetime

In [None]:
IMG_W = 128
IMG_H = 128
IMG_C = 3


SHARDS = 1
BATCH_SIZE = 128
SUPER_BATCH = 8

latent_dim = 256
last_epoch = 0

SAMPLES_PATH = "samples_metu"
CHECKPOINT_PATH = "saved_model_metu"
LOG_DIR = "gan_logs/" + datetime.now().strftime("%Y%m%d-%H%M%S")
GEN_CPT_NAME = f"{CHECKPOINT_PATH}/gencp-{{epoch:02d}}.ckpt"
DIS_CPT_NAME = f"{CHECKPOINT_PATH}/discp-{{epoch:02d}}.ckpt"

In [None]:
METU_RAW_PATH = '/qarr/studia/magister/datasets/METU/930k_logo_v3/'
METU_DATASET_PATH = '/qarr/studia/magister/datasets/METU/resized_930k_logo'
METU_DATASET_PATH = '/home/zenfur/magister/resized_930k_logo'

In [None]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
  try:
    # Currently, memory growth needs to be the same across GPUs
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
    logical_gpus = tf.config.experimental.list_logical_devices('GPU')
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
  except RuntimeError as e:
    # Memory growth must be set before GPUs have been initialized
    print(e)
else:
    print("No compatible GPUs found")

## Define METU dataset loading pipeline

In [None]:
# Static load into RAM
# metuDataset = load_dataset(METU_DATASET_PATH, nexamples=1024*14)
# metuDataset = metuDataset.shuffle(len(metuDataset)).cache().batch(batch_size)

In [None]:
imagesList = tf.io.matching_files(METU_DATASET_PATH + "/*.jpg")

In [None]:
@tf.function
def tf_read_image(path):
    img = tf.io.read_file(path)
    return tf.image.decode_jpeg(img, channels=3, dct_method='INTEGER_ACCURATE')


@tf.function
def tf_convert_and_normalize_img(img):
    c = tf.constant(128.0, dtype=tf.dtypes.float32)
    img = tf.cast(img, tf.dtypes.float32)
    img = tf.math.subtract(img, c)
    return tf.math.divide(img, c)



In [None]:
pathsDB = tf.data.Dataset.from_tensor_slices(imagesList)
lenMetu = len(pathsDB)
SHARD_IDX = 0
metuDB = (pathsDB.shard(SHARDS, SHARD_IDX)
                .shuffle(lenMetu//SHARDS + (lenMetu % SHARDS > SHARD_IDX))
                #.repeat()
                .map(tf_read_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
                .batch(BATCH_SIZE*SUPER_BATCH)
                .map(tf_convert_and_normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
                .prefetch(tf.data.experimental.AUTOTUNE)
                .unbatch()
                .batch(BATCH_SIZE)
         )

train_dataset = metuDB

In [None]:
plt.imshow(tf_read_image(imagesList.numpy()[10]).numpy())

## Preparing the model

In [None]:
# Attempting to read epoch number from last checkpoint
try:
    checkpoints = os.listdir(CHECKPOINT_PATH)
    last_epoch = np.max([int(re.search(r"-[0-9]+\.", i)[0][1:-1]) for i in checkpoints if
                             re.search(r"-[0-9]+\.", i)])
    print("Detected {} epoch as last checkpoint".format(last_epoch))
except (ValueError,  FileNotFoundError):
    last_epoch = 0
    print("Did not detect any checkpoints to continue from")

In [None]:
# Initialize generator and disciminator models
weight_init = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
d_model = build_discriminator_ref(weight_init, IMG_H=IMG_H, IMG_W=IMG_W, IMG_C=IMG_C)
g_model = build_generator_ref(latent_dim, weight_init, IMG_H=IMG_H, IMG_W=IMG_W, IMG_C=IMG_C)

In [None]:
if last_epoch:
    d_model.load_weights(DIS_CPT_NAME.format(epoch=last_epoch))
    g_model.load_weights(GEN_CPT_NAME.format(epoch=last_epoch))
    print(f"Loading models from epoch {last_epoch}")
else:
    print("No checkpoint to load from")

In [None]:
gan = GAN_old(d_model, g_model, latent_dim)
bce_loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True, label_smoothing=0.1)
d_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
g_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
gan.compile(d_optimizer, g_optimizer, bce_loss_fn)

In [None]:
tboard_callback = tf.keras.callbacks.TensorBoard(log_dir = LOG_DIR, histogram_freq=1)

In [None]:
# Epochs to train
train_for = 2

In [None]:
n_samples = 25
noise = np.random.normal(size=(n_samples, latent_dim))

logging.getLogger('tensorflow').setLevel(logging.ERROR)
try:
    for epoch in range(last_epoch, last_epoch+train_for):
        gan.fit(train_dataset, 
                initial_epoch=epoch, 
                epochs=epoch+1, 
                steps_per_epoch=lenMetu//SHARDS//BATCH_SIZE,
                callbacks=[tboard_callback]
               )
        g_model.save_weights(checkpoint_path_g.format(epoch=epoch+1))
        d_model.save_weights(checkpoint_path_d.format(epoch=epoch+1))

        n_samples = 25
        new_noise = np.random.normal(size=(n_samples, latent_dim))
        new_noise[0:5] = noise[0:5]
        examples = g_model.predict(new_noise)
        #save_plot(examples, epoch, int(np.sqrt(n_samples)))
        last_epoch += 1
except KeyboardInterrupt:
    print("Interrupted")

logging.getLogger('tensorflow').setLevel(logging.WARNING)

In [None]:
tf.keras.backend.clear_session()