In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models
import os
import cv2
import numpy as np
from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import ModelCheckpoint

# function to load and preprocess images
def load_and_preprocess_images(folder_path):
    images = []
    for filename in os.listdir(folder_path):
        img_path = os.path.join(folder_path, filename)
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        if img is not None:
            img = cv2.resize(img, (128, 128)) 
            img = img.astype('float32') / 255.0 
            images.append(img)
    return np.array(images)

low_res_folder = '/home/hinata/code/fyp/images/ml_images/rec/4_noise_img'
high_res_folder = '/home/hinata/code/fyp/images/ml_images/rec/16_noise_img'

low_res_images = load_and_preprocess_images(low_res_folder)

high_res_images = load_and_preprocess_images(high_res_folder)

train_low_images, test_low_images, train_high_images, test_high_images = train_test_split(
    low_res_images, high_res_images, test_size=0.2, random_state=42
)

# U-Net model architecture
def unet(input_shape):
    inputs = tf.keras.Input(shape=input_shape)

    # Encoder
    conv1 = layers.Conv2D(64, 3, activation='relu', padding='same')(inputs)
    conv1 = layers.Conv2D(64, 3, activation='relu', padding='same')(conv1)
    pool1 = layers.MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = layers.Conv2D(128, 3, activation='relu', padding='same')(pool1)
    conv2 = layers.Conv2D(128, 3, activation='relu', padding='same')(conv2)
    pool2 = layers.MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = layers.Conv2D(256, 3, activation='relu', padding='same')(pool2)
    conv3 = layers.Conv2D(256, 3, activation='relu', padding='same')(conv3)
    pool3 = layers.MaxPooling2D(pool_size=(2, 2))(conv3)

    # Bottleneck
    conv4 = layers.Conv2D(512, 3, activation='relu', padding='same')(pool3)
    conv4 = layers.Conv2D(512, 3, activation='relu', padding='same')(conv4)

    # Decoder
    up5 = layers.Conv2DTranspose(256, 2, strides=(2, 2), padding='same')(conv4)
    up5 = layers.concatenate([up5, conv3], axis=3)
    conv5 = layers.Conv2D(256, 3, activation='relu', padding='same')(up5)
    conv5 = layers.Conv2D(256, 3, activation='relu', padding='same')(conv5)

    up6 = layers.Conv2DTranspose(128, 2, strides=(2, 2), padding='same')(conv5)
    up6 = layers.concatenate([up6, conv2], axis=3)
    conv6 = layers.Conv2D(128, 3, activation='relu', padding='same')(up6)
    conv6 = layers.Conv2D(128, 3, activation='relu', padding='same')(conv6)

    up7 = layers.Conv2DTranspose(64, 2, strides=(2, 2), padding='same')(conv6)
    up7 = layers.concatenate([up7, conv1], axis=3)
    conv7 = layers.Conv2D(64, 3, activation='relu', padding='same')(up7)
    conv7 = layers.Conv2D(64, 3, activation='relu', padding='same')(conv7)

    outputs = layers.Conv2D(1, 1, activation='sigmoid')(conv7) 
    model = tf.keras.Model(inputs=inputs, outputs=outputs, name="img_to_img_unet")
    return model

input_shape = (128, 128, 1)  # Adjust input shape based on your data
model = unet(input_shape)

model.compile(optimizer='adam', loss='mean_squared_error')

# a ModelCheckpoint callback to save the model during training
checkpoint_path = '/home/hinata/code/fyp_final_collation/Machine_Learning/trained_unet_img_to_img_model.h5'
model_checkpoint = ModelCheckpoint(checkpoint_path, save_best_only=True)

model.fit(train_low_images, train_high_images, 
          epochs=250, batch_size=32, validation_split=0.2, callbacks=[model_checkpoint])

model.evaluate(test_low_images, test_high_images)

model.save('/home/hinata/code/fyp_final_collation/Machine_Learning/trained_unet_img_to_img_model.h5')


Test model on new data.

In [None]:
import matplotlib.pyplot as plt

model_path = '/home/hinata/code/fyp_final_collation/Machine_Learning/trained_unet_img_to_img_model.h5'
model = tf.keras.models.load_model(model_path)

num_samples = 5
random_indices = np.random.choice(len(test_low_images), num_samples, replace=False)
sample_low_images = test_low_images[random_indices]
sample_high_images = test_high_images[random_indices]

# Make predictions on the random sample
predictions = model.predict(sample_low_images)

plt.figure(figsize=(12, 8))
for i in range(num_samples):
    plt.subplot(2, num_samples, i + 1)
    plt.imshow(sample_low_images[i].reshape(128, 128), cmap='gray')
    plt.title('Low Resolution Input')
    plt.axis('off')

    plt.subplot(2, num_samples, i + num_samples + 1)
    plt.imshow(sample_high_images[i].reshape(128, 128), cmap='gray')
    plt.title('High Resolution Ground Truth')
    plt.axis('off')

    plt.subplot(2, num_samples, i + 2 * num_samples + 1)
    plt.imshow(predictions[i].reshape(128, 128), cmap='gray')
    plt.title('Predicted High Resolution')
    plt.axis('off')

plt.tight_layout()
plt.show()
