In [None]:
import numpy as np
import os

def rgb_psnr(target, pred):
    err = (target - pred)**2
    mse = err.mean()
    return 10 * np.log10( 255**2 / mse)

gt_schema = os.path.join('/group/jug/ashesh/data/Rain100H/Rain100HCombined_test/','data_{}.tif')
pred_schema = '/group/jug/ashesh/data/paper_stats/All_P128_G16_M1_Sk0/kth_{}/pred_disentangle_2405_D30-M3-S0-L0_28.tif'
gt_files = [gt_schema.format(idx) for idx in range(1,101)]
pred_files = [pred_schema.format(idx-1) for idx in range(1,101)]



In [None]:
from disentangle.core.tiff_reader import load_tiff
import json

def get_input_target(gt_file):
    gt = load_tiff(gt_file)
    input = gt[:3].transpose(1,2,0)
    target = gt[3:6].transpose(1,2,0)
    return input, target

def get_pred(pred_file):
    pred = load_tiff(pred_file).astype(np.float32)
    config_fpath = pred_file.replace('.tif','.json')
    with open(config_fpath) as f:
        config = json.load(f)
        pred += float(config['offset'])
    pred[pred < 0] = 0
    pred[pred > 255] = 255
    return pred[0,...,:3].astype(np.uint8)



In [None]:
import matplotlib.pyplot as plt
pred = get_pred(pred_files[10])
input,target = get_input_target(gt_files[10])

_,ax = plt.subplots(figsize=(9,3),ncols=3)
ax[0].imshow(input)
ax[1].imshow(pred.astype(np.uint8))
ax[2].imshow(target)


In [None]:
from tqdm import tqdm
psnr_list = []
for gt_file, pred_file in tqdm(zip(gt_files, pred_files)):
    input, target = get_input_target(gt_file)
    pred = get_pred(pred_file)
    psnr_list.append(rgb_psnr(target, pred))

In [None]:
np.mean(psnr_list)