In [1]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm 
import os
import tensorflow as tf
import gc
import time

In [2]:
'''
    This is the control cell
    resume model name: is the directory to resume from
    resume epoch: is the epoch from that specific run that you want to resume
    resume from last: should be true with resume parameter to resume from last run of current model
    resume: resume from specific directory with sepcific epoch
'''
root_data_dir = "maps/"
runs_dir = "Runs/"
model_name = "LR_0005_Disc_Factor_1_Gen_Factor_25_Truely_dynamic_disc_loss_BN_Everywhere"
resume_model_name = ""
resume_epoch = 0
model_dir = runs_dir + model_name + "/"
if os.listdir(runs_dir).count(model_name) == 0:
    os.mkdir(runs_dir + model_name)
maps = os.listdir(root_data_dir)
dmaps = [m for m in maps if m[0] == "d"]
gmaps = [m for m in maps if m[0] == "g"]
resume_from_last = False
resume = False
train_val_split_rate = 0.9
n_crumples = len(dmaps)
idxs = np.arange(n_crumples)
np.random.shuffle(idxs)
e_crumple_idxs = idxs[:int(n_crumples * train_val_split_rate)]
f_crumple_idxs = idxs[int(n_crumples * train_val_split_rate):]
Batch_size = 5
Epochs = 150
Val_Batch_size = Batch_size * 2
total_chunks = 153
images_per_chunk = 1000
steps_per_epoch = int(images_per_chunk * train_val_split_rate) * total_chunks // Batch_size
steps_per_val_epoch = int(images_per_chunk * (1 - train_val_split_rate)) * total_chunks // Val_Batch_size
saved_models_list = os.listdir(runs_dir + model_name)
key_func = lambda x:int(x.split("_")[-1][:-3]) if x[-3:] == ".h5" else 0
saved_models_list.sort(key=key_func)
if resume_from_last:
    last_epoch = int(saved_models_list[-1].split("_")[-1][:-3])
else:
    last_epoch = -1

In [3]:
def quick_deformation(crumple_id, img):
    global dmaps, gmaps, root_data_dir
    g_map = np.load(root_data_dir + gmaps[crumple_id])
    d_map = np.load(root_data_dir + dmaps[crumple_id])
    ret = np.zeros_like(img)
    ret[...] = img[:, d_map[..., 0], d_map[..., 1], :] * np.expand_dims(g_map, 2)
    return ret

In [4]:
'''
    data generator loades images in chunks of 1000
    then feeds it by batch size to the network.
    first we load image, then we normalize it between -+127.5 
    then we apply deformation maps and re-normalize with same range
'''
def data_gen(training=True, batch_size=1):
    global e_crumple_idxs, f_crumple_idxs, total_chunks, train_val_split_rate, images_per_chunk
    if training:
        crumple_idxs = e_crumple_idxs
        start_batch = 0
        end_batch = int(images_per_chunk * train_val_split_rate)
    else:
        crumple_idxs = f_crumple_idxs
        start_batch = int(images_per_chunk * train_val_split_rate)
        end_batch = images_per_chunk
    while 1:
        crumple_i = 0
        for chunk_num in range(total_chunks):
            for j in range(start_batch, end_batch, batch_size):
                imgs = np.load("Images/" + str(chunk_num) + ".npy", mmap_mode="r")
                IMAGE_SHAPE = 180
                imgs = np.expand_dims(imgs[j:j + batch_size], -1)
                im = imgs[:,20:-20,20:-20]
                im_norm = np.subtract(im, 127.5, dtype=np.float32)
                cr_id = crumple_idxs[crumple_i]
                cr_im = quick_deformation(cr_id, im)
                cr_im_norm = np.subtract(cr_im,  127.5, dtype=np.float32)
                yield cr_im_norm, im_norm
                crumple_i = (crumple_i + 1) % len(crumple_idxs)

