In [None]:
from matplotlib import pyplot as plt
from keras.models import Model
from tensorflow.image import ssim
from tensorflow.keras import backend as K
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input
import tensorflow_probability as tfp
import numpy as np
from IPython import display
import cv2
import tensorflow as tf
import os
import pathlib
import time
import datetime
import glob
import gc

DATASET_PATH = 'task_3/dataset/processed_hk_norm_unenhanced_iris_dataset_64x240_png/'

BATCH_SIZE = 4

IMG_HEIGHT = 64
IMG_WIDTH = 240
# --------------------------------------------------------------------------------------------------------------- #

# data disrtibution 

# 209 subjects 
# 150 training   72%
# 30 validation  14%
# 29 testing     14%

training_files = glob.glob(DATASET_PATH + 'train/*.png')
test_files = glob.glob(DATASET_PATH + 'test/*.png')
test_files.sort()
trainingset_size = len(training_files)

validation_files = []
new_test_files = []
len_file = len(test_files[1])

i=0
for file in test_files:
    sub = test_files[i][len_file-11:len_file-8]
    if int(sub) <= 180 :
        validation_files.append(file)
    else :
        new_test_files.append(file)
    i = i+1

validationset_size = len(validation_files)

test_files = new_test_files
testset_size = len(test_files)

DATASET_PATH  = pathlib.Path(DATASET_PATH)

# --------------------------------------------------------------------------------------------------------------- #

# instead of epochs
EPOCHS = 80

NSTEPS = trainingset_size * EPOCHS

N_EPOCH_EARLY_STOPPING = 40 # 10 epochs * 4 

INPUT_CHANNELS = 1
OUTPUT_CHANNELS = 4

# buffer size is equal to training set size
BUFFER_SIZE = trainingset_size

# log directory 
LOG_DIR = "logs/" + '_nsteps_' + str(NSTEPS) + '_batchsize_' + str(BATCH_SIZE)  + '/'

In [None]:
def load(image_file):
    # Read and decode an image file to a uint8 tensor
    image = tf.io.read_file(image_file)
    image = tf.io.decode_png(image)
        
    # Split each image tensor into two tensors:
    # - one with a real building facade image
    # - one with an architecture label image 
    w = tf.shape(image)[1]
    w = w // 2

    input_image = image[:, :w, :]
    real_image = image[:, w:, :]

    # Convert both images to float32 tensors
    input_image = tf.cast(input_image, tf.float32)
    real_image = tf.cast(real_image, tf.float32)

    return input_image, real_image

