In [1]:
import h5py
import logging
import numpy as np
import os
import random
import signal
import sys
import tensorflow as tf

from datetime import datetime
from glob import glob
from scipy.misc import imread, imresize, imsave
from tqdm import trange, tqdm

In [2]:
datasets = ['maps', 'cityscapes', 'facades', 'edges2handbags', 'edges2shoes']

def read_image(path):
    ''' Seperates input image and ground truth image from a given image path and normalizes it '''

    image = imread(path)
    h, w, c = image.shape
    image_a = image[:, :int(w/2), :].astype(np.float32) / 255.0
    image_b = image[:, int(w/2):, :].astype(np.float32) / 255.0

    # range of pixel values = [-1.0, 1.0]
    image_a = image_a * 2.0 - 1.0
    image_b = image_b * 2.0 - 1.0
    return image_a, image_b

def read_images(base_dir):
    ret = []
    for dir_name in ['train', 'val']:
        data_dir = os.path.join(base_dir, dir_name)
        paths = glob(os.path.join(data_dir, '*.jpg'))

        images_A = []
        images_B = []
        for path in tqdm(paths):
            image_A, image_B = read_image(path)
            if image_A is not None:
                images_A.append(image_A)
                images_B.append(image_B)
        ret.append((dir_name + 'A', images_A))
        ret.append((dir_name + 'B', images_B))
    return ret
    
def store_h5py(base_dir, dir_name, images, image_size):
    f = h5py.File(os.path.join(base_dir, '{}_{}.hy'.format(dir_name, image_size)), 'w')
    for i in range(len(images)):
        grp = f.create_group(str(i))
        if images[i].shape[0] != image_size:
            image = imresize(images[i], (image_size, image_size, 3))
            
            # range of pixel values = [-1.0, 1.0]
            image = image.astype(np.float32) / 255.0
            image = image * 2.0 - 1.0
            grp['image'] = image
        else:
            grp['image'] = images[i]
    f.close()

def convert_h5py(task_name):
    print('Generating h5py file')
    base_dir = os.path.join('datasets', task_name)
    data = read_images(base_dir)
    for dir_name, images in data:
        if images[0].shape[0] == 256:
            store_h5py(base_dir, dir_name, images, 256)
        store_h5py(base_dir, dir_name, images, 128)

def read_h5py(task_name, image_size):
    base_dir = 'datasets/' + task_name
    paths = glob(os.path.join(base_dir, '*_{}.hy'.format(image_size)))
    if len(paths) != 4:
        convert_h5py(task_name)
    ret = []
    for dir_name in ['trainA', 'trainB', 'valA', 'valB']:
        try:
            dataset = h5py.File(os.path.join(base_dir, '{}_{}.hy'.format(dir_name, image_size)), 'r')
        except:
            raise IOError('Dataset is not available. Please try it again')

        images = []
        for id in dataset:
            images.append(dataset[id]['image'].value.astype(np.float32))
        ret.append(images)
    return ret

def get_data(task_name, image_size):
    base_dir = os.path.join('datasets', task_name)

    print('Load data %s' % task_name)
    train_A, train_B, test_A, test_B = \
        read_h5py(task_name, image_size)
    return train_A, train_B, test_A, test_B

In [3]:

class Discriminator(object):
    def __init__(self, name, is_train, norm='instance', activation='leaky', image_size=128):
        logger.info('Init Discriminator %s', name)
        self.name = name
        self._is_train = is_train
        self._norm = norm
        self._activation = activation
        self._reuse = False
        self._image_size = image_size

    def __call__(self, input):
        with tf.variable_scope(self.name, reuse=self._reuse):
            D = conv_block(input, 64, 'C64', 4, 2, self._is_train,
                               self._reuse, norm=None, activation=self._activation)
            D = conv_block(D, 128, 'C128', 4, 2, self._is_train,
                               self._reuse, self._norm, self._activation)
            D = conv_block(D, 256, 'C256', 4, 2, self._is_train,
                               self._reuse, self._norm, self._activation)
            num_layers = 3 if self._image_size == 256 else 1
            for i in range(num_layers):
                D = conv_block(D, 512, 'C512_{}'.format(i), 4, 2, self._is_train,
                                   self._reuse, self._norm, self._activation)
            D = conv_block(D, 1, 'C1', 4, 1, self._is_train,
                               self._reuse, norm=None, activation=None, bias=True)
            D = tf.reduce_mean(D, axis=[1,2,3])

            self._reuse = True
            self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.name)
            return D

