In [None]:
import lpips
import os
from PIL import Image 
import torch 
from rewrite_utils import renormalize, show
import glob 

lpips_metric = lpips.LPIPS(net='alex')
mse_metric = torch.nn.MSELoss()

#local mean squared error from intrinsic image algorithms
import numpy as np

def local_error(correct, estimate, mask, window_size, window_shift):
    """Returns the sum of the local sum-squared-errors, where the estimate may
    be rescaled within each local region to minimize the error. The windows are
    window_size x window_size, and they are spaced by window_shift."""
    M, N = correct.shape[1:]
    ssq = total = 0.
    for i in range(0, M - window_size + 1, window_shift):
        for j in range(0, N - window_size + 1, window_shift):
            correct_curr = correct[:, i:i+window_size, j:j+window_size]
            estimate_curr = estimate[:, i:i+window_size, j:j+window_size]
            mask_curr = mask[:, i:i+window_size, j:j+window_size]
            ssq += ssq_error(correct_curr, estimate_curr, mask_curr)
        total += np.sum(mask_curr * correct_curr**2)
    assert ~np.isnan(ssq/total)
    return ssq / total

def ssq_error(correct, estimate, mask):
    """Compute the sum-squared-error for an image, where the estimate is
    multiplied by a scalar which minimizes the error. Sums over all pixels
    where mask is True. If the inputs are color, each color channel can be
    rescaled independently."""
    assert correct.ndim == 3
    if np.sum(estimate**2 * mask) > 1e-5:
        alpha = np.sum(correct * estimate * mask) / np.sum(estimate**2 * mask)
    else:
        alpha = 0.
    return np.sum(mask * (correct - alpha*estimate) ** 2)

def pilim(img): 
    return renormalize.from_image(img)


In [None]:
test_path = "evaldata/ada_lonoff"
truth_path = "datasets/lonoff_light_all/test_B/"

lpips_all = []
mse_all = []
lmse_all = []

def pilim(img): 
    return renormalize.from_image(img)

i = 0
for img_path in os.listdir(truth_path): 
    i+=1
    print(f'{i} {img_path}')
    w, h = (256, 256)
    truth = pilim(Image.open(os.path.join(truth_path, img_path)).resize((h, w), Image.BILINEAR))
    img_path = img_path[4:-4]
    
    lpips_temp = []
    mse_temp = []
    lmse_temp = []
    mask = np.ones(truth.shape)
    
    for path in glob.glob(os.path.join(test_path, f'{img_path}_output_image*')): 
        test = pilim(Image.open(os.path.join(path)).resize((h, w), Image.BILINEAR))
        lpips_temp.append(torch.squeeze(lpips_metric(test, truth)).item())
        mse_temp.append(torch.squeeze(mse_metric(test, truth)).item())
        lmse_temp.append(local_error(truth.numpy(), test.numpy(), mask, 20, 10))
        
    lpips_all.append(min(lpips_temp))
    #lpips_baseline_all.append(torch.squeeze(lpips_metric(original, truth)))
    
    mse_all.append(min(mse_temp))
    #mse_baseline_all.append(torch.squeeze(mse_metric(original, truth)))
    
    lmse_all.append(min(lmse_temp))
    #lmse_all.append(local_error(truth.numpy(), test.numpy(), mask, 20, 10))
     
#lpips_baseline_avg = torch.mean(torch.stack(lpips_baseline_all))
#mse_basline_avg = torch.mean(torch.stack(mse_baseline_all))  
#lpips_avg = torch.mean(torch.stack(lpips_all))
#mse_avg = torch.mean(torch.stack(mse_all))
lpips_avg = sum(lpips_all)/len(lpips_all)
mse_avg = sum(mse_all)/len(mse_all)
lmse_avg = sum(lmse_all)/len(lmse_all)
print(f'lpips: {lpips_avg}')
print(f'mse: {mse_avg}')
print(f'lmse: {lmse_avg}')

In [None]:
lpips_avg = sum(lpips_all)/len(lpips_all)
mse_avg = sum(mse_all)/len(mse_all)
lmse_avg = sum(lmse_all)/len(lmse_all)
print(f'lpips: {lpips_avg}')
print(f'mse: {mse_avg}')
print(f'lmse: {lmse_avg}')

