This script will compare the performance of these models:

RDN, RRDN, EDSR, SRGAN, RealESRGAN, CycleGAN, DRCT.

In [2]:
import pandas as pd
import os
from tensorflow.image import ssim, psnr
import tensorflow as tf
import tensorflow.keras.layers as layers
from tensorflow.keras import Model
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import random
from istari_tools import (create_test_dataset,
                          calculate_metrics,
                          create_dataset,
                          load_and_preprocess_valid_data)
from mithril_sharp import (build_rdn,
                           build_rrdn,
                           build_edsr,
                           build_srgan_generator,
                           build_real_esrgan_generator,
                           build_cyclegan_generator,
                           build_drct_decoder,
                           build_drct_encoder)


In [3]:
lr_dir = './data/DIV2K_train_LR_bicubic_X4_extracted/DIV2K_train_LR_bicubic/X4'
hr_dir = './data/DIV2K_train_HR_extracted/DIV2K_train_HR'

test_lr_dir = './data/DIV2K_valid_LR_bicubic_X4_extracted/DIV2K_valid_LR_bicubic/X4'
test_hr_dir = './data/DIV2K_valid_HR_extracted/DIV2K_valid_HR'

In [4]:
image_size = 256
scale_factor = 4
batch_size = 16
num_train_images = 800

