# Evaluation Metrics

In this notebook, we run following evaluation metrics as described in our paper: 
1. LPIPS (Learned Perceptual Image Patch Similarity, implementation from https://github.com/richzhang/PerceptualSimilarity)
2. MSE (Mean Squared Error)
3. LMSE (Local Mean Squared Error, implementation from https://github.com/davidstutz/grosse2009-intrinsic-images)
4. FID50k (Frechet Inception Distance, implementation from https://github.com/GaParmar/clean-fid)

In [None]:
import lpips
import os
from PIL import Image 
import torch 
from rewrite_utils import renormalize, show
import glob 
from cleanfid import fid
import numpy as np

def pilim(img): 
    return renormalize.from_image(img)

lpips_metric = lpips.LPIPS(net='alex')
mse_metric = torch.nn.MSELoss()
mask = np.ones(truth.shape) #mask for LMSE

#local mean squared error from Intrinsic Image Algorithms (repo linked above)
def local_error(correct, estimate, mask, window_size, window_shift):
    """Returns the sum of the local sum-squared-errors, where the estimate may
    be rescaled within each local region to minimize the error. The windows are
    window_size x window_size, and they are spaced by window_shift."""
    M, N = correct.shape[1:]
    ssq = total = 0.
    for i in range(0, M - window_size + 1, window_shift):
        for j in range(0, N - window_size + 1, window_shift):
            correct_curr = correct[:, i:i+window_size, j:j+window_size]
            estimate_curr = estimate[:, i:i+window_size, j:j+window_size]
            mask_curr = mask[:, i:i+window_size, j:j+window_size]
            ssq += ssq_error(correct_curr, estimate_curr, mask_curr)
        total += np.sum(mask_curr * correct_curr**2)
    assert ~np.isnan(ssq/total)
    return ssq / total

def ssq_error(correct, estimate, mask):
    """Compute the sum-squared-error for an image, where the estimate is
    multiplied by a scalar which minimizes the error. Sums over all pixels
    where mask is True. If the inputs are color, each color channel can be
    rescaled independently."""
    assert correct.ndim == 3
    if np.sum(estimate**2 * mask) > 1e-5:
        alpha = np.sum(correct * estimate * mask) / np.sum(estimate**2 * mask)
    else:
        alpha = 0.
    return np.sum(mask * (correct - alpha*estimate) ** 2)




In [None]:
test_path = "evaldata/ada_lonoff" #path to folder of generated images
truth_path = "datasets/lonoff_light_all/test_B/" #path to folder of ground truth (LONOFF dataset images)

lpips_all = []
mse_all = []
lmse_all = []

#iterate through each image in LONOFF
i = 0
for img_path in os.listdir(truth_path): 
    i+=1
    print(f'{i} {img_path}')
    
    #reshape all images to be 256 x 256 pixels
    w, h = (256, 256)
    truth = pilim(Image.open(os.path.join(truth_path, img_path)).resize((h, w), Image.BILINEAR))
    
    
    lpips_temp = []
    mse_temp = []
    lmse_temp = []
    
    #compute metrics in comparison to each of the 3 test images relit to different intensities
    img_path = img_path[4:-4]
    for path in glob.glob(os.path.join(test_path, f'{img_path}_output_image*')): 
        test = pilim(Image.open(os.path.join(path)).resize((h, w), Image.BILINEAR))
        lpips_temp.append(torch.squeeze(lpips_metric(test, truth)).item())
        mse_temp.append(torch.squeeze(mse_metric(test, truth)).item())
        lmse_temp.append(local_error(truth.numpy(), test.numpy(), mask, 20, 10))
    
    #use the best out of three relit versions to calculate the final metric
    lpips_all.append(min(lpips_temp))    
    mse_all.append(min(mse_temp))   
    lmse_all.append(min(lmse_temp))
     

lpips_avg = sum(lpips_all)/len(lpips_all)
mse_avg = sum(mse_all)/len(mse_all)
lmse_avg = sum(lmse_all)/len(lmse_all)

print(f'lpips: {lpips_avg}')
print(f'mse: {mse_avg}')
print(f'lmse: {lmse_avg}')

In [None]:
#compute FID50k 

test_folder = 'evaldata/e4e_50k' #path to folder of 50k generated images
truth_folder = 'evaldata/real_bedrooms_50k'#path to folder of 50k real images of bedrooms 
fid_score = fid.compute_fid(test_folder,  truth_folder, mode='clean')
print(f'FID50k: {fid_score}')