## A sample file, implementing basic operations of the paper
### see data files and link to them appropriately

In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, Concatenate,BatchNormalization, Activation, Dense, Add
from tensorflow.keras.models import Model
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import imageio.v2 as imageio
%matplotlib inline

## Load image data

In [11]:
# Define folder paths
coarse_folder = '...coarse'
medium_folder = '...medium'
fine_folder = '...fine'


coarse_images_train=[]
medium_images_train = []
fine_images_train = []

for i in range(1, 9):

    coarse_png_path = os.path.join(coarse_folder, f'img_{i}.png')
    medium_png_path = os.path.join(medium_folder, f'img_{i}.png')
    fine_png_path = os.path.join(fine_folder, f'img_{i}.png')

    coarse_image = tf.keras.preprocessing.image.load_img(coarse_png_path, target_size=(265,2121)) #(358, 1612))
    medium_image = tf.keras.preprocessing.image.load_img(medium_png_path, target_size=(265,2121)) #(358, 1612))
    fine_image = tf.keras.preprocessing.image.load_img(fine_png_path, target_size=(265,2121)) #(358, 1612))

    coarse_image = tf.keras.preprocessing.image.img_to_array(coarse_image)
    medium_image = tf.keras.preprocessing.image.img_to_array(medium_image)
    fine_image = tf.keras.preprocessing.image.img_to_array(fine_image)

    coarse_images_train.append(coarse_image / 255.0)
    medium_images_train.append(medium_image / 255.0)
    fine_images_train.append(fine_image / 255.0)

# Convert training data to TensorFlow tensors
coarse_images_train = tf.convert_to_tensor(coarse_images_train)
medium_images_train = tf.convert_to_tensor(medium_images_train)
fine_images_train = tf.convert_to_tensor(fine_images_train)

# Convert SVG images t format for training data
coarse_images_test=[]
medium_images_test = []
fine_images_test = []

for i in range(9, 11):

    coarse_png_path = os.path.join(coarse_folder, f'img_{i}.png')
    medium_png_path = os.path.join(medium_folder, f'img_{i}.png')
    fine_png_path = os.path.join(fine_folder, f'img_{i}.png')

    coarse_image = tf.keras.preprocessing.image.load_img(coarse_png_path, target_size=(265,2121)) #(358, 1612))
    medium_image = tf.keras.preprocessing.image.load_img(medium_png_path, target_size=(265,2121)) #(358, 1612))
    fine_image = tf.keras.preprocessing.image.load_img(fine_png_path, target_size=(265,2121)) #(358, 1612))

    coarse_image = tf.keras.preprocessing.image.img_to_array(coarse_image)
    medium_image = tf.keras.preprocessing.image.img_to_array(medium_image)
    fine_image = tf.keras.preprocessing.image.img_to_array(fine_image)

    coarse_images_test.append(coarse_image / 255.0)
    medium_images_test.append(medium_image / 255.0)
    fine_images_test.append(fine_image / 255.0)

# Convert training data to TensorFlow tensors
coarse_images_test = tf.convert_to_tensor(coarse_images_test)
medium_images_test = tf.convert_to_tensor(medium_images_test)
fine_images_test = tf.convert_to_tensor(fine_images_test)



fine_images_train_np = np.array(fine_images_train)
fine_images_test_np = np.array(fine_images_test)

coarse_images_train_np = np.array(coarse_images_train)
coarse_images_test_np = np.array(coarse_images_test)

medium_images_train_np = np.array(medium_images_train)
medium_images_test_np = np.array(medium_images_test)

# Check the number of samples in each dataset
num_fine_images = len(fine_images_train_np)
print('train_images',num_fine_images)

num_fine_images = len(fine_images_test_np)
print('test_images',num_fine_images)


# Define the new resolution (coarser grid)
dim1 = (int(2121/10),int(265/10))

# List to store the resized coarse images
coarse2_images_train = []

# Loop through each fine image in the fine_images_train list
for fine_image in fine_images_train_np:
    # Resize the fine image to the new resolution using the INTER_LINEAR interpolation method
    coarse2_image = cv2.resize(fine_image, dim1, interpolation=cv2.INTER_LINEAR)
    coarse2_images_train.append(coarse2_image)

