This script will compare the performance of these models:

RDN, RRDN, EDSR, SRGAN, RealESRGAN, CycleGAN, DRCT.

In [11]:
import pandas as pd
import os
from tensorflow.image import ssim, psnr
import tensorflow as tf
import tensorflow.keras.layers as layers
from tensorflow.keras import Model
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import random
from istari_tools import (display_image_pair,
                          calculate_metrics,
                          create_dataset,
                          load_and_preprocess_valid_data)
from mithril_sharp import (residual_block,
                           build_rdn,
                           build_rrdn,
                           build_edsr,
                           build_srgan_generator,
                           build_real_esrgan_generator,
                           build_cyclegan_generator,
                           build_drct_decoder)


In [12]:
lr_dir = './data/DIV2K_train_LR_bicubic_X4_extracted/DIV2K_train_LR_bicubic/X4'
hr_dir = './data/DIV2K_train_HR_extracted/DIV2K_train_HR'

test_lr_dir = './data/DIV2K_valid_LR_bicubic_X4_extracted/DIV2K_valid_LR_bicubic/X4'
test_hr_dir = './data/DIV2K_valid_HR_extracted/DIV2K_valid_HR'

# parameters
batch_size = 16
image_size = 128
scale_factor = 4
num_channels = 3

num_test_images = 32

In [3]:
test_lr_paths = sorted(os.listdir(test_lr_dir))[:num_test_images]
test_hr_paths = sorted(os.listdir(test_hr_dir))[:num_test_images]

In [4]:
model_builders = [build_rdn, build_rrdn, build_edsr, build_srgan_generator, build_real_esrgan_generator,
                build_cyclegan_generator, build_drct_decoder]
model_names = ['RDN', 'RRDN', 'EDSR', 'SRGAN', 'RealESRGAN', 'CycleGAN', 'DRCT']

In [5]:
results_df = pd.DataFrame(columns=['Model', 'Train_Loss', 'Train_PSNR', 'Train_SSIM', 
                                   'Test_Loss', 'Test_PSNR', 'Test_SSIM'])

In [7]:
# create test dataset
test_lr_images = sorted([os.path.join(test_lr_dir, f) for f in os.listdir(test_lr_dir)])[:num_test_images]
test_hr_images = sorted([os.path.join(test_hr_dir, f) for f in os.listdir(test_hr_dir)])[:num_test_images]

test_dataset = tf.data.Dataset.from_tensor_slices((test_lr_images, test_hr_images))
test_dataset = test_dataset.map(
    lambda x, y: load_and_preprocess_valid_data(x, y, image_size, scale_factor),
    num_parallel_calls=tf.data.AUTOTUNE
)
test_dataset = test_dataset.batch(batch_size)
test_dataset = test_dataset.prefetch(tf.data.AUTOTUNE)

In [10]:
for model_builder, model_name in zip(model_builders, model_names):
    print(f'Training {model_name} model...')
    model = model_builder(image_size, num_channels=num_channels, scale_factor=scale_factor)

    # define loss and optimizer
    loss_fn = tf.keras.losses.MeanSquaredError()
    optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)

    model.compile(loss=loss_fn, optimizer=optimizer, metrics=['mse'])

    train_dataset = create_dataset(lr_dir, hr_dir, image_size, scale_factor)

    epochs = 50
    for epoch in range(epochs):
        for lr_batch, hr_batch in train_dataset:
            with tf.GradientTape() as tape:
                sr_batch = model(lr_batch)
                loss = loss_fn(hr_batch, sr_batch)
            gradients = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        # print progress every few epochs
        if (epoch + 1) % 5 == 0:
            print(f'Epoch: {epoch + 1}/{epochs}, Loss: {loss.numpy()}')

    print('Evaluating the model...')

    train_metrics = {'Loss': [], 'PSNR': [], 'SSIM': []}
    for lr_batch, hr_batch in train_dataset.take(10):  # Take 10 batches for evaluation
        sr_batch = model(lr_batch)
        loss = loss_fn(hr_batch, sr_batch)
        metrics = calculate_metrics(hr_batch, sr_batch)
        
        train_metrics['Loss'].append(float(loss))
        train_metrics['PSNR'].append(metrics['PSNR'])
        train_metrics['SSIM'].append(metrics['SSIM'])
    
    # Evaluate on test data
    test_metrics = {'Loss': [], 'PSNR': [], 'SSIM': []}
    for lr_batch, hr_batch in test_dataset.take(10):
        sr_batch = model(lr_batch)
        loss = loss_fn(hr_batch, sr_batch)
        metrics = calculate_metrics(hr_batch, sr_batch)
        
        test_metrics['Loss'].append(float(loss))
        test_metrics['PSNR'].append(metrics['PSNR'])
        test_metrics['SSIM'].append(metrics['SSIM'])
    
    # Add results to DataFrame
    results_df = results_df.append({
        'Model': model_name,
        'Train_Loss': np.mean(train_metrics['Loss']),
        'Train_PSNR': np.mean(train_metrics['PSNR']),
        'Train_SSIM': np.mean(train_metrics['SSIM']),
        'Test_Loss': np.mean(test_metrics['Loss']),
        'Test_PSNR': np.mean(test_metrics['PSNR']),
        'Test_SSIM': np.mean(test_metrics['SSIM'])
    }, ignore_index=True)
    
    eval_lr_batch, eval_hr_batch = next(iter(train_dataset))
    eval_sr_batch = model.predict(eval_lr_batch)

    plt.figure(figsize=(15, 15))
    plt.subplot(1, 3, 1)
    plt.imshow(eval_lr_batch[0])
    plt.title('Low-Resolution Input')
    plt.subplot(1, 3, 2)
    plt.imshow(eval_sr_batch[0])
    plt.title(f'{model_name} Output')
    plt.subplot(1, 3, 3)
    plt.imshow(eval_hr_batch[0])
    plt.title('High-Resolution Ground Truth')
    plt.show()

    model.save(f'{model_name}.h5')

Training RDN model...


ValueError: Inputs have incompatible shapes. Received shapes (32, 32, 64) and (32, 32, 32)

In [None]:
print("\nModel Performance Comparison:")
print(results_df.to_string(index=False))

In [None]:
# comparative metrics
plt.figure(figsize=(15, 5))

# PSNR comparison
plt.subplot(1, 3, 1)
plt.bar(results_df['Model'], results_df['Test_PSNR'])
plt.title('PSNR Comparison')
plt.xticks(rotation=45)

# SSIM comparison
plt.subplot(1, 3, 2)
plt.bar(results_df['Model'], results_df['Test_SSIM'])
plt.title('SSIM Comparison')
plt.xticks(rotation=45)

# Loss comparison
plt.subplot(1, 3, 3)
plt.bar(results_df['Model'], results_df['Test_Loss'])
plt.title('Loss Comparison')
plt.xticks(rotation=45)

plt.tight_layout()
plt.show()