In [5]:
from dataclasses import dataclass, field
from typing import Optional, TypedDict, List
from functools import partial
import numpy as np
from PIL import Image, ImageDraw
import mediapipe as mp
import cv2

@dataclass
class PredictOutput:
    bboxes: list[list[int | float]] = field(default_factory=list)
    masks: list[Image.Image] = field(default_factory=list)
    preview: Optional[Image.Image] = None


def mediapipe_predict(
    model_type: str, image: Image.Image, confidence: float = 0.3
) -> PredictOutput:
    mapping = {
        "mediapipe_face_short": partial(mediapipe_face_detection, 0),
        "mediapipe_face_full": partial(mediapipe_face_detection, 1),
        "mediapipe_face_mesh": mediapipe_face_mesh,
        "mediapipe_face_mesh_eyes_only": mediapipe_face_mesh_eyes_only,
    }
    if model_type in mapping:
        func = mapping[model_type]
        return func(image, confidence)
    msg = f"[-] ADetailer: Invalid mediapipe model type: {model_type}, Available: {list(mapping.keys())!r}"
    raise RuntimeError(msg)


def mediapipe_face_detection(
    model_type: int, image: Image.Image, confidence: float = 0.3
) -> PredictOutput:
    #import mediapipe as mp

    img_width, img_height = image.size

    mp_face_detection = mp.solutions.face_detection
    draw_util = mp.solutions.drawing_utils

    img_array = np.array(image)

    with mp_face_detection.FaceDetection(
        model_selection=model_type, min_detection_confidence=confidence
    ) as face_detector:
        pred = face_detector.process(img_array)

    if pred.detections is None:
        return PredictOutput()

    preview_array = img_array.copy()

    bboxes = []
    for detection in pred.detections:
        draw_util.draw_detection(preview_array, detection)

        bbox = detection.location_data.relative_bounding_box
        x1 = bbox.xmin * img_width
        y1 = bbox.ymin * img_height
        w = bbox.width * img_width
        h = bbox.height * img_height
        x2 = x1 + w
        y2 = y1 + h

        bboxes.append([x1, y1, x2, y2])

    masks = create_mask_from_bbox(bboxes, image.size)
    preview = Image.fromarray(preview_array)

    return PredictOutput(bboxes=bboxes, masks=masks, preview=preview)


def create_mask_from_bbox(
    bboxes: list[list[float]], shape: tuple[int, int]
) -> list[Image.Image]:
    """
    Parameters
    ----------
        bboxes: list[list[float]]
            list of [x1, y1, x2, y2]
            bounding boxes
        shape: tuple[int, int]
            shape of the image (width, height)

    Returns
    -------
        masks: list[Image.Image]
        A list of masks

    """
    masks = []
    for bbox in bboxes:
        mask = Image.new("L", shape, 0)
        mask_draw = ImageDraw.Draw(mask)
        mask_draw.rectangle(bbox, fill=255)
        masks.append(mask)
    return masks



def mediapipe_face_mesh():
    pass
        
        
def mediapipe_face_mesh_eyes_only():
    pass




output_object = mediapipe_predict(
    'mediapipe_face_full',
    Image.open('images/1.jpeg')
)

masks = output_object.masks

mask_1 = masks[0]


#this shouldn't be necessary, I only want the bbox which is arealy calculated in one of the face detection functions
cv_mask = np.array(mask_1)
contours, _ = cv2.findContours(cv_mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)[-2:]
@dataclass
class BBox:
    x:int
    y: int
    w: int
    h: int
bboxes:List[BBox] = []
i = 0
for contour in contours:
    _x, _y, _w, _h = cv2.boundingRect(contour)
    bbox = BBox(_x, _y, _w, _h)
    bboxes.append(bbox)
    i += 1

print(bboxes)

#display(mask_1)

[BBox(x=588, y=326, w=333, h=333)]