# Convert the list of coarse images to a numpy array
coarse2_images_train_np = np.array(coarse2_images_train)

blurred_images_train = []
dim = (2121,265)
for fine_image in coarse2_images_train_np:
    # Resize the fine image to the new resolution using the INTER_LINEAR interpolation method
    blurred_image = cv2.resize(fine_image, dim, interpolation=cv2.INTER_LINEAR)
    blurred_images_train.append(blurred_image)
    
coarse2_images_test = []

for fine_image in fine_images_test_np:
    # Resize the fine image to the new resolution using the INTER_LINEAR interpolation method
    coarse2_image = cv2.resize(fine_image, dim1, interpolation=cv2.INTER_LINEAR)
    coarse2_images_test.append(coarse2_image)

# Convert the list of coarse images to a numpy array
coarse2_images_test_np = np.array(coarse2_images_test)

blurred_images_test = []

for fine_image in coarse2_images_test_np:
    # Resize the fine image to the new resolution using the INTER_LINEAR interpolation method
    coarse2_image = cv2.resize(fine_image, dim, interpolation=cv2.INTER_LINEAR)
    blurred_images_test.append(coarse2_image)

blurred_images_train_np = np.array(blurred_images_train)
blurred_images_test_np = np.array(blurred_images_test)

print('Test coarse:', coarse2_image.shape)

# Check the number of samples in each dataset
num_blurred_images = len(blurred_images_train_np)
num_fine_images = len(fine_images_train_np)
print(num_blurred_images,num_fine_images)

num_blurred_images = len(blurred_images_test_np)
num_fine_images = len(fine_images_test_np)
print(num_blurred_images,num_fine_images)

#Smaller blur ratio
# Define the new resolution (coarser grid)
dim1 = (int(2121/5),int(265/5))

# List to store the resized coarse images
coarse2_images_train = []

# Loop through each fine image in the fine_images_train list
for fine_image in fine_images_train_np:
    # Resize the fine image to the new resolution using the INTER_LINEAR interpolation method
    coarse2_image = cv2.resize(fine_image, dim1, interpolation=cv2.INTER_LINEAR)
    coarse2_images_train.append(coarse2_image)

# Convert the list of coarse images to a numpy array
coarse2_images_train_np = np.array(coarse2_images_train)

sblurred_images_train = []
dim = (2121,265)
for fine_image in coarse2_images_train_np:
    # Resize the fine image to the new resolution using the INTER_LINEAR interpolation method
    blurred_image = cv2.resize(fine_image, dim, interpolation=cv2.INTER_LINEAR)
    sblurred_images_train.append(blurred_image)

print('Train coarse:', coarse_image.shape)
    
coarse2_images_test = []

for fine_image in fine_images_test_np:
    # Resize the fine image to the new resolution using the INTER_LINEAR interpolation method
    coarse2_image = cv2.resize(fine_image, dim1, interpolation=cv2.INTER_LINEAR)
    coarse2_images_test.append(coarse2_image)

# Convert the list of coarse images to a numpy array
coarse2_images_test_np = np.array(coarse2_images_test)

sblurred_images_test = []

for fine_image in coarse2_images_test_np:
    # Resize the fine image to the new resolution using the INTER_LINEAR interpolation method
    coarse2_image = cv2.resize(fine_image, dim, interpolation=cv2.INTER_LINEAR)
    sblurred_images_test.append(coarse2_image)

sblurred_images_train_np = np.array(sblurred_images_train)
sblurred_images_test_np = np.array(sblurred_images_test)


#Larger blur ratio
# Define the new resolution (coarser grid)
dim1 = (int(2121/20),int(265/20))

# List to store the resized coarse images
coarse2_images_train = []

# Loop through each fine image in the fine_images_train list
for fine_image in fine_images_train_np:
    # Resize the fine image to the new resolution using the INTER_LINEAR interpolation method
    coarse2_image = cv2.resize(fine_image, dim1, interpolation=cv2.INTER_LINEAR)
    coarse2_images_train.append(coarse2_image)

# Convert the list of coarse images to a numpy array
coarse2_images_train_np = np.array(coarse2_images_train)