In [5]:
train_dataset = tf.keras.utils.image_dataset_from_directory(
    lr_dir,
    labels="inferred",
    label_mode=None, 
    image_size=(image_size // scale_factor, image_size // scale_factor),
    batch_size=batch_size,
    shuffle=True
)

train_dataset = train_dataset.map(
    lambda x: (x, tf.image.resize(x, (image_size, image_size)))
)

Found 800 files.


2024-11-02 18:12:06.166279: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [6]:
# split 80% for training 20% for validation
train_size = int(0.8 * len(train_dataset))

train_dataset = train_dataset.take(train_size)
val_dataset = train_dataset.skip(train_size)

## Trial with edsr

In [7]:
model = build_edsr(image_size, scale_factor=4)

In [8]:
model.compile(optimizer='adam', loss='mse')

In [None]:
model.fit(train_dataset, epochs=50, validation_data=val_dataset)

Epoch 1/50
[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m220s[0m 5s/step - loss: 15899.9326
Epoch 2/50


2024-11-02 18:15:55.074276: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
  self.gen.throw(value)


[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m202s[0m 5s/step - loss: 15715.9795
Epoch 3/50


2024-11-02 18:19:16.759359: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]


[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m210s[0m 5s/step - loss: 15727.2441
Epoch 4/50
[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m219s[0m 5s/step - loss: 15826.5449
Epoch 5/50


2024-11-02 18:26:24.910460: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]


[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m219s[0m 5s/step - loss: 15705.3027
Epoch 6/50
[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m224s[0m 6s/step - loss: 15873.6777
Epoch 7/50
[1m24/40[0m [32m━━━━━━━━━━━━[0m[37m━━━━━━━━[0m [1m1:28[0m 6s/step - loss: 15711.8604

In [None]:
model.save('./models_save_states/trial_edst.h5')

In [None]:
# test the model 
test_lr_files = [f for f in os.listdir(test_lr_dir) if f.endswith('.png')]
test_hr_files = [f for f in os.listdir(test_hr_dir) if f.endswith('.png')]

test_dataset = create_test_dataset(test_lr_files, test_hr_files, test_hr_dir, test_lr_dir, image_size, scale_factor)

sr_images = []
hr_images = []

for lr_img, hr_img in test_dataset:
    sr_img = model(lr_img)
    sr_images.append(sr_img[0].numpy())
    hr_images.append(hr_img[0].numpy())

# PSNR and SSIM
metrics = calculate_metrics(hr_images, sr_images)
print(metrics)

## Loop over all models

In [4]:
model_builders = [build_rdn, build_rrdn, build_edsr, build_srgan_generator, build_real_esrgan_generator,
                build_cyclegan_generator, build_drct_decoder]
model_names = ['RDN', 'RRDN', 'EDSR', 'SRGAN', 'RealESRGAN', 'CycleGAN', 'DRCT']

In [5]:
results_df = pd.DataFrame(columns=['Model', 'Train_Loss', 'Train_PSNR', 'Train_SSIM', 
                                   'Test_Loss', 'Test_PSNR', 'Test_SSIM'])

In [None]:
for model_builder, model_name in zip(model_builders, model_names):
    print(f'Training {model_name} model...')

    if model_name == 'DRCT':
        model = build_drct_encoder(image_size, scale_factor=4)
        decoder = build_drct_decoder(image_size, scale_factor=4)
        # loss function for the encoder and a loss function for the decoder (may need modifications for DRCT) 
        loss_fn_encoder = tf.keras.losses.MeanSquaredError()
        loss_fn_decoder = tf.keras.losses.MeanSquaredError()  
        optimizer_encoder = tf.keras.optimizers.Adam(learning_rate=1e-4)
        optimizer_decoder = tf.keras.optimizers.Adam(learning_rate=1e-4)

        model.compile(optimizer=optimizer_encoder, loss=loss_fn_encoder)
        decoder.compile(optimizer=optimizer_decoder, loss=loss_fn_decoder)
    
        model = tf.keras.models.Sequential([model, decoder])
        model.compile(optimizer='adam', loss='mse')

    else:
        model = model_builder(image_size, scale_factor=4)
        # Loss and Optimizer (shared for all models except DRCT)
        loss_fn = tf.keras.losses.MeanSquaredError() 
        optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
        model.compile(optimizer=optimizer, loss=loss_fn)  

    model.fit(train_dataset, epochs=50, validation_data=val_dataset)

    print(f'Evaluating {model_name} model...')

    # evaluate on training data
    train_metrics = {'Loss': [], 'PSNR': [], 'SSIM': []}
    for eval_lr_batch, eval_hr_batch in train_dataset:
        eval_sr_batch = model.predict(eval_lr_batch)
        train_metrics['Loss'].append(loss_fn(eval_hr_batch, eval_sr_batch).numpy())
        train_metrics['PSNR'].append(psnr(eval_hr_batch, eval_sr_batch, max_val=1.0).numpy())
        train_metrics['SSIM'].append(ssim(eval_hr_batch, eval_sr_batch, max_val=1.0).numpy())

    # evaluate on test data
    test_metrics = {'Loss': [], 'PSNR': [], 'SSIM': []}
    for eval_lr_batch, eval_hr_batch in test_dataset:
        eval_sr_batch = model.predict(eval_lr_batch)
        test_metrics['Loss'].append(loss_fn(eval_hr_batch, eval_sr_batch).numpy())
        test_metrics['PSNR'].append(psnr(eval_hr_batch, eval_sr_batch, max_val=1.0).numpy())
        test_metrics['SSIM'].append(ssim(eval_hr_batch, eval_sr_batch, max_val=1.0).numpy())
    
    results_df = results_df.append({
        'Model': model_name,
        'Train_Loss': np.mean(train_metrics['Loss']),
        'Train_PSNR': np.mean(train_metrics['PSNR']),
        'Train_SSIM': np.mean(train_metrics['SSIM']),
        'Test_Loss': np.mean(test_metrics['Loss']),
        'Test_PSNR': np.mean(test_metrics['PSNR']),
        'Test_SSIM': np.mean(test_metrics['SSIM'])
    }, ignore_index=True)

    # visualize the output on a single image
    eval_lr_batch, eval_hr_batch = next(iter(train_dataset))
    eval_sr_batch = model.predict(eval_lr_batch)

    plt.figure(figsize=(15, 15))
    plt.subplot(1, 3, 1)
    plt.imshow(eval_lr_batch[0])
    plt.title('Low-Resolution Input')
    plt.subplot(1, 3, 2)
    plt.imshow(eval_sr_batch[0])
    plt.title(f'{model_name} Output')
    plt.subplot(1, 3, 3)
    plt.imshow(eval_hr_batch[0])
    plt.title('High-Resolution Ground Truth')
    plt.show()

    model.save(f'./models_save_states/{model_name}.h5')

In [None]:
print("\nModel Performance Comparison:")
print(results_df.to_string(index=False))

In [None]:
# comparative metrics
plt.figure(figsize=(15, 5))

# PSNR comparison
plt.subplot(1, 3, 1)
plt.bar(results_df['Model'], results_df['Test_PSNR'])
plt.title('PSNR Comparison')
plt.xticks(rotation=45)

# SSIM comparison
plt.subplot(1, 3, 2)
plt.bar(results_df['Model'], results_df['Test_SSIM'])
plt.title('SSIM Comparison')
plt.xticks(rotation=45)

# Loss comparison
plt.subplot(1, 3, 3)
plt.bar(results_df['Model'], results_df['Test_Loss'])
plt.title('Loss Comparison')
plt.xticks(rotation=45)

plt.tight_layout()
plt.show()