In [4]:
class Encoder(object):
    def __init__(self, name, is_train, norm='instance', activation='leaky',
                 image_size=128, latent_dim=8, use_resnet=True):
        logger.info('Init Encoder %s', name)
        self.name = name
        self._is_train = is_train
        self._norm = norm
        self._activation = activation
        self._reuse = False
        self._image_size = image_size
        self._latent_dim = latent_dim
        self._use_resnet = use_resnet

    def __call__(self, input):
        if self._use_resnet:
            return self._resnet(input)
        else:
            return self._convnet(input)

    def _convnet(self, input):
        with tf.variable_scope(self.name, reuse=self._reuse):
            num_filters = [64, 128, 256, 512, 512, 512, 512]
            if self._image_size == 256:
                num_filters.append(512)

            E = input
            for i, n in enumerate(num_filters):
                E = conv_block(E, n, 'C{}_{}'.format(n, i), 4, 2, self._is_train,
                                   self._reuse, norm=self._norm if i else None, activation='leaky')
            E = flatten(E)
            mu = mlp(E, self._latent_dim, 'FC8_mu', self._is_train, self._reuse,
                         norm=None, activation=None)
            log_sigma = mlp(E, self._latent_dim, 'FC8_sigma', self._is_train, self._reuse,
                                norm=None, activation=None)

            z = mu + tf.random_normal(shape=tf.shape(self._latent_dim)) * tf.exp(log_sigma)

            self._reuse = True
            self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.name)
            return z, mu, log_sigma

    def _resnet(self, input):
        with tf.variable_scope(self.name, reuse=self._reuse):
            num_filters = [128, 256, 512, 512]
            if self._image_size == 256:
                num_filters.append(512)

            E = input
            E = conv_block(E, 64, 'C{}_{}'.format(64, 0), 4, 2, self._is_train,
                               self._reuse, norm=None, activation='leaky', bias=True)
            for i, n in enumerate(num_filters):
                E = residual(E, n, 'res{}_{}'.format(n, i + 1), self._is_train,
                                 self._reuse, norm=self._norm, bias=True)
                E = tf.nn.avg_pool(E, [1, 2, 2, 1], [1, 2, 2, 1], 'SAME')
            E = tf.nn.relu(E)
            E = tf.nn.avg_pool(E, [1, 8, 8, 1], [1, 8, 8, 1], 'SAME')
            E = flatten(E)
            mu = mlp(E, self._latent_dim, 'FC8_mu', self._is_train, self._reuse,
                         norm=None, activation=None)
            log_sigma = mlp(E, self._latent_dim, 'FC8_sigma', self._is_train, self._reuse,
                                norm=None, activation=None)

            z = mu + tf.random_normal(shape=tf.shape(self._latent_dim)) * tf.exp(log_sigma)

            self._reuse = True
            self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.name)
            return z, mu, log_sigma

In [5]:
class Generator(object):
    def __init__(self, name, is_train, norm='batch', image_size=128):
        logger.info('Init Generator %s', name)
        self.name = name
        self._is_train = is_train
        self._norm = norm
        self._reuse = False
        self._image_size = image_size

    def __call__(self, input, z):
        with tf.variable_scope(self.name, reuse=self._reuse):
            batch_size = int(input.get_shape()[0])
            latent_dim = int(z.get_shape()[-1])
            num_filters = [64, 128, 256, 512, 512, 512, 512]
            if self._image_size == 256:
                num_filters.append(512)

            layers = []
            G = input
            z = tf.reshape(z, [batch_size, 1, 1, latent_dim])
            z = tf.tile(z, [1, self._image_size, self._image_size, 1])
            G = tf.concat([G, z], axis=3)
            for i, n in enumerate(num_filters):
                G = conv_block(G, n, 'C{}_{}'.format(n, i), 4, 2, self._is_train,
                                self._reuse, norm=self._norm if i else None, activation='leaky')
                layers.append(G)

            layers.pop()
            num_filters.pop()
            num_filters.reverse()

            for i, n in enumerate(num_filters):
                G = deconv_block(G, n, 'CD{}_{}'.format(n, i), 4, 2, self._is_train,
                                self._reuse, norm=self._norm, activation='relu')
                G = tf.concat([G, layers.pop()], axis=3)
            G = deconv_block(G, 3, 'last_layer', 4, 2, self._is_train,
                               self._reuse, norm=None, activation='tanh')

            self._reuse = True
            self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.name)
            return G

