In [1]:
project_root = 'temp/'

video_path = '../demo/case2-openfield/openfield-1min-raw.mp4'
masks_path = f'{project_root}mask.mp4'

video_align_path = f'{project_root}video-align.mp4'
mask_align_path = f'{project_root}mask-align.mp4'

body_rgb = [122, 228, 240]
tail_rgb = [255, 208, 236]
crf = 18

In [2]:
import cv2
import numpy as np
from tqdm import tqdm
from castle.utils.video_io import ReadArray, WriteArray

In [3]:
video = ReadArray(video_path)
masks = ReadArray(masks_path)
fps = video.fps
video_align = WriteArray(video_align_path, fps, crf)
masks_align = WriteArray(mask_align_path, fps, crf)

In [4]:
print(len(video), len(masks))
n = min(len(video), len(masks))

1800 1800


In [5]:
def roi_connected_components(frame, roi):
    mask = np.zeros_like(frame[:, :, 0])
    tolerance = 20
    lower_bound = np.array([roi[0] - tolerance,
                            roi[1] - tolerance,
                            roi[2] - tolerance])
    upper_bound = np.array([roi[0] + tolerance,
                            roi[1] + tolerance,
                            roi[2] + tolerance])
    within_range = cv2.inRange(frame, lower_bound, upper_bound)
    mask[within_range > 0] = 255
    output = cv2.connectedComponentsWithStats(mask, 8, cv2.CV_32S)
    num_labels = output[0]
    if num_labels == 1:
        return False, None
    
    return True, output



def get_contour(frame, roi):
    ok, connected_components = roi_connected_components(frame, roi)
    if not ok:
        return False, None
    num_labels, labels, stats, centroids = connected_components
    
    areas = [stats[j, cv2.CC_STAT_AREA] for j in range(1, num_labels)]
    maxi_comp_id = np.argmax(areas)
    selected_label = (labels == (maxi_comp_id+1)).astype(np.uint8) * 255
    _, binary_mask = cv2.threshold(selected_label, 0, 255, cv2.THRESH_BINARY)
    contour = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0][0]
    contour = contour.squeeze()
    if contour.ndim == 1:
        return False, None
    return True, contour
    
    
def get_centroids(frame, roi):
    ok, connected_components = roi_connected_components(frame, roi)
    if not ok:
        return False, None, None
    num_labels, _, stats, centroids = connected_components
    areas = [stats[i, cv2.CC_STAT_AREA] for i in range(1, num_labels)]
    max_label = np.argmax(areas)
    x = centroids[max_label + 1][0]
    y = centroids[max_label + 1][1]
    return True, x, y


def get_mask(frame, roi):
    ok, connected_components = roi_connected_components(frame, roi)
    if not ok:
        return False, None
    num_labels, labels, stats, centroids = connected_components
    areas = [stats[i, cv2.CC_STAT_AREA] for i in range(1, num_labels)]
    max_label = np.argmax(areas)
    return True,  (labels == (max_label+1))
    
    
def find_closest_points(ref, contour):
    mini = int(1e6)
    point_close = None

    for i in range(len(contour)):
        distance = cv2.norm(ref - contour[i])
        if distance < mini:
            mini = distance
            point_close = contour[i]

    return point_close, mini

In [6]:
thetas = []

crop_h, crop_w = 640, 640
raw_h, raw_w = video[0].shape[:2]
u, d = int(raw_h // 2 - crop_h // 2), int(raw_h // 2 + crop_h // 2)
l, r = int(raw_w // 2 - crop_w // 2), int(raw_w // 2 + crop_w // 2)
center = (raw_w // 2, raw_h // 2)


for i in tqdm(range(n)):

    m = masks[i]
    ok, body_x, body_y = get_centroids(m, body_rgb)
    if not ok:
        print(f'No mask at the frame id = {i}. And skip it.')
        continue
    ok, tail_contour = get_contour(m, tail_rgb)
    if not ok:
        print(f'No mask at the frame id = {i}. And skip it.')
        continue
    (tail_root_x, tail_root_y), _ = find_closest_points((body_x, body_y), tail_contour)
    theta = np.arctan2(tail_root_y - body_y, tail_root_x - body_x) * 180 / np.pi
    thetas.append(theta)
    
    f = video[i]
    matrix = np.float32([[1, 0, center[0] - body_x], [0, 1, center[1] - body_y]])
    f = cv2.warpAffine(f, matrix, (raw_w, raw_h))
    m = cv2.warpAffine(m, matrix, (raw_w, raw_h))
    matrix = cv2.getRotationMatrix2D(center, theta-90, 1.0)
    f = cv2.warpAffine(f, matrix, (raw_w, raw_h))
    m = cv2.warpAffine(m, matrix, (raw_w, raw_h))
    f = f[u:d, l:r]
    m = m[u:d, l:r]
    video_align.append(f)
    masks_align.append(m)


masks_align.close()    
video_align.close()

 82%|████████▏ | 1476/1800 [00:31<00:06, 51.06it/s]

No mask at the frame id = 1470. And skip it.
No mask at the frame id = 1471. And skip it.


100%|██████████| 1800/1800 [00:38<00:00, 46.47it/s]
