# Training

### Imports

In [None]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import preprocessing

from settings import *
import utils
from callbacks import *
from gan import GAN

utils.reset_rand()

### Check GPU

In [None]:
gpus = tf.config.experimental.list_physical_devices("GPU")

if gpus:

	try:
		tf.config.experimental.set_visible_devices(gpus[0], "GPU")
		print("Using GPU :)")

	except RuntimeError as e:
		print(e)

else:
	print("Using CPU :(")

### Dataset

In [None]:
dataset = preprocessing.image_dataset_from_directory(
	DATA_DIR,
	label_mode = None,
	color_mode = "rgb",
	batch_size = BATCH_SIZE,
	image_size = (IMAGE_SIZE, IMAGE_SIZE),
	shuffle = True
)

dataset = dataset.map(utils.tf_norm_img)

if FLIP_DATASET:
	flipped_dataset = dataset.map(tf.image.flip_left_right)
	dataset = dataset.concatenate(flipped_dataset)
	dataset = dataset.shuffle(BATCH_SIZE)

NB_DATA = len(os.listdir(DATA_DIR)) * 2 if FLIP_DATASET else len(os.listdir(DATA_DIR))
print("Dataset final size:", NB_DATA)

### Model

In [None]:
gan = GAN()
gan.compile()
gan.summary()

### First run / Continue

In [None]:
save_found = gan.load_weights(MODELS_DIR)

if save_found:
	samples_z = np.load(os.path.join(OUTPUT_DIR, "samples_z.npy"))
	samples_noise = np.load(os.path.join(OUTPUT_DIR, "samples_noise.npy"))

else:
	samples_z = np.random.normal(0., 1., (OUTPUT_SHAPE[0] * OUTPUT_SHAPE[1], LATENT_DIM))
	samples_noise = np.random.normal(0., 1., ((NB_BLOCKS * 2) - 1, OUTPUT_SHAPE[0] * OUTPUT_SHAPE[1], IMAGE_SIZE, IMAGE_SIZE, 1))

	if not os.path.exists(OUTPUT_DIR):
		os.makedirs(OUTPUT_DIR)

	np.save(os.path.join(OUTPUT_DIR, "samples_z.npy"), samples_z)
	np.save(os.path.join(OUTPUT_DIR, "samples_noise.npy"), samples_noise)

### Training

In [None]:
history = gan.fit(
	dataset,
	batch_size = BATCH_SIZE,
	epochs = NB_EPOCHS,
	shuffle = True,
	callbacks = [
		Updates(),
		SaveSamples(samples_z, samples_noise),
		SaveModels()
	]
)