# Setup

In [1]:
! git clone https://github.com/Tymyan1/datasets.git
! pip install tensorflow-gpu==2.0.0-alpha0
! pip install tensorflow-addons
! pip install zipfile36
# ! pip install pydrive

fatal: destination path 'datasets' already exists and is not an empty directory.


In [0]:
import pathlib
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
import json
from glob import glob
import datetime
import os
import matplotlib.pyplot as plt
from zipfile import ZipFile


# Util

In [0]:
def load_data(path='./datasets/horse2zebra'):
    data_root = pathlib.Path(path)

    all_image_paths1 = list(data_root.glob('testA/*'))
    all_image_paths2 = list(data_root.glob('testB/*'))

    imgs1 = [load_and_pre_process_image(str(path)) for path in all_image_paths1]
    imgs2 = [load_and_pre_process_image(str(path)) for path in all_image_paths2]

    imgs1 = tf.convert_to_tensor(imgs1)
    imgs2 = tf.convert_to_tensor(imgs2)

    return imgs1, imgs2

def linear_decayed_lr(epoch, lr, total_epochs, non_decayed_epochs=100):
    if epoch < non_decayed_epochs:
        return lr
    return lr * (1 - 1 / (epoch - non_decayed_epochs) * (total_epochs - epoch))

def tau_thres(input_, tau=0.1):
    return np.where(input_ <= tau, 0, input_)

