In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:80% !important; }</style>"))

In [None]:
from pathlib import Path
import sys
sys.path.insert(0, "..")
import _pickle as pickle

import cv2
import dlib
import numpy as np
from tqdm import tqdm
from torchvision import transforms
import torch
import matplotlib.pyplot as plt
import matplotlib

import data.transformations as transformations
from models.generators import ResnetGenerator, UNetGenerator
from utils import constants, data_utils, general_utils, personal_constants
from data import plot

In [None]:
DATASET_PERSON_OUTPUT_PATH = Path(f'./local_data/person_processed_dim{constants.DATASET_300VW_IMSIZE}')
landmarks_file_name = 'annotations.npy'
model_path = Path(f'./local_data/eval/')
output_path = Path(f'./local_data/eval/')
device = 'cpu'

# IMSIZE 64
model_name_to_instance_settings = {
    'model1': (ResnetGenerator.ResnetGenerator, {'n_hidden': 24, 'use_dropout': False}),
    'hinge1': (UNetGenerator.UNetGenerator, {'n_hidden': 24, 'use_dropout': True}),
}
# IMSIZE 128
model_name_to_instance_settings = {
    'stijn1': (UNetGenerator.UNetGenerator, {'n_hidden': 64, 'use_dropout': True}),
    'lossesAll_epoch3': (UNetGenerator.UNetGenerator, {'n_hidden': 64, 'use_dropout': True}),
}


model_name_to_instance_settings = {
    # 'klaus_monday': (UNetGenerator.UNetGenerator, {'n_hidden': 64, 'use_dropout': True}),
    # 'klaus_wednesday': (UNetGenerator.UNetGenerator, {'n_hidden': 64, 'use_dropout': True}),
    'full_loss': (UNetGenerator.UNetGenerator, {'n_hidden': 64, 'use_dropout': True}),
    'ablated_pp_id': (UNetGenerator.UNetGenerator, {'n_hidden': 64, 'use_dropout': True}),
    'ablated_cons': (UNetGenerator.UNetGenerator, {'n_hidden': 64, 'use_dropout': True}),
    'ablated_pix': (UNetGenerator.UNetGenerator, {'n_hidden': 64, 'use_dropout': True}),
}

# set to None to use all
filter_persons = ['borris', 'cameron', 'glasses']

In [None]:
def image_to_batch(image: np.ndarray) -> torch.Tensor:
    return torch.from_numpy(image[np.newaxis, ...]).float()


def image_from_batch(batch: torch.Tensor) -> torch.Tensor:
    return batch[0]


transform_to_input = [
    transformations.Resize._f,
    transformations.RescaleValues._f,
    transformations.ChangeChannels._f,
    image_to_batch,
    lambda batch: batch.to(device),
]

transform_from_input = [
    image_from_batch,
    general_utils.de_torch,
    general_utils.denormalize_picture,
]


all_videos = sorted([p for p in DATASET_PERSON_OUTPUT_PATH.iterdir() if p.is_dir() and p.stem.endswith('_test')])
all_videos = [v for v in all_videos if filter_persons is None or v.stem.split('_')[0] in filter_persons]

In [None]:
model_bar = tqdm(model_name_to_instance_settings.items(), desc='model')
for model_name, (model_class, kwargs) in model_bar:
    model_bar.set_postfix(model=model_name)
    
    with (open(model_path / f'{model_name}.pickle', 'rb')) as openfile:
        weights = pickle.load(openfile)
    # weights = torch.load(str(model_path / f'{model_name}.pickle'), map_location='cpu')
    # weights = torch.load(str(model_path / f'{model_name}.pickle'), map_location=torch.device('cpu'))
    model = model_class(**kwargs)
    model.load_state_dict(weights['generator'])
    model = model.to(device)
    model = model.eval()
    
    model_output_path = output_path / model_name
    model_output_path.mkdir(parents=True, exist_ok=True)
    
    video_bar = tqdm(all_videos, desc='video', leave=False)
    for video_path in video_bar:
        video_bar.set_postfix(model=model_name, video=video_path.stem)
        
        model_video_path = model_output_path / video_path.stem / 'images'
        # if model_video_dir.exists():
#             continue
        
        model_video_path.mkdir(parents=True, exist_ok=True)
        all_landmarks = np.load(video_path / landmarks_file_name)
        
        person = video_path.stem.split('_')[0]        
        image = cv2.imread(str(video_path.parent / f'{person}_train' / 'images' / '000001.jpg'))
        for t in transform_to_input:
            image = t(image)
        
        image_bar = tqdm(range(1, all_landmarks.shape[0] + 1), desc='image', leave=False)
        for image_index in image_bar:
            image_bar.set_postfix(model=model_name, video=video_path.stem)
            
            frame_output_path = model_video_path / f'{image_index:06d}.jpg'
            if frame_output_path.exists():
                continue
            
            multi_dim_landmarks = data_utils.single_to_multi_dim_landmarks(
                all_landmarks[image_index - 1],
                constants.DATASET_300VW_IMSIZE
            )
            
            for t in transform_to_input:
                multi_dim_landmarks = t(multi_dim_landmarks)
                
            output = torch.cat(
                (image, multi_dim_landmarks), dim=constants.CHANNEL_DIM
            )
            output = model(output)
            
            for t in transform_from_input:
                output = t(output)
            
            cv2.imwrite(str(frame_output_path), output, [int(cv2.IMWRITE_JPEG_QUALITY), constants.DATASET_300VW_IMAGE_QUALITY])
            
    

