In [0]:
# !wget https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip
# !unzip ngrok-stable-linux-amd64.zip


In [1]:
get_ipython().system_raw('tensorboard --logdir ./summaries --host 0.0.0.0 --port 6006 &')
get_ipython().system_raw('./ngrok http 6006 &')
! curl -s http://localhost:4040/api/tunnels | python3 -c \
    "import sys, json; print(json.load(sys.stdin)['tunnels'][0]['public_url'])"

https://c0ea41bb.ngrok.io


# Setup

In [2]:
! git clone https://github.com/Tymyan1/datasets.git
! pip install tensorflow-gpu==2.0.0-alpha0
! pip install tensorflow-addons
! pip install pydrive
! pip install zipfile36

fatal: destination path 'datasets' already exists and is not an empty directory.


# Util

In [0]:
import pathlib
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

def linear_decayed_lr(epoch, lr, total_epochs, non_decayed_epochs=100):
    if epoch < non_decayed_epochs:
        return lr
    return lr * (1 - 1 / (epoch - non_decayed_epochs) * (total_epochs - epoch))

def tau_thres(input_, tau=0.1):
    return np.where(input_ <= tau, 0, input_)

# @tf.function
def get_bg_map(att_map):
    bg = (1 - mapa)
    return bg

# @tf.function
def compose_img(att, fg, bg):
    return (att * fg) + bg

# based on https://github.com/gabrielpierobon/cnnshapes/blob/master/README.md
# given the input img, plots all the activations along the whole network 
def show_flow(model, img):
    # build a model
    layer_outputs = [layer.output for layer in model.layers]
    _model = tf.keras.models.Model(inputs=model.input, outputs=layer_outputs)
    
    activations = _model.predict(img)
    
    path = 'layer_imgs/genA/'
    os.makedirs(path, exist_ok=True)
    
    layer_names = []
    for layer in model.layers:
        layer_names.append(layer.name) # Names of the layers, so you can have them as part of your plot

    images_per_row = 16
    i = 0
    for layer_name, layer_activation in zip(layer_names, activations): # Displays the feature maps
        n_features = layer_activation.shape[-1] # Number of features in the feature map
        size = layer_activation.shape[1] # The feature map has shape (1, size, size, n_features).
        n_cols = n_features // images_per_row # Tiles the activation channels in this matrix
        n_cols = max(n_cols, 1)
        display_grid = np.zeros((size * n_cols, images_per_row * size))
        
        feature_counter = 0
        for col in range(n_cols): # Tiles each filter into a big horizontal grid
            for row in range(images_per_row):
                if feature_counter >=  n_features:
                    break
                
                channel_image = layer_activation[0,
                                                 :, :,
                                                 col * images_per_row + row]
                channel_image -= channel_image.mean() # Post-processes the feature to make it visually palatable
                channel_image /= channel_image.std()
                channel_image *= 64
                channel_image += 128
                channel_image = np.clip(channel_image, 0, 255).astype('uint8')
                display_grid[col * size : (col + 1) * size, # Displays the grid
                             row * size : (row + 1) * size] = channel_image
                feature_counter += 1
        scale = 1. / size
        plt.figure(figsize=(scale * display_grid.shape[1],
                            scale * display_grid.shape[0]))
        plt.title(layer_name)
        plt.grid(False)
        
       
        # plt.imshow(display_grid, aspect='auto', cmap='viridis')
        plt.savefig('{}/{:03d}_{}.png'.format(path, i, layer_name), aspect='auto')#, cmap='viridis')
        i += 1
    print('done')

# Data Loader

In [0]:
from glob import glob
import numpy as np
import tensorflow as tf

class Data_Loader():
    ROOT_PATH = './datasets/'

    def __init__(self, name, batch_size, img_shape=(256,256,3), patch=None):
        self.name = name
        self.batch_size = batch_size
        self.img_dims = (patch + (3,)) if patch else img_shape

        # load in the data
        pathsA = glob(Data_Loader.ROOT_PATH + name + '/trainA/*')
        pathsB = glob(Data_Loader.ROOT_PATH + name + '/trainB/*')
        countA = len(pathsA)
        countB = len(pathsB)
        self.n_batches = int(min(countA, countB) / self.batch_size)
        
        self.dsA = tf.data.Dataset.from_tensor_slices(pathsA)
        self.dsB = tf.data.Dataset.from_tensor_slices(pathsB)

        self.dsA = self.dsA.map(lambda img: _load_and_preprocess_image(img, img_dims=[self.img_dims[0], self.img_dims[1]]), num_parallel_calls=tf.data.experimental.AUTOTUNE)
        self.dsB = self.dsB.map(lambda img: _load_and_preprocess_image(img, img_dims=[self.img_dims[0], self.img_dims[1]]), num_parallel_calls=tf.data.experimental.AUTOTUNE)
        
        self.dsA = self.dsA.apply(tf.data.experimental.shuffle_and_repeat(buffer_size=countA))
        self.dsB = self.dsB.apply(tf.data.experimental.shuffle_and_repeat(buffer_size=countB))
        
        self.dsA = self.dsA.batch(batch_size)
        self.dsB = self.dsB.batch(batch_size)
        
        self.dsA = self.dsA.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
        self.dsB = self.dsB.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
        
        self.itA = iter(self.dsA)
        self.itB = iter(self.dsB)
        
        self.it_counter = 0
        
        samples = {
            'horse2zebra': ('datasets/horse2zebra/testA/n02381460_120.jpg', 'datasets/horse2zebra/testB/n02391049_1880.jpg'),
#             'sim2larvae': ('datasets/sim2larvae/testA/view3151.png', 'datasets/sim2larvae/testB/img_331_8.png')
            'sim2larvae': ('datasets/sim2larvae/testM/view3151.png', 'datasets/sim2larvae/testB/img_331_8.png')
        }

        self.samples = np.expand_dims(_load_and_preprocess_image(samples[name][0], [self.img_dims[0], self.img_dims[1]]), axis=0), \
                       np.expand_dims(_load_and_preprocess_image(samples[name][1], [self.img_dims[0], self.img_dims[1]]), axis=0)
        
        
    def load_batch(self):
        self.it_counter += 1
        return next(self.itA), next(self.itB)

    def batches_left(self):
        return self.it_counter % self.n_batches
    
    def sample_batch(self):
        return self.samples
    
