In [None]:
import torch
import torchvision
import cv2
import numpy as np

In [None]:
hf_path = 'https://huggingface.co/jspark2000/yolov5-vehicle/resolve/main/best.pt'
yolov5_model = torch.hub.load('ultralytics/yolov5', 'custom', path=hf_path, force_reload=True)

mask_rcnn_model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
mask_rcnn_model.eval()


def load_image(image_path):
    image = cv2.imread(image_path)
    if image is None:
        raise FileNotFoundError(f"Image file not found at {image_path}")
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    return image_rgb


def detect_objects(image):
    results = yolov5_model(image)
    return results.xyxy[0].cpu().numpy()


def segment_objects(image, boxes, target_size=(224, 224)):
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToPILImage(),
        torchvision.transforms.Resize(target_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR),
        torchvision.transforms.ToTensor()
    ])

    image_tensor = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
    masks = []

    for _, box in enumerate(boxes):
        x1, y1, x2, y2 = map(int, box[:4])
        cropped_image = image_tensor[:, y1:y2, x1:x2]
        resized_image = transform(cropped_image.permute(1, 2, 0).numpy())
        resized_image = resized_image.unsqueeze(0)

        with torch.no_grad():
            output = mask_rcnn_model(resized_image)

        mask = output[0]['masks'][0, 0].mul(255).byte().cpu().numpy()
        print(output)
        masks.append((x1, y1, x2, y2, mask))

    return masks


def draw_boxes_and_masks(image, boxes, masks):
    heights = []

    for _, (_, mask_info) in enumerate(zip(boxes, masks)):
        x1, y1, x2, y2, mask = mask_info

        # colored_mask = np.zeros_like(image, dtype=np.uint8)
        # colored_mask[y1:y2, x1:x2][mask > 127] = (0, 255, 0)

        # image = cv2.addWeighted(image, 1, colored_mask, 0.5, 0)
        vehicle_height_pixels = y2 - y1

        heights.append(vehicle_height_pixels)

    return image, heights


def main(image_path):    
    image_rgb = load_image(image_path)
    boxes = detect_objects(image_rgb)
    vehicle_boxes = [box for box in boxes if int(box[5]) in [1, 10]]
    masks = segment_objects(image_rgb, vehicle_boxes)
    image_with_masks, heights = draw_boxes_and_masks(image_rgb, vehicle_boxes, masks)

    output_image = cv2.cvtColor(image_with_masks, cv2.COLOR_RGB2BGR)
    cv2.imwrite('../images/results/output_image_with_masks.jpg', output_image)

    print(heights)

In [None]:
main("../images/sample.jpg")