# Setup

## Imports

In [None]:
import sys
import icecream
import torch
import rp

In [None]:
rp.pip_import('lpips') # https://pypi.org/project/lpips/
import lpips

In [None]:
sys.path.append('./translator')
from translator.easy_translator import EasyTranslator
from translator.pytorch_msssim import numpy_msssim

In [None]:
from IPython.display import clear_output
from IPython.display import Video

## Other Setup

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# devuce = 'cpu'
torch.cuda.set_device(0) #Choose a free GPU

In [None]:
%matplotlib inline
%config InlineBackend.figure_format='retina'

# Load Trainer/Data/Config

In [None]:
VERSION_ONE_NAME='three_synth_base'
VERSION_TWO_NAME='three_synth_base__no_texture'

VERSION_ONE_NAME='three_synth_base_256'
VERSION_TWO_NAME='three_synth_base__no_texture_256'

VERSION_ONE_NAME='three_synth_base_512'
VERSION_TWO_NAME='three_synth_base__no_texture_512'

label_values = [0,75,150,255]

scene_folder_path_one = './datasets/three_synth/scenes_anim_steveflipped'
scene_folder_path_two = './datasets/three_synth/scenes_anim'
photo_folder_path     = './datasets/three_synth/photos_anim'

In [None]:
def get_translator(version_name):
    checkpoint_folder = './translator/trained_models/outputs/%s/checkpoints'%version_name
    config_file       = './translator/configs/%s.yaml'%version_name
    
    return EasyTranslator(label_values, checkpoint_folder, config_file, device)

#Since these are in inference mode, they shouldn't take much VRAM - we can have two at once
translator_one = get_translator(VERSION_ONE_NAME)
translator_two = get_translator(VERSION_TWO_NAME)

#Does this make it faster when running multiple times?
translator_one.translate = rp.memoized(translator_one.translate)
translator_two.translate = rp.memoized(translator_two.translate)

In [None]:
scene_images_one = rp.ImageDataset(scene_folder_path_one)
scene_images_two = rp.ImageDataset(scene_folder_path_two)
photo_images     = rp.ImageDataset(photo_folder_path    )

In [None]:
icecream.ic(
    len(scene_images_one),
    len(scene_images_two),
    len(photo_images    ),
)

assert len(scene_images_one) == len(scene_images_two) == len(photo_images)

length = len(photo_images)

In [None]:
loss_fn_alex = lpips.LPIPS(net='alex')
def perceptual_loss(img1, img2):
    img1 = rp.as_float_image(rp.as_rgb_image(img1))
    img2 = rp.as_float_image(rp.as_rgb_image(img2))
    
    img1 = img1*2-1 # [0,1] -> [-1,1]
    img2 = img2*2-1 # [0,1] -> [-1,1]
    
    img1 = rp.as_torch_image(img1)[None]
    img2 = rp.as_torch_image(img2)[None]
    
    return float(loss_fn_alex(img1, img2))

