In [None]:
# -*- coding: utf-8 -*-

"""
 1) IMPORT CURRENT EXPERIMENT
"""
import time
import numpy as np
from statistics import mean
from experiment import exp
dManager = exp.config.dManager #-- Data Manager Class
model    = exp.config.model #-- Model

"""
 2) INITIALIZE THE TRAINING CLASS
"""
gpu_memory_fraction = exp.config.gpu_memory_fraction
trainer = model.Trainer (
	name         = exp.config.name,
	path2restore = exp.config.path_to_restore, #-- Model state recovey
	file2record_loss = './training_log/45k_trainloss.txt',
	testfile2record_loss = './training_log/45k_valloss.txt',
	model        = model, #-- Imported model
	
	monitoring  = True,
	is_training = True,
	
	gpu_memory_fraction=gpu_memory_fraction)

"""
 3) TRAINING ITERATIONS
"""
start = time.time()
for step in range(exp.config.maxIterations):
	#------
	#-- 0) Booleans of the current iteration
	isValidation = (step+1) % exp.config.validationPeriod == 0
	
	#------
	#-- 1) Extract 2 independent image batches
	tensor_images, tensor_features = dManager.generateBatch()
	tensor_discrimination, _ = dManager.generateBatch()
	#------
	#-- 2) Run a training iteration (VAE and Discriminator are trained in parallel)
	trainer.train(
		tensor_images      = tensor_images,
		tensor_images_disc = tensor_discrimination,
		features = tensor_features)
	
	if step == 0 or (step + 1) % 2000 == 0:
		trainer.save("./training_log/checkpoints", 2000)
		print("step: {}, time: {}".format(step, time.time()-start))
		start = time.time()

	#-- 4) Run a validation iteration
		val_recon_rgb_losses = []
		val_recon_gan_losses = []
		val_gan_losses = []
		for val_step in range( 48800 // 104):
			val_tensor_images, val_tensor_features = dManager.generateValBatch()
			val_tensor_discrimination, _ = dManager.generateValBatch()
			recon_rgb, recon_gan_loss, gan_loss = trainer.validate(
				tensor_images      = val_tensor_images,
				tensor_images_disc = val_tensor_discrimination,
				features = val_tensor_features)
			val_recon_rgb_losses.append(recon_rgb)
			val_recon_gan_losses.append(recon_gan_loss) 
			val_gan_losses.append(gan_loss)
		avg_recon_rgb = sum(val_recon_rgb_losses)/len(val_recon_rgb_losses) 
		avg_recon_gan = sum(val_recon_gan_losses)/len(val_recon_gan_losses)
		avg_gan = sum(val_gan_losses)/len(val_gan_losses)
		if len(val_recon_rgb_losses)==469 & len(val_recon_gan_losses)==469 & len(val_gan_losses)==469:
			trainer._testfile2record_loss.write('{}: VAL RGB RECON LOSS: {:.4f}, RECON GAN LOSS: {:.4f}, GAN LOSS: {:.4f}\n'.format(step, avg_recon_rgb, avg_recon_gan, avg_gan ))
			trainer._testfile2record_loss.flush()

print("Training done.")