In [6]:
def _norm(input, is_train, reuse=True, norm=None):
    assert norm in ['instance', 'batch', None]
    if norm == 'instance':
        with tf.variable_scope('instance_norm', reuse=reuse):
            eps = 1e-5
            mean, sigma = tf.nn.moments(input, [1, 2], keep_dims=True)
            normalized = (input - mean) / (tf.sqrt(sigma) + eps)
            out = normalized

    elif norm == 'batch':
        with tf.variable_scope('batch_norm', reuse=reuse):
            out = tf.contrib.layers.batch_norm(input,
                                               decay=0.99, center=True,
                                               scale=True, is_training=True)
    else:
        out = input

    return out

def _activation(input, activation=None):
    assert activation in ['relu', 'leaky', 'tanh', 'sigmoid', None]
    if activation == 'relu':
        return tf.nn.relu(input)
    elif activation == 'leaky':
        return tf.contrib.keras.layers.LeakyReLU(0.2)(input)
    elif activation == 'tanh':
        return tf.tanh(input)
    elif activation == 'sigmoid':
        return tf.sigmoid(input)
    else:
        return input

def flatten(input):
    return tf.reshape(input, [-1, np.prod(input.get_shape().as_list()[1:])])

def conv2d(input, num_filters, filter_size, stride, reuse=False,
           pad='SAME', dtype=tf.float32, bias=False):
    stride_shape = [1, stride, stride, 1]
    filter_shape = [filter_size, filter_size, input.get_shape()[3], num_filters]

    w = tf.get_variable('w', filter_shape, dtype, tf.random_normal_initializer(0.0, 0.02))
    if pad == 'REFLECT':
        p = (filter_size - 1) // 2
        x = tf.pad(input, [[0,0],[p,p],[p,p],[0,0]], 'REFLECT')
        conv = tf.nn.conv2d(x, w, stride_shape, padding='VALID')
    else:
        assert pad in ['SAME', 'VALID']
        conv = tf.nn.conv2d(input, w, stride_shape, padding=pad)

    if bias:
        b = tf.get_variable('b', [1,1,1,num_filters], initializer=tf.constant_initializer(0.0))
        conv = conv + b
    return conv

def conv2d_transpose(input, num_filters, filter_size, stride, reuse,
                     pad='SAME', dtype=tf.float32):
    assert pad == 'SAME'
    n, h, w, c = input.get_shape().as_list()
    stride_shape = [1, stride, stride, 1]
    filter_shape = [filter_size, filter_size, num_filters, c]
    output_shape = [n, h * stride, w * stride, num_filters]

    w = tf.get_variable('w', filter_shape, dtype, tf.random_normal_initializer(0.0, 0.02))
    deconv = tf.nn.conv2d_transpose(input, w, output_shape, stride_shape, pad)
    return deconv

def mlp(input, out_dim, name, is_train, reuse, norm=None, activation=None,
        dtype=tf.float32, bias=True):
    with tf.variable_scope(name, reuse=reuse):
        _, n = input.get_shape()
        w = tf.get_variable('w', [n, out_dim], dtype, tf.random_normal_initializer(0.0, 0.02))
        out = tf.matmul(input, w)
        if bias:
            b = tf.get_variable('b', [out_dim], initializer=tf.constant_initializer(0.0))
            out = out + b
        out = _activation(out, activation)
        out = _norm(out, is_train, reuse, norm)
        return out

def conv_block(input, num_filters, name, k_size, stride, is_train, reuse, norm,
          activation, pad='SAME', bias=False):
    with tf.variable_scope(name, reuse=reuse):
        out = conv2d(input, num_filters, k_size, stride, reuse, pad, bias=bias)
        out = _norm(out, is_train, reuse, norm)
        out = _activation(out, activation)
        return out

def residual(input, num_filters, name, is_train, reuse, norm, pad='REFLECT',
             bias=False):
    with tf.variable_scope(name, reuse=reuse):
        with tf.variable_scope('res1', reuse=reuse):
            out = conv2d(input, num_filters, 3, 1, reuse, pad, bias=bias)
            out = _norm(out, is_train, reuse, norm)
            out = tf.nn.relu(out)

        with tf.variable_scope('res2', reuse=reuse):
            out = conv2d(out, num_filters, 3, 1, reuse, pad, bias=bias)
            out = _norm(out, is_train, reuse, norm)

        with tf.variable_scope('shortcut', reuse=reuse):
            shortcut = conv2d(input, num_filters, 1, 1, reuse, pad, bias=bias)

        return tf.nn.relu(shortcut + out)

