# 1. Import libraries and modules

In [None]:
import tensorflow as tf
import numpy as np
import argparse
import datetime
from PIL import Image
from tqdm import tqdm
from BatchGenerator import batchgenerator
from StyleTransfer import styletransfer
from Discriminator import discriminator
from Trainer import trainer
from Trainer_with_discriminator import trainer_with_discriminator

# 2. Hyper-parameters and directories.

In [None]:
parser = argparse.ArgumentParser(description='')
parser.add_argument('--save_dir', dest='save_dir', type=str, default='./model_save')
parser.add_argument('--load_ckp_dir', dest='load_ckp_dir', type=str, default='./training_ckps/original_model')
parser.add_argument('--save_ckp_dir', dest='save_ckp_dir', type=str, default='./training_ckps/train')
parser.add_argument('--log_dir', dest='log_dir', type=str, default='./training_log/test')
parser.add_argument('--history_dir', dest='history_dir', type=str, default='./history')

parser.add_argument('--epochs', dest='epochs', type=int, default=80)
parser.add_argument('--img_size', dest='img_size', type=int, default=256)
parser.add_argument('--batch_size', dest='batch_size', type=int, default=8)
parser.add_argument('--learning_rate', dest='learning_rate', type=float, default=1e-4)
parser.add_argument('--learning_rate_decay', dest='learning_rate_decay', type=float, default=5e-5)
parser.add_argument('--max_iteration', dest='max_iteration', type=int, default=2000)

parser.add_argument('--content_loss_weight', dest='content_loss_weight', type=float, default=1)
parser.add_argument('--style_loss_weight', dest='style_loss_weight', type=float, default=10)
parser.add_argument('--tv_loss_weight', dest='tv_loss_weight', type=float, default=0)
parser.add_argument('--gradient_panelty_weight', dest='gradient_panelty_weight', type=float, default=10)

parser.add_argument('--use_discriminator', dest='use_discriminator', type=bool, default=False)
parser.add_argument('--continue_learn', dest='continue_learn', type=bool, default=False)

parser.add_argument('--style_layers', dest='style_layers', type=list, default=['block1_conv1','block2_conv1','block3_conv1','block4_conv1'])
parser.add_argument('--content_layer', dest='content_layer', type=str, default='block4_conv1')
args, unknown = parser.parse_known_args()

# 3. Create folders if it is not exist

In [None]:
import os
if not os.path.exists(args.history_dir): os.makedirs(args.history_dir)
if not os.path.exists(args.log_dir): os.makedirs(args.log_dir)
if not os.path.exists(args.save_ckp_dir): os.makedirs(args.save_ckp_dir)
if not os.path.exists(args.save_dir): os.makedirs(args.save_dir)  

# 4. Load Batch Generator and StyleTransfer

In [None]:
batchgen = batchgenerator('./data/train/', args.batch_size, args.img_size)
styler = styletransfer(args)
disc = discriminator(args) if args.use_discriminator else None

In [None]:
import matplotlib.pyplot as plt
import utils

# Sample training images display
a = utils.MinMax_Scale(batchgen.next_batch('content'))
plt.subplot(221)
plt.title('Content input 1'); plt.imshow(a[0,:,:,:]); plt.axis('off'); plt.colorbar();
plt.subplot(222)
plt.title('Content input 2'); plt.imshow(a[1,:,:,:]); plt.axis('off'); plt.colorbar();

a = utils.MinMax_Scale(batchgen.next_batch('style'))
plt.subplot(223)
plt.title('Style input 1'); plt.imshow(a[0,:,:,:]); plt.axis('off'); plt.colorbar();
plt.subplot(224)
plt.title('Style input 2'); plt.imshow(a[1,:,:,:]); plt.axis('off'); plt.colorbar();

# 5. Train the StyleTransfer

In [None]:
# Load trainer
if not args.use_discriminator:
    MyTrainer = trainer(styler, args)
elif args.use_discriminator:
    MyTrainer = trainer_with_discriminator(styler, disc, args)

# Start to train
MyTrainer.train(batchgen) 

# 6. Test the StyleTransfer

In [None]:
from os.path import join, splitext
import os

# load batch generator for test data.
batchgen_test = batchgenerator('./data/test/', args.batch_size, args.img_size)

# Save folder create
test_save = join(args.history_dir,'test')
if not os.path.exists(test_save): os.makedirs(test_save)

# Inference using all test data
for content_img_name in batchgen_test.content_img_names:
    print('>> Test for the content image ' + content_img_name+'...' )
    # load a content image
    content_batch = tf.expand_dims(batchgen_test.one_test_img('content', name=content_img_name, crop= False), axis=0)

    for style_img_name in batchgen_test.style_img_names:
        # load a style image
        style_batch = tf.expand_dims(batchgen_test.one_test_img('style', name=style_img_name, crop=False), axis=0)
        
        # Get the synthesized image
        output = styler(content_batch, style_batch)[0,:,:,:]

        # save the synthesized image
        content_base = os.path.basename(content_img_name)
        style_base = os.path.basename(style_img_name)
        
        save_name = splitext(content_base)[0] + '_to_' + splitext(style_base)[0] + '.jpg'
        save_img = Image.fromarray(np.uint8(output))
        save_img.save(join(test_save, save_name))

# Image show that last example
plt.figure(figsize=(15,15))
plt.subplot(221)
plt.imshow(utils.MinMax_Scale(content_batch[0,:,:,:]))
plt.subplot(222)
plt.imshow(utils.MinMax_Scale(style_batch[0,:,:,:]))
plt.subplot(2,2,3)
plt.imshow(utils.MinMax_Scale(output))

# 7. Model save

In [None]:
styler.dec.net.save_weights(args.save_ckp_dir+"/decoder/decoder_ckpt")
if args.use_discriminator:
    disc.net.save_weights(args.save_ckp_dir+"/discriminator/discriminator_ckpt")