In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

In [2]:
import tensorflow as tf

In [3]:
from tensorflow.keras import layers
import cv2
import random
import numpy as np
from matplotlib import pyplot
from tensorflow.keras.optimizers.experimental import RMSprop
from tensorflow.keras.models import Sequential

In [4]:
def cust_data_generator(dir_path, batch_size):
    # Get the list of all files in the directory
    file_list = [f for f in os.listdir(dir_path) if f.endswith('.png')]
    num_files = len(file_list)

    while True:
        # shuffle
        np.random.shuffle(file_list)

        for i in range(0, num_files, batch_size):
            # get a list of filenames of a batch
            batch_files = file_list[i : i + batch_size]

            batch_images = []
            batch_IMRS = []

            for file in batch_files:
                file_path = os.path.join(dir_path, file)

                # read the image and generate IMR
                img = cv2.imread(file_path, cv2.IMREAD_GRAYSCALE)
                arr = cv2.normalize(img, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)

                height, width = arr.shape
                # Assuming a resolution of 128 x 128
                grid_width, grid_height = width // 4, height // 4

                imr_values = [0, 0.25, 0.5, 0.75, 1]

                IMR = np.zeros((grid_height, grid_width))

                for i in range(grid_height):
                    for j in range(grid_width):
                        start_i = i * 4
                        end_i = min((i + 1) * 4, height)
                        start_j = j * 4
                        end_j = min((j + 1) * 4, width)

                        area = arr[start_i:end_i, start_j:end_j]

                        if area.size > 0:
                            avg_brightness = np.mean(area)

                            closest_values = sorted(imr_values, key=lambda x: abs(avg_brightness - x))[:2]
                            assigned_value = random.choice(closest_values)
                            IMR[i][j] = assigned_value
                
                IMR_reshaped = IMR.reshape(32, 32, 1)
                img = cv2.normalize(img, None, alpha=-1, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
                batch_images.append(img)
                batch_IMRS.append(IMR_reshaped)
            
            yield np.array(batch_images), np.array(batch_IMRS)


In [5]:
@tf.keras.saving.register_keras_serializable(package="lossfunc", name="wasserstein_loss")
def wasserstein_loss(y_true, y_pred):
    return tf.keras.backend.mean(y_true * y_pred)

In [6]:
def make_generator():
    # IMR Input
    input1 = layers.Input(shape=(32, 32, 1))
    x1 = layers.Conv2DTranspose(128, kernel_size=3, dilation_rate=2, padding='same')(input1)
    x1 = layers.BatchNormalization()(x1)
    x1 = layers.LeakyReLU()(x1)
    # Noise vector Input
    input2 = layers.Input(shape=(128, 8, 1))
    x2 = layers.Reshape((32, 32, 1))(input2)
    # concatenate the layers
    x = layers.Concatenate()([x1, x2])
    # upsample
    x = layers.UpSampling2D()(x)
    x = layers.UpSampling2D()(x)

    x = layers.Conv2DTranspose(64, kernel_size=3, strides=1, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)

    x = layers.Conv2DTranspose(128, kernel_size=3, strides=1, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)

    x = layers.Conv2DTranspose(256, kernel_size=3, strides=1, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)

    x = layers.Conv2DTranspose(128, kernel_size=3, strides=1, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)

    x = layers.Conv2DTranspose(64, kernel_size=3, strides=1, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)

    x = layers.Conv2DTranspose(32, kernel_size=3, strides=1, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)

    x = layers.Conv2DTranspose(16, kernel_size=3, strides=1, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)

    final = layers.Conv2DTranspose(1, kernel_size=5, strides=1, padding='same', activation='tanh')(x)
    # x = layers.BatchNormalization()(x)
    # x = layers.LeakyReLU()(x)

    return tf.keras.Model(inputs=[input1, input2], outputs=final)

In [7]:
def make_conversion_critic():
    # IMR Input
    input1 = layers.Input(shape=(32, 32, 1))
    x1 = layers.Conv2D(32, kernel_size=4, padding='same')(input1)
    x1 = layers.BatchNormalization()(x1)
    x1 = layers.LeakyReLU()(x1)

    x1 = layers.Conv2D(32, kernel_size=4, padding='same')(x1)
    x1 = layers.BatchNormalization()(x1)
    x1 = layers.LeakyReLU()(x1) # 32x32x32

    # Input from Generator
    input2 = layers.Input(shape=(128, 128, 1))
    x2 = layers.Conv2D(32, kernel_size=3, dilation_rate=2, padding='same')(input2)
    x2 = layers.BatchNormalization()(x2)
    x2 = layers.LeakyReLU()(x2)

    x2 = layers.Conv2D(32, kernel_size=4, strides=2, padding='same')(x2)
    x2 = layers.BatchNormalization()(x2)
    x2 = layers.LeakyReLU()(x2)

    x2 = layers.Conv2D(32, kernel_size=4, strides=2, padding='same')(x2)
    x2 = layers.BatchNormalization()(x2)
    x2 = layers.LeakyReLU()(x2) # 32x32x32

    x = layers.Concatenate()([x1, x2])
    x = layers.Conv2D(32, kernel_size=4, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)
    
    x = layers.Conv2D(32, kernel_size=4, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)
    
    x = layers.Conv2D(16, kernel_size=5, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)

    x = layers.Flatten()(x)
    final = layers.Dense(1)(x)

    critic = tf.keras.Model(inputs=[input1, input2], outputs=final)
    critic.compile(loss=wasserstein_loss, optimizer=RMSprop(learning_rate=0.00025))
    return critic

In [8]:
def make_realism_critic():
    # Input from Generator
    input_layer = layers.Input(shape=(128, 128, 1))
    
    x = layers.Conv2D(256, kernel_size=4, strides=1)(input_layer)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)
    
    x = layers.Conv2D(128, kernel_size=4, strides=1)(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)
    
    x = layers.Conv2D(64, kernel_size=4, strides=1)(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)
    
    x = layers.Conv2D(32, kernel_size=4, strides=1)(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)
    
    x = layers.Conv2D(16, kernel_size=4, strides=1)(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)
    
    x = layers.Flatten()(x)
    final = layers.Dense(1)(x)

    critic = tf.keras.Model(inputs=input_layer, outputs=final)
    critic.compile(loss=wasserstein_loss, optimizer=RMSprop(learning_rate=0.0005))
    return critic

