tuning for sinogram unet.

In [None]:
import os
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
from sklearn.model_selection import train_test_split
import cv2
import matplotlib.pyplot as plt
from tensorflow.image import ssim
from tensorflow.keras.optimizers import Adam


# loss functions
def ssim_loss(y_true, y_pred):
    return 1 - tf.reduce_mean(ssim(y_true, y_pred, max_val=1.0))

# PSNR loss function
def psnr_loss(y_true, y_pred):
    return -tf.image.psnr(y_true, y_pred, max_val=1.0)

# load and preprocess sinogram images
def load_and_preprocess_sinograms(input_folder, output_folder):
    input_images = []
    output_images = []
    
    input_filenames = sorted(os.listdir(input_folder))
    output_filenames = sorted(os.listdir(output_folder))
    
    for input_filename, output_filename in zip(input_filenames, output_filenames):
        input_img_path = os.path.join(input_folder, input_filename)
        output_img_path = os.path.join(output_folder, output_filename)
        
        input_img = cv2.imread(input_img_path, cv2.IMREAD_GRAYSCALE)
        output_img = cv2.imread(output_img_path, cv2.IMREAD_GRAYSCALE)
        
        if input_img is not None and output_img is not None:
            input_img = cv2.resize(input_img, (128, 128)) 
            output_img = cv2.resize(output_img, (128, 128))  
            
            input_img = input_img.astype('float32') / 255.0  
            output_img = output_img.astype('float32') / 255.0  
            
            input_images.append(input_img)
            output_images.append(output_img)
    
    return np.array(input_images), np.array(output_images)

input_folder = '/jupyter/work/fyp/data/sinograms/1st_set/lsd'
output_folder = '/jupyter/work/fyp/data/sinograms/1st_set/hsd'

input_images, output_images = load_and_preprocess_sinograms(input_folder, output_folder)

train_input_images, test_input_images, train_output_images, test_output_images = train_test_split(
    input_images, output_images, test_size=0.2, random_state=42
)

# best width and depth values
best_width = 128
best_depth = 5

# define hyperparameters to tune
learning_rates = [0.001, 0.0001]
optimizers = [Adam, tf.keras.optimizers.SGD]
loss_functions = [psnr_loss, ssim_loss]

# collect loss values for different hyperparameter combinations
loss_combinations = []

for lr in learning_rates:
    for optimizer in optimizers:
        for loss_func in loss_functions:
            # U-Net model architecture for sinograms
            inputs = tf.keras.Input(shape=(128, 128, 1))
            x = inputs

            # Encoder
            for _ in range(best_depth):
                x = layers.Conv2D(best_width, 3, activation='relu', padding='same')(x)
                x = layers.MaxPooling2D(pool_size=(2, 2))(x)

            # Bottleneck
            x = layers.Conv2D(best_width, 3, activation='relu', padding='same')(x)

            # Decoder
            for _ in range(best_depth):
                x = layers.Conv2DTranspose(best_width, 2, strides=(2, 2), padding='same')(x)
                x = layers.Conv2D(best_width, 3, activation='relu', padding='same')(x)

            outputs = layers.Conv2D(1, 1, activation='sigmoid', padding='same')(x)

            model = tf.keras.Model(inputs=inputs, outputs=outputs, name=f"sinogram_unet_{best_width}_{best_depth}")

            optimizer_instance = optimizer(learning_rate=lr)
            model.compile(optimizer=optimizer_instance, loss=loss_func, metrics=['accuracy'])

            # train the model
            print(f"Training model with LR={lr}, Optimizer={optimizer.__name__}, Loss={loss_func.__name__}")
            history = model.fit(train_input_images, train_output_images, epochs=50, batch_size=32, validation_split=0.2, verbose=1)

            # evaluate the model
            test_loss = model.evaluate(test_input_images, test_output_images, verbose=1)

            # collect test loss values
            loss_combinations.append((lr, optimizer.__name__, loss_func.__name__, test_loss[0]))

            print(f"Test Loss: {test_loss[0]}")

for lr, optimizer, loss_func, test_loss in loss_combinations:
    print(f"LR={lr}, Optimizer={optimizer}, Loss={loss_func}, Test Loss={test_loss}")

# plot the loss for each combination
loss_values = [loss for _, _, _, loss in loss_combinations]
plt.plot(loss_values)
plt.xlabel('Combination')
plt.ylabel('Test Loss')
plt.title('Test Loss for Different Hyperparameter Combinations')
plt.xticks(range(len(loss_combinations)), [f"{lr}, {optimizer}, {loss_func}" for lr, optimizer, loss_func, _ in loss_combinations], rotation=90)
plt.grid(True)
plt.tight_layout()
plt.show()