In [5]:
def generator_model():
    global resume_from_last, resume
    if not resume_from_last and not resume:
        FACTOR = 2.5
        x_inp = tf.keras.layers.Input(shape=(180, 180, 1))
        padding_layer = tf.keras.layers.ZeroPadding2D(padding=(2, 2))(x_inp) # 184
        c1 = tf.keras.layers.Conv2D(int(24 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(padding_layer) # 184
        c_dial_1 = tf.keras.layers.Conv2D(int(8 * FACTOR), (3, 3), dilation_rate=1, padding="same", activation=tf.nn.leaky_relu)(padding_layer)
        c1_out = tf.keras.layers.Concatenate(axis = 3)([c1, c_dial_1])
        c2 = tf.keras.layers.Conv2D(int(12 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(c1_out) # 184
        c_dial_2 = tf.keras.layers.Conv2D(int(4 * FACTOR), (3, 3), dilation_rate=1, padding="same", activation=tf.nn.leaky_relu)(c1_out)
        c2_out = tf.keras.layers.Concatenate(axis = 3)([c2, c_dial_2])
        bn0 = tf.keras.layers.BatchNormalization()(c2_out)

        c3 = tf.keras.layers.Conv2D(int(12 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(bn0) # 184
        c_dial_3 = tf.keras.layers.Conv2D(int(4 * FACTOR), (3, 3), dilation_rate=1, padding="same", activation=tf.nn.leaky_relu)(bn0)
        c3_out = tf.keras.layers.Concatenate(axis = 3)([c3, c_dial_3])
        c4 = tf.keras.layers.Conv2D(int(12 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(c3_out) # 184
        c_dial_4 = tf.keras.layers.Conv2D(int(4 * FACTOR), (3, 3), dilation_rate=1, padding="same", activation=tf.nn.leaky_relu)(c3_out)
        c4_out = tf.keras.layers.Concatenate(axis = 3)([c4, c_dial_4])
        bn1 = tf.keras.layers.BatchNormalization()(c4_out)

        c5 = tf.keras.layers.Conv2D(int(12 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(bn1) # 184
        c_dial_5 = tf.keras.layers.Conv2D(int(4 * FACTOR), (3, 3), dilation_rate=1, padding="same", activation=tf.nn.leaky_relu)(bn1)
        c5_out = tf.keras.layers.Concatenate(axis = 3)([c5, c_dial_5])
        p1 = tf.keras.layers.MaxPooling2D()(c5_out)
        c5 = tf.keras.layers.Conv2D(int(16 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(p1) # 92
        bn2 = tf.keras.layers.BatchNormalization()(c5)

        c6 = tf.keras.layers.Conv2D(int(16 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(bn2) # 92
        p2 = tf.keras.layers.MaxPooling2D()(c6)
        c7 = tf.keras.layers.Conv2D(int(32 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(p2) # 46
        bn3 = tf.keras.layers.BatchNormalization()(c7)
        
        c8 = tf.keras.layers.Conv2D(int(32 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(bn3) # 46
        p3 = tf.keras.layers.MaxPooling2D()(c8)
        c9 = tf.keras.layers.Conv2D(int(64 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(p3) # 23
        bn4 = tf.keras.layers.BatchNormalization()(c9)
        
        c10 = tf.keras.layers.Conv2D(int(64 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(bn4) # 23
        
        
        
        
        d11 = tf.keras.layers.Conv2D(int(64 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(c10) # 23
        bn5 = tf.keras.layers.BatchNormalization()(d11)
        
        d10 = tf.keras.layers.Conv2D(int(64 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(d11) # 23
        u4 = tf.keras.layers.UpSampling2D()(d10)
        cc4 = tf.keras.layers.Concatenate()([u4, c8])
        d9 = tf.keras.layers.Conv2D(int(32 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(c8) # 46
        bn6 = tf.keras.layers.BatchNormalization()(d9)
        
        d8 = tf.keras.layers.Conv2D(int(32 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(bn6) # 46
        u3 = tf.keras.layers.UpSampling2D()(d8)
        cc3 = tf.keras.layers.Concatenate()([u3, c6])
        d7 = tf.keras.layers.Conv2D(int(32 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(cc3) # 92
        bn7 = tf.keras.layers.BatchNormalization()(d7)
        
        d6 = tf.keras.layers.Conv2D(int(16 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(bn7) # 92
        u2 = tf.keras.layers.UpSampling2D()(d6)
        cc2 = tf.keras.layers.Concatenate()([u2, c5_out])
        d5 = tf.keras.layers.Conv2D(int(24 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(cc2) # 184
        d_dial_5 = tf.keras.layers.Conv2D(int(8 * FACTOR), (3, 3), dilation_rate=1, padding="same", activation=tf.nn.leaky_relu)(cc2)
        d5_out = tf.keras.layers.Concatenate(axis = 3)([d5, d_dial_5])
        bn8 = tf.keras.layers.BatchNormalization()(d5_out)
        
        d4 = tf.keras.layers.Conv2D(int(12 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(bn8) # 184
        d_dial_4 = tf.keras.layers.Conv2D(int(4 * FACTOR), (3, 3), dilation_rate=1, padding="same", activation=tf.nn.leaky_relu)(bn8)
        d4_out = tf.keras.layers.Concatenate(axis = 3)([d4, d_dial_4])
        bn9 = tf.keras.layers.BatchNormalization()(d4_out)
        
        d3 = tf.keras.layers.Conv2D(int(12 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(bn9) # 184
        d_dial_3 = tf.keras.layers.Conv2D(int(4 * FACTOR), (3, 3), dilation_rate=1, padding="same", activation=tf.nn.leaky_relu)(bn9)
        d3_out = tf.keras.layers.Concatenate(axis = 3)([d3, d_dial_3])
        d2 = tf.keras.layers.Conv2D(int(8 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(d3_out) # 184
        d_dial_2 = tf.keras.layers.Conv2D(int(4 * FACTOR), (3, 3), dilation_rate=1, padding="same", activation=tf.nn.leaky_relu)(d3_out)
        d2_out = tf.keras.layers.Concatenate(axis = 3)([d2, d_dial_2])
        d1 = tf.keras.layers.Conv2D(1, (3, 3), padding="same", activation=None)(d2) # 184
        crop_layer = tf.keras.layers.Cropping2D(cropping=((2, 2), (2, 2)))(d1) # 180
        model = tf.keras.models.Model(inputs=[x_inp], outputs=[crop_layer])
        # model.compile(optimizer=tf.keras.optimizers.Adam(0.0003), loss=mean_elastic_distance)
    elif resume and not resume_from_last:
        target = glob.glob(runs_dir + "/" + resume_model_name + "/*" + resume_epoch + "*")[0]
        model = tf.keras.models.load_model(target)
    else:
        model = tf.keras.models.load_model(runs_dir + "/" + model_name + "/" + os.listdir(runs_dir + model_name)[-2], custom_objects={"mean_elastic_distance":mean_elastic_distance})
    
    return model

In [6]:
def discriminator_model():
    if not resume_from_last:
        FACTOR = 1.0
        x_inp = tf.keras.layers.Input(shape=(180, 180, 1))
        padding_layer = tf.keras.layers.ZeroPadding2D(padding=(2, 2))(x_inp) # 184
        c1 = tf.keras.layers.Conv2D(int(24 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(padding_layer) # 184
        c_dial_1 = tf.keras.layers.Conv2D(int(8 * FACTOR), (3, 3), dilation_rate=1, padding="same", activation=tf.nn.leaky_relu)(padding_layer)
        c1_out = tf.keras.layers.Concatenate(axis = 3)([c1, c_dial_1])
        c2 = tf.keras.layers.Conv2D(int(12 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(c1_out) # 184
        c_dial_2 = tf.keras.layers.Conv2D(int(4 * FACTOR), (3, 3), dilation_rate=1, padding="same", activation=tf.nn.leaky_relu)(c1_out)
        c2_out = tf.keras.layers.Concatenate(axis = 3)([c2, c_dial_2])
        c3 = tf.keras.layers.Conv2D(int(12 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(c2_out) # 184
        c_dial_3 = tf.keras.layers.Conv2D(int(4 * FACTOR), (3, 3), dilation_rate=1, padding="same", activation=tf.nn.leaky_relu)(c2_out)
        c3_out = tf.keras.layers.Concatenate(axis = 3)([c3, c_dial_3])
        c4 = tf.keras.layers.Conv2D(int(12 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(c3_out) # 184
        c_dial_4 = tf.keras.layers.Conv2D(int(4 * FACTOR), (3, 3), dilation_rate=1, padding="same", activation=tf.nn.leaky_relu)(c3_out)
        c4_out = tf.keras.layers.Concatenate(axis = 3)([c4, c_dial_4])
        c5 = tf.keras.layers.Conv2D(int(12 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(c4_out) # 184
        c_dial_5 = tf.keras.layers.Conv2D(int(4 * FACTOR), (3, 3), dilation_rate=1, padding="same", activation=tf.nn.leaky_relu)(c4_out)
        c5_out = tf.keras.layers.Concatenate(axis = 3)([c5, c_dial_5])
        p1 = tf.keras.layers.MaxPooling2D()(c5_out)
        c5 = tf.keras.layers.Conv2D(int(16 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(p1) # 92
        c6 = tf.keras.layers.Conv2D(int(16 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(c5) # 92
        p2 = tf.keras.layers.MaxPooling2D()(c6)
        c7 = tf.keras.layers.Conv2D(int(16 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(p2) # 46
        c8 = tf.keras.layers.Conv2D(int(16 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(c7) # 46
        p3 = tf.keras.layers.MaxPooling2D()(c8)
        c9 = tf.keras.layers.Conv2D(int(16 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(p3) # 23
        c10 = tf.keras.layers.Conv2D(int(16 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(c9) # 23
        p4 = tf.keras.layers.MaxPooling2D()(c10)
        c11 = tf.keras.layers.Conv2D(int(16 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(p4) # 12
        c12 = tf.keras.layers.Conv2D(int(16 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(c11) # 12
        p5 = tf.keras.layers.MaxPooling2D()(c12)
        c13 = tf.keras.layers.Conv2D(int(16 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(p5) # 12
        c14 = tf.keras.layers.Conv2D(int(16 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(c13) # 12
        p6 = tf.keras.layers.MaxPooling2D()(c14)
        c15 = tf.keras.layers.Conv2D(int(16 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(p6) # 6
        c16 = tf.keras.layers.Conv2D(int(16 * FACTOR), (3, 3), padding="same", activation=tf.nn.leaky_relu)(c15) # 6
        p7 = tf.keras.layers.GlobalMaxPooling2D()(c16)
        out_layer = tf.keras.layers.Dense(1, activation="sigmoid")(p7)
        model = tf.keras.models.Model(inputs=[x_inp], outputs=[out_layer])
    else:
        model = tf.keras.models.load_model(runs_dir + "/" + model_name + "/" + os.listdir(runs_dir + model_name)[-2])
    return model

In [7]:
generator = generator_model()

In [8]:
discriminator = discriminator_model()

In [9]:
generator.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 180, 180, 1  0           []                               
                                )]                                                                
                                                                                                  
 zero_padding2d (ZeroPadding2D)  (None, 184, 184, 1)  0          ['input_1[0][0]']                
                                                                                                  
 conv2d (Conv2D)                (None, 184, 184, 60  600         ['zero_padding2d[0][0]']         
                                )                                                                 
                                                                                              

In [10]:
discriminator.summary()

Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_2 (InputLayer)           [(None, 180, 180, 1  0           []                               
                                )]                                                                
                                                                                                  
 zero_padding2d_1 (ZeroPadding2  (None, 184, 184, 1)  0          ['input_2[0][0]']                
 D)                                                                                               
                                                                                                  
 conv2d_31 (Conv2D)             (None, 184, 184, 24  240         ['zero_padding2d_1[0][0]']       
                                )                                                           

In [11]:
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=False, reduction="none")

In [12]:
def generator_loss(fake_output, generator_output, ground_truth):
    mse = tf.reduce_mean(tf.square(generator_output - ground_truth), axis=[1, 2, 3]) / 1000
    mae = tf.reduce_mean(tf.abs(generator_output - ground_truth), axis=[1, 2, 3]) / 31.2
    fooling_loss = cross_entropy(tf.ones_like(fake_output), fake_output)
    return tf.reduce_mean(fooling_loss + mae), fooling_loss, mae, mse

In [13]:
def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    return tf.reduce_mean(real_loss + fake_loss), real_loss, fake_loss

In [14]:
generator_optimizer = tf.keras.optimizers.Adam(0.00005)
discriminator_optimizer = tf.keras.optimizers.Adam(0.0005)

In [None]:
train_gen = data_gen(training=True, batch_size=Batch_size)
gc.collect()
gen_loss_history = []
disc_loss_history = []
flag_disc_train = True
history_metric_gen_fooling_loss_hist = []
history_metric_gen_mse_hist = []
history_metric_gen_mae_hist = []
history_metric_disc_real_loss_hist = []
history_metric_disc_fake_loss_hist = []
for e in range(last_epoch + 1, Epochs):
    print("Epoch", e + 1, "/", Epochs, ":")
    metric_gen_fooling_loss_hist = []
    metric_gen_mse_hist = []
    metric_gen_mae_hist = []
    metric_disc_real_loss_hist = []
    metric_disc_fake_loss_hist = []
    this_epoch_disc_loss = []
    for step in tqdm(range(int(total_chunks * 1000 * train_val_split_rate / Batch_size))):
        noise, images = next(train_gen)
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            generated_images = generator(noise, training=True)
            real_output = discriminator(images, training=True)
            fake_output = discriminator(generated_images, training=True)
            gen_loss, batch_gen_fooling_loss, batch_gen_mae, batch_gen_mse = generator_loss(fake_output, generated_images, images)#batch_gen_mse, cbatch_gen_mae = generator_loss(fake_output, generated_images, images)
            disc_loss, batch_disc_real_loss, batch_disc_fake_loss = discriminator_loss(real_output, fake_output)
        metric_gen_fooling_loss_hist.append(batch_gen_fooling_loss)
        metric_gen_mse_hist.append(batch_gen_mse)
        metric_gen_mae_hist.append(batch_gen_mae)
        metric_disc_real_loss_hist.append(batch_disc_real_loss)
        metric_disc_fake_loss_hist.append(batch_disc_fake_loss)
        this_epoch_disc_loss.append(disc_loss)
        gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
        zip1 = zip(gradients_of_generator, generator.trainable_variables)
        generator_optimizer.apply_gradients(zip1)
        if flag_disc_train: # Only update the discriminator every ten steps.
            gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
            zip2 = zip(gradients_of_discriminator, discriminator.trainable_variables)
            discriminator_optimizer.apply_gradients(zip2)
    history_metric_gen_fooling_loss_hist.append(np.mean(metric_gen_fooling_loss_hist))
    history_metric_gen_mse_hist.append(np.mean(metric_gen_mse_hist))
    history_metric_gen_mae_hist.append(np.mean(metric_gen_mae_hist))
    history_metric_disc_real_loss_hist.append(np.mean(metric_disc_real_loss_hist))
    history_metric_disc_fake_loss_hist.append(np.mean(metric_disc_fake_loss_hist))
    generator.save(model_dir + model_name + "_gen_" + str(e) + ".h5")
    discriminator.save(model_dir + model_name + "_disc_" + str(e) + ".h5")
    gen_loss_history.append(gen_loss)
    disc_loss_history.append(disc_loss)
    
    if np.mean(this_epoch_disc_loss) < 0.1:
        flag_disc_train = False
    else:
        flag_disc_train = True

    plt.figure(figsize=(40, 30))
    plt.subplot(7, 1, 1)
    plt.plot(gen_loss_history)
    plt.title("Generator Loss")
    plt.subplot(7, 1, 2)
    plt.plot(disc_loss_history)
    plt.title("Discriminator Loss")
    plt.subplot(7, 1, 3)
    plt.plot(history_metric_gen_fooling_loss_hist)
    plt.title("Generator Fooling Loss (Metric)")
    plt.subplot(7, 1, 4)
    plt.plot(history_metric_gen_mse_hist)
    plt.title("Generator MSE (Metric)")
    plt.subplot(7, 1, 5)
    plt.plot(history_metric_gen_mae_hist)
    plt.title("Generator MAE (Metric)")
    plt.subplot(7, 1, 6)
    plt.plot(history_metric_disc_real_loss_hist)
    plt.title("Discriminator Real Loss (Metric)")
    plt.subplot(7, 1, 7)
    plt.plot(history_metric_disc_fake_loss_hist)
    plt.title("Discriminator Fake Loss (Metric)")
    plt.savefig(model_dir + model_name + "_status" + ".png")
    plt.close()
    
    dg_tr = data_gen(training=True, batch_size=1)
    dg_te = data_gen(training=False, batch_size=1)
    plt.figure(figsize=(8, 3))
    fig, axs = plt.subplots(3, 8, figsize=(30, 12))
    for i in range(4):
        a, b = dg_tr.send(None)
        axs[0, i].set_title("Train Input " + str(i + 1))
        axs[0, i].imshow(a[0], cmap='gray')
        axs[0, i].axis('off')
        axs[1, i].set_title("Train GT " + str(i + 1))
        axs[1, i].imshow(b[0], cmap='gray')
        axs[1, i].axis('off')
        axs[2, i].set_title("Train Output " + str(i + 1))
        axs[2, i].imshow(np.maximum(np.minimum(generator.predict(a, verbose=0)[0] + 127.5, 255), 0), cmap='gray')
        axs[2, i].axis('off')
        c, d = dg_te.send(None)
        axs[0, 4 + i].set_title("Test Input " + str(i + 1))
        axs[0, 4 + i].imshow(c[0], cmap='gray')
        axs[0, 4 + i].axis('off')
        axs[1, 4 + i].set_title("Test GT " + str(i + 1))
        axs[1, 4 + i].imshow(d[0], cmap='gray')
        axs[1, 4 + i].axis('off')
        axs[2, 4 + i].set_title("Test Output " + str(i + 1))
        axs[2, 4 + i].imshow(np.maximum(np.minimum(generator.predict(c, verbose=0)[0] + 127.5, 255), 0), cmap='gray')
        axs[2, 4 + i].axis('off')
    fig.savefig(model_dir + model_name + "_" + str(e) + ".png", bbox_inches='tight')
    plt.close()
    gc.collect()

100%|██████████████████████████████████████████████████████████████████████████| 27540/27540 [3:52:04<00:00,  1.98it/s]


Epoch 29 / 150 :


100%|██████████████████████████████████████████████████████████████████████████| 27540/27540 [3:46:14<00:00,  2.03it/s]


Epoch 30 / 150 :


100%|██████████████████████████████████████████████████████████████████████████| 27540/27540 [3:49:26<00:00,  2.00it/s]


Epoch 31 / 150 :


100%|██████████████████████████████████████████████████████████████████████████| 27540/27540 [3:53:11<00:00,  1.97it/s]


Epoch 32 / 150 :


100%|██████████████████████████████████████████████████████████████████████████| 27540/27540 [3:58:38<00:00,  1.92it/s]


Epoch 33 / 150 :


100%|██████████████████████████████████████████████████████████████████████████| 27540/27540 [4:00:48<00:00,  1.91it/s]


Epoch 34 / 150 :


100%|██████████████████████████████████████████████████████████████████████████| 27540/27540 [4:04:04<00:00,  1.88it/s]


Epoch 35 / 150 :


100%|██████████████████████████████████████████████████████████████████████████| 27540/27540 [4:06:30<00:00,  1.86it/s]


Epoch 36 / 150 :


100%|██████████████████████████████████████████████████████████████████████████| 27540/27540 [4:05:43<00:00,  1.87it/s]


Epoch 37 / 150 :


100%|██████████████████████████████████████████████████████████████████████████| 27540/27540 [4:12:26<00:00,  1.82it/s]


Epoch 38 / 150 :


100%|██████████████████████████████████████████████████████████████████████████| 27540/27540 [4:18:09<00:00,  1.78it/s]


Epoch 39 / 150 :


100%|██████████████████████████████████████████████████████████████████████████| 27540/27540 [4:19:26<00:00,  1.77it/s]


Epoch 40 / 150 :


100%|██████████████████████████████████████████████████████████████████████████| 27540/27540 [4:17:00<00:00,  1.79it/s]


Epoch 41 / 150 :


100%|██████████████████████████████████████████████████████████████████████████| 27540/27540 [4:20:35<00:00,  1.76it/s]


Epoch 42 / 150 :


100%|██████████████████████████████████████████████████████████████████████████| 27540/27540 [4:23:12<00:00,  1.74it/s]


Epoch 43 / 150 :


100%|██████████████████████████████████████████████████████████████████████████| 27540/27540 [4:26:18<00:00,  1.72it/s]


Epoch 44 / 150 :


100%|██████████████████████████████████████████████████████████████████████████| 27540/27540 [4:29:27<00:00,  1.70it/s]


Epoch 45 / 150 :


  2%|█▊                                                                          | 662/27540 [06:34<4:28:10,  1.67it/s]