## Implementing YOLOv11

In [None]:
import torch
import cv2
import os
from ultralytics import YOLO
import matplotlib.pyplot as plt

In [None]:
model = YOLO("yolo11n.pt")

if torch.cuda.is_available():
    device = "cuda"  
elif torch.backends.mps.is_available():
    device = "mps"   
else:
    device = "cpu"   
print(device)

In [None]:
train_results = model.train(
    data="cavity_data/data.yaml",  
    epochs=150, 
    imgsz=640, 
    device=device,  
)

In [None]:
metrics = model.val(split="test")

In [None]:
results = {
    "mAP50": float(metrics.box.map50),   
    "Precision": float(metrics.box.p.mean()),  
    "Recall": float(metrics.box.r.mean()),     
}

print("\nFINAL RESULTS: \n")
for k, v in results.items():
    print(f"{k:>10}: {v:.4f}")

In [None]:
img_path = "cavity_data/images/test/test_000000.jpg"
img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)

label_path = img_path.replace("images", "labels").replace(".jpg", ".txt")
h, w, _ = img.shape
img_gt = img.copy()

if os.path.exists(label_path):
    with open(label_path) as f:
        for line in f:
            cls, xc, yc, bw, bh = map(float, line.strip().split())
            x1 = int((xc - bw/2) * w)
            y1 = int((yc - bh/2) * h)
            x2 = int((xc + bw/2) * w)
            y2 = int((yc + bh/2) * h)
            cv2.rectangle(img_gt, (x1, y1), (x2, y2), (0,255,0), 2)
            cv2.putText(img_gt, str(int(cls)), (x1, y1-5), 
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,255,0), 2)

results = model(img_path)
boxes = results[0].boxes.xyxy.cpu().numpy()   
classes = results[0].boxes.cls.cpu().numpy()  

img_pred = img.copy()
for box, cls in zip(boxes, classes):
    x1, y1, x2, y2 = map(int, box)
    cv2.rectangle(img_pred, (x1, y1), (x2, y2), (255,0,0), 2)
    cv2.putText(img_pred, str(int(cls)), (x1, y1-5), 
                cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,0,0), 2)

plt.figure(figsize=(12,6))

plt.subplot(1,2,1)
plt.title("Ground Truth")
plt.imshow(img_gt)
plt.axis("off")

plt.subplot(1,2,2)
plt.title("Predictions")
plt.imshow(img_pred)
plt.axis("off")

plt.show()
