In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:80% !important; }</style>"))

In [None]:
from pathlib import Path
import sys
sys.path.insert(0, "..")
import _pickle as pickle

import cv2
import dlib
import numpy as np
from tqdm import tqdm
from torchvision import transforms
import torch
import matplotlib.pyplot as plt

import data.transformations as transformations
from models.generators import ResnetGenerator, UNetGenerator
from utils import constants, data_utils, general_utils, personal_constants
from data import plot

In [None]:
DATASET_PERSON_OUTPUT_PATH = Path(f'./local_data/person_processed_dim{constants.DATASET_300VW_IMSIZE}')
model_path = Path(f'./local_data/eval/')
output_path = Path(f'./local_data/eval/')
device = 'cuda'

# IMSIZE 64
model_name_to_instance_settings = {
    'model1': (ResnetGenerator.ResnetGenerator, {'n_hidden': 24, 'use_dropout': False}),
    'hinge1': (UNetGenerator.UNetGenerator, {'n_hidden': 24, 'use_dropout': True}),
}
# IMSIZE 128
model_name_to_instance_settings = {
    'stijn1': (UNetGenerator.UNetGenerator, {'n_hidden': 64, 'use_dropout': True}),
    'lossesAll_epoch3': (UNetGenerator.UNetGenerator, {'n_hidden': 64, 'use_dropout': True}),
}


model_name_to_instance_settings = {
    # 'klaus_monday': (UNetGenerator.UNetGenerator, {'n_hidden': 64, 'use_dropout': True}),
    'full_loss': (UNetGenerator.UNetGenerator, {'n_hidden': 64, 'use_dropout': True}),
    'ablated_pp_id': (UNetGenerator.UNetGenerator, {'n_hidden': 64, 'use_dropout': True}),
}

# set to None to use all
filter_person = 'lady'

In [None]:
def image_to_batch(image: np.ndarray) -> torch.Tensor:
    return torch.from_numpy(image[np.newaxis, ...]).float()


def image_from_batch(batch: torch.Tensor) -> torch.Tensor:
    return batch[0]


transform_to_input = [
    transformations.Resize._f,
    transformations.RescaleValues._f,
    transformations.ChangeChannels._f,
    image_to_batch,
    lambda batch: batch.to(device),
]

transform_from_input = [
    image_from_batch,
    general_utils.de_torch,
    general_utils.denormalize_picture,
]


all_videos = sorted([p for p in DATASET_PERSON_OUTPUT_PATH.iterdir() if p.is_dir() and p.stem.endswith('_test')])
all_videos = [v for v in all_videos if filter_person is None or v.stem.split('_')[0] == filter_person]

In [None]:
model_bar = tqdm(model_name_to_instance_settings.items(), desc='model')
for model_name, (model_class, kwargs) in model_bar:
    model_bar.set_postfix(model=model_name)
    
    with (open(model_path / f'{model_name}.pickle', 'rb')) as openfile:
        weights = pickle.load(openfile)
    model = model_class(**kwargs)
    model.load_state_dict(weights['generator'])
    model = model.to(device)
    model = model.eval()
    
    model_output_path = output_path / model_name
    model_output_path.mkdir(parents=True, exist_ok=True)
    
    video_bar = tqdm(all_videos, desc='video', leave=False)
    for video_path in video_bar:
        video_bar.set_postfix(model=model_name, video=video_path.stem)
        
        model_video_path = model_output_path / video_path.stem / 'images'
        # if model_video_dir.exists():
#             continue
        
        model_video_path.mkdir(parents=True, exist_ok=True)
        all_landmarks = np.load(video_path / 'annotations.npy')
        
        person = video_path.stem.split('_')[0]        
        image = cv2.imread(str(video_path.parent / f'{person}_train' / 'images' / '000001.jpg'))
        for t in transform_to_input:
            image = t(image)
        
        image_bar = tqdm(range(1, all_landmarks.shape[0] + 1), desc='image', leave=False)
        for image_index in image_bar:
            image_bar.set_postfix(model=model_name, video=video_path.stem)
            
            frame_output_path = model_video_path / f'{image_index:06d}.jpg'
            if frame_output_path.exists():
                continue
            
            multi_dim_landmarks = data_utils.single_to_multi_dim_landmarks(
                all_landmarks[image_index - 1],
                constants.DATASET_300VW_IMSIZE
            )
            
            for t in transform_to_input:
                multi_dim_landmarks = t(multi_dim_landmarks)
                
            output = torch.cat(
                (image, multi_dim_landmarks), dim=constants.CHANNEL_DIM
            )
            output = model(output)
            
            for t in transform_from_input:
                output = t(output)
            
            cv2.imwrite(str(frame_output_path), output, [int(cv2.IMWRITE_JPEG_QUALITY), constants.DATASET_300VW_IMAGE_QUALITY])
            
    

In [None]:
model_to_diffs = {}
model_bar = tqdm(model_name_to_instance_settings.items(), desc='model')
for model_name, (model_class, kwargs) in model_bar:
    model_bar.set_postfix(model=model_name)
    
    model_output_path = output_path / model_name
    diffs = []
    video_bar = tqdm(all_videos, desc='video', leave=False)
    for video_path in video_bar:
        video_bar.set_postfix(model=model_name, video=video_path.stem)
        
        model_video_path = model_output_path / video_path.stem
        difference = None
        all_images = sorted(list((video_path / 'images').glob('*.jpg')))
        image_bar = tqdm(all_images, desc='image', leave=False)
        for ground_truth_image_path in image_bar:
            image_bar.set_postfix(model=model_name, video=video_path.stem)
            
            ground_truth_image = cv2.imread(str(ground_truth_image_path))
            ground_truth_image = transformations.Resize._f(ground_truth_image)
            
            model_image_path = model_video_path / 'images' / ground_truth_image_path.name
            model_image = cv2.imread(str(model_image_path))
            model_image = transformations.Resize._f(model_image)
            
            if difference is None:
                difference = np.zeros_like(ground_truth_image, dtype=float)
            difference += (ground_truth_image - model_image) / len(all_images)
    
        diffs.append(difference)

    diffs = np.asarray(diffs)
    model_to_diffs[model_name] = diffs

In [None]:
model_bar = tqdm(model_name_to_instance_settings.items(), desc='model')
for model_name, (model_class, kwargs) in model_bar:
    model_bar.set_postfix(model=model_name)
    
    model_output_path = output_path / model_name
    
    video_bar = tqdm(all_videos, desc='video', leave=False)
    for video_index, video_path in enumerate(video_bar):
        video_bar.set_postfix(model=model_name, video=video_path.stem)
        
        fig = plt.figure()
        plt.axis('off')
        image = np.mean(model_to_diffs[model_name][video_index], axis=-1).astype('uint8')
        im = plt.imshow(image)
        ax = plt.gca()
        p = ax.pcolor(image, vmin=0, vmax=255)
        plt.colorbar(p)
        plt.title(f'{model_name} {video_path.stem}')
        fig.tight_layout()

        fig_path = model_output_path / f'{model_name}_{video_path.stem}_error.png'
        plt.savefig(str(fig_path))

        # needs to be after savefig otherwise the saved image will be blank
        plt.show()
        plt.close()