lblurred_images_train = []
dim = (2121,265)
for fine_image in coarse2_images_train_np:
    # Resize the fine image to the new resolution using the INTER_LINEAR interpolation method
    blurred_image = cv2.resize(fine_image, dim, interpolation=cv2.INTER_LINEAR)
    lblurred_images_train.append(blurred_image)

print('Train coarse:', coarse_image.shape)
    
coarse2_images_test = []

for fine_image in fine_images_test_np:
    # Resize the fine image to the new resolution using the INTER_LINEAR interpolation method
    coarse2_image = cv2.resize(fine_image, dim1, interpolation=cv2.INTER_LINEAR)
    coarse2_images_test.append(coarse2_image)

# Convert the list of coarse images to a numpy array
coarse2_images_test_np = np.array(coarse2_images_test)

lblurred_images_test = []

for fine_image in coarse2_images_test_np:
    # Resize the fine image to the new resolution using the INTER_LINEAR interpolation method
    coarse2_image = cv2.resize(fine_image, dim, interpolation=cv2.INTER_LINEAR)
    lblurred_images_test.append(coarse2_image)

lblurred_images_train_np = np.array(lblurred_images_train)
lblurred_images_test_np = np.array(lblurred_images_test)

train_images 8
test_images 2
Test coarse: (265, 2121, 3)
8 8
2 2
Train coarse: (265, 2121, 3)
Train coarse: (265, 2121, 3)


In [None]:
#If there si available, load pre-run models
model = tf.keras.models.load_model('model_blurred_5X5.h5')

In [6]:
from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, Concatenate, MaxPooling2D,Add,UpSampling2D,ZeroPadding2D,Cropping2D

def build_model_EDR():
    # Coarse Input
    coarse_input = Input(shape=(265, 2121, 3))
    coarse_x = Conv2D(32, (3, 3), activation='relu', padding='same')(coarse_input)
    coarse_x = Conv2D(32, (3, 3), activation='relu', padding='same')(coarse_x)
    coarse_x_residual = Conv2D(32, (1, 1), activation='relu', padding='same')(coarse_input)  # Add 1x1 conv for residual
    coarse_x_residual = Add()([coarse_x_residual, coarse_x])
    
    # Medium Input
    medium_input = Input(shape=(265, 2121, 3))
    medium_x = Conv2D(32, (3, 3), activation='relu', padding='same')(medium_input)
    medium_x = Conv2D(32, (3, 3), activation='relu', padding='same')(medium_x)
    medium_x_residual = Conv2D(32, (1, 1), activation='relu', padding='same')(medium_input)  # Add 1x1 conv for residual
    medium_x_residual = Add()([medium_x_residual, medium_x])

    # Blurred1 Input
    blur1_input = Input(shape=(265, 2121, 3))
    blur1_x = Conv2D(32, (3, 3), activation='relu', padding='same')(blur1_input)
    blur1_x = Conv2D(32, (3, 3), activation='relu', padding='same')(blur1_x)
    blur1_x_residual = Conv2D(32, (1, 1), activation='relu', padding='same')(blur1_input)  # Add 1x1 conv for residual
    blur1_x_residual = Add()([blur1_x_residual, blur1_x])

    # Blurred2 Input
    blur2_input = Input(shape=(265, 2121, 3))
    blur2_x = Conv2D(32, (3, 3), activation='relu', padding='same')(blur2_input)
    blur2_x = Conv2D(32, (3, 3), activation='relu', padding='same')(blur2_x)
    blur2_x_residual = Conv2D(32, (1, 1), activation='relu', padding='same')(blur2_input)  # Add 1x1 conv for residual
    blur2_x_residual = Add()([blur2_x_residual, blur2_x])

    # Blurred3 Input
    blur3_input = Input(shape=(265, 2121, 3))
    blur3_x = Conv2D(32, (3, 3), activation='relu', padding='same')(blur3_input)
    blur3_x = Conv2D(32, (3, 3), activation='relu', padding='same')(blur3_x)
    blur3_x_residual = Conv2D(32, (1, 1), activation='relu', padding='same')(blur3_input)  # Add 1x1 conv for residual
    blur3_x_residual = Add()([blur3_x_residual, blur3_x])
    
    
    # Concatenate the features from both inputs
    merged = Concatenate()([coarse_x_residual, medium_x_residual,blur1_x_residual,blur2_x_residual,blur3_x_residual])

    # Decoder (transpose convolution)
    x = Conv2DTranspose(3, (3, 3), activation='relu', padding='same')(merged)
    #x = Conv2DTranspose(3, (3, 3), activation='relu', padding='same')(x)

    model = Model(inputs=[coarse_input, medium_input,blur1_input,blur2_input,blur3_input], outputs=x)
    return model

