In [1]:
import torch
from torchvision.io import read_video, write_video
from torchvision.transforms import v2 as T

# paths
import os
import sys

# set paths
dirpath = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(dirpath)

# my imports
from models.SoSi_detection import SoSiDetectionModel  # noqa: E402
from utils.plot_utils import unnormalize, voc_img_bbox_plot  # noqa: E402

# the lifesaver
%load_ext autoreload
%autoreload 2

# torch setup
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

#### Load the model:

In [2]:
# load model
model_path = 'models\\model_savepoints\\'
model_file = 'p02_model_Feb-25_16-26-11.pth'
model_path = os.path.join(dirpath, model_path, model_file)

# build and load model
model = SoSiDetectionModel()  
sucess = model.load_state_dict(torch.load(model_path, map_location=device))
print(sucess)
model.to(device).eval();

<All keys matched successfully>


#### Get video file:

In [3]:
video_file = os.path.join(dirpath, 'inference\\cats_wild.mp4')

# video params
video_h, video_w = 360, 640
video_fps = 30

# calculate end time and time jump for inference
video_start_sec = 60*3
video_end_sec = 60*5
batch_size = 256 
video_jump = batch_size / video_fps

#### Define Transforms for inferences

In [4]:
# get model transforms
backbone_transforms = model.backbone_transforms()

# pre-procesing transform
preprocess = T.Compose([
    # standard transforms - resizing and center cropping for 1:1 aspect ratio and 224 size
    T.Resize(size = backbone_transforms.resize_size, interpolation = backbone_transforms.interpolation),
    T.CenterCrop(size=backbone_transforms.crop_size),
    
    # standard transforms - normalizing
    T.ToImage(),
    T.ToDtype(torch.float32, scale=True),
    T.Normalize(mean = backbone_transforms.mean, std = backbone_transforms.std)
])

# # post-processing transform
postprocess_bbox = T.Compose([
    # standard transforms - resize bbox to original
    T.Resize(size = min(video_h, video_w)),
])

#### Video Inference loop

In [5]:
video_frames = []
video_current_start = video_start_sec

# loop on frames
while video_current_start < video_end_sec:
    # read frames
    frames, _ ,_  = read_video(filename = video_file, 
                                    start_pts = video_current_start, end_pts = min(video_current_start + video_jump, video_end_sec),
                                    output_format="TCHW", pts_unit = 'sec')
    video_current_start += video_jump
    frames = frames.to(device)
    # if no frames read, break the loop
    if frames.numel() == 0:  
        break
    # preprocess
    frames_preproces = preprocess(frames).to(device)
    # infer
    pred_boxes, pred_labels_logits = model(frames_preproces)
    # compute the labels
    pred_labels_bools = torch.sigmoid(pred_labels_logits).squeeze().gt(0.5).tolist()
    pred_labels_str = ["cat" if pred_label else "none" for pred_label in pred_labels_bools]
    # unnormalize TODO Temporary!
    frames_un_norm = unnormalize(frames_preproces).cpu()
    # plot bbox on each frame
    frames_boxes_list = [voc_img_bbox_plot(frames_un_norm[idx], pred_boxes[idx], [pred_labels_str[idx]]) for idx in range(len(frames))]
    # take only the frames for which the label is not "none"
    frames_selected = [(frames_un_norm[idx] * 255).to(torch.uint8) if pred_labels_bools[idx] is False else frames_boxes_list[idx] for idx in range(len(frames))]
    # add to list of all frames
    video_frames.extend(frames_selected)

#### Video Output

In [6]:
video_out_file = os.path.join(dirpath, f'inference\\cats_wild_sosi_{model_file[:-4]}.mp4')
# write video
video_output = torch.stack(video_frames).permute(0, 2, 3, 1).cpu()
write_video(video_out_file, video_output, fps=video_fps, video_codec='libx264')