In [None]:
def resize(input_image, real_image, height, width):
    input_image = tf.image.resize(input_image, [height, width],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    real_image = tf.image.resize(real_image, [height, width],
                               method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    
    return input_image, real_image

In [None]:
def normalize(input_image, real_image):
    input_image = (input_image / 127.5) - 1
    real_image = (real_image / 127.5) - 1

    return input_image, real_image

def denormalize(input_image, real_image): # to [0, 255]
    input_image = (input_image + 1) * 127.5
    real_image = (real_image + 1) * 127.5

    return real_image

In [None]:
# augmentation step 

def random_crop(input_image, real_image):
    stacked_image = tf.stack([input_image, real_image], axis=0)
    cropped_image = tf.image.random_crop(
        stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])

    return cropped_image[0], cropped_image[1]

@tf.function()
def random_jitter(input_image, real_image):
    # Resizing to 286x286
    jitter_offset = 30
    input_image, real_image = resize(input_image, real_image, IMG_HEIGHT + jitter_offset, IMG_WIDTH + jitter_offset)

    # Random cropping back to 256x256
    input_image, real_image = random_crop(input_image, real_image)

    if tf.random.uniform(()) > 0.5:
        # Random mirroring
        input_image = tf.image.flip_left_right(input_image)
        real_image = tf.image.flip_left_right(real_image)

    return input_image, real_image

In [None]:
# daugman feature extraction 

def tf_ProcessSingleChannel(channel):
    h = tf.histogram_fixed_width(channel, value_range=(0, 255), nbins=256)

    h = tf.cast(h, tf.float32)
    pixel_values = tf.range(256, dtype=tf.float32)
    
    weighted_sum = tf.reduce_sum(pixel_values * h)
    total_pixels = tf.reduce_sum(h)
    mean_val = weighted_sum / total_pixels

    # Compute variance and standard deviation
    variance = tf.reduce_sum(((pixel_values - mean_val) ** 2) * h) / total_pixels
    std_dev = tf.sqrt(variance)

    # Compute Gaussian values
    gaussian_vals = (1 / (std_dev * tf.sqrt(2 * np.pi))) * tf.exp(-0.5 * ((pixel_values - mean_val) / std_dev) ** 2)

    # Set threshold
    threshold = tf.reduce_max(gaussian_vals) * 0.1  # For example, 10% of the maximum

    # Find values to eliminate
    to_eliminate = gaussian_vals < threshold

    ProcessedChannel = tf.cast(tf.identity(channel), dtype=tf.float32)  # Create a copy

    # Replace values below the threshold
    for i in range(len(to_eliminate)):
        if to_eliminate[i]:
            ProcessedChannel = tf.where(channel == i, mean_val + std_dev, ProcessedChannel)

    return ProcessedChannel

def tf_GaussHistCut(image):
    channels = 1
    if len(image.shape) > 2:
        _, _, channels = image.shape

    if channels == 3:  # RGB image
        CorrectedImage = tf.zeros_like(image, dtype=tf.uint8)

        for ch in range(channels):
            CorrectedImage[:, :, ch] = tf_ProcessSingleChannel(image[:, :, ch])
    
    else:  # Grayscale image
        CorrectedImage = tf_ProcessSingleChannel(image)

    return CorrectedImage

def tf_rescale(data):
    data_min = tf.reduce_min(data)
    data_max = tf.reduce_max(data)
    return (data - data_min) / (data_max - data_min)

def tf_mad_normalize(channel):
    mad = tfp.stats.percentile(tf.abs(channel - tfp.stats.percentile(channel, 50)), 50)
    is_zero_mad = tf.equal(mad, 0)
    channel = tf.where(is_zero_mad, tf.zeros_like(channel), (channel - tfp.stats.percentile(channel, 50)) / mad)
    return tf_rescale(channel)

def tf_daugman_normalization(image) : #(image):

    AR, AG, AB = tf.split(image, num_or_size_splits=3, axis=-1)

    # Apply GaussHistCut
    AR = tf_GaussHistCut(tf.cast(AR, dtype=tf.int32)) #uint8 non è supportato da hist
    #AG = tf_GaussHistCut(AG)
    #AB = tf_GaussHistCut(AB)

    AR = tf_mad_normalize(AR)
    #AG = tf_mad_normalize(AG)
    #AB = tf_mad_normalize(AB)

    # Replace NaN and Inf values with 0
    AR = tf.where(tf.math.is_nan(AR) | tf.math.is_inf(AR), tf.zeros_like(AR), AR)
    #AG = tf.where(tf.math.is_nan(AG) | tf.math.is_inf(AG), tf.zeros_like(AG), AG)
    #AB = tf.where(tf.math.is_nan(AB) | tf.math.is_inf(AB), tf.zeros_like(AB), AB)

    # Create the normalized image
    #norm_image = tf.concat([AR, AG, AB], axis=-1)

    return AR #return norm_image
    
def tf_gaborconvolve(im, nscale, minWaveLength, mult, sigmaOnf):
    rows = IMG_HEIGHT #im.shape[0]
    cols = IMG_WIDTH #im.shape[1]
    
    filtersum = tf.zeros(cols, dtype=tf.float32)
    EO = [None] * nscale
    
    ndata = cols

    logGabor = tf.zeros(ndata, dtype=tf.float32)
    result = tf.zeros([rows, ndata], dtype=tf.complex128)
    
    radius = tf.range(0, ndata // 2 + 1, dtype=tf.float64) / (ndata // 2) / 2  # Frequency values 0 - 0.5
    zerovalue = tf.cast(tf.constant([1.0]), dtype=tf.float64)
    radius = tf.tensor_scatter_nd_update(radius, tf.constant([[0]]), zerovalue)
    
    wavelength = minWaveLength  # Initialize filter wavelength
    
    for s in range(nscale):
        # Construct the filter - first calculate the radial filter component
        fo = 1.0 / wavelength  # Centre frequency of filter
        # corresponding to fo
        
        sum = tf.exp( tf.cast( - tf.pow((tf.math.log(radius/fo)), 2), dtype=tf.float32) / (2 * tf.pow(tf.math.log(sigmaOnf), 2)))


        indexes = tf.expand_dims(tf.range(0, sum.shape[0]), axis=1)

        logGabor = tf.tensor_scatter_nd_update(logGabor, indexes, sum)
        logGabor = tf.tensor_scatter_nd_update(logGabor, tf.constant([[0]]), tf.constant([0.0]))
        
        filter = logGabor
        filtersum = filtersum + filter
        
        for r in range(rows):
            signal = im[r, 0:ndata]
            imagefft = tf.signal.fft(tf.cast(signal, dtype=tf.complex128))
            filter = tf.cast(filter, dtype=tf.complex128)
            result = tf.tensor_scatter_nd_add(result, [tf.constant([r])], [tf.signal.ifft(imagefft * filter)])
        
        EO[s] = result
        wavelength *= mult  # Finally calculate the wavelength of the next filter
    
    filtersum = tf.signal.fftshift(filtersum)
    
    return EO, filtersum

def tf_encode(polar_array, nscales, minWaveLength, mult, sigmaOnf):
    # Convoluzione della regione normalizzata con filtri di Gabor
    E0, _ = tf_gaborconvolve(polar_array, nscales, minWaveLength, mult, sigmaOnf)
    
    H = tf.zeros(E0[0].shape)
    for k in range(1, nscales + 1):
        E1 = E0[k - 1]

        cond_0 = tf.math.logical_and(tf.math.real(E1) <= 0, tf.math.imag(E1) <= 0)
        cond_1 = tf.math.logical_and(tf.math.real(E1) <= 0, tf.math.imag(E1) > 0)
        cond_2 = tf.math.logical_and(tf.math.real(E1) > 0, tf.math.imag(E1) <= 0)
        cond_3 = tf.math.logical_and(tf.math.real(E1) > 0, tf.math.imag(E1) > 0)

        H=tf.where(cond_0,0.0,H)
        H=tf.where(cond_1,1.0,H)
        H=tf.where(cond_2,2.0,H)
        H=tf.where(cond_3,3.0,H)

    return H

def tf_GaborBitStreamSTACKED(AR): #polarImage):

    #AR, AG, AB = tf.split(polarImage, num_or_size_splits=3, axis=-1)

    nscales = 1
    minWaveLength = 24
    mult = 1
    sigmaOnf = 0.5

    TR = tf_encode(tf.squeeze(AR), nscales, minWaveLength, mult, sigmaOnf)
    #TG = tf_encode(tf.squeeze(AG), nscales, minWaveLength, mult, sigmaOnf)
    #TB = tf_encode(tf.squeeze(AB), nscales, minWaveLength, mult, sigmaOnf)

    TR = tf.cast(TR, dtype=tf.uint8)

    return tf.expand_dims(TR, axis=2) #return tf.concat([tf.expand_dims(TR, axis=2) , tf.expand_dims(TG, axis=2), tf.expand_dims(TB, axis=2)], axis=-1)

def tf_daugman_feature_extractor(inp):
    return tf_GaborBitStreamSTACKED(inp)

In [None]:
# printing input image - feature image

inp,tar = load(training_files[1])
print(training_files[0])
norm_inp = tf_daugman_normalization(inp)
norm_tar = tf_daugman_normalization(tar)

fi_inp= tf_daugman_feature_extractor(norm_inp)
fi_tar = tf_daugman_feature_extractor(norm_tar)

display_list = [inp/255.0 , norm_inp, fi_tar]
title = ['Input Image', 'Norm_Input 1D', 'Feature Image']
plt.figure(figsize=(15, 15))

for i in range(3):
    plt.subplot(1, 3, i+1)
    plt.title(title[i])
    plt.imshow(display_list[i])
    #plt.axis('off')

plt.show()

In [None]:
def load_image_train(image_file):
    input_image, real_image = load(image_file)

    input_image = tf_daugman_normalization(input_image)

    real_image = tf_daugman_normalization(real_image)
    real_image = tf_daugman_feature_extractor(real_image)

    return input_image, real_image

In [None]:
def load_image_validation(image_file):
    input_image, real_image = load(image_file)

    input_image = tf_daugman_normalization(input_image)

    real_image = tf_daugman_normalization(real_image)
    real_image = tf_daugman_feature_extractor(real_image)

    return input_image, real_image

In [None]:
def load_image_test(image_file):
    input_image, real_image = load(image_file)

    input_image = tf_daugman_normalization(input_image)

    real_image = tf_daugman_normalization(real_image)
    real_image = tf_daugman_feature_extractor(real_image)

    return input_image, real_image

In [None]:
train_dataset = tf.data.Dataset.list_files(str(DATASET_PATH / 'train/*.png'))
train_dataset = train_dataset.map(load_image_train,
                                  num_parallel_calls=tf.data.AUTOTUNE)
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.batch(BATCH_SIZE)

In [None]:
validation_dataset = tf.data.Dataset.from_tensor_slices(validation_files)
validation_dataset = validation_dataset.map(load_image_validation)
validation_dataset = validation_dataset.batch(BATCH_SIZE)

In [None]:
test_dataset = tf.data.Dataset.from_tensor_slices(test_files)
test_dataset = test_dataset.map(load_image_test,
                                    num_parallel_calls=tf.data.AUTOTUNE)
test_dataset = test_dataset.batch(1)

In [None]:
def downsample(filters, size, apply_batchnorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)
    
    result = tf.keras.Sequential()
    result.add(tf.keras.layers.Conv2D(filters, size, strides=2, padding='same', kernel_initializer=initializer, use_bias=False))

    if apply_batchnorm:
        result.add(tf.keras.layers.BatchNormalization())

    result.add(tf.keras.layers.LeakyReLU())

    return result

In [None]:
def upsample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)

    result = tf.keras.Sequential()
    result.add(tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                    padding='same',
                                    kernel_initializer=initializer,
                                    use_bias=False))

    result.add(tf.keras.layers.BatchNormalization())

    if apply_dropout:
        result.add(tf.keras.layers.Dropout(0.5))

    result.add(tf.keras.layers.ReLU())

    return result

