In [None]:
# Model comparisons

In [None]:
import cv2
from face_alignment import FaceAlignment, LandmarksType
import os
import urllib.request as urlreq

from loreal_poc.dataloaders.loaders import DataLoader300W
from loreal_poc.dataloaders.wrappers import CroppedDataLoader

from loreal_poc.models.wrappers import OpenCVWrapper, FaceAlignmentWrapper
from loreal_poc.tests.performance import NMEMean
from loreal_poc.tests.base import Test, TestDiff

In [None]:
dl = DataLoader300W(dir_path="300W/sample")
chosen_idx = 4
image, ground_truth_landmarks, _ = dl[chosen_idx]

In [None]:
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
facealignment_model = FaceAlignmentWrapper(model=FaceAlignment(LandmarksType.TWO_D, device=device, flip_input=False))
opencv_model = OpenCVWrapper()

In [None]:
test = Test(metric=NMEMean, threshold=1).run(model=opencv_model, dataloader=dl)

test.to_dict()

In [None]:
from loreal_poc.marks.facial_parts import FacialParts

facial_part = FacialParts.left_half

test_diff = TestDiff(metric=NMEMean, threshold=1).run(
    model=opencv_model, dataloader=dl, dataloader_ref=CroppedDataLoader(dl, part=facial_part), facial_part=facial_part
)
test_diff.to_dict()

In [None]:
# from dataclasses import fields
def report(models, dataloader, tests, facial_parts):
    results = []
    for model in models:
        # for __facial_part in fields(FacialParts):
        #  _facial_part = __facial_part.default
        for facial_part in facial_parts:
            for test_cls in tests:
                test_result = test_cls(metric=NMEMean, threshold=1).run(
                    model=model,
                    dataloader=CroppedDataLoader(dataloader, part=facial_part),
                    dataloader_ref=dl,
                    facial_part=facial_part,
                )
                results.append(test_result.to_dict())
    return results


report = report([opencv_model], dl, [TestDiff], [FacialParts.bottom_half, FacialParts.upper_half])

In [None]:
import pandas as pd

pd.DataFrame(report)