In [None]:
model_to_diffs = {}
model_bar = tqdm(model_name_to_instance_settings.items(), desc='model')
for model_name, (model_class, kwargs) in model_bar:
    model_bar.set_postfix(model=model_name)
    
    model_output_path = output_path / model_name
    diffs = []
    video_bar = tqdm(all_videos, desc='video', leave=False)
    for video_path in video_bar:
        video_bar.set_postfix(model=model_name, video=video_path.stem)
        
        model_video_path = model_output_path / video_path.stem
        difference = None
        all_images = sorted(list((video_path / 'images').glob('*.jpg')))
        image_bar = tqdm(all_images, desc='image', leave=False)
        for ground_truth_image_path in image_bar:
            image_bar.set_postfix(model=model_name, video=video_path.stem)
            
            ground_truth_image = cv2.imread(str(ground_truth_image_path))
            ground_truth_image = transformations.Resize._f(ground_truth_image)
            
            model_image_path = model_video_path / 'images' / ground_truth_image_path.name
            model_image = cv2.imread(str(model_image_path))
            model_image = transformations.Resize._f(model_image)
            
            if difference is None:
                difference = np.zeros_like(ground_truth_image, dtype=float)
            difference += (ground_truth_image - model_image) / len(all_images)
    
        diffs.append(difference)

    diffs = np.asarray(diffs)
    model_to_diffs[model_name] = diffs

In [None]:
def sigmoid(x):
    z1 = np.exp(-x)
    z2 = np.exp(x)
    return np.where(x >= 0, 1 / (1 + z1), z2 / (1 + z2))


def rescale_values_custom(value: np.ndarray, range_values = 2) -> np.ndarray:
    # don't rescale landmarks
    if value.shape[-1] == constants.DATASET_300VW_N_LANDMARKS:
        return value

    value = value.astype(float)
    value = (value / 255) * range_values - (range_values / 2)
    # assert -1 <= value.min() <= value.max() <= 1
    return value


model_bar = tqdm(model_name_to_instance_settings.items(), desc='model')
for model_name, (model_class, kwargs) in model_bar:
    model_bar.set_postfix(model=model_name)
    
    model_output_path = output_path / model_name
    
    video_bar = tqdm(all_videos, desc='video', leave=False)
    for video_index, video_path in enumerate(video_bar):
        video_bar.set_postfix(model=model_name, video=video_path.stem)
        
        fig = plt.figure()
        plt.axis('off')
        image = np.mean(model_to_diffs[model_name][video_index], axis=-1)
        
        # print(image.min(), image.max(), image)
        # image = rescale_values_custom(image, 10)
        # image = sigmoid(image)
        # image *= 255
        # print(image.min(), image.max(), image)
        
        # print(image)
        image = image.astype('uint8')
        im = plt.imshow(image)
        ax = plt.gca()
        p = ax.pcolor(image, vmin=0, vmax=255)
        plt.colorbar(p)
        plt.title(f'{model_name} {video_path.stem}')
        fig.tight_layout()

        fig_path = model_output_path / f'{model_name}_{video_path.stem}_error.png'
        plt.savefig(str(fig_path))

        # needs to be after savefig otherwise the saved image will be blank
        plt.show()
        plt.close()

In [None]:
def plot_image(image_path: Path, subplot_index: int, *args, **kwargs):
    image = cv2.imread(str(image_path))
    image = general_utils.BGR2RGB_numpy(image)
    image = transformations.Resize._f(image)
    return create_subplot(image, subplot_index, *args, **kwargs)


def plot_landmarks(landmarks_path: Path, subplot_index: int, frame_id: int, *args, **kwargs):
    all_landmarks = np.load(landmarks_path)
    multi_dim_landmarks = data_utils.single_to_multi_dim_landmarks(
        all_landmarks[frame_id - 1],
        constants.DATASET_300VW_IMSIZE
    )
    multi_dim_landmarks = transformations.Resize._f(multi_dim_landmarks)
    landmarks = np.sum(multi_dim_landmarks, axis=-1)
    landmarks *= -1
    landmarks = general_utils.denormalize_picture(landmarks, binarised=True)
    return create_subplot(landmarks, subplot_index, is_landmarks=True, *args, **kwargs)


