# Utils

In [None]:
#| default_exp utils

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
def fix_notebook_widgets():
    """Taken from https://github.com/microsoft/vscode-jupyter/issues/13163"""
    from IPython.display import clear_output, DisplayHandle
    def update_patch(self, obj):
        clear_output(wait=True)
        self.display(obj)
    DisplayHandle.update = update_patch

In [None]:
#| export
from fastcore.all import *
from fastdownload import *
from fastai.vision.all import *

In [None]:
#| export
def data_path():
    return fastai_path('data')

@delegates(FastDownload.download)
def fetch_file(*args, **kwargs):
    return FastDownload(fastai_cfg()).download(*args, **kwargs)

In [None]:
#| export
def return_list(f):
    @delegates(f)
    def wrapper(*args, **kwargs):
        return list(f(*args, **kwargs))
    return wrapper

In [None]:
@return_list
def foo(n):
    yield from range(n)

test_eq(foo(10), list(range(10)))

In [None]:
#| export
"""
based on https://github.com/timesler/facenet-pytorch/blob/master/examples/lfw_evaluate.ipynb
"""
import os

import PIL
import torch
from facenet_pytorch import MTCNN, training
from fastprogress.fastprogress import *


def mtcnn_aligned(path: Path,  # path to unaligned images
                  force=False,  # compute MTCNN alignment even if aligned images exist
                  batched=True
                  ) -> Path:   # path to aligned images
    """Uses MTCNN to align and extract faces"""
    mtcnn_path = path.with_name(path.name+'_mtcnn')
    if not force and mtcnn_path.exists():
        return mtcnn_path

    mtcnn = MTCNN(
        image_size=160,
        margin=14,
        device=default_device(),
        selection_method='center_weighted_size'
    )

    loader = torch.utils.data.DataLoader(
        Datasets(get_image_files(path), [PIL.Image.open, noop]),
        num_workers=0 if os.name == 'nt' else 8,
        batch_size=64 if batched else 1,
        collate_fn=training.collate_pil
    )

    for imgs, paths in progress_bar(loader, comment='MTCNN'):
        output_paths = [mtcnn_path/p.relative_to(path) for p in paths]
        mtcnn(imgs, save_path=output_paths)

    return mtcnn_path

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()