In [None]:
import numpy as np
from tqdm import tqdm
from castle import generate_aot
from castle.utils.plot import generate_mix_image, generate_mask_image
from castle.utils.video_io import ReadArray, WriteArray

In [None]:
# Init DeAOT

video_path = '../demo/case2-openfield/openfield-1min-raw.mp4'
frame0 = ReadArray(video_path)[0]

frame0_mask_path = 'temp/frame0_mask.npy'
mask_full = np.load(frame0_mask_path)
num_object = np.max(mask_full)

tracker = generate_aot(model_type='r50_deaotl')
tracker.add_reference_frame(frame0, mask_full, num_object)

In [None]:
# DeAOT inference

crf = 18 # Output video quality
mix_video_path = f'temp/mix.mp4'
mask_video_path = f'temp/mask.mp4'

video = ReadArray(video_path)
mix_video = WriteArray(mix_video_path, video.fps, crf)
mask_video = WriteArray(mask_video_path, video.fps, crf)

n = len(video)
for i in tqdm(range(n)):
    frame = video[i]
    mask = tracker.track(frame)
    mask = mask.squeeze().detach().cpu().numpy().astype(np.uint8)
    mix_img = generate_mix_image(frame, mask)
    mix_video.append(mix_img)
    mask_img = generate_mask_image(mask)
    mask_video.append(mask_img)

mix_video.close()
mask_video.close()