def create_subplot(image, subplot_index, n_rows, n_columns, is_first_row, title, is_landmarks = False, is_error = False):
    plt.subplot(n_rows, n_columns, subplot_index)
    # plt.axis('off')
    plt.xticks([])
    plt.yticks([])
    
    if is_landmarks:
        plt.imshow(image, cmap='gray')
    elif is_error:
        plt.imshow(image, vmin=0, vmax=255)
    else:
        plt.imshow(image)

    if is_first_row:
        plt.title(title)
    
    return subplot_index + 1, image



# default: 10
matplotlib.rcParams.update({'font.size': 5.25})
# 390, 524
# image_ids = [('lady', 455), ('arnold', 342)]
# image_ids = [('lady', 455), ('arnold', 390), ('beard', 10), ('glasses', 500), ('borris', 206), ('cameron', 1164)]
image_ids = [('glasses', 1193), ('beard', 643), ('cameron', 1297), ('lady', 455)]# , ('borris', 206), ('cameron', 1164)]
model_order = ['ablated_cons', 'ablated_pix', 'ablated_pp_id', 'full_loss']
pretty_names = ['–Cons. Losses', '–Pixel Loss', '–Feature Losses', 'Full Loss']
ground_truth_frame_id = 1

n_rows = len(image_ids)
n_columns =  4 + 2 * len(model_order)
subplot_index = 1
fig = plt.figure(figsize=(10, 4.5), dpi=300)
for current_row, (person, frame_id) in enumerate(image_ids):
    is_first_row = current_row == 0
    video_train_dir = f'{person}_train'
    video_test_dir = f'{person}_test'
    
    landmarks_path = DATASET_PERSON_OUTPUT_PATH / video_train_dir / landmarks_file_name
    subplot_index, _ = plot_landmarks(landmarks_path, subplot_index, ground_truth_frame_id, n_rows, n_columns, is_first_row, 'Source\nLandmarks')
    
    image_path = DATASET_PERSON_OUTPUT_PATH / video_train_dir / 'images' / f'{ground_truth_frame_id:06d}.jpg'
    subplot_index, _ = plot_image(image_path, subplot_index, n_rows, n_columns, is_first_row, 'Input\nImage')
    
    landmarks_path = DATASET_PERSON_OUTPUT_PATH / video_test_dir / landmarks_file_name
    subplot_index, _ = plot_landmarks(landmarks_path, subplot_index, frame_id, n_rows, n_columns, is_first_row, 'Input\nLandmarks')
    
    target_image_path = DATASET_PERSON_OUTPUT_PATH / video_test_dir / 'images' / f'{frame_id:06d}.jpg'
    subplot_index, target_image = plot_image(target_image_path, subplot_index, n_rows, n_columns, is_first_row, 'Target\nImage')
    
    model_images = []
    for model_name, pretty_name in zip(model_order, pretty_names):
        image_path = output_path / model_name / video_test_dir / 'images' / target_image_path.name
        subplot_index, image = plot_image(image_path, subplot_index, n_rows, n_columns, is_first_row, f'{pretty_name}\nOutput')
        model_images.append(image)

    for mi, mn, pn in zip(model_images, model_order, pretty_names):
        error = np.mean(np.abs(target_image - mi), axis=-1)
        subplot_index, _ = create_subplot(error, subplot_index, n_rows, n_columns, is_first_row, f'{pn}\nError', is_error=True)
    
# fig.tight_layout(pad=0.15)
plt.subplots_adjust(wspace=0.25, hspace=-0.55)

fig_path = output_path / 'comparison.png'
plt.savefig(str(fig_path), bbox_inches = 'tight')

plt.show()

In [None]:
model_to_diffs = {}
model_bar = tqdm(model_name_to_instance_settings.items(), desc='model')
for model_name, (model_class, kwargs) in model_bar:
    model_bar.set_postfix(model=model_name)
    
    model_output_path = output_path / model_name
    diffs = []
    video_bar = tqdm(all_videos, desc='video', leave=False)
    for video_path in video_bar:
        video_bar.set_postfix(model=model_name, video=video_path.stem)
        
        model_video_path = model_output_path / video_path.stem
        
        person = video_path.stem.split('_')[0]   
        model_video_output_file = model_output_path / f'{person}.avi'
        video = cv2.VideoWriter(str(model_video_output_file), cv2.VideoWriter_fourcc('M','J','P','G'), 30, (constants.IMSIZE, constants.IMSIZE))
        
        difference = None
        all_images = sorted(list((video_path / 'images').glob('*.jpg')))
        image_bar = tqdm(all_images, desc='image', leave=False)
        for ground_truth_image_path in image_bar:
            image_bar.set_postfix(model=model_name, video=video_path.stem)
            
            model_image_path = model_video_path / 'images' / ground_truth_image_path.name
            model_image = cv2.imread(str(model_image_path))
            # model_image = transformations.Resize._f(model_image)
            
            video.write(model_image)
            
        video.release()