In [None]:
import set_path
import torch
import pandas as pd
import cv2
import numpy as np
import matplotlib.pyplot as plt
import yaml
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction, predict
from typing import List, Tuple

In [None]:
model_num = 8
model_dir = f"runs/detect/train{model_num}"
model_path = f"{model_dir}/weights/best.pt"
print(model_path)
detection_model = AutoDetectionModel.from_pretrained(
                                                    model_type="yolov8",
                                                    model_path=model_path,
                                                    confidence_threshold=0.3,
                                                    device="cpu" if not torch.cuda.is_available() else "cuda:0"
                                                    )
with open(f"{model_dir}/args.yaml", "r") as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

In [None]:
image_name = "test_33.jpg"
image_path = f"data/filtered_images/{image_name}"
annotations_path = "data/annotations/annotations_test.csv"
annotations = pd.read_csv(annotations_path)
annotations.columns = ["image_name", "x1", "y1", "x2", "y2", "class", "image_width", "image_height"]

In [None]:
annotations_for_image = annotations[annotations["image_name"] == image_name]
annotations_for_image

In [None]:
def load_image(img_path:str) -> np.ndarray:
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img

def draw_image_with_annotations(image:np.ndarray, annotations:List[Tuple[int, int, int, int, str]]) -> np.ndarray:
    for annotation in annotations:
        x1, y1, x2, y2, class_name = annotation
        cv2.rectangle(image, (x1, y1), (x2, y2), (255, 0, 0), 2)
        cv2.putText(image, class_name, (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    plt.axis("off")
    plt.show()

Inference on a single image

In [None]:
image = load_image(image_path)

In [None]:

result = get_sliced_prediction(
                                image=image_path,
                                detection_model=detection_model,
                                slice_width=config["imgsz"], # Use the same slice width as the image size
                                slice_height=config["imgsz"], # Use the same slice width as the image size
                                overlap_height_ratio=0.4,
                                overlap_width_ratio=0.4
                                )

def extract_annotation(annotation_dict):
    bbox = annotation_dict["bbox"]
    x1, y1, w, h = bbox
    predicted_class = annotation_dict["category_name"]
    x2 = x1 + w
    y2 = y1 + h
    return [int(x1), int(y1), int(x2), int(y2), predicted_class]

predictions = [extract_annotation(annotation) for annotation in result.to_coco_predictions()]
predictions

In [None]:
gt_annotations = annotations_for_image[["x1", "y1", "x2", "y2", "class"]].values.tolist()
gt_annotations

In [None]:
draw_image_with_annotations(image=image, annotations=gt_annotations)

In [None]:
draw_image_with_annotations(image=image, annotations=predictions)

In [None]:
print(len(gt_annotations), len(predictions))