In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:80% !important; }</style>"))

In [None]:
from pathlib import Path
from typing import Tuple

import numpy as np
from tqdm import tqdm_notebook as tqdm
%pylab inline
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib import patches
rcParams['figure.figsize'] = 10, 10

In [None]:
base_path = Path('/media/klaus/Ondrag/dev/datasets/300VW_Dataset_2015_12_14/')
n_videos = 114
n_points = 68
padding = 0
expand_ratio = 0.5

In [None]:
def plot(image, points: np.ndarray = None, box: Tuple[int, int, int, int] = None) -> None:
    plt.figure()
    plt.imshow(image)
    plt.axis('off')

    if points is not None:
        plt.plot(points[:, 0], points[:, 1], 'gx')
    
    if box is not None:
        x1, y1, x2, y2 = box
        rect = patches.Rectangle((x1, y1), x2 - x1, y2 - y1, 
                                 linewidth=1, edgecolor='r', facecolor='none')
        ax = plt.gca()
        ax.add_patch(rect)
    
    plt.show()    


def load_pts_file(file_path: Path) -> np.ndarray:
    with open(str(file_path), 'r') as file:
        lines = file.readlines()
        
    assert lines[0].strip().startswith('version: 1'), str(file_path)
    assert lines[1] == f'n_points: {n_points}\n', str(file_path)
    
    lines = [l.strip() for l in lines]
    # remove
    # version: 1
    # n_points: 68
    # {
    lines = lines[3:]
    # remove
    # }
    lines = lines[:-1]
    points = [[float(x) for x in p.split()]
              for p in lines]
    points = np.asarray(points)
    assert points.shape == (n_points, 2)
    return points


def points_to_box(points: np.ndarray, image_size: Tuple[int, int]) -> Tuple[float, float, float, float]:
    x1, y1, x2, y2 = [points[:, 0].min(), points[:, 1].min(), points[:, 0].max(), points[:, 1].max()]
    x1, y1 = [t - padding for t in (x1, y1)]
    x2, y2 = [t + padding for t in (x2, y2)]
    box_height, box_width = y2 - y1 + 1, x2 - x1 + 1
    assert box_height > 1 and box_width > 1
    
    if expand_ratio is not None:
        box_height *= expand_ratio
        box_width *= expand_ratio
        x1, y1 = [math.floor(t - s) for t, s in zip((x1, y1), (box_width, box_height))]
        x2, y2 = [math.ceil(t + s) for t, s in zip((x2, y2), (box_width, box_height))]
    
    image_height, image_width, _ = image_size
    x1, y1 = [t if t >= 0 else 0 for t in (x1, y1)]
    x2, y2 = [t if t < m else m for t, m in zip((x2, y2), (image_width, image_height))]
    assert x1 <= x2 and y1 <= y2
    
    return x1, y1, x2, y2


def extract(image: np.ndarray, box: Tuple[float, float, float, float]) -> np.ndarray:
    x1, y1, x2, y2 = box
    face = image[y1:y2+1, x1:x2+1, ...]
    return face


def offset_points(points: np.ndarray, box: Tuple[float, float, float, float]) -> np.ndarray:
    x1, y1, x2, y2 = box
    points = copy(points)
    points[:, 0] -= x1
    points[:, 1] -= y1
    return points


all_videos = sorted([p for p in base_path.iterdir() if p.is_dir()])
assert len(all_videos) == n_videos
for video_path in tqdm(all_videos, desc='video'):
    all_frames_paths = sorted(list((video_path / 'extraction').glob('*.png')))
    all_annotations_paths = sorted(list((video_path / 'annot').glob('*.pts')))
    assert len(all_frames_paths) == len(all_annotations_paths)
    is_first_frame = True
    for frame_path, annotation_path in tqdm(list(zip(all_frames_paths, all_annotations_paths)), desc='frame', 
                                            leave=False):
        assert frame_path.stem == annotation_path.stem
        
        image = mpimg.imread(str(frame_path))
        image_points = load_pts_file(str(annotation_path))
        image_box = points_to_box(image_points, image.shape)
        face = extract(image, image_box)
        face_points = offset_points(image_points, image_box)
        
        if is_first_frame:
            is_first_frame = False
            # plot(image)
            # plot(image, image_points)
            plot(image, image_points, image_box)
            # plot(face)
            plot(face, face_points)
            break
        
    break