def _load_and_preprocess_image(path, img_dims):
    image = tf.io.read_file(path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, img_dims)
    image = (image / 127.5) - 1
    return image    

# Layers


In [0]:
import numpy as np
import tensorflow as tf

class Pad(tf.keras.layers.Layer):

    def __init__(self, paddings, mode='CONSTANT', constant_values=0, **kwargs):
        super(Pad, self).__init__(**kwargs)
        self.paddings = paddings
        self.mode = mode
        self.constant_values = constant_values

    def call(self, inputs):
        return tf.pad(inputs, self.paddings, mode=self.mode, constant_values=self.constant_values)


class TauThreshold(tf.keras.layers.Layer):
    def __init__(self, tau=0.1, **kwargs):
        super(TauThreshold, self).__init__(**kwargs)
        self.tau = tau

    def call(self, input_):
        zeros = tf.zeros_like(input_)
        return tf.where(tf.less(input_, self.tau), zeros, input_)
    

class UpSample(tf.keras.layers.Layer):
    def __init__(self, size, **kwargs):
        super(UpSample, self).__init__(**kwargs)
        self.size = size

    def call(self, input_):
        return tf.image.resize(input_, size=self.size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

# Learning Rate Scheduler

In [0]:
import tensorflow as tf
import json

class LinearDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
    """
    Linear learning rate decay down to 0 applied after decay_offset_steps steps
    """
    def __init__(self, lr, total_steps, decay_offset_steps):
        self.lr = lr
        self.total_steps = total_steps
        self.decay_offset_steps = decay_offset_steps
        #1.13 version bellow
#         self.current_learning_rate = tf.Variable(initial_value=lr, trainable=False, dtype=tf.float32)

#     def __call__(self, step):
#         self.current_learning_rate.assign(tf.cond(
#             step >= self.decay_offset_steps,
#             true_fn=lambda: self.lr * (
#                         1 - 1 / (self.total_steps - self.decay_offset_steps) * (step - self.decay_offset_steps)),
#             false_fn=lambda: self.lr
#         ))
#         return self.current_learning_rate

    @tf.function
    def __call__(self, step):
        if step >= self.decay_offset_steps:
            return self.lr * (1 - 1 / (self.total_steps - self.decay_offset_steps) * (step - self.decay_offset_steps))
        return self.lr
    
    def get_config(self):
        return json.dumps({
            'lr': self.lr,
            'total_steps': self.total_steps,
            'decay_offset_steps': self.decay_offset_steps
#             'current_learning_rate': self.current_learning_rate
        })
    


# ItemPool

In [0]:
from tensorflow import stack
import random

class ItemPool(object):
    def __init__(self, size=50):
        self.size = size
        self.queue = []

    def call(self, elem):
        if len(self.queue) < self.size:
            self.queue.append(np.expand_dims(elem), axis=0)
            return elem
        else:
            if random.random() < .5:
                # replace a random element with the new one
                index = random.randint(0, self.size-1)
                tmp = self.queue[index]
                self.queue[index] = np.expand_dims(elem, axis=0)
                return tmp
            else:
                # just return the current element without adding
                return elem


# Generator

In [0]:
import tensorflow as tf
import tensorflow_addons as tfa

def _g_conv_layer(input_, filters, filter_size, norm='instance', strides=2, pad='VALID', relu=True):
#     input_ = tf.cast(input_, tf.float32)
    conv = tf.keras.layers.Conv2D(filters, kernel_size=filter_size, padding=pad, strides=strides)(input_)
    if norm == 'instance':
        conv = tfa.layers.InstanceNormalization()(conv)
    if relu == True:
        conv = tf.keras.layers.ReLU()(conv)
    return conv

def _g_res_block(input_, norm='instance'):
    filters = input_.shape[-1]
    out = Pad([[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT')(input_)
#     out = input_
    out = _g_conv_layer(out, filters, norm=norm, filter_size=3, strides=1, pad='VALID', relu=True)
    out = tf.keras.layers.ReLU()(out)

    out = Pad([[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT')(out)
    out = _g_conv_layer(out, filters, norm=norm, filter_size=3, strides=1, pad='VALID', relu=False)
    out = tf.keras.layers.add([input_, out])
#     out = tf.keras.layers.ReLU()(out) #TODO enable??
    return out

def _g_deconv_layer(input_, filters, filter_size, pad='SAME', norm='instance'):
    size = (input_.shape[1] * 2,) + (input_.shape[2] * 2,)
#     out = UpSample(size=size)(input_)
#     out = Pad([[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT')(out)
#     out = tf.keras.layers.Conv2D(input_.shape[-1], kernel_size=3, padding='VALID', strides=1)(out)
    out = tf.keras.layers.Conv2DTranspose(filters, kernel_size=filter_size, strides=2, padding=pad)(input_)
    if norm == 'instance':
        out = tfa.layers.InstanceNormalization()(out)
    out = tf.keras.layers.ReLU()(out)
    return out


def build_generator(input_shape, name):

    norm = 'instance'
    
    model = input_ = tf.keras.layers.Input(shape=input_shape)
    
    # c7s1-32-R
    model = Pad([[0, 0], [3, 3], [3, 3], [0, 0]], mode='REFLECT')(model)
    model = _g_conv_layer(model, 32, filter_size=7, strides=1, pad='VALID', norm=norm)
    
    # upsampling
    # c3s2-64-R
    model = _g_conv_layer(model, filters=64, filter_size=3, strides=2, pad='SAME', norm=norm)
    # c3s2-128-R
    model = _g_conv_layer(model, filters=128, filter_size=3, strides=2, pad='SAME', norm=norm)
    
    # residual blocks
    # r128 * 9
    for i in range(9):
        model = _g_res_block(model)
    
    
    # downsampling
    # tc64s2
    model = _g_deconv_layer(model, 64, 3, pad='SAME', norm=norm)
    # tc32s2
    model = _g_deconv_layer(model, 32, 3, pad='SAME', norm=norm)
    
    # c3s1-3-T
    model = Pad([[0, 0], [3, 3], [3, 3], [0, 0]], mode='REFLECT')(model)
    model = tf.keras.layers.Conv2D(filters=3, kernel_size=7, strides=1, padding='VALID')(model) # 3 img channels
    model = tf.keras.layers.Activation('tanh')(model)

    return tf.keras.Model(input_, model, name=name)

# Discriminator

In [0]:
import tensorflow as tf
import tensorflow_addons as tfa

def _discriminator_layer(input_, input_nn, filters, filter_size, strides, norm='instance'):
    """
    :param input_:
    :param filters:
    :param filter_size:
    :param norm:
    :return: Two discriminator blocks (with/without normalization layers) with shared layers
    """
#     input_ = tf.cast(input_, tf.float32)
    layer = layer_nn = tf.keras.layers.Conv2D(filters, kernel_size=filter_size, strides=strides, padding='SAME')(input_)
    if norm=='instance':
        layer = tfa.layers.InstanceNormalization()(layer)
    relu = tf.keras.layers.LeakyReLU(alpha=0.2)
    layer = relu(layer)
    layer_nn = relu(layer_nn)

    return layer, layer_nn


def build_discriminator(input_shape, name):
    """
    :param input_shape:
    :return: Two discriminators(with/without normalization) with shared layers
    """
    
    input_ = tf.keras.layers.Input(shape=input_shape)

    # c4s2-64-LR
    dis, dis_nn = _discriminator_layer(input_, input_, 64, filter_size=4, strides=2, norm='instance')
    # c4s2-128-LR
    dis, dis_nn = _discriminator_layer(dis, dis_nn, 128, filter_size=4, strides=2, norm='instance')
    # c4s2-256-LR
    dis, dis_nn = _discriminator_layer(dis, dis_nn, 256, filter_size=4, strides=2, norm='instance')
    # c4s1-512-LR
    dis, dis_nn = _discriminator_layer(dis, dis_nn, 512, filter_size=4, strides=1, norm='instance')
    
    # c4s1-1
    last_layer = tf.keras.layers.Conv2D(1, kernel_size=4, strides=1, padding='SAME')
    dis = last_layer(dis)
    dis_nn = last_layer(dis_nn)

    return tf.keras.Model(input_, dis, name=name), tf.keras.Model(input_, dis_nn, name=name+'_nn')

# Attention

In [0]:
import tensorflow as tf
import tensorflow_addons as tfa

def _a_conv_layer(input_, filters, filter_size, norm='instance', strides=2, pad='SAME', relu=True):
    conv = tf.keras.layers.Conv2D(filters, kernel_size=filter_size, padding=pad, strides=strides)(input_)
    if norm == 'instance':
        conv = tfa.layers.InstanceNormalization()(conv)
    if relu == True:
        conv = tf.keras.layers.ReLU()(conv)
    return conv

def _a_res_block(input_, norm='instance'):
    filters = input_.shape[-1]
   
    out = Pad([[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT')(input_)
#     out = input_
    out = _a_conv_layer(out, filters, norm=norm, filter_size=3, strides=1, pad='VALID', relu=False)
    out = tf.keras.layers.ReLU()(out)

    out = Pad([[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT')(out)
    out = _a_conv_layer(out, filters, norm=norm, filter_size=3, strides=1, pad='VALID', relu=False)
    out = tf.keras.layers.add([input_, out])
#     out = tf.keras.layers.ReLU()(out) #TODO enable??
    return out


def build_attention_net(input_shape, name):
    size = (input_shape[0],) + (input_shape[1],)
    # print(size)
    model = input_ = tf.keras.layers.Input(shape=input_shape)

    # c7s1-32-R
    model = _a_conv_layer(model, filters=32, filter_size=7,  strides=1)
    # c3s2-64-R
    model = _a_conv_layer(model, filters=64, filter_size=3)
    # r64
    model = _a_res_block(model, 3)
    # up2
    model = tf.keras.layers.UpSampling2D(size=2, interpolation='nearest')(model)
#     model = UpSample(size=size)(model)
    # c3s1-64-R
    model = _a_conv_layer(model, strides=1, filters=64, filter_size=3)
    # up2
#     model = tf.keras.layers.UpSampling2D(size=2, interpolation='nearest')(model)
#     model = UpSample(size=size)(model) #TODO enable??
    # c3s1-32-R
    model = _a_conv_layer(model, strides=1, filters=32, filter_size=3) #TODO change to stride=1
    # c7s1-1-S
    model = _a_conv_layer(model, strides=1, filters=1, filter_size=7, relu=False)
    model = tf.keras.activations.sigmoid(model)

    return tf.keras.Model(input_, model, name=name)

# Losses

In [0]:
def cyclic_loss(gen_img, real_img): 
    return tf.reduce_mean(tf.abs(real_img - gen_img))

def adversarial_loss(prediction_on_real, prediction_on_fake, prediction_on_cyclic): 
    return (2 * tf.reduce_mean(tf.math.squared_difference(prediction_on_real, 1))) + \
                tf.reduce_mean(tf.math.squared_difference(prediction_on_fake, 0)) + \
                tf.reduce_mean(tf.math.squared_difference(prediction_on_cyclic, 0))

def generator_adversarial_loss(dis_prediction_of_img):
    return tf.reduce_mean(tf.math.squared_difference(dis_prediction_of_img, 1))



# Model

In [0]:
import sys
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import datetime
import os
from zipfile import ZipFile


class Model2(tf.keras.Model):
    def __init__(self, epochs, data_loader, patch=(70,70), load_from_checkpoint=True, use_att=True, tau=0.1, attention_epochs_threshold=30, lr_d=0.002, lr_g=0.005, beta1=0.5, epoch_decay=100, pool_size=50, train_start=0):
        super(tf.keras.Model, self).__init__()
        self.data = data_loader
        self.weights_path = 'output/' + self.data.name + '/weights.h5'
        self.epochs = epochs
        self.pool_size = pool_size
        self.use_att = use_att
        self.tau = tau
        self.att_epoch_thresh = attention_epochs_threshold
        self.train_start = train_start
        
        # Loss weights
        self.lambda_cyclic = 10.0  # Cycle-consistency loss weight
        self.lambda_id = 0.1 * self.lambda_cyclic  # Identity loss weight
        
#         # google drive authentication
#         auth.authenticate_user()
#         gauth = GoogleAuth()
#         gauth.credentials = GoogleCredentials.get_application_default()

        # Calculate output shape of D (PatchGAN)
        # self.disc_patch = (self.data.img_dims[0], self.data.img_dims[1], 1)
        self.disc_patch = (32,32,1) #TODO this is just a quickfix

        # build models
        # print(self.data.img_dims)
        self.disA, self.disA_no_norm = build_discriminator(self.data.img_dims, 'disA')
        self.disB, self.disB_no_norm = build_discriminator(self.data.img_dims, 'disB')
      
        self.genA = build_generator(self.data.img_dims, 'genA')
        self.genB = build_generator(self.data.img_dims, 'genB')
        self.attA = build_attention_net(self.data.img_dims, 'attA')
        self.attB = build_attention_net(self.data.img_dims, 'attB')
        

        # lr schedulers
        _total_steps = self.data.batches_left() * self.epochs
        _decay_offset_steps = epoch_decay * self.epochs
        lr_sched_d = LinearDecay(lr_d, _total_steps, _decay_offset_steps)
        lr_sched_g = LinearDecay(lr_g, _total_steps, _decay_offset_steps)
        
        # set the optimizers
        self.optimizer_d = tf.keras.optimizers.Adam(learning_rate=lr_sched_d, beta_1=beta1)
        self.optimizer_g = tf.keras.optimizers.Adam(learning_rate=lr_sched_g, beta_1=beta1)
        
        vars_genA_att = np.concatenate([self.genA.trainable_variables, self.attA.trainable_variables])
        vars_genB_att = np.concatenate([self.genB.trainable_variables, self.attB.trainable_variables])
#         self.tr = self.optimizer_g.minimize(total_loss, var_list=vars_genA_att)

        #TODO move?
        if load_from_checkpoint:
            self.load_weights(self.weights_path)

        # inputs
        imgA = tf.keras.layers.Input(shape=self.data.img_dims)
        imgB = tf.keras.layers.Input(shape=self.data.img_dims)

        # get attention maps
        # attnMapA = toZeroThreshold(AttnA(realA))
        attA = TauThreshold(self.tau)(self.attA(imgA))
        attB = TauThreshold(self.tau)(self.attB(imgB))
        
        # fgA = attnMapA * realA
        imgA_fg = imgA * attA
        imgB_fg = imgB * attB
        
        # bgA = (1 - attnMapA) * realA
        imgA_bg = (1 - attA) * imgA
        imgB_bg = (1 - attB) * imgB
        
        # IMAGE TRANSLATION
        # genB = genA2B(fgA) 
        fakeA_fg = self.genA(imgB_fg)
        fakeB_fg = self.genB(imgA_fg)
        
        # fakeB = (attnMapA * genB) + bgA
        fakeB = fakeB_fg * attA + imgA_bg # s'
        fakeA = fakeA_fg * attB + imgB_bg

        # CYCLIC TRANSLATION
        # get attention maps
        # attnMapfakeB = toZeroThreshold(AttnB(fakeB))
        attA_fake = TauThreshold(self.tau)(self.attA(fakeA))
        attB_fake = TauThreshold(self.tau)(self.attB(fakeB))

        # get the foreground
        # fgfakeB = attnMapfakeB * fakeB
        fakeA_fg = fakeA * attA_fake
        fakeB_fg = fakeB * attB_fake

        # get the background
        # bgfakeB = (1 - attnMapfakeB) * fakeB
        cyclicA_bg = (1 - attA_fake) * fakeA
        cyclicB_bg = (1 - attB_fake) * fakeB
        
        # genA_ = genB2A(fgfakeB)
        cyclicA_fg = self.genA(fakeB_fg)
        cyclicB_fg = self.genB(fakeA_fg)
        
        # combine
        # A_ = (attnMapfakeB * genA_) + bgfakeB
        cyclicA = attB_fake * cyclicA_fg + cyclicB_bg # s''
        cyclicB = attA_fake * cyclicB_fg + cyclicA_bg # s''
        
        # compile basic discriminators
        self.disA.compile(loss='mse', optimizer=self.optimizer_d, metrics=['accuracy'])
        self.disB.compile(loss='mse', optimizer=self.optimizer_d, metrics=['accuracy'])
        self.disA_no_norm.compile(loss='mse', optimizer=self.optimizer_d, metrics=['accuracy'])
        self.disB_no_norm.compile(loss='mse', optimizer=self.optimizer_d, metrics=['accuracy'])

        #TODO attention to identity mappings? 
        # Identity mapping of images
        # img1_id = self.gen2(img1)
        # img2_id = self.gen1(img2)

        # combined model only trains generators (and attention)
        self.disA.trainable = False
        self.disB.trainable = False
        self.disA_no_norm.trainable = False
        self.disB_no_norm.trainable = False

        # discriminate the fake images
        # stage 1 - whole image with normalisation
        validityA_stage1 = self.disA(fakeA)
        validityB_stage1 = self.disB(fakeB)
        
        cyclic_valA_stage1 = self.disA(cyclicA)
        cyclic_valB_stage1 = self.disB(cyclicB)

        # use dis without normalisation on the foreground only
        validityA_stage2 = self.disA_no_norm(fakeA_fg)
        validityB_stage2 = self.disB_no_norm(fakeB_fg)

        cyclic_valA_stage2 = self.disA_no_norm(cyclicA_fg)
        cyclic_valB_stage2 = self.disB_no_norm(cyclicB_fg)
        
        # build and compile combined model for stage 1
        self.combined_model1 = tf.keras.Model(inputs=[imgA, imgB],
                                             outputs=[validityA_stage1, validityB_stage1, cyclic_valA_stage1, cyclic_valB_stage1, cyclicA, cyclicB])#, img1_id, img2_id])
        self.combined_model1.compile(optimizer=self.optimizer_g,
                                    loss=['mse', 'mse', 'mse', 'mse', 'mae', 'mae'])#, 'mae', 'mae'],
                                    #loss_weights=[1, 1, self.lambda_cyclic, self.lambda_cyclic])#, self.lambda_id, self.lambda_id])

        # no more attention training in stage 2
        self.attA.trainable = False
        self.attB.trainable = False

        # build and compile combined model for stage 2
        self.combined_model2 = tf.keras.Model(inputs=[imgA, imgB],
                                                    outputs=[validityA_stage2, validityB_stage2, cyclic_valA_stage2, cyclic_valB_stage2, cyclicA, cyclicB])#, img1_id, img2_id])
        self.combined_model2.compile(optimizer=self.optimizer_g,
                                           loss=['mse', 'mse', 'mse', 'mse', 'mae', 'mae'])#, 'mae', 'mae'],
                                           #loss_weights=[1, 1, self.lambda_cyclic, self.lambda_cyclic])#, self.lambda_id, self.lambda_id])
#         self.combined_model1.summary()
#         print('_________-___________')
#         self.combined_model2.summary()



    def train(self, sample_interval=5):
        
        start_time = datetime.datetime.now()

        # fake img pools for dis
        self.fakeA_pool = ItemPool(size=self.pool_size)
        self.fakeB_pool = ItemPool(size=self.pool_size)
        self.fakeA_cyclic_pool = ItemPool(size=self.pool_size)
        self.fakeB_cyclic_pool = ItemPool(size=self.pool_size)
        
        self.fakeA_fg_pool = ItemPool(size=self.pool_size)
        self.fakeB_fg_pool = ItemPool(size=self.pool_size)
        self.fakeA_cyclic_fg_pool = ItemPool(size=self.pool_size)
        self.fakeB_cyclic_fg_pool = ItemPool(size=self.pool_size)
        
        # adversarial loss ground truth
        valid = np.ones((self.data.batch_size,) + self.disc_patch)
        fake = np.zeros((self.data.batch_size,) + self.disc_patch)

        is_after_att_thres = False
       
        avg_loss_disA_real = tf.keras.metrics.Mean(name='loss_disA_real', dtype=tf.float32)
        avg_loss_disA_fake = tf.keras.metrics.Mean(name='loss_disA_fake', dtype=tf.float32)
        avg_loss_disA_cyclic = tf.keras.metrics.Mean(name='loss_disA_cyclic', dtype=tf.float32)
        avg_loss_genA_cyclic = tf.keras.metrics.Mean(name='loss_genA_cyclic', dtype=tf.float32)
        avg_loss_disA_total = tf.keras.metrics.Mean(name='loss_disA_total', dtype=tf.float32)
        
        avg_loss_disB_real = tf.keras.metrics.Mean(name='loss_disB_real', dtype=tf.float32)
        avg_loss_disB_fake = tf.keras.metrics.Mean(name='loss_disB_fake', dtype=tf.float32)
        avg_loss_disB_cyclic = tf.keras.metrics.Mean(name='loss_disB_cyclic', dtype=tf.float32)
        avg_loss_genB_cyclic = tf.keras.metrics.Mean(name='loss_genB_cyclic', dtype=tf.float32)
        avg_loss_disB_total = tf.keras.metrics.Mean(name='loss_disB_total', dtype=tf.float32)
        
        avg_loss_dis_total = tf.keras.metrics.Mean(name='loss_dis_total', dtype=tf.float32)
        avg_loss_gen_total = tf.keras.metrics.Mean(name='avg_loss_dis_total', dtype=tf.float32)
        
        os.makedirs('summaries', exist_ok=True)
        summary_writer = tf.summary.create_file_writer('summaries')
        with summary_writer.as_default():
          for epoch in range(self.train_start+1, self.epochs+1):
              self.cur_epoch = epoch
              if epoch >= self.att_epoch_thresh:
                  is_after_att_thres = True

              for batch_i in range(1, self.data.n_batches+1):
                  imgA, imgB = self.data.load_batch()

                  # GENERATE IMAGES
                  # get attention maps
                  # attnMapA = toZeroThreshold(AttnA(realA))
                  attA = tau_thres(self.attA.predict(imgA), tau=self.tau)
                  attB = tau_thres(self.attB.predict(imgB), tau=self.tau)

                  # fgA = attnMapA * realA
                  imgA_fg = imgA * attA
                  imgB_fg = imgB * attB

                  # bgA = (1 - attnMapA) * realA
                  imgA_bg = (1 - attA) * imgA
                  imgB_bg = (1 - attB) * imgB

                  # IMAGE TRANSLATION
                  # genB = genA2B(fgA) 
                  fakeA_fg = self.genA.predict(imgB_fg)
                  fakeB_fg = self.genB.predict(imgA_fg)

                  # fakeB = (attnMapA * genB) + bgA
                  fakeB = fakeB_fg * attA + imgA_bg # s'
                  fakeA = fakeA_fg * attB + imgB_bg

                  # CYCLIC TRANSLATION
                  # get attention maps
                  # attnMapfakeB = toZeroThreshold(AttnB(fakeB))
                  attA_fake = tau_thres(self.attA.predict(fakeA), tau=self.tau)
                  attB_fake = tau_thres(self.attB.predict(fakeB), tau=self.tau)

                  # get the foreground
                  # fgfakeB = attnMapfakeB * fakeB
                  fakeA_fg = fakeA * attA_fake
                  fakeB_fg = fakeB * attB_fake

                  # get the background
                  # bgfakeB = (1 - attnMapfakeB) * fakeB
                  cyclicA_bg = (1 - attA_fake) * fakeA
                  cyclicB_bg = (1 - attB_fake) * fakeB

                  # genA_ = genB2A(fgfakeB)
                  cyclicA_fg = self.genA.predict(fakeB_fg)
                  cyclicB_fg = self.genB.predict(fakeA_fg)

                  # combine
                  # A_ = (attnMapfakeB * genA_) + bgfakeB
                  cyclicA = attB_fake * cyclicA_fg + cyclicB_bg # s''
                  cyclicB = attA_fake * cyclicB_fg + cyclicA_bg # s''

                  # TRAIN DISCRIMINATORS
                  if is_after_att_thres:
                      # pool management
  #                     fakeA_fg_pool = self.fakeA_fg_pool.call(fakeA_fg)
  #                     fakeB_fg_pool = self.fakeB_fg_pool.call(fakeB_fg)
  #                     cyclicA_fg_pool = self.fakeA_cyclic_fg_pool.call(cyclicA_fg)
  #                     cyclicB_fg_pool = self.fakeB_cyclic_fg_pool.call(cyclicB_fg)

                      # discriminate only on the foreground
                      # real losses
                      real_predA = self.disA_no_norm.train_on_batch(imgA_fg, valid, sample_weight=np.array([2]))
                      real_predB = self.disB_no_norm.train_on_batch(imgB_fg, valid, sample_weight=np.array([2]))


                      # fake losses
                      fake_predA = self.disA_no_norm.train_on_batch(fakeA_fg, fake)
                      fake_predB = self.disB_no_norm.train_on_batch(fakeB_fg, fake)

                      # fake cyclic losses
                      cyclic_fake_predA = self.disA_no_norm.train_on_batch(cyclicA_fg, fake)
                      cyclic_fake_predB = self.disB_no_norm.train_on_batch(cyclicB_fg, fake)
                  else:
                      # pool management
  #                     fakeA_pool = self.fakeA_pool.call(fakeA)
  #                     fakeB_pool = self.fakeB_pool.call(fakeB)
  #                     cyclicA_pool = self.fakeA_cyclic_pool.call(cyclicA)
  #                     cyclicB_pool = self.fakeB_cyclic_pool.call(cyclicB)


                      # real loss
                      real_predA = self.disA.train_on_batch(imgA, valid, sample_weight=np.array([2]))
                      real_predB = self.disB.train_on_batch(imgB, valid, sample_weight=np.array([2]))

                      # fake loss
                      fake_predA = self.disA.train_on_batch(fakeA, fake)
                      fake_predB = self.disB.train_on_batch(fakeB, fake) 

                      # cyclic fake loss
                      cyclic_fake_predA = self.disA.train_on_batch(cyclicA, fake)
                      cyclic_fake_predB = self.disB.train_on_batch(cyclicB, fake)

                  # dis losses total
                  disA_loss = adversarial_loss(real_predA, fake_predA, cyclic_fake_predA)
                  disB_loss = adversarial_loss(real_predB, fake_predB, cyclic_fake_predA)
                  dis_total_loss = disA_loss + disB_loss


                  # TRAIN GENERATORS
                  if is_after_att_thres:
                      gen_loss = self.combined_model2.train_on_batch([imgA, imgB],
                                                                           [valid, valid, valid, valid, imgA, imgB])#, imgs1, imgs2])
                  else:
                      gen_loss = self.combined_model1.train_on_batch([imgA, imgB],
                                                                        [valid, valid, valid, valid, imgA, imgB])#, imgs1, imgs2])

                  avg_loss_disA_real.update_state(real_predA)
                  avg_loss_disB_real.update_state(real_predB)
                  avg_loss_disA_fake.update_state(fake_predA)
                  avg_loss_disB_fake.update_state(fake_predB)
                  avg_loss_disA_cyclic.update_state(cyclic_fake_predA)
                  avg_loss_disB_cyclic.update_state(cyclic_fake_predB)
                  avg_loss_disA_total.update_state(disA_loss)
                  avg_loss_disB_total.update_state(disB_loss)
                  avg_loss_genA_cyclic.update_state(gen_loss[5])
                  avg_loss_genB_cyclic.update_state(gen_loss[6])

                  avg_loss_gen_total.update_state(gen_loss[0])
                  avg_loss_dis_total.update_state(dis_total_loss)
                  
                  
                  elapsed_time = datetime.datetime.now() - start_time
                  if batch_i % 50 == 0:
                      print(
                          "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %05f] time: %s " \
                          % (epoch, self.epochs,
                             batch_i, self.data.n_batches,
                             dis_total_loss,
                             gen_loss[0],
                             elapsed_time))
                      
                      
                      tf.summary.scalar('avg_loss_disA_real', avg_loss_disA_real.result(), step=self.optimizer_g.iterations)
                      tf.summary.scalar('avg_loss_disA_fake', avg_loss_disA_fake.result(), step=self.optimizer_g.iterations)
                      tf.summary.scalar('avg_loss_disA_cyclic', avg_loss_disA_cyclic.result(), step=self.optimizer_g.iterations)
                      tf.summary.scalar('avg_loss_genA_cyclic', avg_loss_genA_cyclic.result(), step=self.optimizer_g.iterations)
                      tf.summary.scalar('avg_loss_disA_total', avg_loss_disA_total.result(), step=self.optimizer_g.iterations)
                      tf.summary.scalar('avg_loss_disB_real', avg_loss_disB_real.result(), step=self.optimizer_g.iterations)
                      tf.summary.scalar('avg_loss_disB_fake', avg_loss_disB_fake.result(), step=self.optimizer_g.iterations)
                      tf.summary.scalar('avg_loss_disB_cyclic', avg_loss_disB_cyclic.result(), step=self.optimizer_g.iterations)
                      tf.summary.scalar('avg_loss_genB_cyclic', avg_loss_genB_cyclic.result(), step=self.optimizer_g.iterations)
                      tf.summary.scalar('avg_loss_disB_total', avg_loss_disB_total.result(), step=self.optimizer_g.iterations)
                      tf.summary.scalar('avg_loss_dis_total', avg_loss_dis_total.result(), step=self.optimizer_g.iterations)
                      tf.summary.scalar('avg_loss_gen_total', avg_loss_gen_total.result(), step=self.optimizer_g.iterations)

                      avg_loss_disA_real.reset_states()
                      avg_loss_disA_fake.reset_states()
                      avg_loss_disA_cyclic.reset_states()
                      avg_loss_genA_cyclic.reset_states()
                      avg_loss_disA_total.reset_states()
                      avg_loss_disB_real.reset_states()
                      avg_loss_disB_fake.reset_states()
                      avg_loss_disB_cyclic.reset_states()
                      avg_loss_genB_cyclic.reset_states()
                      avg_loss_disB_total.reset_states()
                      avg_loss_dis_total.reset_states()
                      avg_loss_gen_total.reset_states()

                  # printing
  #                 print(
  #                     "[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %05f, adv: %05f, recon: %05f] time: %s " \
  #                     % (epoch + 1, self.epochs,
  #                        batch_i + 1, self.data.n_batches,
  #                        dis_total_loss[0], 100 * dis_total_loss[1],
  #                        gen_loss[0],
  #                        np.mean(gen_loss[1:3]),
  #                        np.mean(gen_loss[3:5]),
  #                        # np.mean(gen_loss[5:6]),
  #                        elapsed_time))

                  # If at save interval => save generated image samples
                  if batch_i % 1000 == 0:
                      it = batch_i // 1000
                      self.sample_images(epoch, it)

                      mod_name = 'full_orig'
                      os.makedirs('model/%s/%s' % (self.data.name, mod_name), exist_ok=True)
          #                 save_time = datetime.datetime.now()

                      self.combined_model1.save_weights('model/{}/{}/combined1_{:03d}_{:d}.h5'.format(self.data.name, mod_name, epoch, it))
                      self.combined_model2.save_weights('model/{}/{}/combined2_{:03d}_{:d}.h5'.format(self.data.name, mod_name, epoch, it))
                      self.disA.save_weights('model/{}/{}/disA{:03d}_{:d}.h5'.format(self.data.name, mod_name, epoch, it))
                      self.disB.save_weights('model/{}/{}/disB{:03d}_{:d}.h5'.format(self.data.name, mod_name, epoch, it))
          #                 self.disA_no_norm.save_weights('model/%s/%s/disA_nn%d' % (self.data.name, mod_name, epoch))
          #                 self.disB_no_norm.save_weights('model/%s/%s/disB_nn%d' % (self.data.name, mod_name, epoch))
          #                 print('Saving models took: ' + str(datetime.datetime.now() - save_time))
                      self.save_things_to_drive(epoch, it)


    def load_weights(self, comb1_p, comb2_p, disA_p, disB_p):
        self.combined_model1.load_weights(comb1_p)
        self.combined_model2.load_weights(comb2_p)
        self.disA.load_weights(disA_p)
        self.disB.load_weights(disB_p)
        
        
    def sample_images(self, epoch, it):
#         os.makedirs('images/%s' % self.data.name, exist_ok=True)
        r, c = 2, 9

        imgA, imgB = self.data.sample_batch()

        # attnMapA = toZeroThreshold(AttnA(realA))
        attA = TauThreshold(self.tau)(self.attA.predict(imgA))
        attB = TauThreshold(self.tau)(self.attB.predict(imgB))
        
        # fgA = attnMapA * realA
        imgA_fg = imgA * attA
        imgB_fg = imgB * attB
        
        # bgA = (1 - attnMapA) * realA
        imgA_bg = (1 - attA) * imgA
        imgB_bg = (1 - attB) * imgB
        
        # IMAGE TRANSLATION
        # genB = genA2B(fgA) 
        fakeA_fg = self.genA.predict(imgB_fg)
        fakeB_fg = self.genB.predict(imgA_fg)
        
        # fakeB = (attnMapA * genB) + bgA
        fakeB = fakeB_fg * attA + imgA_bg # s'
        fakeA = fakeA_fg * attB + imgB_bg

        # CYCLIC TRANSLATION
        # get attention maps
        # attnMapfakeB = toZeroThreshold(AttnB(fakeB))
        attA_fake = TauThreshold(self.tau)(self.attA.predict(fakeA))
        attB_fake = TauThreshold(self.tau)(self.attB.predict(fakeB))

        # get the foreground
        # fgfakeB = attnMapfakeB * fakeB
        fakeA_fg = fakeA * attA_fake
        fakeB_fg = fakeB * attB_fake

        # get the background
        # bgfakeB = (1 - attnMapfakeB) * fakeB
        cyclicA_bg = (1 - attA_fake) * fakeA
        cyclicB_bg = (1 - attB_fake) * fakeB
        
        # genA_ = genB2A(fgfakeB)
        cyclicA_fg = self.genA.predict(fakeB_fg)
        cyclicB_fg = self.genB.predict(fakeA_fg)
        
        # combine
        # A_ = (attnMapfakeB * genA_) + bgfakeB
        cyclicA = attB_fake * cyclicA_fg + cyclicB_bg # s''
        cyclicB = attA_fake * cyclicB_fg + cyclicA_bg # s''
        
        attA = 2 * (attA - 0.5)
        attB = 2 * (attB - 0.5)
        attA_fake = 2 * (attA_fake - 0.5)
        attB_fake = 2 * (attB_fake - 0.5)
        
        attA = np.concatenate([attA] * 3, axis=3)
        attB = np.concatenate([attB] * 3, axis=3)
        attA_fake = np.concatenate([attA_fake] * 3, axis=3)
        attB_fake = np.concatenate([attB_fake] * 3, axis=3)
        
#         print('imgA ' +  str(np.amax(imgA)) + " - " + str(np.amin(imgA)))
#         print('attA ' + str(np.amax(attA)) + " - " + str(np.amin(attA)))
#         print('fakeB ' + str(np.amax(fakeB)) + " - " + str(np.amin(fakeB)))
#         print('cyclicA ' + str(np.amax(cyclicA)) + " - " + str(np.amin(cyclicA)))
#         print('imgB ' + str(np.amax(imgB)) + " - " + str(np.amin(imgB)))
#         print('attB ' + str(np.amax(attB)) + " - " + str(np.amin(attB)))
#         print('fakeA ' + str(np.amax(fakeA)) + " - " + str(np.amin(fakeA)))
#         print('cyclicB ' + str(np.amax(cyclicB)) + " - " + str(np.amin(cyclicB)))


        gen_imgs = np.concatenate([imgA, attA, imgA_fg, fakeB_fg, fakeB, attB_fake, fakeB_fg, cyclicA_fg, cyclicA,
                                   imgB, attB, imgB_fg, fakeA_fg, fakeA, attA_fake, fakeA_fg, cyclicB_fg, cyclicB])

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        titles = ['Original', 'Attention', 'Original fg', 'Translated fg', 'Translated', 'Attention', 'Fake fg', 'Cyclic fg', 'Cyclic']
        fig, axs = plt.subplots(r, c, figsize=(64, 64))
        
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i, j].imshow(gen_imgs[cnt])
                axs[i, j].set_title(titles[j])
                axs[i, j].axis('off')
                cnt += 1
        fig.tight_layout()
        fig.savefig("output/{}/att{:0>3d}_{:d}.png".format(self.data.name, epoch, it))
        plt.close()
        plt.clf() 

    def save_things_to_drive(self, epoch, it):
        mod_name = 'full_orig'

        file_paths = ['model/{}/{}/combined1_{:03d}_{:d}.h5'.format(self.data.name, mod_name, epoch, it),
                     'model/{}/{}/combined2_{:03d}_{:d}.h5'.format(self.data.name, mod_name, epoch, it),
                     'model/{}/{}/disA{:03d}_{:d}.h5'.format(self.data.name, mod_name, epoch, it),
                     'model/{}/{}/disB{:03d}_{:d}.h5'.format(self.data.name, mod_name, epoch, it),
                     "output/{}/att{:0>3d}_{:d}.png".format(self.data.name, epoch, it)]
            
        name = 'attGAN_orig_{:d}_{:d}.zip'.format(epoch, it)
        with ZipFile(name,'w') as zip: 
            for file in file_paths: 
                zip.write(file)
                
#         upload = self.drive.CreateFile({"title":name})
#         upload.SetContentFile('/content/' + name)
#         upload.Upload()

# Run

In [0]:
# import os
# import argparse

class Dummy():
    pass

args = Dummy()
args.dataset = 'sim2larvae'
# args.datasets_dir = 'datasets'
args.img_shape = 256
args.patch = None
args.batch_size = 1
args.epochs = 33
args.epoch_decay = 10
args.lr_dis = 0.0002
args.lr_gen = 0.0002
args.beta1 = 0.5
args.att = True
args.att_epochs = 10
args.tau = 0.1
args.pool_size = 50
args.sample_int = 1
args.train_start = 0 # 0 for starting from scratch

# output_dir
if not os.path.exists('./output'):
    os.makedirs('output')

output_dir = './output/' + args.dataset
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# # save settings
# with open(output_dir + '/args.yaml', 'w') as f:
#     yaml.dump(args, f)

data = Data_Loader(img_shape=(args.img_shape, args.img_shape, 3),
                                  name=args.dataset,
                                  patch=args.patch,
                                  batch_size=args.batch_size)
print("setting up...")
model = Model2(data_loader=data,
                    epochs=args.epochs,
                    lr_d=args.lr_dis,
                    lr_g=args.lr_gen,
                    beta1=args.beta1,
                    epoch_decay=args.epoch_decay,
                    pool_size=args.pool_size,
                    load_from_checkpoint=False,
                    use_att=args.att,
                    tau=args.tau,
                    attention_epochs_threshold=args.att_epochs,
                    train_start=args.train_start)

# p_root = 
# model.load_weights()

print("training...")
model.train(sample_interval=args.sample_int)

setting up...
training...


# Tests

In [0]:
# ! zip attGAN_orig_horse.zip output/horse2zebra/*
! rm -r summaries

In [0]:

# path = 'model/horse2zebra/att_resize_conv/'
# # model.load_weights(path + 'combined115', path + 'combined215', path + 'disA15', path + 'disB15')
# model.combined_model1.load_weights('weights1.h5')
genA = model.genA
genB = model.genB
attA = model.attA
attB = model.attB
imgA, imgB = model.data.sample_batch()



show_flow(genA, imgB)



In [0]:
# from pydrive.auth import GoogleAuth
# from pydrive.drive import GoogleDrive
# from google.colab import auth
# from oauth2client.client import GoogleCredentials

# auth.authenticate_user()
# gauth = GoogleAuth()
# gauth.credentials = GoogleCredentials.get_application_default()
# drive = GoogleDrive(gauth)

In [0]:
# gauth = GoogleAuth()
# gauth.CommandLineAuth()
# drive = GoogleDrive(gauth)

In [0]:
! zip -r orig_horse2zebra output/horse2zebra