In [1]:
import torch
from torchvision.io import read_video
from torchvision.transforms import v2 as T

# paths
import os
import sys

# set paths
dirpath = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(dirpath)

# my imports
from models.SoSi_detection import SoSiDetectionModel
from utils.plot_utils import unnormalize, voc_img_bbox_plot

# the lifesaver
%load_ext autoreload
%autoreload 2

# torch setup
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

#### Load the model:

In [2]:
# load model
model_path = 'models\\model_savepoints\\'
model_file = 'p2_model_Feb-19_18-09-20.pth'
model_path = os.path.join(dirpath, model_path, model_file)

# build and load model
model = SoSiDetectionModel()  
sucess = model.load_state_dict(torch.load(model_path, map_location=device))
print(sucess)
model.to(device).eval();

<All keys matched successfully>


#### Get video file:

In [3]:
video_file = os.path.join(dirpath, 'inference\\kittens_video.mp4')

# get video h, w
video_h, video_w = 360, 640

#### Define Transforms for inferences

In [4]:
# get model transforms
backbone_transforms = model.backbone_transforms()

# pre-procesing transform
preprocess = T.Compose([
    # standard transforms - resizing and center cropping for 1:1 aspect ratio and 224 size
    T.Resize(size = backbone_transforms.resize_size, interpolation = backbone_transforms.interpolation),
    T.CenterCrop(size=backbone_transforms.crop_size),
    
    # standard transforms - normalizing
    T.ToImage(),
    T.ToDtype(torch.float32, scale=True),
    T.Normalize(mean = backbone_transforms.mean, std = backbone_transforms.std)
])

# post-processing transform


#### Video Inference loop

In [5]:
# calculate end time and time jump
video_end_sec = 1
video_fps = 30
batch_size = 128
video_jump = batch_size / video_fps

In [None]:
video_start_sec = 0
video_frames = []

# loop on frames
while video_start_sec < video_end_sec:
    # read frames
    frames, _ ,_  = read_video(filename = video_file, 
                                    start_pts = video_start_sec, end_pts = min(video_start_sec + video_jump, video_end_sec),
                                    output_format="TCHW", pts_unit = 'sec')
    frames = frames.to(device)
    video_start_sec += video_jump
    # if no frames read, break the loop
    if frames.numel() == 0:  
        break
    # preprocess
    frames_preproces = preprocess(frames)
    # infer
    preds = model(frames)
    # TODO Rescale video etc
    # plot bbox
    frames_un_norm = unnormalize(frames_preproces)
    frames_boxes = voc_img_bbox_plot(frames_un_norm, preds['boxes'], preds['labels'])
    video_frames.append(frames_boxes)

Value(False)


In [None]:
# write video
video_output = torch.stack(video_frames)
torchvision.io.write_video()