In [None]:
import cv2
import facer
import numpy as np
import torch
from tqdm import tqdm
from matplotlib import pyplot as plt
from preprocessing.extract_faces import get_video_clip, save_video_lossless
device = 'cuda' if torch.cuda.is_available() else 'cpu'
video_path = "/stock/FaceForensicC23/cropped_faces/F2F/010_005.avi"
landmark_path = video_path.replace(
    "cropped_faces", "cropped_faces(landmark)"
).replace("avi", "npy")

fps, frames = get_video_clip(video_path, stride=1)

In [None]:
len(frames)

In [None]:
frames = np.stack(frames)
frames = frames.transpose((0, 3, 1, 2))
frames = torch.from_numpy(frames).to(device)
image_ids = torch.tensor([i for i in range(frames.shape[0])], device=device)

landmarks = np.load(landmark_path)
landmarks = torch.from_numpy(
    np.stack([
        np.stack([
            np.mean(landmarks[f, idxs - 16], axis=0) for idxs in [
                np.array([i for i in range(37, 43)]),
                np.array([i for i in range(43, 49)]),
                np.array([34]),
                np.array([49]),
                np.array([55])
            ]
        ]) for f in range(landmarks.shape[0])
    ])
).float().to(device)

In [None]:
frames.shape

In [None]:
landmarks.shape

In [None]:
face_parser = facer.face_parser(
    'farl/lapa/448', device=device
)  # optional "farl/celebm/448"

In [None]:
frames.shape

In [None]:
import math
math.ceil((frames.shape[0] + 1) / 35)

In [None]:
from tqdm import tqdm
result = []
bsize = 35
with torch.inference_mode():
    for i in tqdm(range(frames.shape[0] // bsize + 1)):
        batch_frames = frames[i * bsize:(i + 1) * bsize]
        batch_landmarks = landmarks[i * bsize:(i + 1) * bsize]
        assert batch_frames.shape[0] == batch_landmarks.shape[0]
        faces = face_parser(
            batch_frames,
            {
                "points": batch_landmarks,
                "image_ids": torch.arange(0, batch_landmarks.shape[0]).to(device)
            }
        )
        result.append(
            faces["seg"]["logits"].cpu()
        )

In [None]:
seg_logits = torch.cat(result, dim=0)

In [None]:
seg_probs = seg_logits.softmax(dim=1)  # nfaces x nclasses x h x w
n_classes = seg_probs.size(1)
seg_label_img = seg_probs.argmax(dim=1)

In [None]:
torch.max(seg_label_img)

In [None]:
vis_seg_probs = seg_label_img.float() / n_classes * 255
# vis_img = vis_seg_probs.sum(0, keepdim=True)

In [None]:
torch.max(vis_seg_probs)

In [None]:
torch.sum((seg_label_img[2] == vis_seg_probs[2]) == True)

In [None]:
from matplotlib import pyplot as plt
plt.imshow(vis_seg_probs[0].numpy().astype(np.uint8))
plt.show()