In [None]:
class Result:
    def __init__(self, index):
        
        scene_image_one = rp.as_float_image(scene_images_one[index])
        scene_image_two = rp.as_float_image(scene_images_two[index])
        photo_image     = rp.as_float_image(photo_images    [index])

        translation_one = translator_one.translate(scene_image_one) 
        translation_two = translator_two.translate(scene_image_two) 

        photo_image = translator_one.scaled_input(photo_image)
        scene_image_one = translator_one.scaled_input(scene_image_one)
        scene_image_two = translator_two.scaled_input(scene_image_two)
        
        mask = translator_one.scaled_input(scene_image_one)
        mask = rp.as_float_image(mask)
        mask = scene_image_one[:,:,2]<.99 #White everywhere except the table
        # mask = mask | True # Uncomment this line to disable the mask

        translation_one *= mask[:,:,None]
        translation_two *= mask[:,:,None]
        photo_image     *= mask[:,:,None]

        l1_loss_one = abs(photo_image-translation_one).mean()
        l1_loss_two = abs(photo_image-translation_two).mean()

        msssim_one = numpy_msssim(photo_image,translation_one,normalize=True)
        msssim_two = numpy_msssim(photo_image,translation_two,normalize=True)
        
        lpips_one = perceptual_loss(photo_image,translation_one)
        lpips_two = perceptual_loss(photo_image,translation_two)

        LOSS_BRIGHTNESS = 2 #How much do we multiply the loss by in the images?

        def indicator(boolean):
            #Puts a * next to the better metric
            return '*' if boolean else ' '

        output_frame = (
            rp.labeled_image(
                rp.grid_concatenated_images(
                    [
                        [
                            rp.resize_image_to_fit(
                                rp.cv_text_to_image(
                                    (
                                        (
                                            "Index: %i" + '\n'+\
                                                                '\n'+\
                                            "L1 Loss:   "     + '\n'+\
                                            "   %s %.5f : %s" + '\n'+\
                                            "   %s %.5f : %s" + '\n'+\
                                                                '\n'+\
                                            "MSSSIM:   "      + '\n'+\
                                            "   %s %.5f : %s" + '\n'+\
                                            "   %s %.5f : %s" + '\n'+\
                                                                '\n'+\
                                            "LPIPS:   "       + '\n'+\
                                            "   %s %.5f : %s" + '\n'+\
                                            "   %s %.5f : %s"    
                                        ) % (
                                            index,
                                            indicator(l1_loss_one < l1_loss_two), l1_loss_one, VERSION_ONE_NAME,
                                            indicator(l1_loss_two < l1_loss_one), l1_loss_two, VERSION_TWO_NAME,
                                            indicator(msssim_one  > msssim_two ), msssim_one , VERSION_ONE_NAME,
                                            indicator(msssim_two  > msssim_one ), msssim_two , VERSION_TWO_NAME,
                                            indicator(lpips_one   < lpips_two  ), lpips_one  , VERSION_ONE_NAME,
                                            indicator(lpips_two   < lpips_one  ), lpips_two  , VERSION_TWO_NAME,
                                        )
                                    ),
                                    scale=1,
                                ),
                                *rp.get_image_dimensions(photo_image),
                            ),
                            rp.labeled_image(
                                photo_image,
                                'Ground Truth',
                                size=20,
                            ),
                            rp.labeled_image(
                                mask,
                                'Mask',
                                size=20,
                            ),
                        ],
                        [
                            rp.labeled_image(
                                scene_image_one, 
                                'Untranslated UVL Scene',
                                size=20,
                            ),
                            rp.labeled_image(
                                translation_one,
                                VERSION_ONE_NAME,
                                size=20,
                            ),
                            rp.labeled_image(
                                rp.as_grayscale_image(abs(photo_image-translation_one))*LOSS_BRIGHTNESS,
                                'Ground Truth VS '+VERSION_ONE_NAME,
                                size=20,
                            ),
                        ],
                        [
                            rp.labeled_image(
                                scene_image_two, 
                                'Untranslated UVL Scene',
                                size=20,
                            ),
                            rp.labeled_image(
                                translation_two,
                                VERSION_TWO_NAME,
                                size=20,
                            ),
                            rp.labeled_image(
                                rp.as_grayscale_image(abs(photo_image-translation_two))*LOSS_BRIGHTNESS,
                                'Ground Truth VS '+VERSION_TWO_NAME,
                                size=20,
                            ),
                        ],
                    ]
                ),
                'Translation Comparisons',
                size=50,
                text_color=(255,128,255),
            )
        )

        self.index           = index
        self.scene_image_one = scene_image_one
        self.scene_image_two = scene_image_two
        self.photo_image     = photo_image
        self.mask            = mask
        self.l1_loss_one     = l1_loss_one
        self.l1_loss_two     = l1_loss_two
        self.msssim_one      = msssim_one
        self.msssim_two      = msssim_two
        self.lpips_one       = lpips_one
        self.lpips_two       = lpips_two
        self.output_frame    = output_frame

In [None]:
index = rp.random_index(length)
rp.display_image(Result(index).output_frame)

In [None]:
output_video_path   = 'untracked/eval_megavideo__%s__vs__%s.mp4' % (VERSION_ONE_NAME, VERSION_TWO_NAME)
output_video_writer = rp.VideoWriterMP4(output_video_path, video_bitrate='max')

display_eta = rp.eta(length, title='Writing to %s:'%output_video_path)

l1_loss_one_vals = []
l1_loss_two_vals = []
msssim_one_vals  = []
msssim_two_vals  = []
lpips_one_vals   = []
lpips_two_vals   = []

for index in range(length)[::5]:
    # display_eta(index)
    
    result = Result(index)

    l1_loss_one_vals.append(result.l1_loss_one)
    l1_loss_two_vals.append(result.l1_loss_two)
    msssim_one_vals .append(result.msssim_one )
    msssim_two_vals .append(result.msssim_two )
    lpips_one_vals  .append(result.lpips_one  )
    lpips_two_vals  .append(result.lpips_two  )
    
    output_video_writer.write_frame(result.output_frame)
    
output_video_writer.finish()
clear_output()
print("Done! Download video from", output_video_path)

In [None]:
icecream.ic(
    rp.mean(l1_loss_one_vals),
    rp.mean(l1_loss_two_vals),
    rp.mean(msssim_one_vals ),
    rp.mean(msssim_two_vals ),
    rp.mean(lpips_one_vals  ),
    rp.mean(lpips_two_vals  ),
);

In [None]:
#Show some graphs of the losses over the frames of the video

rp.line_graph_via_bokeh(
    dict(
        l1_loss_one = l1_loss_one_vals,
        l1_loss_two = l1_loss_two_vals,
    ),
    title = 'L1 Loss (Lower is better)',
    xlabel = 'Frame Number',
    ylabel = 'Loss',
    logy=10,
)

rp.line_graph_via_bokeh(
    dict(
        msssim_one  = msssim_one_vals ,
        msssim_two  = msssim_two_vals ,
    ),
    title = 'MSSSIM (Multiscale Structural Image Similarity - Higher is better)',
    xlabel = 'Frame Number',
    ylabel = 'Loss',
    logy=10,
)

rp.line_graph_via_bokeh(
    dict(
        lpips_one   = lpips_one_vals  ,
        lpips_two   = lpips_two_vals  ,
    ),
    title = 'LPIPS (Perceptual Loss - Lower is better)',
    xlabel = 'Frame Number',
    ylabel = 'Loss',
    logy=10,
)