# Setup

## Imports

In [None]:
import sys
import icecream
import torch
import json
import rp

In [None]:
rp.pip_import('lpips') # https://pypi.org/project/lpips/
import lpips

In [None]:
sys.path.append('./translator')
from translator.easy_translator import EasyTranslator
from translator.pytorch_msssim import numpy_msssim

In [None]:
from IPython.display import clear_output
from IPython.display import Video

## Other Setup

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# devuce = 'cpu'
torch.cuda.set_device(0) #Choose a free GPU

In [None]:
%matplotlib inline
%config InlineBackend.figure_format='retina'

## Path Configuration

In [None]:
!ls untracked | grep TEST_OUT

In [None]:
####  NOTE: Generate these variations with translator_tester.ipynb

In [None]:
VARIATION_NAME='TEST_OUT__alphabet_five_base__pure_munit'
VARIATION_NAME='TEST_OUT__alphabet_three_base__only_texture'
VARIATION_NAME='TEST_OUT__alphabet_three_base__just_tex_reality__run0'

In [None]:
test_root='/raid/ryan/CleanCode/Datasets/diff_rendering/alphabet_three/test'
translation_folder='./untracked/'+VARIATION_NAME
translation_filetype='png'

In [None]:
photo_folder=rp.path_join(test_root,'halved_photos')
photo_folder=rp.path_join(test_root,'halved_photos')
matches=json.loads(rp.text_file_to_string(rp.path_join(test_root,"matches.json")))
trans_dims=rp.get_image_file_dimensions(rp.random_element(rp.get_all_files(translation_folder)))
icecream.ic(trans_dims);

# Main

####  REMEMBER: Lower LPIPS is better!

In [None]:
loss_fn_alex = lpips.LPIPS(net='alex').to(device)
def perceptual_loss(photo, trans):
    #Lower is better!
    
    photo=rp.cv_resize_image(photo,rp.get_image_dimensions(trans))
    
    img1=photo
    img2=trans
    
    img1 = rp.as_float_image(rp.as_rgb_image(img1))
    img2 = rp.as_float_image(rp.as_rgb_image(img2))
    
    assert img1.shape==img2.shape
    
    img1 = img1*2-1 # [0,1] -> [-1,1]
    img2 = img2*2-1 # [0,1] -> [-1,1]
    
    img1 = rp.as_torch_image(img1)[None].to(device)
    img2 = rp.as_torch_image(img2)[None].to(device)
    
    return float(loss_fn_alex(img1, img2))

In [None]:
def load_translation(translation_name):
    translation_file=rp.with_file_extension(translation_name,translation_filetype)
    trans=rp.load_image(rp.path_join(translation_folder,translation_file))
    trans=rp.as_float_image(trans)
    trans=rp.as_rgb_image(trans)
    return trans
def load_photo(photo_filename):
    photo_filename=rp.get_file_name(photo_filename)
    photo=rp.load_image(rp.path_join(photo_folder,photo_filename))
    photo=rp.cv_resize_image(photo,trans_dims)
    photo=rp.as_float_image(photo)
    photo=rp.as_rgb_image(photo)
    return photo

In [None]:
scores_lpips={}
scores_l2={}

In [None]:
wipe_line='\r'+' '*100+'\r'
for photo_filename in matches:
    print(wipe_line+photo_filename)
    photo=load_photo(photo_filename)
    translation_names=matches[photo_filename]
    display_eta=rp.eta(len(translation_names))
    for i,translation_name in enumerate(translation_names):
        display_eta(i)
        trans=load_translation(translation_name)
        if translation_name not in scores_lpips:
            scores_lpips[translation_name]=perceptual_loss(photo,trans)
        if translation_name not in scores_l2:
            scores_l2[translation_name]=((photo-trans)**2).mean()
print(wipe_line+'DONE!')

In [None]:
icecream.ic(
    min(scores_lpips.values()),
    max(scores_lpips.values()),
    len(scores_lpips),
    min(scores_l2.values()),
    max(scores_l2.values()),
    len(scores_l2),
);

NOTES:
Unsuprisingly, L2 fails to find the correct permutation more often than LPIPS does.
That being said, even when searching for the best L2, its corresponding LPIPS score is still .1x -- far better than any of the other translation methods.
Let's do a flex here: let's force my method to stick to a single permutation for all samples, but let the other algorithms do a different permutation each. In addition, let's compare histograms of the distributions of the scores (best scores - meaning 14 datapoints per). IN FACT: The minimum of one method is LARGER than the average of another! (((I BET ITS CAUSE OF THE TABLE. MASK THE TABLE!)))

## TODO: Compare to reconstructions! Mask out the cubes!

In [None]:
display_line_graphs=False

print("Displaying individual, uncoordinated best matchings")

def get_best_translation_name(photo_file,scores=scores_lpips):
    subdict={trans_name:scores[trans_name] for trans_name in matches[photo_file]}
    return sorted(subdict,key=lambda trans_name:scores[trans_name])[0] #0 for the first best, 1 for the second best, etc

for photo_file in matches:
    best_trans_name=get_best_translation_name(photo_file)
    
    score_lpips=scores_lpips[best_trans_name]
    score_l2   =scores_l2[best_trans_name]
    
    photo = load_photo(photo_file)
    trans = load_translation(best_trans_name)
    
    info_image = rp.labeled_image(
        rp.horizontally_concatenated_images(
            photo,
            trans,
            abs(photo-trans).mean(2),
        ),
        rp.get_file_name(photo_file)+' : '+best_trans_name + ',  LPIPS = %.3f,  L2 = %.4f'%(score_lpips,score_l2),
        size=20,
    )
    
    
    print(photo_file,best_trans_name,score_lpips)
    rp.display_image(info_image)
    if display_line_graphs:
        rp.line_graph_via_bokeh(
            {
                'LPIPS':sorted([scores_lpips[x] for x in matches[photo_file]]),
                'L2'   :sorted([scores_l2   [x] for x in matches[photo_file]]),
            },
            title='Score Distribution',
            logy=True,
        )


# Histogram TEst

In [None]:

# import numpy as np
# from bokeh.io import show, output_notebook
# from bokeh.plotting import figure
# output_notebook()

# data = np.random.normal(0, 0.5, 1000)
# hist, edges = np.histogram(data, density=True, bins=10)

# p = figure()
# p.quad(top=hist, bottom=0, left=edges[:-1], right=edges[1:], line_color="white")

# show(p)