def deconv_block(input, num_filters, name, k_size, stride, is_train, reuse,
                 norm, activation):
    with tf.variable_scope(name, reuse=reuse):
        out = conv2d_transpose(input, num_filters, k_size, stride, reuse)
        out = _norm(out, is_train, reuse, norm)
        out = _activation(out, activation)
        return out

In [7]:
logging.info("Start BicycleGAN")
logger = logging.getLogger('Bicycle-gan')
logger.setLevel(logging.INFO)

def makedirs(path):
    if not os.path.exists(path):
        os.makedirs(path)

In [8]:
from collections import defaultdict
args = defaultdict(dict)
args['train'] = True
args['task'] = 'facades'
args['coeff_gan'] = 1.0
args['coeff_vae'] = 1.0
args['coeff_kl'] = 0.01
args['coeff_reconstruct'] = 10
args['coeff_latent'] = 0.5
args['instance_normalization'] = False
args['log_step'] = 100
args['batch_size'] = 1
args['image_size'] = 256
args['latent_dim'] = 8
args['use_resnet'] = True
args['load_model'] = ''

In [9]:
class BicycleGAN(object):
    def __init__(self, args):
        self._log_step = args['log_step']
        self._batch_size = args['batch_size']
        self._image_size = args['image_size']
        self._latent_dim = args['latent_dim']
        self._coeff_gan = args['coeff_gan']
        self._coeff_vae = args['coeff_vae']
        self._coeff_reconstruct = args['coeff_reconstruct']
        self._coeff_latent = args['coeff_latent']
        self._coeff_kl = args['coeff_kl']
        self._norm = 'instance' if args['instance_normalization'] else 'batch'
        self._use_resnet = args['use_resnet']

        self._augment_size = self._image_size + (30 if self._image_size == 256 else 15)
        self._image_shape = [self._image_size, self._image_size, 3]

        self.is_train = tf.placeholder(tf.bool, name='is_train')
        self.lr = tf.placeholder(tf.float32, name='lr')
        self.global_step = tf.train.get_or_create_global_step(graph=None)

        image_a = self.image_a = \
            tf.placeholder(tf.float32, [self._batch_size] + self._image_shape, name='image_a')
        image_b = self.image_b = \
            tf.placeholder(tf.float32, [self._batch_size] + self._image_shape, name='image_b')
        z = self.z = \
            tf.placeholder(tf.float32, [self._batch_size, self._latent_dim], name='z')

        # Data augmentation
        seed = random.randint(0, 2**31 - 1)
        def augment_image(image):
            image = tf.image.resize_images(image, [self._augment_size, self._augment_size])
            image = tf.random_crop(image, [self._batch_size] + self._image_shape, seed=seed)
            image = tf.map_fn(lambda x: tf.image.random_flip_left_right(x, seed), image)
            return image

        image_a = tf.cond(self.is_train,
                          lambda: augment_image(image_a),
                          lambda: image_a)
        image_b = tf.cond(self.is_train,
                          lambda: augment_image(image_b),
                          lambda: image_b)

        # Generator
        G = Generator('G', is_train=self.is_train,
                      norm=self._norm, image_size=self._image_size)

        # Discriminator
        D = Discriminator('D', is_train=self.is_train,
                          norm=self._norm, activation='leaky',
                          image_size=self._image_size)

        # Encoder
        E = Encoder('E', is_train=self.is_train,
                    norm=self._norm, activation='relu',
                    image_size=self._image_size, latent_dim=self._latent_dim,
                    use_resnet=self._use_resnet)

        # conditional VAE-GAN: B -> z -> B'
        z_encoded, z_encoded_mu, z_encoded_log_sigma = E(image_b)
        image_ab_encoded = G(image_a, z_encoded)

        # conditional Latent Regressor-GAN: z -> B' -> z'
        image_ab = self.image_ab = G(image_a, z)
        z_recon, z_recon_mu, z_recon_log_sigma = E(image_ab)


        # Discriminate real/fake images
        D_real = D(image_b)
        D_fake = D(image_ab)
        D_fake_encoded = D(image_ab_encoded)

        loss_vae_gan = (tf.reduce_mean(tf.squared_difference(D_real, 0.9)) +
            tf.reduce_mean(tf.square(D_fake_encoded)))

        loss_image_cycle = tf.reduce_mean(tf.abs(image_b - image_ab_encoded))

        loss_gan = (tf.reduce_mean(tf.squared_difference(D_real, 0.9)) +
            tf.reduce_mean(tf.square(D_fake)))

        loss_latent_cycle = tf.reduce_mean(tf.abs(z - z_recon))

        loss_kl = -0.5 * tf.reduce_mean(1 + 2 * z_encoded_log_sigma - z_encoded_mu ** 2 -
                                       tf.exp(2 * z_encoded_log_sigma))

        loss = self._coeff_vae * loss_vae_gan - self._coeff_reconstruct * loss_image_cycle + \
            self._coeff_gan * loss_gan - self._coeff_latent * loss_latent_cycle - \
            self._coeff_kl * loss_kl

        # Optimizer
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            self.optimizer_D = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
                                .minimize(loss, var_list=D.var_list, global_step=self.global_step)
            self.optimizer_G = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
                                .minimize(-loss, var_list=G.var_list)
            self.optimizer_E = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) \
                                .minimize(-loss, var_list=E.var_list)

        # Summaries
        self.loss_vae_gan = loss_vae_gan
        self.loss_image_cycle = loss_image_cycle
        self.loss_latent_cycle = loss_latent_cycle
        self.loss_gan = loss_gan
        self.loss_kl = loss_kl
        self.loss = loss

        tf.summary.scalar('loss/vae_gan', loss_vae_gan)
        tf.summary.scalar('loss/image_cycle', loss_image_cycle)
        tf.summary.scalar('loss/latent_cycle', loss_latent_cycle)
        tf.summary.scalar('loss/gan', loss_gan)
        tf.summary.scalar('loss/kl', loss_kl)
        tf.summary.scalar('loss/total', loss)
        tf.summary.scalar('model/D_real', tf.reduce_mean(D_real))
        tf.summary.scalar('model/D_fake', tf.reduce_mean(D_fake))
        tf.summary.scalar('model/D_fake_encoded', tf.reduce_mean(D_fake_encoded))
        tf.summary.scalar('model/lr', self.lr)
        tf.summary.image('image/A', image_a[0:1])
        tf.summary.image('image/B', image_b[0:1])
        tf.summary.image('image/A-B', image_ab[0:1])
        tf.summary.image('image/A-B_encoded', image_ab_encoded[0:1])
        self.summary_op = tf.summary.merge_all()

                
    def train(self, sess, summary_writer, data_A, data_B):
        logger.info('Start training.')
        logger.info('  {} images from A'.format(len(data_A)))
        logger.info('  {} images from B'.format(len(data_B)))

        assert len(data_A) == len(data_B), \
            'Data size mismatch dataA(%d) dataB(%d)' % (len(data_A), len(data_B))
        data_size = len(data_A)
        num_batch = data_size // self._batch_size
        epoch_length = 310
        
        num_initial_iter = 8
        num_decay_iter = 2
        lr = lr_initial = 0.0002
        lr_decay = lr_initial / num_decay_iter
  
        initial_step = sess.run(self.global_step)
        num_global_step = (num_initial_iter + num_decay_iter) * epoch_length
        t = trange(initial_step, num_global_step,
                   total=num_global_step, initial=initial_step)

        for step in t:
            epoch = step // epoch_length
            iter = step % epoch_length

            if epoch > num_initial_iter:
                lr = max(0.0, lr_initial - (epoch - num_initial_iter) * lr_decay)

            if iter == 0:
                data = list(zip(data_A, data_B))
                random.shuffle(data)
                data_A, data_B = zip(*data)

            image_a = np.stack(data_A[iter*self._batch_size:(iter+1)*self._batch_size])
            image_b = np.stack(data_B[iter*self._batch_size:(iter+1)*self._batch_size])
            sample_z = np.random.normal(size=(self._batch_size, self._latent_dim))

            fetches = [self.loss, self.optimizer_D,
                       self.optimizer_G, self.optimizer_E]
            if step % self._log_step == 0:
                fetches += [self.summary_op]

            fetched = sess.run(fetches, feed_dict={self.image_a: image_a,
                                                   self.image_b: image_b,
                                                   self.is_train: True,
                                                   self.lr: lr,
                                                   self.z: sample_z})

            if step % self._log_step == 0:
                z = np.random.normal(size=(1, self._latent_dim))
                image_ab = sess.run(self.image_ab, feed_dict={self.image_a: image_a,
                                                            self.z: z,
                                                            self.is_train: False})
                imsave('results/r_{}.jpg'.format(step), np.squeeze(image_ab, axis=0))

                summary_writer.add_summary(fetched[-1], step)
                summary_writer.flush()
                t.set_description('Loss({:.3f})'.format(fetched[0]))

    def test(self, sess, data_A, data_B, base_dir):
        step = 0
        for (dataA, dataB) in tqdm(zip(data_A, data_B)):
            step += 1
            image_a = np.expand_dims(dataA, axis=0)
            image_b = np.expand_dims(dataB, axis=0)
            images_random = []
            images_random.append(image_a)
            images_random.append(image_b)
            images_linear = []
            images_linear.append(image_a)
            images_linear.append(image_b)

            z = np.random.normal(size=(1, self._latent_dim))
            image_ab = sess.run(self.image_ab, feed_dict={self.image_a: image_a,
                                                         self.z: z,
                                                         self.is_train: False})
            images_random.append(image_ab)
            
            z = np.zeros((1, self._latent_dim))
            image_ab = sess.run(self.image_ab, feed_dict={self.image_a: image_a,
                                                         self.z: z,
                                                         self.is_train: False})
            images_linear.append(image_ab)
            
            images = np.concatenate(images_random, axis=1)
            images = np.squeeze(images, axis=0)        
            imsave(os.path.join(base_dir, 'random_{}.jpg'.format(step)), images)
            
            images = np.concatenate(images_linear, axis=1)
            images = np.squeeze(images, axis=0)
            imsave(os.path.join(base_dir, 'linear_{}.jpg'.format(step)), images)
            