In [None]:
def Generator():
    inputs = tf.keras.layers.Input(shape=[IMG_HEIGHT, IMG_WIDTH, INPUT_CHANNELS])

    down_stack = [
        downsample(64, 4, apply_batchnorm=False),  
        downsample(128, 4), 
        downsample(256, 4),  
        downsample(512, 4),  
      ]

    up_stack = [
        upsample(512, 4, apply_dropout=True),  
        upsample(256, 4),  
        upsample(128, 4),
        upsample(64, 4), 
    ]



    initializer = tf.random_normal_initializer(0., 0.02)
    last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                         strides=2,
                                         padding='same',
                                         kernel_initializer=initializer,
                                         activation='softmax') 


    # --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- #
    x = inputs

    # Downsampling through the model
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)

    skips = reversed(skips[:-1])

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = tf.keras.layers.Concatenate()([x, skip])

    x = last(x)
    
    return tf.keras.Model(inputs=inputs, outputs=x)

In [None]:
generator = Generator()
#tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)

In [None]:
# -------- PRINT GENERATED IMAGES -------- #

gen_output = generator(norm_inp[tf.newaxis, ...], training=False)
#print(gen_output.shape)
#plt.imshow(gen_output[0, ...])

In [None]:
def DiceLoss(y_true, y_pred, smooth=1e-6):
    # convert the tensor to one-hot for multi-class segmentation
    y_true = K.squeeze(y_true, 3)
    y_true = tf.cast(y_true, "int32")
    y_true = tf.one_hot(y_true, 4, axis=-1)
    
    # cast to float32 datatype
    y_true = K.cast(y_true, 'float32')
    y_pred = K.cast(y_pred, 'float32')
    
    #flatten label and prediction tensors
    inputs = K.flatten(y_pred)
    targets = K.flatten(y_true)

    intersection = K.sum(targets * inputs)
    dice = (2*intersection + smooth) / (K.sum(targets) + K.sum(inputs) + smooth)
    return 1 - dice


