In [3]:
# pytorch
import torch
import torch.nn as nn
import torchvision
from torch import Tensor
# model
from torchvision.models import efficientnet_v2_m, EfficientNet_V2_M_Weights
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.transforms.functional import to_pil_image
from torchvision.utils import draw_bounding_boxes

device = 'cuda' if torch.cuda.is_available() else 'cpu'

##dataset

# other
import IPython.display as display
import numpy as np
from typing import Tuple, List

In [14]:
## person model

person_weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
person_model = fasterrcnn_resnet50_fpn_v2(weights=person_weights, box_score_thresh=0.9)
person_model.eval()

# Step 2: Initialize the inference transforms
person_preprocess = person_weights.transforms()

In [8]:
## scooter model
scooter_weight_path = "./scooter_model.pth"
scooter_weights = EfficientNet_V2_M_Weights.DEFAULT
scooter_preprocess = scooter_weights.transforms()
scooter_model = efficientnet_v2_m(weights=scooter_weights)
scooter_model.classifier[-1] = nn.Linear(1280, 1)
scooter_model = nn.DataParallel(scooter_model, device_ids=[0, 1, 2, 3])
scooter_model.load_state_dict(torch.load(scooter_weight_path, map_location=device))
scooter_model.eval()


<PIL.Image.Image image mode=RGB size=1058x466 at 0x2B4D1FF74220>


In [None]:
## helmet model
helmet_weight_path = "./helmet_model.pth"
helmet_weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
helmet_preprocess = helmet_weights.transforms()
helmet_model = torchvision.models.detection.fasterrcnn_resnet50_fpn_v2(weights=helmet_weights, box_score_thresh=0.9)
in_features = helmet_model.roi_heads.box_predictor.cls_score.in_features
helmet_model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 8)
helmet_model.load_state_dict(torch.load(helmet_weight_path, map_location=device))
helmet_model.eval()

In [None]:
def resize_box(box: Tensor, factor: float = 0.3) -> Tensor:
    """Resize the box coordinates."""
    box[0] -= int((box[2] - box[0]) * factor)
    box[1] -= int((box[3] - box[1]) * factor)
    box[2] += int((box[2] - box[0]) * factor)
    box[3] += int((box[3] - box[1]) * factor)
    return box


def crop_and_transform(image: Tensor, box: Tensor, device: str) -> Tensor:
    """Crop the image using the given box coordinates and transform it to a tensor."""
    box_image = torchvision.transforms.functional.crop(image, box[1], box[0], box[3] - box[1], box[2] - box[0])
    return box_image


def change_image_for_scooter(image_tensor: Tensor) -> Tensor:
    """Resize the image tensor for scooter model."""
    preprocess_image = scooter_preprocess(image_tensor).squeeze(0).to(device)
    return preprocess_image


def is_person_valid(label: int) -> bool:
    """Check if the label corresponds to a valid person."""
    return label == 1


def change_image_for_helmet(image_tensor: Tensor) -> Tensor:
    """Resize the image tensor for helmet model."""
    preprocess_image = helmet_preprocess(image_tensor).squeeze(0).to(device)
    return preprocess_image


def is_helmet_present(model, image_tensor: Tensor) -> bool:
    """Check if a helmet is present in the image."""
    result = model(change_image_for_helmet(image_tensor))
    helmet_labels = result[0]["labels"].tolist()
    return 1 in helmet_labels or 2 in helmet_labels or 3 in helmet_labels


def is_scooter_present(model, image_tensor: Tensor) -> bool:
    """Check if a scooter is present in the image."""
    result = model(change_image_for_scooter(image_tensor))
    predicted = (torch.sigmoid(result) > 0.2).float().item()
    return predicted == 1


def inference(
        helmet_model,
        person_model,
        scooter_model,
        image: Tensor,
        device: str = "cpu",
) -> Tuple[List[np.ndarray], List[int], List[float]]:
    batch = [person_preprocess(image)]
    person_result = person_model(batch)[0]

    boxes: Tensor = person_result[0]["boxes"]
    labels: Tensor = person_result[0]["labels"]
    scores: Tensor = person_result[0]["scores"]

    valid_indices: List[int] = [
        index
        for index, (label, score) in enumerate(zip(labels, scores))
        if is_person_valid(label)
    ]

    for index in valid_indices:
        box = resize_box(boxes[index])
        box_image_tensor = crop_and_transform(image, box, device)

        if is_scooter_present(scooter_model, box_image_tensor):
            if is_helmet_present(helmet_model, box_image_tensor):
                labels[index] = 2
            else:
                labels[index] = 3
    return (
        [boxes[i] for i in valid_indices],
        [labels[i] for i in valid_indices],
        [scores[i] for i in valid_indices],
    )


def draw_box(
        img: Tensor,
        boxes: List[Tensor],
        labels: List[str],
):
    box = draw_bounding_boxes(img, boxes=boxes,
                              labels=labels,
                              colors="red",
                              fill=True,
                              width=4, font_size=40)
    im = to_pil_image(box.detach())
    display.display(im)
num_frames = 100
video_path = "./test_video.mp4"
output_path = "./result_video.mp4"
video = torchvision.io.read_video(video_path, start_pts=0, end_pts=num_frames)

In [None]:
result_video = []
for frame_idx in range(num_frames):
    frame = video[frame_idx].to(device)
    result = inference(helmet_model,
                       person_model=person_model,
                       scooter_model=scooter_model,
                       image=frame,
                       device=device)
    if len(result[0]) == 0:
        continue
    d_box = draw_box(frame, boxes=torch.stack(result[0]), labels=result[1])
    result_video.append(d_box.permute(1, 2, 0))

torchvision.iowrite_video(output_path, result_video, info['video_fps'])