In [38]:
import torch
from torchvision.transforms import v2 as T
import torchvision.io

# paths
import os
import sys

# set paths
dirpath = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(dirpath)

# my imports
from models.SoSi_detection import SoSiDetectionModel
from utils.plot_utils import unnormalize, voc_img_bbox_plot

# the lifesaver
%load_ext autoreload
%autoreload 2

# torch setup
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


#### Load the model:

In [39]:
# load model
model_path = 'models\\model_savepoints\\'
model_file = 'p2_model_Feb-19_18-09-20.pth'
model_path = os.path.join(dirpath, model_path, model_file)

# build and load model
model = SoSiDetectionModel()  
sucess = model.load_state_dict(torch.load(model_path, map_location=device))
print(sucess)
model.to(device).eval();

RuntimeError: Error(s) in loading state_dict for SoSiDetectionModel:
	Missing key(s) in state_dict: "bbox_head.1.weight", "bbox_head.1.bias", "bbox_head.3.weight", "bbox_head.3.bias", "bbox_head.5.weight", "bbox_head.5.bias", "bbox_head.7.weight", "bbox_head.7.bias". 
	Unexpected key(s) in state_dict: "bbox_head.0.weight", "bbox_head.0.bias", "bbox_head.4.weight", "bbox_head.4.bias". 

#### Define Transforms for inferences

In [30]:
def get_transform(backbone_transforms):
    transforms = []    
    # standard transforms - resizing and center cropping for 1:1 aspect ratio and 224 size
    transforms.append(T.Resize(size = backbone_transforms.resize_size, interpolation = backbone_transforms.interpolation))
    transforms.append(T.CenterCrop(size=backbone_transforms.crop_size))
    
    # standard transforms - normalizing
    transforms.append(T.ToImage())
    transforms.append(T.ToDtype(torch.float32, scale=True)) # scale to 0-1
    transforms.append(T.Normalize(mean = backbone_transforms.mean, std = backbone_transforms.std))
    
    return T.Compose(transforms)

In [31]:
preprocess = get_transform(model.backbone_transforms())

#### Get video file:

In [None]:
video_file = os.path.join(dirpath, 'inference\\kittens_video.mp4')

# calculate end time and time jump
video_end_sec = 1
video_fps = 30
batch_size = 128
video_jump = batch_size / video_fps

In [None]:
video_start_sec = 0
video_frames = []

# loop on frames
while video_start_sec < video_end_sec:
    # read frames
    frames = torchvision.io.read_video(filename = video_file, 
                                    start_pts = video_start_sec, end_pts = min(video_start_sec + video_jump, video_end_sec),
                                    output_format="TCHW", pts_unit = 'sec')
    video_start_sec += video_jump
    # if no frames read, break the loop
    if frames.numel() == 0:  
        break
    # preprocess
    frames_preproces = preprocess(frames)
    # infer
    preds = model(frames)
    # TODO Rescale video etc
    # plot bbox
    frames_un_norm = unnormalize(frames_preproces)
    frames_boxes = voc_img_bbox_plot(frames_un_norm, preds['boxes'], preds['labels'])
    video_frames.append(frames_boxes)

In [None]:
# write video
video_output = torch.stack(video_frames)
torchvision.io.write_video()