# Apply a trained SRFlow model

## Reproduce using our Setup

- Use the ./setup.sh
- It install the all packages from the requirements.txt
- If this notebook does not work, copy the code to start it from setup.py
- The python interpreter should be '../myenv/bin/python3'

In [None]:
import sys
print(sys.executable) 

# Initialize

In [None]:
import natsort, glob, pickle, torch
from collections import OrderedDict
import numpy as np
import os

import options.options as option
from models import create_model
from imresize import imresize

import Measure

def find_files(wildcard): return natsort.natsorted(glob.glob(wildcard, recursive=True))

from PIL import Image
def imshow(array):
    display(Image.fromarray(array))

from test import load_model, fiFindByWildcard, imread

def pickleRead(path):
    with open(path, 'rb') as f:
        return pickle.load(f)

In [None]:
# Convert to tensor
def t(array): return torch.Tensor(np.expand_dims(array.transpose([2, 0, 1]), axis=0).astype(np.float32)) / 255

# convert to image
def rgb(t): return (np.clip((t[0] if len(t.shape) == 4 else t).detach().cpu().numpy().transpose([1, 2, 0]), 0, 1) * 255).astype(np.uint8)

# List model files

If you do not see models here, download them as in `setup.sh`.

In [None]:
find_files("../models/*.pth")

# Configuration files


In [None]:
find_files("confs/*.yml")

# List dataset directories


In [None]:
find_files("../datasets/**/")

In [None]:
conf_path = './confs/SRFlow_CelebA_8X.yml'

### Instancieate the Model

In [None]:
model, opt = load_model(conf_path)

### Find the png paths

In [None]:
lq_paths = fiFindByWildcard(os.path.join(opt['dataroot_LR'], '*.png'))
gt_paths = fiFindByWildcard(os.path.join(opt['dataroot_GT'], '*.png'))
print(lq_paths, gt_paths) # For CelebA we have just 1

In [None]:
lqs = [imread(p) for p in lq_paths]
gts = [imread(p) for p in gt_paths]


In [None]:
print("First LR image")
imshow(lqs[0])

print("First HR image")
imshow(gts[0])

# Super-Resolve using SRFlow for multiple temperatures

In [None]:
measure = Measure.Measure()


In [None]:
lq = lqs[0]
gt = gts[0]

for temperature in np.linspace(0, 1, num=11):
    # Sample a super-resolution for a low-resolution image
    sr = rgb(model.get_sr(lq=t(lq), heat=temperature))
    imshow(sr)
    psnr, ssim, lpips = measure.measure(sr, gt)
    print('Temperature: {:0.2f} - PSNR: {:0.1f}, SSIM: {:0.1f}, LPIPS: {:0.2f}\n\n'.format(temperature, psnr, ssim, lpips))


# LR Consistency

In [None]:
lq = lqs[0]
gt = gts[0]

temperature = 0.9

downsampled = lq
for idx in range(5):
    sr = rgb(model.get_sr(lq=t(downsampled), heat=temperature))
    downsampled = imresize(sr, 1/8)
    
    imshow(sr)
