## SET UP

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!git clone https://github.com/LauraMJanssen/esrgan.git
%cd esrgan/
!pip install -r requirements.txt

## For transfer learning:

In [None]:
! mkdir -p checkpoints/esrgan

In [None]:
! unzip "/content/esrgan/pretrainedModel/esrgan_inference.zip" \
        -d "/content/esrgan/pretrainedModel/"

In [None]:
! mv /content/esrgan/pretrainedModel/esrgan_inference/* checkpoints/esrgan

## Data Processing
-> needs to be in '.tfrecod' format

In [None]:
! rm /content/esrgan/data/DIV2K800_sub_bin.tfrecord
! python data/convert_train_tfrecord.py \
        --hr_dataset_path='/content/drive/MyDrive/IVUSImages/HR_Train' \
        --lr_dataset_path='/content/drive/MyDrive/IVUSImages/LR_Scale4_Train' \
        --output_path="./data/DIV2K800_sub_bin.tfrecord" \
        --is_binary=True

In [None]:
! rm /content/esrgan/data/DIV2K800_sub_bin_valid.tfrecord
! python data/convert_train_tfrecord.py \
        --hr_dataset_path='/content/drive/MyDrive/IVUSImages/HR_Valid' \
        --lr_dataset_path='/content/drive/MyDrive/IVUSImages/LR_Scale4_Valid' \
        --output_path="./data/DIV2K800_sub_bin_valid.tfrecord" \
        --is_binary=True

## Train

In [None]:
! python train_esrgan.py --cfg_path="./configs/esrgan.yaml" \
                         --gpu=1

In [None]:
!zip -r /content/ESRGAN_model.zip /content/esrgan-tf2/checkpoints/esrgan

In [None]:
from google.colab import files
files.download("/content/ESRGAN_model.zip")

## TEST

(1) Define quality metric functions

In [11]:
from skimage.metrics import structural_similarity as ssim

# define a function for peak signal-to-noise ratio (PSNR)
def psnr(target, ref):
         
    # assume RGB image
    target_data = target.astype(float)
    ref_data = ref.astype(float)

    diff = ref_data - target_data
    diff = diff.flatten('C')

    rmse = math.sqrt(np.mean(diff ** 2.))

    return 20 * math.log10(255. / rmse)

# define function for mean squared error (MSE)
def mse(target, ref):
    # the MSE between the two images is the sum of the squared difference
    err = np.sum((target.astype('float') - ref.astype('float')) ** 2)
    err /= float(target.shape[0] * target.shape[1])
    
    return err

# define function that combines all three image quality metrics
def compare_images(target, ref):
    scores = []
    scores.append(psnr(target, ref))
    scores.append(mse(target, ref))
    scores.append(ssim(target, ref, multichannel =True))
    
    return scores

(2) Execute model on specific image

In [None]:
from IPython.display import Image
!python test.py \
        --cfg_path="./configs/esrgan.yaml" \
        --img_path="/content/drive/MyDrive/IVUSImages/HR_Valid/001.png"

(3) Print results \
(make sure to adapt ref and degraded parts according to image read above)

In [None]:
import math
import numpy as np
import cv2

HR = cv2.imread('/content/drive/MyDrive/IVUSImages/HR_Valid/001.png')
LR = cv2.imread('/content/drive/MyDrive/IVUSImages/LR_Scale4_Valid/001.png')
SR = cv2.imread('/content/result.png')
  
# image quality calculations
scores = []
h, w, _ = LR.shape
LR2 = cv2.resize(LR, (4*w, 4*h))
scores.append(compare_images(LR2, HR))

scores.append(compare_images(SR, HR))
  
# print all scores for all images
print('Degraded Image: \nPSNR: {}\nMSE: {}\nSSIM: {}\n'
        .format(scores[0][0], scores[0][1], scores[0][2]))
print('Reconstructed Image: \nPSNR: {}\nMSE: {}\nSSIM: {}\n'
        .format(scores[1][0], scores[1][1], scores[1][2]))