In [None]:
def str2bool(v):
    return v.lower() == 'true'

class FastSaver(tf.train.Saver):
    def save(self, sess, save_path, global_step=None, latest_filename=None,
             meta_graph_suffix="meta", write_meta_graph=True):
        super(FastSaver, self).save(sess, save_path, global_step, latest_filename,
                                    meta_graph_suffix, False)


def run(args):
    logger.info('Read data:')
    train_A, train_B, test_A, test_B = get_data(args['task'], args['image_size'])
	
    tempA = train_A
    train_A = train_B
    train_B = tempA
	
    tempB = test_A
    test_A = test_B
    test_B = tempB
	
    logger.info('Build graph:')
    model = BicycleGAN(args)

    variables_to_save = tf.global_variables()
    init_op = tf.variables_initializer(variables_to_save)
    init_all_op = tf.global_variables_initializer()
    saver = FastSaver(variables_to_save)

    logger.info('Trainable vars:')
    var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                 tf.get_variable_scope().name)
    for v in var_list:
        logger.info('  %s %s', v.name, v.get_shape())

    if args['load_model'] != '':
        model_name = args['load_model']
    else:
        model_name = '{}_{}'.format(args['task'], datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    logdir = './logs'
    makedirs(logdir)
    logdir = os.path.join(logdir, model_name)
    logger.info('Events directory: %s', logdir)
    summary_writer = tf.summary.FileWriter(logdir)

    makedirs('./results')

    def init_fn(sess):
        logger.info('Initializing all parameters.')
        sess.run(init_all_op)

    sv = tf.train.Supervisor(is_chief=True,
                             logdir=logdir,
                             saver=saver,
                             summary_op=None,
                             init_op=init_op,
                             init_fn=init_fn,
                             summary_writer=summary_writer,
                             ready_op=tf.report_uninitialized_variables(variables_to_save),
                             global_step=model.global_step,
                             save_model_secs=300,
                             save_summaries_secs=30)

    if args['train']:
        logger.info("Starting training session.")
        with sv.managed_session() as sess:
            model.train(sess, summary_writer, train_A, train_B)

    logger.info("Starting testing session.")
    with sv.managed_session() as sess:
        base_dir = os.path.join('results', model_name)
        makedirs(base_dir)
        model.test(sess, test_A, test_B, base_dir)

def main():
    
    def shutdown(signal, frame):
        tf.logging.warn('Received signal %s: exiting', signal)
        sys.exit(128+signal)
    signal.signal(signal.SIGINT, shutdown)
    signal.signal(signal.SIGTERM, shutdown)

    run(args)

if __name__ == "__main__":
    main()