In [9]:
# define the combined generator and critic model for ConversionGAN
def define_ConvGAN(generator, critic):
    for layer in critic.layers:
        if not isinstance(layer, layers.BatchNormalization):
            layer.trainable = False
    imr_input = layers.Input(shape=(32, 32, 1))
    z_input = layers.Input(shape=(128, 8, 1))
    g = generator([imr_input, z_input])
    c = critic([imr_input, g])
    model = tf.keras.Model(inputs=[imr_input, z_input], outputs=c, name="ConversionGAN")
    model.compile(loss=wasserstein_loss, optimizer=RMSprop(learning_rate=0.00025))
    return model        

In [10]:
# define the combined generator and critic model for RealismGAN
def define_RealGAN(generator, critic):
    for layer in critic.layers:
        if not isinstance(layer, layers.BatchNormalization):
            layer.trainable = False
    imr_input = layers.Input(shape=(32, 32, 1))
    z_input = layers.Input(shape=(128, 8, 1))
    g = generator([imr_input, z_input])
    c = critic(g)
    model = tf.keras.Model(inputs=[imr_input, z_input], outputs=c, name="RealismGAN")
    model.compile(loss=wasserstein_loss, optimizer=RMSprop(learning_rate=0.0005))
    return model 

In [11]:
def get_vis_imrs(dir_path):
    file_list = [f for f in os.listdir(dir_path) if f.endswith('.png')]
    num_files = len(file_list)

    vis_IMRS = []

    for file in file_list:
        file_path = os.path.join(dir_path, file)

        # read the image and generate IMR
        img = cv2.imread(file_path, cv2.IMREAD_GRAYSCALE)
        arr = cv2.normalize(img, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)

        height, width = arr.shape
        # Assuming a resolution of 128 x 128
        grid_width, grid_height = width // 4, height // 4

        imr_values = [0, 0.25, 0.5, 0.75, 1]

        IMR = np.zeros((grid_height, grid_width))

        for i in range(grid_height):
            for j in range(grid_width):
                start_i = i * 4
                end_i = min((i + 1) * 4, height)
                start_j = j * 4
                end_j = min((j + 1) * 4, width)

                area = arr[start_i:end_i, start_j:end_j]

                if area.size > 0:
                    avg_brightness = np.mean(area)

                    closest_values = sorted(imr_values, key=lambda x: abs(avg_brightness - x))[:2]
                    assigned_value = random.choice(closest_values)
                    IMR[i][j] = assigned_value
                
        IMR_reshaped = IMR.reshape(32, 32, 1)
        vis_IMRS.append(IMR_reshaped)
    return np.array(vis_IMRS)

