# Visualization of test images with a given input model

Given a path to the trained UNet model, the script visualises 5 test images predicted by the trained model

In [1]:
#libraries

import torch
import models
import os
import yaml
import numpy as np
from PIL import Image
from patchify import patchify, unpatchify
import matplotlib.pyplot as plt
from skimage.exposure import equalize_hist
import math
from skimage.metrics import structural_similarity as ssim
import torch.nn as nn

In [2]:
device=torch.device('cuda')
torch.cuda.device_count()

3

In [3]:
dir = '/home/gpu/girish/outdoor_disp/_direct_concat'
amp = 20
split_ratio = 0.8
raw = False
monochrome= False

In [4]:
#model construction
model_path = os.path.join(dir, 'epoch-best-psnr.pth')
sv_file = torch.load(model_path, map_location=device)
model = models.make(sv_file['model'], load_sd=True).eval().cuda()
model = nn.DataParallel(model)

In [5]:
def patchify_img(image, patch_size=256):
    size_x = (image.shape[0] // patch_size) * patch_size  # get width to nearest size divisible by patch size
    size_y = (image.shape[1] // patch_size) * patch_size
    instances = []

    # Crop original image to size divisible by patch size from top left corner
    image = image[:size_x, :size_y, :]

    # Extract patches from each image, step=patch_size means no overlap
    patch_img = patchify(image, (patch_size, patch_size, 3), step=patch_size)

    # iterate over vertical patch axis
    for j in range(patch_img.shape[0]):
        # iterate over horizontal patch axis
        for k in range(patch_img.shape[1]):
            # patches are located like a grid. use (j, k) indices to extract single patched image
            single_patch_img = patch_img[j, k]

            # Drop extra extra dimension from patchify
            instances.append(np.squeeze(single_patch_img))
    return np.vstack([np.expand_dims(x, 0) for x in instances]), size_x, size_y, patch_img.shape[0], patch_img.shape[1]

In [6]:
def image_transform(img, amp=False, raw=False, output=False):
    if raw and not output:
        arr = np.fromfile(open(img, 'rb'), dtype=np.uint8).reshape(2160,4096)
        arr = np.repeat(np.expand_dims(arr, axis=2), 3, axis=2)
    else:
        arr = np.array(Image.open(img))
        if len(arr.shape)==3:
            pass
        elif len(arr.shape)==2:
            arr = np.repeat(np.expand_dims(arr, axis=2), 3, axis=2)
    res, h, w, p_x, p_y = patchify_img(np.array(arr), patch_size=512)
    arr = torch.from_numpy(res/255).permute(0, 3, 1, 2)
    if amp:
        arr = (arr*amp).clamp_(0,1)
    return arr, h, w, p_x, p_y
    

In [None]:
#load image
with open(os.path.join(dir, 'config.yaml'), 'r') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
        
if config['val_dataset']['dataset']['args'].get('root_path_inp') is not None:
    test_dir = config['val_dataset']['dataset']['args']['root_path_inp']
    out_dir = config['val_dataset']['dataset']['args']['root_path_out']
    filenames = sorted(os.listdir(test_dir))
    outfile = sorted(os.listdir(out_dir))
    img_files = filenames[math.ceil(len(filenames)*split_ratio):]
    #img_files = filenames[0:]
    imgs, h_s, w_s, pxs, pys, out_imgs = [], [], [], [], [], []
    for file in img_files:
        res = image_transform(os.path.join(test_dir,file), amp, raw)
        if monochrome:
            imgs.append(res[0][:,:1,:,:])
        else:
            imgs.append(res[0])
        h_s.append(res[1])
        w_s.append(res[2])
        pxs.append(res[3])
        pys.append(res[4])
        out_imgs.append(image_transform(os.path.join(out_dir, file), raw=raw, output=True)[0])
                       
else:
    test_dir1 = config['val_dataset']['dataset']['args']['root_path_inp1']
    test_dir2 = config['val_dataset']['dataset']['args']['root_path_inp2']
    out_dir = config['val_dataset']['dataset']['args']['root_path_out']
    img_files1 = sorted(os.listdir(test_dir1))
    img_files2 = sorted(os.listdir(test_dir2))
    outfile = sorted(os.listdir(out_dir))
    imgs, h_s, w_s, pxs, pys, out_imgs = [], [], [], [], [], []
    
    for imgf1, imgf2 in zip(img_files1, img_files2):
        res1 = image_transform(os.path.join(test_dir1, imgf1), amp, raw)
        res2 = image_transform(os.path.join(test_dir2, imgf2), amp, raw)
    
        if monochrome:
            res = torch.cat([res1[0][:,:1,:,:], res2[0][:,:1,:,:]], axis=1)
            imgs.append(res)
        else:
            res = torch.cat([res1[0], res2[0]], dim=1)
            imgs.append(res)
        h_s.append(res2[1])
        w_s.append(res2[2])
        pxs.append(res2[3])
        pys.append(res2[4])
        out_imgs.append(image_transform(os.path.join(out_dir, imgf1), raw=raw, output=True)[0])


In [None]:
def psnr(pred, out, rgb_range=1):
    '''
    inp: patch_count * channels * H * W
    pred: patch_count * channels * H * W
    '''
    diff = (pred - out)/ rgb_range
    mse = torch.mean(torch.pow(diff, 2))

    return -10 * torch.log10(mse)

In [None]:
imgs[0].shape

In [None]:
#predict
results, psnrs, ssims = [], [], []
for i, img in enumerate(imgs):
    pred_patches = []
    print(img.shape)
    for patch in img:
        pred = model((patch.unsqueeze(0).float()-0.5)/0.5)
        pred_patches.append((pred*0.5+0.5).clamp_(0,1).detach().cpu())
    pred = torch.vstack(pred_patches)
    psnrs.append(psnr(pred, out_imgs[i][:,:3,:,:]).item())
    result = unpatchify(pred.permute(0,2,3,1).reshape(pxs[0],pys[0],1,512,512,3).detach().numpy(), (h_s[0],w_s[0],3))
    results.append(result)
    

In [None]:
for i in range(len(results)):
    gt = unpatchify(out_imgs[i][:,:3,:,:].permute(0,2,3,1).reshape(pxs[0],pys[0],1,512,512,3).detach().numpy(), (h_s[0],w_s[0],3))
    ssims.append(ssim(results[i], gt, channel_axis=2))
    
avg_ssim = np.mean(ssims)
print(avg_ssim)

In [None]:
#visualise
sorted_psnrs = sorted(psnrs)

for psnr in sorted_psnrs[:5]:
    index = psnrs.index(psnr)
    plt.figure(figsize=(10,20))
    plt.subplot(1,2,1)
    if raw:
        plt.imshow(np.fromfile(open(os.path.join(out_dir, img_files1[index]), 'rb'), dtype=np.uint8).reshape(2160,4096), 'gray')
    else:
        plt.imshow(Image.open(os.path.join(out_dir, img_files[index])), 'gray')
    plt.axis('off')
    plt.subplot(1,2,2)
    plt.imshow(results[index], 'gray')
    plt.axis('off')
    plt.title('PSNR: '+str(round(psnr,2)) + ' SSIM: '+str(round(ssims[index], 2)))
    plt.show()



In [None]:
#visualise
for psnr in sorted_psnrs[-5:]:
    index = psnrs.index(psnr)
    plt.figure(figsize=(10,20))
    plt.subplot(1,2,1)
    if raw:
        plt.imshow(np.fromfile(open(os.path.join(out_dir, img_files[index]), 'rb'), dtype=np.uint8).reshape(2160,4096), 'gray')
    else:
        plt.imshow(Image.open(os.path.join(out_dir, img_files[index])), 'gray')
    plt.axis('off')
    plt.subplot(1,2,2)
    plt.imshow(results[index], 'gray')
    plt.axis('off')
    plt.title('PSNR: '+str(round(psnr,2)) + ' SSIM: '+str(round(ssims[index], 2)))
    plt.show()