#Tensorflow / Keras 
def IoULoss(y_true, y_pred, smooth=1e-6):
    # convert the tensor to one-hot for multi-class segmentation
    y_true = K.squeeze(y_true, 3)
    y_true = tf.cast(y_true, "int32")
    y_true = tf.one_hot(y_true, 4, axis=-1)
    
    # cast to float32 datatype
    y_true = K.cast(y_true, 'float32')
    y_pred = K.cast(y_pred, 'float32')
    
    #flatten label and prediction tensors
    inputs = K.flatten(y_pred)
    targets = K.flatten(y_true)
    
    intersection = K.sum(targets * inputs)
    total = K.sum(targets) + K.sum(inputs)
    union = total - intersection
    
    IoU = (intersection + smooth) / (union + smooth)
    return 1 - IoU

scce_loss = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=False,
    reduction='auto')

In [None]:
def generator_loss(gen_output, target):
    return IoULoss(target, gen_output)

In [None]:
def create_mask(pred_mask):
    pred_mask = tf.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., tf.newaxis]
    return pred_mask

In [None]:
def validation_step(generator, validation_ds) :
    val_error = []
    for input_image, target in validation_ds:
        gen_output = generator(input_image, training=False)
        loss = generator_loss(gen_output, target)
        val_error.append(loss)
        
    return np.mean(val_error)

