In [7]:
import torch
import torchvision
import cv2
import numpy as np
from pathlib import Path
from boxmot import BoTSORT
from sahi.models.base import DetectionModel
from sahi.predict import get_sliced_prediction

# Define a custom detection model that is compatible with SAHI
class TorchvisionDetectionModel(DetectionModel):
    def __init__(self, model, confidence_threshold=0.5, device='cpu'):
        super().__init__(confidence_threshold=confidence_threshold, device=device)
        self.model = model.to(device)
        self.device = device
        self.confidence_threshold = confidence_threshold
        self.model.eval()

    def load_model(self):
        # The model is already loaded during initialization, so we just set it
        self.set_model(self.model)

    def set_model(self, model):
        # Set the model directly
        self.model = model

    def perform_inference(self, image):
        # Convert the image to a tensor and move to the specified device
        image_tensor = torchvision.transforms.functional.to_tensor(image).unsqueeze(0).to(self.device)

        # Perform detection using the model
        with torch.no_grad():
            outputs = self.model(image_tensor)[0]

        return outputs

    def convert_original_predictions(self, outputs, original_image, shift_amount=None, full_shape=None):
        # Convert the model output to the format expected by SAHI
        detection_list = []
        for i, score in enumerate(outputs['scores']):
            if score >= self.confidence_threshold:
                x1, y1, x2, y2 = outputs['boxes'][i].cpu().numpy()  # Bounding box coordinates

                # Apply shift if shift_amount is provided
                if shift_amount is not None:
                    x1 += shift_amount[1]
                    y1 += shift_amount[0]
                    x2 += shift_amount[1]
                    y2 += shift_amount[0]

                conf = score.item()  # Confidence score
                label = outputs['labels'][i].item()  # Class label

                detection_list.append({
                    'bbox': [x1, y1, x2, y2],
                    'score': conf,
                    'category_id': label
                })

        return detection_list

# Load a pre-trained Faster R-CNN model from torchvision
device = torch.device('cpu')  # Use 'cuda' if you have a GPU
detector = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

# Wrap the torchvision model in the custom SAHI-compatible detection model
detection_model = TorchvisionDetectionModel(model=detector, device=device, confidence_threshold=0.5)

# Initialize BoTSORT Tracker
tracker = BoTSORT(
    reid_weights=Path('osnet_x0_25_msmt17.pt'),  # Path to ReID model
    device=device,  # Use CPU for inference
    half=False
)

# Open the video file (use 0 for webcam or provide a video file path)
vid = cv2.VideoCapture(0)

# Function to generate a unique color for each track ID
def get_color(track_id):
    np.random.seed(int(track_id))
    return tuple(np.random.randint(0, 255, 3).tolist())

while True:
    # Capture frame-by-frame
    ret, frame = vid.read()

    # If ret is False, it means we have reached the end of the video or there's an error
    if not ret:
        break

    # Get sliced predictions using SAHI's get_sliced_prediction
    result = get_sliced_prediction(
        frame,
        detection_model,
        slice_height=256,
        slice_width=256,
        overlap_height_ratio=0.2,
        overlap_width_ratio=0.2
    )

    # Extract detections from result
    num_predictions = len(result.object_prediction_list)
    dets = np.zeros([num_predictions, 6], dtype=np.float32)
    for ind, object_prediction in enumerate(result.object_prediction_list):
        bbox = object_prediction.bbox.to_xyxy()
        dets[ind, :4] = np.array(bbox, dtype=np.float32)
        dets[ind, 4] = object_prediction.score.value
        dets[ind, 5] = object_prediction.category.id

    # Update the tracker with the detections
    tracks = tracker.update(dets, frame)  # (M x (x1, y1, x2, y2, id, conf, cls, ind))

    # Draw the tracking results on the image
    for track in tracks:
        x1, y1, x2, y2, track_id, conf, cls = track[:7].astype('int')
        color = get_color(track_id)

        # Draw bounding box with unique color
        cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)

        # Add text with ID, confidence, and class
        cv2.putText(frame, f'ID: {track_id}, Conf: {conf:.2f}, Class: {cls}', 
                    (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)

    # Display the frame with tracking results
    cv2.imshow('BoXMOT + Torchvision with Tiled Inference', frame)

    # Simulate wait for key press to continue, press 'q' to exit
    key = cv2.waitKey(1) & 0xFF
    if key == ord(' ') or key == ord('q'):
        break

# Release the video capture and close all OpenCV windows
vid.release()
cv2.destroyAllWindows()


[32m2024-09-30 23:12:25.295[0m | [1mINFO    [0m | [36mboxmot.utils.torch_utils[0m:[36mselect_device[0m:[36m52[0m - [1mYolo Tracking v11.0.2 🚀 Python-3.11.5 torch-2.2.2CPU[0m
[32m2024-09-30 23:12:25.316[0m | [32m[1mSUCCESS [0m | [36mboxmot.appearance.reid_model_factory[0m:[36mload_pretrained_weights[0m:[36m183[0m - [32m[1mLoaded pretrained weights from osnet_x0_25_msmt17.pt[0m


Performing prediction on 60 slices.


TypeError: TorchvisionDetectionModel.convert_original_predictions() got an unexpected keyword argument 'full_shape'