## Check GPU Access

In [None]:
import locale
locale.getpreferredencoding = lambda x: "UTF-8"

!nvidia-smi

## Install Libraries

In [None]:
!pip install torch torchvision opencv-python matplotlib ultralytics

In [None]:
import torch
import torchvision
import cv2
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image
from ultralytics import YOLO
from torchvision.models.detection.retinanet import RetinaNet_ResNet50_FPN_Weights
from torchvision.models.detection.faster_rcnn import FasterRCNN_ResNet50_FPN_Weights


def load_image(image_path):
    image = Image.open(image_path).convert("RGB")
    transform = transforms.Compose([
        transforms.ToTensor()
    ])
    return transform(image).unsqueeze(0)


def load_yolov9_model():
    model = YOLO('yolov9c.pt')
    return model


def load_fasterrcnn_model():
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
    model.eval()
    return model


def load_retinanet_model():
    model = torchvision.models.detection.retinanet_resnet50_fpn(weights=RetinaNet_ResNet50_FPN_Weights.DEFAULT)
    model.eval()
    return model


def run_yolov9(model, image):
    results = model.predict(image)
    boxes = results[0].boxes
    if boxes is not None:
        box_coords = boxes.xyxy
    else:
        box_coords = None
    return box_coords


def run_fasterrcnn(model, image, boxes):
    refined_boxes = []
    refined_labels = []
    refined_scores = []

    image = image.squeeze(0)

    for box in boxes:
        x1, y1, x2, y2 = map(int, box.tolist())

        cropped_image = image[:, y1:y2, x1:x2].unsqueeze(0)
        outputs = model([cropped_image])

        refined_boxes.append(outputs[0]["boxes"])
        refined_labels.append(outputs[0]["labels"])
        refined_scores.append(outputs[0]["scores"])

    return torch.cat(refined_boxes), torch.cat(refined_labels), torch.cat(refined_scores)


def run_retinanet(model, image, boxes):
    refined_boxes = []
    refined_labels = []
    refined_scores = []

    image = image.squeeze(0)

    for box in boxes:
        x1, y1, x2, y2 = map(int, box.tolist())

        cropped_image = image[:, y1:y2, x1:x2].unsqueeze(0)
        outputs = model([cropped_image])

        refined_boxes.append(outputs[0]["boxes"])
        refined_labels.append(outputs[0]["labels"])
        refined_scores.append(outputs[0]["scores"])

    return torch.cat(refined_boxes), torch.cat(refined_labels), torch.cat(refined_scores)


def plot_image_with_boxes(image, boxes, title="Detected Objects"):
    img = np.array(image.permute(1, 2, 0).cpu().numpy())
    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)

    for box in boxes:
        x1, y1, x2, y2 = map(int, box)
        img = cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)

    plt.figure(figsize=(10, 10))
    plt.title(title)
    plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    plt.show()


def ensemble_detection(image_path):
    image = load_image(image_path)

    # YOLO for initial detection
    yolo_model = load_yolov9_model()
    yolo_boxes = run_yolov9(yolo_model, image)
    print("YOLOv9 detected boxes:", yolo_boxes)
    plot_image_with_boxes(image.squeeze(0), yolo_boxes, "YOLOv9 Detections")

    # Faster R-CNN for bounding box refinement
    fasterrcnn_model = load_fasterrcnn_model()
    faster_boxes, faster_labels, faster_scores = run_fasterrcnn(fasterrcnn_model, image, yolo_boxes)
    print("Faster R-CNN refined boxes:", faster_boxes)
    plot_image_with_boxes(image.squeeze(0), faster_boxes, "Faster R-CNN Detections")

    # RetinaNet for final refinement and detection of smaller objects
    retinanet_model = load_retinanet_model()
    retinanet_boxes, retinanet_labels, retinanet_scores = run_retinanet(retinanet_model, image, faster_boxes)
    print("RetinaNet refined boxes:", retinanet_boxes)
    plot_image_with_boxes(image.squeeze(0), retinanet_boxes, "RetinaNet Detections")


image_path = "/kaggle/input/vehicle-detection-8-classes-object-detection/train/images/Highway_1007_2020-07-30_jpg.rf.b97b7d182ed136840b68dc1680a76610.jpg"
ensemble_detection(image_path)