# Face detection

Although Medusa's main focus is 3D/4D reconstruction, it also contains functionality for face detection, facial landmark prediction, and cropping as these steps often need to be performed before feeding images into reconstruction models. 

In this tutorial, we will demonstrate Medusa's face detection functionality.

## Detection on (single) images

The first step in many face analysis pipelines is _face detection_. Medusa contains two classes that perform face detection:

* `SCRFDetector`: a detection model based on InsightFace's SCRFD model `citep`{};
* `YunetDetector`: a detection model implemented in OpenCV

We recommend using the `SCRFDetector` as our experience is that it is substantially more accurate than `YunetDetector` (albeit a bit slower when run on CPU); if you want to use the `YunetDetector`, make sure to install OpenCV first (`pip install python-opencv`). So we'll use the `SCRFDetector` for the rest of this section.


In [None]:
from medusa.detect import SCRFDetector

Under the hood, `SCRFDetector` uses an ONNX model provided by InsightFace, but our implementation is quite a bit faster than the original InsightFace implementation as ours uses PyTorch throughout (rather than a mix of PyTorch and numpy).

The `SCRFDetector` takes the following inputs upon initialization:

* `det_size`: size to resize images to before passing to the detection model;
* `det_threshold`: minimum detection threshold (float between 0-1)
* `nms_threshold`: non-maximum suppression threshold (boxes overlapping more than this proportion are removed)
* `device`: either "cpu" or "cuda" (determined automatically by default)

The most important arguments are `det_size` and `det_threshold`; a higher `det_size` (a tuple with two integers, width x height) leads to potentially more accurate detections but slower processing; increasing `det_threshold` leads to more conservative detections (fewer false alarms, but more misses) and vice versa.

In our experience, the defaults are fine for most images/videos:

In [None]:
detector = SCRFDetector()

Now let's apply it to some example data. We'll use a single frame from out example video:

In [None]:
from medusa.data import get_example_image
img = get_example_image(load=True)

Here, `img` represents is loaded as a PyTorch tensor, but the detectors in Medusa can deal with paths to images or numpy arrays, too. Now, to process this image with the detector, we'll call the `detector` object as if it is a function (which internally triggers the `__call__` method):

In [None]:
det = detector(img)

The output of the detector call, `det`, contains a dictionary with information:

In [None]:
det.keys()

Notably, all values of the dictionary are PyTorch tensors. The most important keys are:

* `conf`: the confidence of each detection (0-1)
* `lms`: a set of five landmark coordinates per detection
* `bbox`: a bounding box per detection

Let's take a look at `conf`:

In [None]:
conf = det['conf']
print(f"Conf: {conf.item():.3f}, shape: {tuple(conf.shape)}")

So for this image, there is only one detection with a confidence of 0.884. Note that there may be more than one detection per image when there are more faces in the image! 

Now, let's also take a look at the bounding box for the detection:

In [None]:
det['bbox']

The bounding box contains four values (in pixel units) that represent the box' mimimum x-value, minimum y-value, maximum x-value, and maximum y-value (in that order). We can in fact visualize this bounding box quite straightforwardly using `torchvision`:

In [None]:
import torch
from IPython.display import Image
from torchvision.utils import draw_bounding_boxes, save_image

# Note that `draw_bounding_boxes` expects the img to be in C x H x W format and uint8, 
# so squeeze out batch dimension and convert float32 to uint8
red = (255, 0, 0)
img = img.squeeze(0).to(torch.uint8)
img_draw = draw_bounding_boxes(img, det['bbox'], colors=red, width=2)

# Save image to disk and display in notebook
save_image(img_draw.float(), './viz/bbox.png', normalize=True)
Image('./viz/bbox.png')

Looks like a proper bounding box! Now, let's finally look at the predicted facial landmarks:

In [None]:
det['lms']  # B x 5 x 2

As you can see, each detection also comes with 5 landmarks consisting of two values (one for X, one for Y) in pixel units. As we'll show below (again, using `torchvision`), these landmarks refer to the left eye, right eye, tip of the nose, left mouthcorner, and right mouth corner:

In [None]:
from torchvision.utils import draw_keypoints

# Note that `draw_keypoints` also expects the img to be in C x H x W format
img_draw = draw_keypoints(img, det['lms'], colors=red, radius=4)

# Save image to disk and display in notebook
save_image(img_draw.float(), './viz/lms.png', normalize=True)
Image('./viz/lms.png')

## Detection on batches of images

Thus far, we only applied face detection to a single image, but Medusa's face detectors also work on batches of images such that it can be easily used to process video data, which gives us a good excuse to showcase Medusa's powerful `BatchResults` class (explained later).

Let's try this out on our example video, which we load in batches using Medusa's `VideoLoader`:

In [None]:
from medusa.data import get_example_video
from medusa.io import VideoLoader

vid = get_example_video()
loader = VideoLoader(vid, batch_size=64)

# The loader can be used as an iterator (e.g. in a for loop), but here we only
# load in a single batch; note that we always need to move the data to the desired
# device (CPU or GPU)
batch = next(iter(loader))
batch = batch.to(loader.device)

# B (batch size) x C (channels) x H (height) x W (width)
print(batch.shape)

Initialize the detector as usual and call it on the batch of images like we did on a single image:

In [None]:
detector = SCRFDetector()
out = detector(batch)

print(out.keys())
print(out['bbox'].shape)

To visualize the detection results of this batch of images, we could write a for-loop and use `torchvision` to create for each face/detection and image with the bounding box and face landmarks, but Medusa has a specialized class for this type of batch data to make aggregation and visualization easier:

In [None]:
from medusa.containers import BatchResults
from IPython.display import Video

# `BatchResults` takes any output from a detector model (or crop model) ...
results = BatchResults(**out)

# ... which it'll then visualize as a video (or, if video=False, a set of images)
results.visualize('./viz/test.mp4', batch, video=True, fps=loader._metadata['fps'])

# Embed in notebook
Video('./viz/test.mp4', embed=True)

The `BatchResults` class is especially useful when dealing with multiple batches of images (which will be the case for most videos). When dealing with multiple batches, initialize an "empty" `BatchResults` object before any processing, and then in each iteration call its `add` method with the results from the detector.

Here, we show an example for three consecutive batches; note that `BatchResults` will store anything you give it, so here we're also giving it the raw images (using the `images=batch`):

In [None]:
loader = VideoLoader(get_example_video(n_faces=2))
results = BatchResults()
for i, batch in enumerate(loader):
    batch = batch.to(loader.device)
    out = detector(batch)
    results.add(images=batch, **out)

    if i == 2:
        break

Right now, the `results` object contains for each detection attribute (like `lms`, `conf`, `bbox`, etc) a list with one value for each batch:

In [None]:
# List of length 3 (batches), with each 64 values (batch size)
results.conf

We can concatenate everything by calling the object's `concat` method:

In [None]:
results.concat()
results.conf.shape  # 192 (= 3 * 64)

Now, we can visualize the results as before (note that we give it the raw images as well):

In [None]:
results.visualize('./viz/test.mp4', results.images, video=True, fps=loader._metadata['fps'])
Video('./viz/test.mp4', embed=True)