In [None]:
def save_models(string, generator) :
    generator.save('models/' + string)
    return

In [None]:
#generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_optimizer = tf.keras.optimizers.SGD()

In [None]:
def generate_images(model, test_input, tar):
    prediction = model(test_input, training=False)
    mask_prediction = create_mask(prediction)

    plt.figure(figsize=(15, 15))

    display_list = [test_input[0], tar[0], mask_prediction[0]]
    title = ['Input Image', 'Ground Truth', 'Predicted Image']

    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.title(title[i])
        # Getting the pixel values in the [0, 1] range to plot.
        plt.imshow(display_list[i])
        plt.axis('off')
    
    plt.show()

In [None]:
def save_images(model, test_input, tar, step):
    prediction = model(test_input)#, training=True)
    mask_prediction = create_mask(prediction)
    plt.figure(figsize=(15, 15))

    display_list = [test_input[0], tar[0], mask_prediction[0]]
    title = ['Input Image', 'Ground Truth', 'Predicted Image']

    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.title(title[i])
        # Getting the pixel values in the [0, 1] range to plot.
        plt.imshow(display_list[i])
        plt.axis('off')
    #plt.show()
    filename = 'image_' + str(step) + '.jpg'
    plt.savefig(filename)

In [None]:
summary_writer = tf.summary.create_file_writer(
    LOG_DIR + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

In [None]:
@tf.function
def train_step(input_image, target, step, pw_val_error):
    with tf.GradientTape() as gen_tape :
        gen_output = generator(input_image, training=True)

        gen_loss = generator_loss(gen_output, target)

    generator_gradients = gen_tape.gradient(gen_loss,
                                          generator.trainable_variables)
    
    generator_optimizer.apply_gradients(zip(generator_gradients,
                                          generator.trainable_variables))

    with summary_writer.as_default():
        tf.summary.scalar("loss", gen_loss, step=step//trainingset_size)
        tf.summary.scalar('pw_val_error', pw_val_error, step=step//trainingset_size)


In [None]:
def fit(train_ds, validation_ds, test_ds, steps):
    
    example_input, example_target = next(iter(test_ds.take(1)))
    start = time.time()
    min_val_error = float("inf")
    pw_val_error = float("inf")
    count = 0
    count_stopping = 0

    for step, (input_image, target) in train_ds.repeat().take(steps).enumerate():
        
        if (step) % trainingset_size == 0:
            display.clear_output(wait=True)

            if step != 0:
                print(f'Time taken for {trainingset_size} steps: {time.time()-start:.2f} sec\n')

            start = time.time()

            generate_images(generator, example_input, example_target)
            print(f"Step: {step//trainingset_size}k")

        train_step(input_image, target, step, pw_val_error)

        if ((count +1) % trainingset_size) == 0 : 
            pw_val_error = validation_step(generator, validation_ds)
            print("pw_error : ", pw_val_error, "   at step : ", count)
            
            if pw_val_error < min_val_error :
                print("updating min_val_error..")
                count_stopping = 0
                min_val_error = pw_val_error
                filename = 'best_' + str(count + 1) + '.h5'

                save_images(generator, example_input, example_target, count+1)
                save_models(filename, generator)
            else :
                count_stopping = count_stopping +1 

        
        # Training step
        if (step+1) % int(trainingset_size * 0.05) == 0:
            print('.', end='', flush=True)
                           
        count = count +1 

In [None]:
fit(train_dataset, validation_dataset, test_dataset, steps= NSTEPS)
save_models('last_model_.h5', generator)