In [12]:
def generate_fake_samples(generator, imrs, z_vectors, n_samples):
    X = generator.predict([imrs, z_vectors], verbose=0)
    y = np.ones((n_samples, 1))
    return X, y

In [13]:
def summarize(step, g_model, c_critic, r_critic, imrs, z_vectors, n_samples=10):
    plots_path = './Model/plots/'
    weights_path = './Model/savedModel/Generator/'
    critics_weights_path = './Model/savedModel/Critics/'
    X, _ = generate_fake_samples(g_model, imrs, z_vectors, n_samples)
    X = (X + 1) / 2.0
    fig, axs = pyplot.subplots(2, 5, figsize=(10, 4))

    for i, ax in enumerate(axs.flatten()):
        img = pyplot.imread(X[i]) if isinstance(X[i], str) else X[i]
        ax.imshow(img, cmap='gray')
        ax.axis('off')
    pyplot.tight_layout()
    
    # save plot to file
    filename1 = 'generated_plot_%04d.png' % (step + 1)
    pyplot.savefig(os.path.join(plots_path, filename1))
    pyplot.clf()
    pyplot.close()

    # save the models
    filename2 = 'model_%04d.tf' % (step + 1)
    c_critic_fname = 'conv_%04d.tf' % (step + 1)
    r_critic_fname = 'real_%04d.tf' % (step + 1)
    g_model.save(os.path.join(weights_path, filename2))
    c_critic.save(os.path.join(critics_weights_path, c_critic_fname))
    r_critic.save(os.path.join(critics_weights_path, r_critic_fname))
    
    print('Saved: %s, %s, %s and %s' % (filename1, filename2, c_critic_fname, r_critic_fname))

In [14]:
def plot_history(c1_hist, c2_hist, r1_hist, r2_hist, gc_hist, gr_hist):
    losses_path = './Model/loss/'
    pyplot.plot(c1_hist, label='critic_real')
    pyplot.plot(c2_hist, label='critic_fake')
    pyplot.plot(gc_hist, label='generator')
    pyplot.legend()
    filename1 = 'ConversionGAN_line_plot_loss.png'
    pyplot.savefig(os.path.join(losses_path, filename1))
    pyplot.clf()

    pyplot.plot(r1_hist, label='critic_real')
    pyplot.plot(r2_hist, label='critic_fake')
    pyplot.plot(gr_hist, label='generator')
    pyplot.legend()
    filename2 = 'RealismGAN_line_plot_loss.png'
    pyplot.savefig(os.path.join(losses_path, filename2))
    pyplot.close()

### Newer train_gan function

