In [None]:
from tensorflow.keras.utils import save_img
from matplotlib import pyplot as plt
import tensorflow_probability as tfp
from keras import backend as K

from os import mkdir, path
import tensorflow as tf
import numpy as np
import pathlib
import glob
import cv2


DATASET_PATH = 'matlab_script/tasks/task_1/hk_test/synthetic_norm_irises/'
MODELS_PATH = 'task_3/feature_extractor/models/'
GENERATED_IMAGES_PATH = 'task_3/feature_extractor/models/'

models_list = ['best_pix2pix_bs4_dice_loss_LP_5.h5', 
               'best_unet_bs4_iou_loss_ep_80_estop_10']

IMG_WIDTH = 240
IMG_HEIGHT = 64

In [None]:
testset_files = glob.glob(DATASET_PATH + '/*.png')
testset_files.sort()
PATHLIB_DATASET_PATH  = pathlib.Path(DATASET_PATH)
LENGTH_IMAGE_PATH = len(testset_files)
print(testset_files[0])

In [None]:
def load(image_file):
    # Read and decode an image file to a uint8 tensor
    image = tf.io.read_file(image_file)
    image = tf.io.decode_png(image)
    image = tf.cast(image, tf.float32)
        
    return image

In [None]:
def resize(input_image, height, width):
    input_image = tf.image.resize(input_image, [height, width],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    
    return input_image

In [None]:
# Normalizing the images to [-1, 1]
def normalize(input_image):
    input_image = (input_image / 127.5) - 1
    return input_image

In [None]:
# daugman feature extraction 

def tf_ProcessSingleChannel(channel):
    h = tf.histogram_fixed_width(channel, value_range=(0, 255), nbins=256)

    h = tf.cast(h, tf.float32)
    pixel_values = tf.range(256, dtype=tf.float32)
    
    weighted_sum = tf.reduce_sum(pixel_values * h)
    total_pixels = tf.reduce_sum(h)
    mean_val = weighted_sum / total_pixels

    # Compute variance and standard deviation
    variance = tf.reduce_sum(((pixel_values - mean_val) ** 2) * h) / total_pixels
    std_dev = tf.sqrt(variance)

    # Compute Gaussian values
    gaussian_vals = (1 / (std_dev * tf.sqrt(2 * np.pi))) * tf.exp(-0.5 * ((pixel_values - mean_val) / std_dev) ** 2)

    # Set threshold
    threshold = tf.reduce_max(gaussian_vals) * 0.1  # For example, 10% of the maximum

    # Find values to eliminate
    to_eliminate = gaussian_vals < threshold

    ProcessedChannel = tf.identity(channel)  # Create a copy

    # Replace values below the threshold
    for i in range(len(to_eliminate)):
        if to_eliminate[i]:
            ProcessedChannel = tf.where(channel == i, mean_val + std_dev, ProcessedChannel)

    return ProcessedChannel

def tf_GaussHistCut(image):
    channels = 1
    if len(image.shape) > 2:
        _, _, channels = image.shape

    if channels == 3:  # RGB image
        CorrectedImage = tf.zeros_like(image, dtype=tf.uint8)

        for ch in range(channels):
            CorrectedImage[:, :, ch] = tf_ProcessSingleChannel(image[:, :, ch])
    
    else:  # Grayscale image
        CorrectedImage = tf_ProcessSingleChannel(image)

    return CorrectedImage

def tf_rescale(data):
    data_min = tf.reduce_min(data)
    data_max = tf.reduce_max(data)
    return (data - data_min) / (data_max - data_min)

def tf_mad_normalize(channel):
    mad = tfp.stats.percentile(tf.abs(channel - tfp.stats.percentile(channel, 50)), 50)
    is_zero_mad = tf.equal(mad, 0)
    channel = tf.where(is_zero_mad, tf.zeros_like(channel), (channel - tfp.stats.percentile(channel, 50)) / mad)
    return tf_rescale(channel)

def tf_daugman_normalization(AR) : #(image):

    #AR, AG, AB = tf.split(image, num_or_size_splits=3, axis=-1)

    # Apply GaussHistCut
    AR = tf_GaussHistCut(AR)
    #AG = tf_GaussHistCut(AG)
    #AB = tf_GaussHistCut(AB)

    AR = tf_mad_normalize(AR)
    #AG = tf_mad_normalize(AG)
    #AB = tf_mad_normalize(AB)

    # Replace NaN and Inf values with 0
    AR = tf.where(tf.math.is_nan(AR) | tf.math.is_inf(AR), tf.zeros_like(AR), AR)
    #AG = tf.where(tf.math.is_nan(AG) | tf.math.is_inf(AG), tf.zeros_like(AG), AG)
    #AB = tf.where(tf.math.is_nan(AB) | tf.math.is_inf(AB), tf.zeros_like(AB), AB)

    # Create the normalized image
    #norm_image = tf.concat([AR, AG, AB], axis=-1)

    return AR #return norm_image
    
def tf_gaborconvolve(im, nscale, minWaveLength, mult, sigmaOnf):
    rows = IMG_HEIGHT #im.shape[0]
    cols = IMG_WIDTH #im.shape[1]
    
    filtersum = tf.zeros(cols, dtype=tf.float32)
    EO = [None] * nscale
    
    ndata = cols

    logGabor = tf.zeros(ndata, dtype=tf.float32)
    result = tf.zeros([rows, ndata], dtype=tf.complex128)
    
    radius = tf.range(0, ndata // 2 + 1, dtype=tf.float64) / (ndata // 2) / 2  # Frequency values 0 - 0.5
    zerovalue = tf.cast(tf.constant([1.0]), dtype=tf.float64)
    radius = tf.tensor_scatter_nd_update(radius, tf.constant([[0]]), zerovalue)
    
    wavelength = minWaveLength  # Initialize filter wavelength
    
    for s in range(nscale):
        # Construct the filter - first calculate the radial filter component
        fo = 1.0 / wavelength  # Centre frequency of filter
        # corresponding to fo
        
        sum = tf.exp( tf.cast( - tf.pow((tf.math.log(radius/fo)), 2), dtype=tf.float32) / (2 * tf.pow(tf.math.log(sigmaOnf), 2)))


        indexes = tf.expand_dims(tf.range(0, sum.shape[0]), axis=1)

        logGabor = tf.tensor_scatter_nd_update(logGabor, indexes, sum)
        logGabor = tf.tensor_scatter_nd_update(logGabor, tf.constant([[0]]), tf.constant([0.0]))
        
        filter = logGabor
        filtersum = filtersum + filter
        
        for r in range(rows):
            signal = im[r, 0:ndata]
            imagefft = tf.signal.fft(tf.cast(signal, dtype=tf.complex128))
            filter = tf.cast(filter, dtype=tf.complex128)
            result = tf.tensor_scatter_nd_add(result, [tf.constant([r])], [tf.signal.ifft(imagefft * filter)])
        
        EO[s] = result
        wavelength *= mult  # Finally calculate the wavelength of the next filter
    
    filtersum = tf.signal.fftshift(filtersum)
    
    return EO, filtersum

def tf_encode(polar_array, nscales, minWaveLength, mult, sigmaOnf):
    # Convoluzione della regione normalizzata con filtri di Gabor
    E0, _ = tf_gaborconvolve(polar_array, nscales, minWaveLength, mult, sigmaOnf)
    
    H = tf.zeros(E0[0].shape)
    for k in range(1, nscales + 1):
        E1 = E0[k - 1]

        cond_0 = tf.math.logical_and(tf.math.real(E1) <= 0, tf.math.imag(E1) <= 0)
        cond_1 = tf.math.logical_and(tf.math.real(E1) <= 0, tf.math.imag(E1) > 0)
        cond_2 = tf.math.logical_and(tf.math.real(E1) > 0, tf.math.imag(E1) <= 0)
        cond_3 = tf.math.logical_and(tf.math.real(E1) > 0, tf.math.imag(E1) > 0)

        H=tf.where(cond_0,0.0,H)
        H=tf.where(cond_1,1.0,H)
        H=tf.where(cond_2,2.0,H)
        H=tf.where(cond_3,3.0,H)

    return H

def tf_GaborBitStreamSTACKED(AR): #polarImage):

    #AR, AG, AB = tf.split(polarImage, num_or_size_splits=3, axis=-1)

    nscales = 1
    minWaveLength = 24
    mult = 1
    sigmaOnf = 0.5

    TR = tf_encode(tf.squeeze(AR), nscales, minWaveLength, mult, sigmaOnf)
    #TG = tf_encode(tf.squeeze(AG), nscales, minWaveLength, mult, sigmaOnf)
    #TB = tf_encode(tf.squeeze(AB), nscales, minWaveLength, mult, sigmaOnf)

    TR = tf.cast(TR, dtype=tf.uint8)

    return tf.expand_dims(TR, axis=2) #return tf.concat([tf.expand_dims(TR, axis=2) , tf.expand_dims(TG, axis=2), tf.expand_dims(TB, axis=2)], axis=-1)

def tf_daugman_feature_extractor(inp):
    return tf_GaborBitStreamSTACKED(inp)

In [None]:
import scipy.io
import os 

def save_image(image, images_path, i, grayscale):

    parts = testset_files[i].split('/')
    file_name = parts[-1]

    if grayscale==True:
        color_images_path = images_path + '/grayscale'
    else :
        color_images_path = images_path + '/rgb'

    if not os.path.exists(color_images_path):
        os.makedirs(color_images_path)

    file_path = color_images_path + '/' + file_name 

    file_path = file_path[0:(len(file_path)-4)]
    file_path = file_path + '.mat'

    img = image[0].numpy()
    scipy.io.savemat(file_path, {'matrix': img})


In [None]:
# printing img testset 

inp = load(testset_files[0])
ar_inp,_,_ = tf.split(inp, num_or_size_splits=3, axis=-1)
norm_ar = tf_daugman_normalization(ar_inp)

plt.figure(figsize=(6, 6))

display_list = [norm_ar]
title = ['Ground Truth']

for i in range(1):
    plt.subplot(1, 2, i+1)
    plt.title(title[i])
    plt.imshow(display_list[i])
    plt.axis('off')

plt.show()

In [None]:
def load_image_test(image_file):
    input_image = load(image_file)
    R,G,B = tf.split(input_image, num_or_size_splits=3, axis=-1)
    normR = tf_daugman_normalization(R)
    normG = tf_daugman_normalization(G)
    normB = tf_daugman_normalization(B)
    
    return normR, normG, normB

In [None]:
test_dataset = tf.data.Dataset.from_tensor_slices(testset_files)
test_dataset = test_dataset.map(load_image_test)
test_dataset = test_dataset.batch(1)

In [None]:
for model_name in models_list :
    generator = tf.keras.models.load_model(MODELS_PATH + model_name + '.h5', compile=False)
    model_images_path = GENERATED_IMAGES_PATH + model_name
    mkdir(model_images_path)

   #graycale 
    n_image = 0
    for inpR, inpG, inpB in test_dataset :
        new_img = (inpR + inpG + inpB) / 3
        gen_output = generator(new_img, training=False)          
        save_image(gen_output, model_images_path, n_image, True)
        n_image = n_image + 1

    #rgb 
    n_image = 0
    for inpR, inpG, inpB in test_dataset :
        gen_outR = generator(inpR, training=False)    
        gen_outG = generator(inpG, training=False)    
        gen_outB = generator(inpB, training=False)    
        gen_output = tf.concat([gen_outR, gen_outG, gen_outB], axis=-1)
    
        save_image(gen_output, model_images_path, n_image, False)
        n_image = n_image + 1
