## The notebook will output NM2M.h5 which translates a non-melanoma image to melanoma.


In [None]:
## built on top of https://github.com/hasibzunair/adversarial-lesions/blob/master/isic2016_scripts/train_cyclegan.ipynb

# Run this block once to install these libs
!pip install git+https://www.github.com/keras-team/keras-contrib.git >> /dev/null

In [None]:
import tensorflow_addons
# tfa.layers.InstanceNormalization

In [None]:
import cv2
from tensorflow_addons.layers import InstanceNormalization
import datetime
import matplotlib.pyplot as plt
import sys
import numpy as np
import os
import pandas as pd
import time
import tensorflow as tf
from glob import glob
import tensorflow.keras.backend as K
import random as r
from matplotlib.pyplot import imread
from copy import deepcopy as cp

In [None]:
DEVICE = ["TPU", "GPU"][1]

TARGET_IMG_SIZE = 256 #256
# USE DIFFERENT SEED FOR DIFFERENT STRATIFIED KFOLD
SEED = 42

DATASET_NAME = 'isic2020_gan_train_512' # folder name


def seed_all(seed):
    
    ''' A function to seed everything for getting reproducible results. '''
    r.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['TF_DETERMINISTIC_OPS'] = str(seed)
    os.environ['TF_KERAS'] = str(seed)
    tf.random.set_seed(seed)
    
seed_all(SEED)

In [None]:
if DEVICE == "TPU":
    print("connecting to TPU...")
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        print('Running on TPU ', tpu.master())
    except ValueError:
        print("Could not connect to TPU")
        tpu = None

    if tpu:
        try:
            print("initializing  TPU ...")
            tf.config.experimental_connect_to_cluster(tpu)
            tf.tpu.experimental.initialize_tpu_system(tpu)
            strategy = tf.distribute.experimental.TPUStrategy(tpu)
            print("TPU initialized")
        except _:
            print("failed to initialize TPU")
    else:
        DEVICE = "GPU"

if DEVICE != "TPU":
    print("Using default strategy for CPU and single GPU")
    strategy = tf.distribute.get_strategy()

if DEVICE == "GPU":
    print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))

AUTO     = tf.data.experimental.AUTOTUNE
REPLICAS = strategy.num_replicas_in_sync
print(f'REPLICAS: {REPLICAS}')


In [None]:
MIXED_PRECISION = False
XLA_ACCELERATE = False

if MIXED_PRECISION:
    from tensorflow.keras.mixed_precision import experimental as mixed_precision
    if tpu: policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
    else: policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
    mixed_precision.set_policy(policy)
    print('Mixed precision enabled')

if XLA_ACCELERATE:
    tf.config.optimizer.set_jit(True)
    print('Accelerated Linear Algebra enabled')

In [None]:
print(tf.test.gpu_device_name())

def mk_dir(dir_path):
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)
    return dir_path


base_path = os.path.abspath("../input")
dataset_path = os.path.join(base_path, "isic2020-gan-train-512/isic2020_gan_train_512/")
output_base_path = os.path.abspath('/kaggle/working')
model_path = mk_dir(os.path.join(output_base_path, "/models"))
print(os.listdir(dataset_path))

In [None]:
path = glob('{}/isic2020-gan-train-512/%s/%s/*'.format(base_path) % (DATASET_NAME, 'trainNM'))
print(len(path), f'{len(path)/321.20}%')

path = glob('{}/isic2020-gan-train-512/%s/%s/*'.format(base_path) % (DATASET_NAME, 'testM'))
print(len(path), f'{len(path)/5.81}%')