In [15]:
def train_GAN(generator, conversion_critic, realism_critic, ConvGAN, RealGAN, vis_imr, start, epochs, n_critic=2, batch_size=32):
    
    half_batch = batch_size // 2
    dir_path = './Tiles/Train/' 
    filelist = os.listdir(dir_path)
    num_files = len(filelist)
    steps_per_epoch = num_files // batch_size

    conv1_hist, conv2_hist, real1_hist, real2_hist, gc_hist, gr_hist = list(), list(), list(), list(), list(), list() 
    
    # Initialize the custom data generator
    data_generator = cust_data_generator(dir_path, batch_size)
    vis_vectors = np.random.normal(-1, 1, (10, 128, 8, 1))
    for epoch in range(start, epochs):
        c1_tmp, c2_tmp, r1_tmp, r2_tmp, gc_tmp, gr_tmp = list(), list(), list(), list(), list(), list()
        
        # iterate batch by batch
        vis = True
        for step in range(steps_per_epoch): 
            # get a batch of real images and imrs
            l = next(data_generator)
            x_real, imrs = l[0], l[1]

            # generate a batch of fake images from the batch of imrs and z_vectors
            z_vectors = np.random.normal(-1, 1, (batch_size, 128, 8, 1))
            x_fake, y_fake = generate_fake_samples(generator, imrs, z_vectors, n_samples=batch_size)

            # call test_on_batch to get loss of the critics without updating their weights in the process
            c_loss = conversion_critic.test_on_batch([imrs, x_fake], y_fake)
            r_loss = realism_critic.test_on_batch(x_fake, y_fake)
            
            critic = 'Realism'

            # choose the critic to be trained in this batch
            if (c_loss > (2 * r_loss)):
                C = conversion_critic
                critic = 'Conversion'
            else:
                C = realism_critic
                critic = 'Realism'
                
            for i in range(n_critic):
                start = i * half_batch
                end = (i + 1) * half_batch
                
                # Half a batch of all the inputs
                imr = imrs[start:end]
                real_images = x_real[start:end]
                fake_images = x_fake[start:end]
                
                y_real = -np.ones((half_batch, 1))
                y_fake = np.ones((half_batch, 1))
                
                # Train the critic
                if critic == 'Conversion':
                    c_loss1 = C.train_on_batch([imr, real_images], y_real)
                    c_loss2 = C.train_on_batch([imr, fake_images], y_fake)
                    c1_tmp.append(c_loss1)
                    c2_tmp.append(c_loss2)
                    
                else:
                    r_loss1 = C.train_on_batch(real_images, y_real)
                    r_loss2 = C.train_on_batch(fake_images, y_fake)
                    r1_tmp.append(r_loss1)
                    r2_tmp.append(r_loss2)
            
            # Train the Generator
            y_gen = -np.ones((batch_size, 1))

            if critic == 'Conversion':
                g_loss = ConvGAN.train_on_batch([imrs, z_vectors], y_gen)
                gc_tmp.append(g_loss)
            else:
                g_loss = RealGAN.train_on_batch([imrs, z_vectors], y_gen)
                gr_tmp.append(g_loss)
        
        conv1_hist.append(np.mean(c1_tmp))
        conv2_hist.append(np.mean(c2_tmp))
        real1_hist.append(np.mean(r1_tmp))
        real2_hist.append(np.mean(r2_tmp))
        gc_hist.append(np.mean(gc_tmp))
        gr_hist.append(np.mean(gr_tmp))
            
        print('epoch %d > conv1=%.3f,        conv2=%.3f,        real1=%.3f,        real2=%.3f,        convgen=%.3f,        realgen=%.3f' % (epoch + 1, conv1_hist[-1], conv2_hist[-1], real1_hist[-1], real2_hist[-1], gc_hist[-1], gr_hist[-1]))
        # summarize
        if (epoch + 1) % 10 == 0: # (epoch + 1) % 2 == 0
            summarize(epoch, generator, conversion_critic, realism_critic, vis_imr, vis_vectors, 10)
    plot_history(conv1_hist, conv2_hist, real1_hist, real2_hist, gc_hist, gr_hist)


In [None]:
# Create Models
generator = make_generator()
conversion_critic = make_conversion_critic()
realism_critic = make_realism_critic()
RealGAN = define_RealGAN(generator, realism_critic)
ConvGAN = define_ConvGAN(generator, conversion_critic)

# Train the GANs
vis_path = './Tiles/vis/'
vis_IMRS = get_vis_imrs(vis_path)
train_GAN(generator, conversion_critic, realism_critic, ConvGAN, RealGAN, vis_IMRS, 0, 50, 2, 32)

### **Loading and Training**


In [None]:
c_path = './Model/savedModel/Critics/conv_0050.tf'
r_path = './Model/savedModel/Critics/real_0050.tf'
g_path = './Model/savedModel/model_0050.tf'
c = tf.keras.saving.load_model(c_path, compile=True)
r = tf.keras.saving.load_model(r_path, compile=True)
g = tf.keras.saving.load_model(g_path)
cg = define_ConvGAN(g, c)
rg = define_RealGAN(g, r)

vis_path = './Tiles/vis/'
vis_IMRS = get_vis_imrs(vis_path)

# Training Call
start_epoch = 60
end_epoch = 100
train_GAN(g, c, r, cg, rg, vis_IMRS, start_epoch, end_epoch, 2, 32)