# Setup

## Imports

In [None]:
import sys
import icecream
import torch
import rp

In [None]:
rp.pip_import('lpips') # https://pypi.org/project/lpips/
import lpips

In [None]:
sys.path.append('./translator')
from translator.easy_translator import EasyTranslator
from translator.pytorch_msssim import numpy_msssim

In [None]:
from IPython.display import clear_output
from IPython.display import Video

## Other Setup

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# devuce = 'cpu'
torch.cuda.set_device(0) #Choose a free GPU

In [None]:
%matplotlib inline
%config InlineBackend.figure_format='retina'

# Load Trainer/Data/Config

In [None]:
VERSION_NAME = 'alphabet_three_base__just_tex_reality__run0'

label_values = [0,75,150,255]

scene_folder_path = '/home/Anonymous/CleanCode/Datasets/diff_rendering/alphabet_three/test/test_uvl/UV_Label_Exr'

In [None]:
def get_translator(version_name):
    checkpoint_folder = './translator/trained_models/outputs/%s/checkpoints'%version_name
    config_file       = './translator/configs/%s.yaml'%version_name
    
    return EasyTranslator(label_values, checkpoint_folder, config_file, device)

#Since these are in inference mode, they shouldn't take much VRAM - we can have two at once
translator = get_translator(VERSION_NAME)

#Does this make it faster when running multiple times?
translator.translate = rp.memoized(translator.translate)

In [None]:
scene_images = rp.ImageDataset(scene_folder_path)

In [None]:
icecream.ic(
    len(scene_images),
)

length = len(scene_images)

In [None]:
loss_fn_alex = lpips.LPIPS(net='alex')
def perceptual_loss(img1, img2):
    img1 = rp.as_float_image(rp.as_rgb_image(img1))
    img2 = rp.as_float_image(rp.as_rgb_image(img2))
    
    img1 = img1*2-1 # [0,1] -> [-1,1]
    img2 = img2*2-1 # [0,1] -> [-1,1]
    
    img1 = rp.as_torch_image(img1)[None]
    img2 = rp.as_torch_image(img2)[None]
    
    return float(loss_fn_alex(img1, img2))

In [None]:
class Result:
    def __init__(self, index):
        
        scene_image = rp.as_float_image(scene_images[index])

        translation = translator.translate(scene_image)

        self.scene_image = scene_image
        self.translation = translation
        
        self.output_frame = translation

In [None]:
index = rp.random_index(length)
rp.display_image(Result(index).output_frame)

In [None]:
output_folder = rp.path_join('untracked','TEST_OUT__'+VERSION_NAME)
rp.make_directory(output_folder)

In [None]:
all_files = rp.get_all_files(scene_folder_path)
all_files = [rp.path_join(scene_folder_path,x) for x in sorted(rp.os.listdir(scene_folder_path))]
all_files.sort()
display_eta = rp.eta(len(all_files))
for i,path in enumerate(all_files):
    image=rp.load_image(path)
    image=rp.as_float_image(image)
    translated=translator.translate(image)
    path_name=rp.get_file_name(path,False)
    output_name=path_name+'.png'
    output_path=rp.path_join(output_folder,output_name)
    rp.save_image(translated,output_path)
    display_eta(i)
    # print(output_path)