In [None]:
import numpy as np
import os

def rgb_psnr(target, pred):
    err = (target - pred)**2
    mse = err.mean()
    return 10 * np.log10( 255**2 / mse)

gt_schema = os.path.join('/group/jug/ashesh/data/Rain100H/Rain100HCombined_test/','data_{}.tif')
pred_schema = '/group/jug/ashesh/data/paper_stats/All_P128_G16_M1_Sk0/kth_{}/pred_disentangle_2405_D30-M3-S0-L0_28.tif'
end_idx = 100
# gt_files = [gt_schema.format(idx) for idx in range(1,101)]
# pred_files = [pred_schema.format(idx-1) for idx in range(1,101)]
# pred_schema = '/group/jug/ashesh/data/paper_stats/All_P128_G64_M1_Sk0/kth_{}/pred_disentangle_2405_D31-M3-S0-L0_7.tif'
# gt_schema = os.path.join('/group/jug/ashesh/data/Haze4KCombined/test/','data_{}.tif')
# end_idx = 1000
gt_files = [gt_schema.format(idx) for idx in range(1,end_idx+1)]
pred_files = [pred_schema.format(idx-1) for idx in range(1,end_idx+1)]



In [None]:
from disentangle.core.tiff_reader import load_tiff
import json

def get_input_target(gt_file):
    gt = load_tiff(gt_file)
    input = gt[:3].transpose(1,2,0)
    target = gt[3:6].transpose(1,2,0)
    return input, target

def get_pred(pred_file):
    pred = load_tiff(pred_file).astype(np.float32)
    config_fpath = pred_file.replace('.tif','.json')
    with open(config_fpath) as f:
        config = json.load(f)
        pred += float(config['offset'])
    pred[pred < 0] = 0
    pred[pred > 255] = 255
    return pred[0,...,:3].astype(np.uint8)



In [None]:
import matplotlib.pyplot as plt
pred = get_pred(pred_files[10])
input,target = get_input_target(gt_files[10])

_,ax = plt.subplots(figsize=(9,3),ncols=3)
ax[0].imshow(input)
ax[1].imshow(pred.astype(np.uint8))
ax[2].imshow(target)


In [None]:
from tqdm import tqdm
from skimage.metrics import structural_similarity as ssim

psnr_list = []
ssim_list = []
input_list = []
target_list = []
pred_list = []
for gt_file, pred_file in tqdm(zip(gt_files, pred_files)):
    input, target = get_input_target(gt_file)
    pred = get_pred(pred_file)
    input_list.append(input)
    target_list.append(target)
    pred_list.append(pred)
    psnr_list.append(rgb_psnr(target, pred))
    ssim_list.append(ssim(pred, target, data_range=255, channel_axis=2))

print('PSNR', np.mean(psnr_list))
print('SSIM', np.mean(ssim_list))


In [None]:
print('PSNR', np.mean(psnr_list))
print('SSIM', np.mean(ssim_list))

In [None]:
input_list[100].shape

In [None]:
import matplotlib.pyplot as plt
from disentangle.analysis.plot_utils import clean_ax

ncols = 3
nrows= 3
img_sz = 2
shape = (400,400, 3)
final_shape = (380, 400, 3)
h_start = (shape[0] - final_shape[0])//2
h_end = -1 * h_start
factor = final_shape[1]/final_shape[0]

_,ax = plt.subplots(figsize=(img_sz*ncols*factor,img_sz*nrows),ncols=ncols,nrows=nrows)
for idx in range(nrows):
    cur_shape = None
    while cur_shape != shape:
        pred_idx = np.random.randint(len(input_list))
        input = input_list[pred_idx]
        target = target_list[pred_idx]
        pred = pred_list[pred_idx]
        cur_shape = input.shape
    print(pred_idx)
    input = input_list[pred_idx][h_start:h_end]
    target = target_list[pred_idx][h_start:h_end]
    pred = pred_list[pred_idx][h_start:h_end]
    ax[idx,0].imshow(input)
    ax[idx,1].imshow(pred)
    ax[idx,2].imshow(target)
    for i in range(3):
        ax[idx,i].axis('off')

# restrict the space between the subplots
plt.subplots_adjust(wspace=0.02  , hspace=0.02)

In [None]:
import matplotlib.pyplot as plt
from disentangle.analysis.plot_utils import clean_ax

ncols = 3
nrows= 3
img_sz = 2
shape = (321,481, 3)
final_shape = (319, 481, 3)
h_start = (shape[0] - final_shape[0])//2
h_end = -1 * h_start
factor = final_shape[1]/final_shape[0]

_,ax = plt.subplots(figsize=(img_sz*ncols*factor,img_sz*nrows),ncols=ncols,nrows=nrows)
for idx in range(nrows):
    cur_shape = None
    while cur_shape != shape:
        pred_idx = np.random.randint(len(pred_files))
        input = input_list[pred_idx]
        target = target_list[pred_idx]
        pred = pred_list[pred_idx]
        cur_shape = input.shape
    print(pred_idx)
    input = input_list[pred_idx][h_start:h_end]
    target = target_list[pred_idx][h_start:h_end]
    pred = pred_list[pred_idx][h_start:h_end]
    ax[idx,0].imshow(input)
    ax[idx,1].imshow(pred)
    ax[idx,2].imshow(target)
    for i in range(3):
        ax[idx,i].axis('off')

# restrict the space between the subplots
plt.subplots_adjust(wspace=0.02  , hspace=0.02)

In [None]:
# 76, 39, 24