In [None]:
import torch
import torchvision.transforms as T
import torchvision.io

# paths
import os
import sys

# set paths
dirpath = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(dirpath)

# my imports
from models.SoSi_detection import SoSiDetectionModel
from utils.plot_utils import voc_img_bbox_plot

# the lifesaver
%load_ext autoreload
%autoreload 2

# torch setup
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


#### Load the model:

In [15]:
# load model
model_path = 'models\\model_savepoints\\'
model_file = 'p1_model_a_1.pth'
model_path = os.path.join(dirpath, model_path, model_file)

# build and load model
model = SoSiDetectionModel()  
sucess = model.load_state_dict(torch.load(model_path, map_location=device))
print(sucess)
model.to(device).eval();

<All keys matched successfully>


In [None]:
torchvision.io.

In [None]:
infer_transform = 

In [None]:
def get_transform(train, backbone_transforms):
    transforms = []    
    # standard transforms - resizing and center cropping for 1:1 aspect ratio and 224 size
    transforms.append(T.Resize(size = backbone_transforms.resize_size, interpolation = backbone_transforms.interpolation))
    transforms.append(T.CenterCrop(size=backbone_transforms.crop_size))
    
    # if training mode, add flips and jitters
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
        transforms.append(T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1))
    
    # standard transforms - normalizing
    transforms.append(T.ToImage())
    transforms.append(T.ToDtype(torch.float32, scale=True)) # scale to 0-1
    # TODO normalization makes everything wierd
    transforms.append(T.Normalize(mean = backbone_transforms.mean, std = backbone_transforms.std))
    
    return T.Compose(transforms)

In [None]:


# Define preprocessing
def preprocess_frame(frame, transform):
    image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB
    image = transform(image)  # Apply transformations
    return image.unsqueeze(0)  # Add batch dimension

# Define postprocessing
def postprocess_predictions(pred_bboxes, pred_labels, threshold=0.5):
    pred_labels = torch.sigmoid(pred_labels)  # Convert logits to probabilities
    pred_labels = (pred_labels > threshold).long()  # Apply threshold

    # Convert bounding boxes (if necessary)
    pred_bboxes = box_convert(pred_bboxes, in_fmt="cxcywh", out_fmt="xyxy")  # Convert (if needed)

    return pred_bboxes.cpu().numpy(), pred_labels.cpu().numpy()

# Draw bounding boxes on a frame
def draw_predictions(frame, boxes, labels):
    for box, label in zip(boxes, labels):
        x1, y1, x2, y2 = map(int, box)  # Convert to int
        color = (0, 255, 0) if label == 1 else (0, 0, 255)  # Green for class 1, Red otherwise
        cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
        cv2.putText(frame, f"Class: {label}", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
    return frame

# Run inference on video
def infer_on_video(model_path, video_path, output_path, device="cuda"):
    model = load_model(model_path, device)

    # Open video
    cap = cv2.VideoCapture(video_path)
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    # Define video writer
    fourcc = cv2.VideoWriter_fourcc(*'XVID')
    out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

    # Define preprocessing transform
    transform = T.Compose([
        T.ToPILImage(),
        T.Resize((224, 224)),  # Adjust to match model input size
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Adjust for your model
    ])

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break  # End of video

        # Preprocess frame
        input_tensor = preprocess_frame(frame, transform).to(device)

        # Run inference
        with torch.no_grad():
            pred_bboxes, pred_labels = model(input_tensor)

        # Postprocess predictions
        boxes, labels = postprocess_predictions(pred_bboxes, pred_labels)

        # Draw predictions on frame
        output_frame = draw_predictions(frame, boxes, labels)

        # Write to output video
        out.write(output_frame)

        # Display frame (optional)
        cv2.imshow("Inference", output_frame)
        if cv2.waitKey(1) & 0xFF == ord('q'):  # Press 'q' to stop
            break

    cap.release()
    out.release()
    cv2.destroyAllWindows()

# Example usage
infer_on_video("model_weights.pth", "input_video.mp4", "output_video.avi")