# Build the model
model = build_model_EDR()
model.compile(optimizer='adam', loss='mse',metrics=['accuracy'])
#model.summary()

## Keras callback - interactive loss plot

In [None]:
from IPython.display import clear_output
from tensorflow import keras

class PlotLearning(keras.callbacks.Callback):
    """
    Callback to plot the learning curves of the model during training.
    """
    def on_train_begin(self, logs={}):
        self.metrics = {}
        for metric in logs:
            self.metrics[metric] = []
            

    def on_epoch_end(self, epoch, logs={}):
        # Storing metrics
        for metric in logs:
            if metric in self.metrics:
                self.metrics[metric].append(logs.get(metric))
            else:
                self.metrics[metric] = [logs.get(metric)]
        
        # Plotting
        metrics = [x for x in logs if 'val' not in x]
        
        f, axs = plt.subplots(1, len(metrics), figsize=(12,4), dpi= 400)
        clear_output(wait=True)

        for i, metric in enumerate(metrics):
            axs[i].plot(range(1, epoch + 2), 
                        self.metrics[metric], 
                        label=metric)
            if logs['val_' + metric]:
                axs[i].plot(range(1, epoch + 2), 
                            self.metrics['val_' + metric], 
                            label='val_' + metric)
                
            axs[i].legend()
            axs[i].grid()
            axs[i].set_xlabel('epochs')
            axs[i].set_ylabel('metric')
        plt.tight_layout()
        plt.show()
        

callbacks_list = [PlotLearning()]

   
test_list=[coarse_images_test_np,medium_images_test_np,sblurred_images_test_np,blurred_images_test_np,lblurred_images_test_np]
train_list=[coarse_images_train_np,medium_images_train_np,sblurred_images_train_np,blurred_images_train_np,lblurred_images_train_np]

import time
# Start timing
start_time = time.time()

history=model.fit(train_list,
          fine_images_train,
          epochs = 3,
          batch_size = 1,
          validation_data=(test_list, 
                           [fine_images_test_np]),
          verbose=1,
          shuffle=True,
          #class_weight=class_weight,
          callbacks=callbacks_list
          )
end_time = time.time()

# Calculate elapsed time
elapsed_time = end_time - start_time
print("Elapsed time:", elapsed_time, "seconds")

from skimage.metrics import mean_squared_error

predicted_all = model.predict(test_list)

# Calculate error metrics
Allpsnr0 = tf.image.psnr(fine_images_test_np[0], predicted_all[0], max_val=1.0)
Allpsnr1 = tf.image.psnr(fine_images_test_np[1], predicted_all[1], max_val=1.0)
print(Allpsnr0,Allpsnr1)


In [None]:
plt.rcParams['figure.dpi']=100

plt.figure(figsize=(12, 4),dpi=100)
plt.subplot(3, 1, 1)
plt.title(r'Predicted, psnr=%.3f'%Allpsnr0.numpy())
plt.imshow(predicted_all[0])
plt.axis('off')
plt.subplot(3, 1, 2)
plt.title(r'Given')
plt.imshow(medium_images_test[0])
plt.axis('off')
plt.subplot(3, 1, 3)
plt.title('Original')
plt.imshow(fine_images_test[0])
plt.axis('off')
plt.tight_layout()
plt.show()

plt.figure(figsize=(12, 4),dpi=100)
plt.subplot(3, 1, 1)
plt.title(r'Predicted, psnr=%.3f'%Allpsnr1.numpy())
plt.imshow(predicted_all[1])
plt.axis('off')
plt.subplot(3, 1, 2)
plt.title(r'Given')
plt.imshow(medium_images_test[1])
plt.axis('off')
plt.subplot(3, 1, 3)
plt.title('Original')
plt.imshow(fine_images_test[1])
plt.axis('off')
plt.tight_layout()
plt.show()