In [None]:
# Built on top of https://github.com/eriklindernoren/Keras-GAN/tree/master/cyclegan
def get_cycle_gan_model():
    class DataLoader():
        def __init__(self, dataset_name, img_res=(TARGET_IMG_SIZE, TARGET_IMG_SIZE)): #128
            self.dataset_name = dataset_name
            self.img_res = img_res

        def load_data(self, domain, batch_size=1, is_testing=False):
            data_type = "train%s" % domain if not is_testing else "test%s" % domain
            path = glob('{}/isic2020-gan-train-512/%s/%s/*'.format(base_path) % (self.dataset_name, data_type))

            batch_images = np.random.choice(path, size=batch_size)

            imgs = []
            for img_path in batch_images:
                img = self.imread(img_path)
                if not is_testing:
                    if img.shape[0] != TARGET_IMG_SIZE:
                        img = cv2.resize(img, self.img_res)

                    if np.random.random() > 0.5:
                        img = np.fliplr(img)
                else:
                    if img.shape[0] != TARGET_IMG_SIZE:
                        img = cv2.resize(img, self.img_res)
                imgs.append(img)

            # rescale to [-1, 1]
            imgs = np.array(imgs)/127.5 - 1.

            # rescale to [0, 1]
#             imgs = np.asarray(imgs).astype('float64')
#             imgs = imgs / 255.
#             imgs = np.asarray(imgs).astype('float64')

            return np.asarray(imgs)

        def load_batch(self, batch_size=1, is_testing=False):
            data_type = "train" if not is_testing else "val"
            path_NM = glob('{}/isic2020-gan-train-512/%s/%sNM/*'.format(base_path) % (self.dataset_name, data_type))
            path_M = glob('{}/isic2020-gan-train-512/%s/%sM/*'.format(base_path) % (self.dataset_name, data_type))

            self.n_batches = int(min(len(path_NM), len(path_M)) // batch_size)
            total_samples = self.n_batches * batch_size

            print(batch_size, "HERE______", self.n_batches, total_samples)
            # Sample n_batches * batch_size from each path list so that model sees all
            # samples from both domains
            path_NM = np.random.choice(path_NM, total_samples, replace=False)
            path_M = np.random.choice(path_M, total_samples, replace=False)

            for i in range(self.n_batches-1):
                batch_NM = path_NM[i*batch_size:(i+1)*batch_size]
                batch_M = path_M[i*batch_size:(i+1)*batch_size]
                imgs_NM, imgs_M = [], []
                for img_NM, img_M in zip(batch_NM, batch_M):
                    img_NM = self.imread(img_NM)
                    img_M = self.imread(img_M)

                    if img_NM.shape[0] != TARGET_IMG_SIZE:
                        img_NM = cv2.resize(img_NM, self.img_res)
                    if img_M.shape[0] != TARGET_IMG_SIZE:
                        img_M = cv2.resize(img_M, self.img_res)

                    if not is_testing and np.random.random() > 0.5:
                            img_NM = np.fliplr(img_NM)
                            img_M = np.fliplr(img_M)

                    imgs_NM.append(img_NM)
                    imgs_M.append(img_M)

                # rescale to [-1, 1]
                imgs_NM = np.array(imgs_NM)/127.5 - 1.
                imgs_M = np.array(imgs_M)/127.5 - 1.

                # rescale to [0, 1]

#                 imgs_NM = np.asarray(imgs_NM).astype('float64')
#                 imgs_NM = imgs_NM / 255.
#                 imgs_NM = np.asarray(imgs_NM).astype('float64')

#                 imgs_M = np.asarray(imgs_M).astype('float64')
#                 imgs_M = imgs_M / 255.
#                 imgs_M = np.asarray(imgs_M).astype('float64')

                yield imgs_NM, imgs_M

        def load_img(self, path):
            img = self.imread(path)
            if img.shape[0] != TARGET_IMG_SIZE:
                img = cv2.resize(img, self.img_res)

            # rescale to [-1, 1]
            img = img/127.5 - 1.

            # rescale to [0, 1]
#             img = np.asarray(img).astype('float64')
#             img = img / 255.
#             img = np.asarray(img).astype('float64')

            return img[np.newaxis, :, :, :]

        def imread(self, path):
            return imread(path, format='RGB').astype(np.float64)

    
    class CycleGAN():
        def __init__(self):
            # Input shape
            self.img_rows = TARGET_IMG_SIZE
            self.img_cols = TARGET_IMG_SIZE
            self.channels = 3
            self.img_shape = (self.img_rows, self.img_cols, self.channels)

            # Configure data loader
            self.dataset_name = DATASET_NAME #contains trainA, trainB, testA, testB
            self.data_loader = DataLoader(dataset_name=self.dataset_name,
                                          img_res=(self.img_rows, self.img_cols))


            # Calculate output shape of D (PatchGAN)
            patch = int(self.img_rows / 2**4)
            self.disc_patch = (patch, patch, 1)

            # Number of filters in the first layer of G and D
            self.gf = 32
            self.df = 64

            # Loss weights
            self.lambda_cycle = 10.0                    # Cycle-consistency loss
            self.lambda_id = 0.1 * self.lambda_cycle    # Identity loss

            optimizer = tf.keras.optimizers.Adam(0.0002, 0.5)

            # Build and compile the discriminators
            self.d_NM = self.build_discriminator()
            self.d_M = self.build_discriminator()
            self.d_NM.compile(loss='mse',
                optimizer=optimizer,
                metrics=['accuracy'])
            self.d_M.compile(loss='mse',
                optimizer=optimizer,
                metrics=['accuracy'])

            #-------------------------
            # Construct Computational
            #   Graph of Generators
            #-------------------------

            # Build the generators
            self.g_NM2M = self.build_generator()
            self.g_M2NM = self.build_generator()

            # Input images from both domains
            img_NM = tf.keras.layers.Input(shape=self.img_shape)
            img_M = tf.keras.layers.Input(shape=self.img_shape)

            # Translate images to the other domain
            fake_M = self.g_NM2M(img_NM)
            fake_NM = self.g_M2NM(img_M)
            # Translate images back to original domain
            reconstr_NM = self.g_M2NM(fake_M)
            reconstr_M = self.g_NM2M(fake_NM)
            # Identity mapping of images
            img_NM_id = self.g_M2NM(img_NM)
            img_M_id = self.g_NM2M(img_M)

            # For the combined model we will only train the generators
            self.d_NM.trainable = False
            self.d_M.trainable = False

            # Discriminators determines validity of translated images
            valid_NM = self.d_NM(fake_NM)
            valid_M = self.d_M(fake_M)

            # Combined model trains generators to fool discriminators
            self.combined = tf.keras.models.Model(inputs=[img_NM, img_M],
                                  outputs=[ valid_NM, valid_M,
                                            reconstr_NM, reconstr_M,
                                            img_NM_id, img_M_id ])
            self.combined.compile(loss=['mse', 'mse',
                                        'mae', 'mae',
                                        'mae', 'mae'],
                                loss_weights=[  1, 1,
                                                self.lambda_cycle, self.lambda_cycle,
                                                self.lambda_id, self.lambda_id ],
                                optimizer=optimizer)

        def build_generator(self):
            """U-Net Generator"""

            def conv2d(layer_input, filters, f_size=4):
                """Layers used during downsampling"""
                d = tf.keras.layers.Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
                d = tf.keras.layers.LeakyReLU(alpha=0.2)(d)
                d = InstanceNormalization()(d)
                return d

            def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
                """Layers used during upsampling"""
                u = tf.keras.layers.UpSampling2D(size=2)(layer_input)
                u = tf.keras.layers.Conv2D(filters, kernel_size=f_size, strides=1, padding='same', activation='relu')(u)
                if dropout_rate:
                    u = tf.keras.layers.Dropout(dropout_rate)(u)
                u = InstanceNormalization()(u)
                u = tf.keras.layers.Concatenate()([u, skip_input])
                return u

            # Image input
            d0 = tf.keras.layers.Input(shape=self.img_shape)

            # Downsampling
            d1 = conv2d(d0, self.gf)
            d2 = conv2d(d1, self.gf*2)
            d3 = conv2d(d2, self.gf*4)
            d4 = conv2d(d3, self.gf*8)

            # Upsampling
            u1 = deconv2d(d4, d3, self.gf*4)
            u2 = deconv2d(u1, d2, self.gf*2)
            u3 = deconv2d(u2, d1, self.gf)

            u4 = tf.keras.layers.UpSampling2D(size=2)(u3)
            output_img = tf.keras.layers.Conv2D(self.channels, kernel_size=4, strides=1, padding='same', activation='tanh')(u4)

            g_model = tf.keras.models.Model(d0, output_img)
    #         print(g_model.summary())
            print("##################################################")
            return g_model

        def build_discriminator(self):

            def d_layer(layer_input, filters, f_size=4, normalization=True):
                """Discriminator layer"""
                d = tf.keras.layers.Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
                d = tf.keras.layers.LeakyReLU(alpha=0.2)(d)
                if normalization:
                    d = InstanceNormalization()(d)
                return d

            img = tf.keras.layers.Input(shape=self.img_shape)

            d1 = d_layer(img, self.df, normalization=False)
            d2 = d_layer(d1, self.df*2)
            d3 = d_layer(d2, self.df*4)
            d4 = d_layer(d3, self.df*8)

            validity = tf.keras.layers.Conv2D(1, kernel_size=4, strides=1, padding='same')(d4)

            d_model = tf.keras.models.Model(img, validity)
    #         print(d_model.summary())
            print("##################################################")        
            return d_model

        def train(self, epochs, batch_size=1, sample_interval=50):

            start_time = datetime.datetime.now()

            # Adversarial loss ground truths
            valid = np.array(np.ones((batch_size,) + self.disc_patch))
            fake = np.array(np.zeros((batch_size,) + self.disc_patch))

            # Make a log file
            record_df = pd.DataFrame(columns=['epoch', 'd_Loss', 'accuracy', 'g_loss', 'adv', 'recon', 'id', 'elapsed_time'])

            for epoch in range(epochs):
                for batch_i, (imgs_NM, imgs_M) in enumerate(self.data_loader.load_batch(batch_size)):

                    # ----------------------
                    #  Train Discriminators
                    # ----------------------

                    # Translate images to opposite domain
                    fake_M = self.g_NM2M.predict(imgs_NM)
                    fake_NM = self.g_M2NM.predict(imgs_M)

#                     print("HERE__________inside2", np.array(imgs_NM).shape, np.array(imgs_NM))
#                     print("HERE__________inside3", np.array(fake_M).shape, np.array(fake_M))

                    # Train the discriminators (original images = real / translated = Fake)
                    dNM_loss_real = self.d_NM.train_on_batch(imgs_NM, valid)
                    dNM_loss_fake = self.d_NM.train_on_batch(fake_NM, fake)
                    dNM_loss = 0.5 * np.add(dNM_loss_real, dNM_loss_fake)

                    dM_loss_real = self.d_M.train_on_batch(imgs_M, valid)
                    dM_loss_fake = self.d_M.train_on_batch(fake_M, fake)
                    dM_loss = 0.5 * np.add(dM_loss_real, dM_loss_fake)

                    # Total disciminator loss
                    d_loss = 0.5 * np.add(dNM_loss, dM_loss)


                    # ------------------
                    #  Train Generators
                    # ------------------

                    # Train the generators
                    g_loss = self.combined.train_on_batch([imgs_NM, imgs_M],
                                                            [valid, valid,
                                                            imgs_NM, imgs_M,
                                                            imgs_NM, imgs_M])

                    elapsed_time = datetime.datetime.now() - start_time

                    # Plot the progress
                    #print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %05f, adv: %05f, recon: %05f, id: %05f] time: %s " \
                    #                                                        % ( epoch, epochs,
                    #                                                            batch_i, self.data_loader.n_batches,
                    #                                                            d_loss[0], 100*d_loss[1],
                    #                                                            g_loss[0],
                    #                                                            np.mean(g_loss[1:3]),
                    #                                                            np.mean(g_loss[3:5]),
                    #                                                            np.mean(g_loss[5:6]),
                    #                                                            elapsed_time))

                    # If at save interval => save generated image samples
                    if batch_i % sample_interval == 0 and epoch % 5 == 0:
                        self.sample_images(epoch, batch_i)


                # Print updates
                print(epoch,"--------------------", d_loss[0], g_loss[0], 100*d_loss[1])

                # Log metrics at end of epoch            
                new_row = {'epoch': epoch, 'd_Loss': d_loss[0], 'accuracy': 100*d_loss[1], 'g_loss': g_loss[0], 'adv': np.mean(g_loss[1:3]), 'recon': np.mean(g_loss[3:5]), 'id': np.mean(g_loss[5:6]), 'elapsed_time': elapsed_time}

                record_df = record_df.append(new_row, ignore_index=True)
                record_df.to_csv("{}/record.csv".format(model_path), index=0)

                # Save file at end of epoch.
                print("Saving model at {} epoch.".format(epoch))
                self.g_NM2M.save(filepath='{}/{}'.format(output_base_path, "NM2M.h5"))
                tf.keras.callbacks.ModelCheckpoint(filepath='{}/{}'.format(output_base_path, "NM2M.h5"), verbose=1,save_best_only=True, save_weights_only=True, save_freq='epoch')

                self.g_M2NM.save(filepath='{}/{}'.format(model_path, "M2NM.h5"))
                tf.keras.callbacks.ModelCheckpoint(filepath='{}/{}'.format(output_base_path, "M2NM.h5"), verbose=1,save_best_only=True, save_weights_only=True, save_freq='epoch')

                self.combined.save(filepath='{}/{}'.format(model_path, "model.h5"))
                tf.keras.callbacks.ModelCheckpoint(filepath='{}/{}'.format(output_base_path, "model.h5"), verbose=1,save_best_only=True, save_weights_only=True, save_freq='epoch')   


            print("####################...Training finished...####################")
            print("####################.....Models Saved.....####################")


        def sample_images(self, epoch, batch_i):
            os.makedirs('{}/images/%s'.format(output_base_path) % self.dataset_name, exist_ok=True)
            r, c = 2, 3

            imgs_NM = self.data_loader.load_data(domain="NM", batch_size=1, is_testing=True)
            imgs_M = self.data_loader.load_data(domain="M", batch_size=1, is_testing=True)

            # Demo (for GIF)
            #imgs_NM = self.data_loader.load_img('datasets/apple2orange/testA/n07740461_1541.jpg')
            #imgs_M = self.data_loader.load_img('datasets/apple2orange/testB/n07749192_4241.jpg')

            # Translate images to the other domain
            fake_M = self.g_NM2M.predict(imgs_NM)
            fake_NM = self.g_M2NM.predict(imgs_M)
            # Translate back to original domain
            reconstr_NM = self.g_M2NM.predict(fake_M)
            reconstr_M = self.g_NM2M.predict(fake_NM)

            gen_imgs = np.concatenate([imgs_NM, fake_M, reconstr_NM, imgs_M, fake_NM, reconstr_M])

            # Rescale images 0 - 1
            gen_imgs = 0.5 * gen_imgs + 0.5

            titles = ['Original', 'Translated', 'Reconstructed']
            fig, axs = plt.subplots(r, c)
            cnt = 0
            for i in range(r):
                for j in range(c):
                    axs[i,j].imshow(gen_imgs[cnt])
                    axs[i, j].set_title(titles[j])
                    axs[i,j].axis('off')
                    cnt += 1
            fig.savefig("{}/images/%s/%d_%d.png".format(output_base_path) % (self.dataset_name, epoch, batch_i))
            plt.close()

    gan_model = None
    gan_model = CycleGAN()
    return gan_model


In [None]:
start_time = time.time()

if DEVICE=='TPU':
    if tpu: tf.tpu.experimental.initialize_tpu_system(tpu)


K.clear_session()
with strategy.scope():
    gan = get_cycle_gan_model()
    print("Training...........")

gan.train(epochs=45, batch_size=1*REPLICAS, sample_interval=500)


end_time = time.time()


print("--- Time taken to train : %s minutes ---" % ((end_time - start_time)//60))