In [None]:
mask = np.ones(truth.shape)
lmse = local_error(truth.numpy(), test.numpy(), mask, 20, 10)
print(lmse)
lmse = local_error(truth.numpy(), original.numpy(), mask, 20, 10)
print(lmse)

In [None]:
import cv2

def get_edges_canny(im): 
    im = cv2.cvtColor(np.array(im).reshape(256, 256, 3), cv2.COLOR_BGR2GRAY)

    im = cv2.Canny(image=np.uint8(im), threshold1=100, threshold2=200)
    im = cv2.convertScaleAbs(im, alpha=255/im.max())
    im = torch.tensor(im/255).unsqueeze(0).repeat(3, 1, 1).float() #dimension of (3, 256, 256) needed for lpips
    return im


In [None]:
test_path = "results/modulated_bedrooms_9resnet/test_230_generated/images/"

lpips_baseline_all = []
mse_baseline_all = []
lpips_all = []
mse_all = []

def pilim(img): 
    return renormalize.from_image(img)

num_imgs = 100
for i in range(num_imgs): 
    print(i)
    img_path = f"bedroom_{i}"
    
    stylespace = (Image.open(
                os.path.join(test_path, img_path+"_stylespace_target.jpg")))
    stylespace = get_edges_canny(stylespace)
    
    test = (Image.open(
                os.path.join(test_path, img_path+"_output_image.jpg")))
    test = get_edges_canny(test)
    
    original = (Image.open(
                os.path.join(test_path, img_path+"_input_image.jpg")))
    original = get_edges_canny(original)
    
    if i%10 == 0: 
        show([(renormalize.as_image(original), renormalize.as_image(test), renormalize.as_image(stylespace))])

    lpips_all.append(torch.squeeze(lpips_metric(original, test)))
    lpips_baseline_all.append(torch.squeeze(lpips_metric(original, stylespace)))
    
    mse_all.append(torch.squeeze(mse_metric(original, test)))
    mse_baseline_all.append(torch.squeeze(mse_metric(original, stylespace)))
     
lpips_baseline_avg = torch.mean(torch.stack(lpips_baseline_all))
mse_basline_avg = torch.mean(torch.stack(mse_baseline_all))  
lpips_avg = torch.mean(torch.stack(lpips_all))
mse_avg = torch.mean(torch.stack(mse_all))


In [None]:
print(f'lpips: {lpips_avg}')
print(f'mse: {mse_avg}')
print(f'lpips baseline: {lpips_baseline_avg}')
print(f'mse baseline: {mse_basline_avg}')

mask = np.ones(test.shape)
lmse = local_error(original.numpy(), test.numpy(), mask, 20, 10)
print(lmse)
lmse = local_error(original.numpy(), stylespace.numpy(), mask, 20, 10)
print(lmse)

In [None]:
alternate_path = "results/modulated_bedrooms_9resnet/test_500_real_datasets_lonoff_light_all_linspace_on/images"
ablation_path ="results/modulated_bedrooms_9resnet/test_500_real_datasets_lonoff_light_all_linspace_on/images"
ada_path = "evaldata/ada_lonoff"
e4e_path = "evaldata/e4e_lonoff_sweep"
truth_path  = "datasets/lonoff_light_all/test_B/"
input_path = "datasets/lonoff_light_all/test_A"

paths = [alternate_path, ada_path, e4e_path]

i = 0
for img_path in os.listdir(truth_path):
    i+=1
    print(f'{i} {img_path}')
    w, h = (256, 256)
    truth = (Image.open(os.path.join(truth_path, img_path)).resize((h, w), Image.BILINEAR))
    img_path = img_path[4:-4]
    
    imgs = []
    imgs.append((Image.open(os.path.join(input_path, f'{img_path}.jpg')).resize((h, w), Image.BILINEAR)))
    imgs.append(truth)
    for path in paths: 
        imgs.append(Image.open(os.path.join(path, f'{img_path}_output_image_1.jpg')).resize((h, w), Image.BILINEAR))
    
    show([tuple(imgs)])
    