class UpSample(tf.keras.layers.Layer):
    def __init__(self, size, **kwargs):
        super(UpSample, self).__init__(**kwargs)
        self.size = size

    def call(self, input_):
        return tf.image.resize(input_, size=self.size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

class Pad(tf.keras.layers.Layer):

    def __init__(self, paddings, mode='CONSTANT', constant_values=0, **kwargs):
        super(Pad, self).__init__(**kwargs)
        self.paddings = paddings
        self.mode = mode
        self.constant_values = constant_values

    def call(self, inputs):
        return tf.pad(inputs, self.paddings, mode=self.mode, constant_values=self.constant_values)
    
class LinearDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
    """
    Linear learning rate decay down to 0 applied after decay_offset_steps steps
    """
    def __init__(self, lr, total_steps, decay_offset_steps):
        self.lr = lr
        self.total_steps = total_steps
        self.decay_offset_steps = decay_offset_steps
        #1.13 version bellow
#         self.current_learning_rate = tf.Variable(initial_value=lr, trainable=False, dtype=tf.float32)

#     def __call__(self, step):
#         self.current_learning_rate.assign(tf.cond(
#             step >= self.decay_offset_steps,
#             true_fn=lambda: self.lr * (
#                         1 - 1 / (self.total_steps - self.decay_offset_steps) * (step - self.decay_offset_steps)),
#             false_fn=lambda: self.lr
#         ))
#         return self.current_learning_rate

    @tf.function
    def __call__(self, step):
        if step >= self.decay_offset_steps:
            return self.lr * (1 - 1 / (self.total_steps - self.decay_offset_steps) * (step - self.decay_offset_steps))
        return self.lr
    
    def get_config(self):
        return json.dumps({
            'lr': self.lr,
            'total_steps': self.total_steps,
            'decay_offset_steps': self.decay_offset_steps
#             'current_learning_rate': self.current_learning_rate
        })
    


# Data loader

In [0]:
class Data_Loader():
    ROOT_PATH = './datasets/'

    def __init__(self, name, batch_size, img_shape=(256,256,3), patch=None):
        self.name = name
        self.batch_size = batch_size
        self.img_dims = (patch + (3,)) if patch else img_shape

        # load in the data
        pathsA = glob(Data_Loader.ROOT_PATH + name + '/trainM/*')
        pathsB = glob(Data_Loader.ROOT_PATH + name + '/trainB/*')
        countA = len(pathsA)
        countB = len(pathsB)
        self.n_batches = int(min(countA, countB) / self.batch_size)
        
        self.dsA = tf.data.Dataset.from_tensor_slices(pathsA)
        self.dsB = tf.data.Dataset.from_tensor_slices(pathsB)
        
        self.dsA = self.dsA.map(lambda img: _load_and_preprocess_image(img, img_dims=[self.img_dims[0], self.img_dims[1]]), num_parallel_calls=tf.data.experimental.AUTOTUNE)
        self.dsB = self.dsB.map(lambda img: _load_and_preprocess_image(img, img_dims=[self.img_dims[0], self.img_dims[1]]), num_parallel_calls=tf.data.experimental.AUTOTUNE)
        
        self.dsA = self.dsA.apply(tf.data.experimental.shuffle_and_repeat(buffer_size=countA))
        self.dsB = self.dsB.apply(tf.data.experimental.shuffle_and_repeat(buffer_size=countB))
        
        self.dsA = self.dsA.batch(batch_size)
        self.dsB = self.dsB.batch(batch_size)
        
        self.dsA = self.dsA.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
        self.dsB = self.dsB.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
        
        self.itA = iter(self.dsA)
        self.itB = iter(self.dsB)
        
        self.it_counter = 0
        
    def load_batch(self):
        self.it_counter += 1
        return next(self.itA), next(self.itB)

    def batches_left(self):
        return self.it_counter % self.n_batches
    
def _load_and_preprocess_image(path, img_dims):
    image = tf.io.read_file(path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, img_dims)
    image = (image / 127.5) - 1
    return image    

# Generator

In [0]:
def _g_conv_layer(input_, filters, filter_size, norm='instance', strides=2, pad='VALID', relu=True):
#     input_ = tf.cast(input_, tf.float32)
    conv = tf.keras.layers.Conv2D(filters, kernel_size=filter_size, padding=pad, strides=strides)(input_)
    if norm == 'instance':
        conv = tfa.layers.InstanceNormalization()(conv)
    if relu == True:
        conv = tf.keras.layers.ReLU()(conv)
    return conv

def _g_res_block(input_, norm='instance'):
    filters = input_.shape[-1]
    out = Pad([[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT')(input_)
    out = _g_conv_layer(out, filters, norm=norm, filter_size=3, strides=1, pad='VALID', relu=True)
    out = tf.keras.layers.ReLU()(out)

    out = Pad([[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT')(out)
    out = _g_conv_layer(out, filters, norm=norm, filter_size=3, strides=1, pad='VALID', relu=False)
    out = tf.keras.layers.add([input_, out])
    out = tf.keras.layers.ReLU()(out)
    return out

def _g_deconv_layer(input_, filters, filter_size, pad='SAME', norm='instance'):
    out = tf.keras.layers.Conv2DTranspose(filters, kernel_size=filter_size, strides=2, padding=pad)(input_)
    if norm == 'instance':
        out = tfa.layers.InstanceNormalization()(out)
    out = tf.keras.layers.ReLU()(out)
    return out


def build_generator(input_shape, name):

    model = input_ = tf.keras.layers.Input(shape=input_shape)
    
    # c7s1-32-R
    model = Pad([[0, 0], [3, 3], [3, 3], [0, 0]], mode='REFLECT')(model)
    model = _g_conv_layer(model, 32, filter_size=7, strides=1)
    
    # upsampling
    # c3s2-64-R
    model = _g_conv_layer(model, filters=64, filter_size=3, strides=2, pad='SAME')
    # c3s2-128-R
    model = _g_conv_layer(model, filters=128, filter_size=3, strides=2, pad='SAME')
    
    # residual blocks
    # r128 * 9
    for i in range(9):
        model = _g_res_block(model)
    
    
    # downsampling
    # tc64s2
    model = _g_deconv_layer(model, 64, 3, pad='SAME')
    # tc32s2
    model = _g_deconv_layer(model, 32, 3, pad='SAME')
    
    # c3s1-3-T
#     model = Pad([[0, 0], [3, 3], [3, 3], [0, 0]], mode='REFLECT')(model)
    model = tf.keras.layers.Conv2D(filters=3, kernel_size=3, strides=1, padding='SAME')(model) # 3 img channels
    model = tf.keras.layers.Activation('tanh')(model)

    return tf.keras.Model(input_, model, name=name)

In [0]:
def _g_conv_layer2(input_, filters, filter_size, norm='instance', strides=2, pad='VALID', relu=True):
#     input_ = tf.cast(input_, tf.float32)
    conv = tf.keras.layers.Conv2D(filters, kernel_size=filter_size, padding=pad, strides=strides)(input_)
    if norm == 'instance':
        conv = tfa.layers.InstanceNormalization()(conv)
    if relu == True:
        conv = tf.keras.layers.ReLU()(conv)
    return conv

def _g_res_block2(input_, norm='instance'):
    filters = input_.shape[-1]
    out = Pad([[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT')(input_)
    out = _g_conv_layer2(out, filters, norm=norm, filter_size=3, strides=1, pad='VALID', relu=True)
    out = tf.keras.layers.ReLU()(out)

    out = Pad([[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT')(out)
    out = _g_conv_layer2(out, filters, norm=norm, filter_size=3, strides=1, pad='VALID', relu=False)
    out = tf.keras.layers.add([input_, out])
    out = tf.keras.layers.ReLU()(out)
    return out

def _g_deconv_layer2(input_, filters, filter_size, pad='SAME', norm='instance'):
    out = tf.keras.layers.Conv2DTranspose(filters, kernel_size=filter_size, strides=2, padding=pad)(input_)
    if norm == 'instance':
        out = tfa.layers.InstanceNormalization()(out)
    out = tf.keras.layers.ReLU()(out)
    return out


def build_generator_orig(input_shape, name):

    model = input_ = tf.keras.layers.Input(shape=input_shape)
    
    # c7s1-64-R
    model = Pad([[0, 0], [3, 3], [3, 3], [0, 0]], mode='REFLECT')(model)
    model = _g_conv_layer(model, 64, filter_size=7, strides=1)
    
    # upsampling
    # c3s2-128-R
    model = _g_conv_layer2(model, filters=64, filter_size=3, strides=2, pad='SAME')
    # c3s2-256-R
    model = _g_conv_layer2(model, filters=128, filter_size=3, strides=2, pad='SAME')
    
    # residual blocks
    # r256 * 9
    for i in range(9):
        model = _g_res_block2(model)
    
    
    # downsampling
    # tc128s2
    model = _g_deconv_layer2(model, 128, 3, pad='SAME')
    # tc64s2
    model = _g_deconv_layer2(model, 64, 3, pad='SAME')
    
    # c3s1-3-T
    model = Pad([[0, 0], [3, 3], [3, 3], [0, 0]], mode='REFLECT')(model)
    model = tf.keras.layers.Conv2D(filters=3, kernel_size=7, strides=1, padding='VALID')(model) # 3 img channels
    model = tf.keras.layers.Activation('tanh')(model)

    return tf.keras.Model(input_, model, name=name)

# Discriminator

In [0]:
def _discriminator_layer(input_, input_nn, filters, filter_size, strides, norm='instance'):
    """
    :param input_:
    :param filters:
    :param filter_size:
    :param norm:
    :return: Two discriminator blocks (with/without normalization layers) with shared layers
    """
    input_ = tf.cast(input_, tf.float32)
    layer = layer_nn = tf.keras.layers.Conv2D(filters, kernel_size=filter_size, strides=strides, padding='SAME')(input_)
    if norm=='instance':
        layer = tfa.layers.InstanceNormalization()(layer)
    relu = tf.keras.layers.LeakyReLU(alpha=0.2)
    layer = relu(layer)
    layer_nn = relu(layer_nn)

    return layer, layer_nn


def build_discriminator(input_shape, name):
    """
    :param input_shape:
    :return: Two discriminators(with/without normalization layers) with shared layers
    """
    
    input_ = tf.keras.layers.Input(shape=input_shape)

    # c4s2-64-LR
    dis, dis_nn = _discriminator_layer(input_, input_, 64, filter_size=4, strides=2, norm='instance')
    # c4s2-128-LR
    dis, dis_nn = _discriminator_layer(dis, dis_nn, 128, filter_size=4, strides=2, norm='instance')
    # c4s2-256-LR
    dis, dis_nn = _discriminator_layer(dis, dis_nn, 256, filter_size=4, strides=2, norm='instance')
    # c4s1-512-LR
    dis, dis_nn = _discriminator_layer(dis, dis_nn, 512, filter_size=4, strides=1, norm='instance')
    
    # c4s1-1
    last_layer = tf.keras.layers.Conv2D(1, kernel_size=4, strides=1, padding='SAME')
    dis = last_layer(dis)
    dis_nn = last_layer(dis_nn)

    return tf.keras.Model(input_, dis, name=name), tf.keras.Model(input_, dis_nn, name=name+'_nn')

In [0]:
def _discriminator_layer2(input_, filters, filter_size, strides, norm='instance'):
    """
    :param input_:
    :param filters:
    :param filter_size:
    :param norm:
    :return: Two discriminator blocks (with/without normalization layers) with shared layers
    """
    input_ = tf.cast(input_, tf.float32)
    layer = layer_nn = tf.keras.layers.Conv2D(filters, kernel_size=filter_size, strides=strides, padding='SAME')(input_)
    if norm=='instance':
        layer = tfa.layers.InstanceNormalization()(layer)
    relu = tf.keras.layers.LeakyReLU(alpha=0.2)
    layer = relu(layer)
    layer_nn = relu(layer_nn)

    return layer


def build_discriminator_orig(input_shape, name):
    """
    :param input_shape:
    :return: Two discriminators(with/without normalization layers) with shared layers
    """
    
    input_ = tf.keras.layers.Input(shape=input_shape)

    # c4s2-64-LR
    dis = _discriminator_layer2(input_, 64, filter_size=4, strides=2, norm='instance')
    # c4s2-128-LR
    dis = _discriminator_layer2(dis, 128, filter_size=4, strides=2, norm='instance')
    # c4s2-256-LR
    dis = _discriminator_layer2(dis, 256, filter_size=4, strides=2, norm='instance')
    # c4s1-512-LR
    dis = _discriminator_layer2(dis, 512, filter_size=4, strides=1, norm='instance')
    
    # c4s1-1
    dis = tf.keras.layers.Conv2D(1, kernel_size=4, strides=1, padding='SAME')(dis)


    return tf.keras.Model(input_, dis, name=name)

# Losses

In [0]:
def adversarial_loss(prediction_on_real, prediction_on_fake): 
    return tf.reduce_mean(tf.math.squared_difference(prediction_on_real, 1)) + \
                tf.reduce_mean(tf.math.squared_difference(prediction_on_fake, 0))


# Model

In [0]:
class Model(tf.keras.models.Model):
    
    def __init__(self, data_loader, epochs=200, lr_d=2e-4, lr_g=2e-4, epoch_decay=100, beta1=0.5, start=0):
        super(Model, self).__init__()
        self.lambda_cyclic = 10.0  # Cycle-consistency loss weight
        self.lambda_id = 0.1 * self.lambda_cyclic  # Identity loss weight
        self.data = data_loader
        self.weights_path = 'output/' + self.data.name + '/weights.h5'
        self.epochs = epochs
        self.disc_patch = (32,32,1) 
        self.start = start
        

        # build models
        # print(self.data.img_dims)
        self.disA = build_discriminator_orig(self.data.img_dims, 'disA')
        self.disB = build_discriminator_orig(self.data.img_dims, 'disB')
      
        self.genA = build_generator(self.data.img_dims, 'genA')
        self.genB = build_generator(self.data.img_dims, 'genB')

        # lr schedulers
        _total_steps = self.data.batches_left() * self.epochs
        _decay_offset_steps = epoch_decay * self.epochs
        lr_sched_d = LinearDecay(lr_d, _total_steps, _decay_offset_steps)
        lr_sched_g = LinearDecay(lr_g, _total_steps, _decay_offset_steps)
        
        # set the optimizers
        self.optimizer_d = tf.keras.optimizers.Adam(learning_rate=lr_sched_d, beta_1=beta1)
        self.optimizer_g = tf.keras.optimizers.Adam(learning_rate=lr_sched_g, beta_1=beta1)
        
         # inputs
        imgA = tf.keras.layers.Input(shape=self.data.img_dims)
        imgB = tf.keras.layers.Input(shape=self.data.img_dims)

        fakeA = self.genA(imgB)
        fakeB = self.genB(imgA)
        
        cyclicA = self.genA(fakeB)
        cyclicB = self.genB(fakeA)
        
        # Identity mapping of images
        imgA_id = self.genB(imgA)
        imgB_id = self.genA(imgB)
        
        self.disA.compile(loss='mse', optimizer=self.optimizer_d, metrics=['accuracy'])
        self.disB.compile(loss='mse', optimizer=self.optimizer_d, metrics=['accuracy'])
        
        validityA = self.disA(fakeA)
        validityB = self.disB(fakeB)
        
        # build and compile combined model
        self.combined_model = tf.keras.Model(inputs=[imgA, imgB],
                                             outputs=[validityA, validityB, cyclicA, cyclicB, imgA_id, imgB_id])
        
        self.combined_model.compile(optimizer=self.optimizer_g,
                                           loss=['mse', 'mse', 'mae', 'mae', 'mae', 'mae'],
                                           loss_weights=[1, 1, self.lambda_cyclic, self.lambda_cyclic, self.lambda_id, self.lambda_id])
        
    
    def train(self):
        os.makedirs('output/%s' % self.data.name, exist_ok=True)
        os.makedirs('log/%s' % self.data.name, exist_ok=True)

        # adversarial loss ground truth
        valid = np.ones((self.data.batch_size,) + self.disc_patch)
        fake = np.zeros((self.data.batch_size,) + self.disc_patch)
        
        start_time = datetime.datetime.now()
        for epoch in range(self.start + 1, self.epochs+1):
            self.cur_epoch = epoch
            for batch_i in range(1, self.data.n_batches+1):
                
                imgsA, imgsB = self.data.load_batch()

                disA_loss_real = self.disA.train_on_batch(imgsA, valid)
                disB_loss_real = self.disB.train_on_batch(imgsB, valid)
                
                fakeA = self.genA.predict(imgsB)
                fakeB = self.genB.predict(imgsA)
                
                disA_loss_fake = self.disA.train_on_batch(fakeA, fake)
                disB_loss_fake = self.disB.train_on_batch(fakeB, fake)
                
                gen_losses = self.combined_model.train_on_batch([imgsA, imgsB], [valid, valid, imgsA, imgsB, imgsA, imgsB])

                elapsed_time = datetime.datetime.now() - start_time

                if batch_i % 50 == 0:
                    print(
                        "[Epoch %d/%d] [Batch %d/%d] time: %s " \
                        % (epoch, self.epochs,
                           batch_i, self.data.n_batches,
                           elapsed_time))
            
            self.sample(imgsA, imgsB)

            os.makedirs('model/%s/%s' % (self.data.name, 'basic'), exist_ok=True)
            self.combined_model.save('model/{}/basic/combined_basic{:03d}'.format(self.data.name, epoch))
            self.disA.save('model/{}/basic/disA{:03d}'.format(self.data.name, epoch))
            self.disB.save('model/{}/basic/disB{:03d}'.format(self.data.name, epoch))
    
            self.save_things_to_drive(epoch)
        
    def sample(self, imgsA, imgsB):
        fakeA = self.genA.predict(imgsB)
        fakeB = self.genB.predict(imgsA)
        
        cyclicA = self.genA.predict(fakeB)
        cyclicB = self.genB.predict(fakeA)
        
        gen_imgs = np.concatenate([imgsA, fakeB, cyclicA, imgsB, fakeA, cyclicB])
        gen_imgs = gen_imgs / 2 + 0.5
        
        r = 2
        c = 3
        titles = ['Original', 'Translated', 'Cyclic']
        fig, axs = plt.subplots(r, c, figsize=(64, 64))
        fig.tight_layout()
        
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i, j].imshow(gen_imgs[cnt])
                axs[i, j].set_title(titles[j])
                axs[i, j].axis('off')
                cnt += 1
        
        fig.savefig("output/{}/basic{:03d}.png".format(self.data.name, self.cur_epoch))
        plt.close()
        plt.clf()
    
    def load_model(self, root_path, epoch):
        self.disA.load_weights('{}/disA{:03d}'.format(root_path, epoch))
        self.disB.load_weights('{}/disB{:03d}'.format(root_path, epoch))
        self.combined_model.load_weights('{}/combined_basic{:03d}'.format(root_path, epoch))
        
    def save_things_to_drive(self, epoch):
        

        file_paths = ['model/{}/basic/disB{:03d}'.format(self.data.name, epoch),
                     'model/{}/basic/disA{:03d}'.format(self.data.name, epoch),
                     'model/{}/basic/combined_basic{:03d}'.format(self.data.name, epoch),
                     'output/{}/basic{:03d}.png'.format(self.data.name, epoch)]
        
        name = 'GAN' + str(epoch) + '.zip'
        with ZipFile(name,'w') as zip: 
            for file in file_paths: 
                zip.write(file)
                
#         upload = self.drive.CreateFile({"title":name})
#         upload.SetContentFile('/content/' + name)
#         upload.Upload()


# Run


In [0]:
data = Data_Loader('sim2larvae', 1, img_shape=(256,256,3), patch=None)
model = Model(data, epochs=100, lr_d=2e-4, lr_g=2e-4, epoch_decay=50, beta1=0.5)
print('training...')
model.train()


In [0]:
# continue training
# ! unzip GAN20.zip
start = 32
data = Data_Loader('sim2larvae', 1, img_shape=(256,256,3), patch=None)
model = Model(data, epochs=100, lr_d=2e-4, lr_g=2e-4, epoch_decay=50, beta1=0.5, start=start)
print('loading...')
model.load_model('model/sim2larvae/basic', epoch=start)
print('training...')
model.train()

loading...
training...
[Epoch 33/100] [Batch 50/3147] time: 0:01:52.930291 
[Epoch 33/100] [Batch 100/3147] time: 0:02:13.536961 
[Epoch 33/100] [Batch 150/3147] time: 0:02:33.825098 
[Epoch 33/100] [Batch 200/3147] time: 0:02:54.065850 
[Epoch 33/100] [Batch 250/3147] time: 0:03:14.826245 
[Epoch 33/100] [Batch 300/3147] time: 0:03:35.251271 
[Epoch 33/100] [Batch 350/3147] time: 0:03:55.791374 
[Epoch 33/100] [Batch 400/3147] time: 0:04:16.897101 
[Epoch 33/100] [Batch 450/3147] time: 0:04:37.745976 
[Epoch 33/100] [Batch 500/3147] time: 0:04:58.217946 
[Epoch 33/100] [Batch 550/3147] time: 0:05:18.775294 
[Epoch 33/100] [Batch 600/3147] time: 0:05:39.149334 
[Epoch 33/100] [Batch 650/3147] time: 0:05:59.942556 
[Epoch 33/100] [Batch 700/3147] time: 0:06:20.242830 
[Epoch 33/100] [Batch 750/3147] time: 0:06:40.500524 
[Epoch 33/100] [Batch 800/3147] time: 0:07:00.923322 
[Epoch 33/100] [Batch 850/3147] time: 0:07:21.720426 
[Epoch 33/100] [Batch 900/3147] time: 0:07:41.970979 
[Epoch

In [22]:
! zip -r samples.zip output/sim2larvae/



  adding: output/sim2larvae/ (stored 0%)
  adding: output/sim2larvae/basic032.png (deflated 12%)
  adding: output/sim2larvae/basic021.png (deflated 12%)
  adding: output/sim2larvae/basic031.png (deflated 12%)
  adding: output/sim2larvae/basic028.png (deflated 12%)
  adding: output/sim2larvae/basic020.png (deflated 12%)
  adding: output/sim2larvae/basic027.png (deflated 12%)
  adding: output/sim2larvae/basic026.png (deflated 12%)
  adding: output/sim2larvae/basic029.png (deflated 13%)
  adding: output/sim2larvae/basic025.png (deflated 12%)
  adding: output/sim2larvae/basic030.png (deflated 12%)
  adding: output/sim2larvae/basic024.png (deflated 12%)
  adding: output/sim2larvae/basic022.png (deflated 12%)
  adding: output/sim2larvae/basic023.png (deflated 12%)


In [0]:
! mv disB020 model/sim2larvae/basic/