In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from dataset import RadarCameraYoloDataset
from models import CameraYOLO
from utils import bbox_iou

In [None]:
# ✅ 학습된 모델 불러오기
model = CameraYOLO(num_classes=7)  # 동일한 모델 구조 사용
model.load_state_dict(torch.load("trained_model.pth"))
model.eval()  # 모델을 평가 모드로 변경

In [None]:
dataset = RadarCameraYoloDataset(data_root="../../WaterScenes/sample_dataset/")
image, radar, labels = dataset[5]
print(image.shape) # B, C, H, W
print(radar.shape)

image = image.unsqueeze(0)
H, W = image.shape[2], image.shape[3]

In [None]:
# ✅ 모델을 사용하여 예측 수행
with torch.no_grad():
    predictions = model(image)  # (1, N, num_classes + 5) 형태

# ✅ 예측 결과 파싱
predictions = predictions.squeeze(0).numpy()  # (N, num_classes + 5)
pred_boxes = predictions[:, :4]  # (N, 4) → x_center, y_center, width, height
obj_confidences = predictions[:, 4]  # (N,) → Objectness Score
class_probs = predictions[:, 5:]  # (N, num_classes) → 클래스 확률

# ✅ 예측된 클래스 결정 (가장 높은 확률 가진 클래스 선택)
pred_class_ids = np.argmax(class_probs, axis=1)  # (N,)

# # ✅ 특정 Confidence 이상인 박스만 필터링
# confidence_threshold = 0.02
# filtered_indices = obj_confidences > confidence_threshold
# pred_boxes = pred_boxes[filtered_indices]
# pred_class_ids = pred_class_ids[filtered_indices]
# print(pred_boxes.shape)
# print(pred_class_ids.shape)

In [None]:
# ✅ GT 바운딩 박스 변환 (YOLO → 픽셀)
gt_boxes = np.array(labels)[:, 1:5]  # (M, 4)
gt_boxes[:, [0, 2]] *= W  # x_center, width 조정
gt_boxes[:, [1, 3]] *= H  # y_center, height 조정

print(gt_boxes[0])

# ✅ 예측 바운딩 박스 변환 (YOLO → 픽셀)
pred_boxes[:, [0, 2]] *= W  # x_center, width 조정
pred_boxes[:, [1, 3]] *= H  # y_center, height 조정

print(pred_boxes[0])


In [None]:
# ✅ GT와 예측 박스 간 IoU 계산
iou_matrix = bbox_iou(pred_boxes, gt_boxes)  # (N, M)
best_iou, best_pred_idx = iou_matrix.max(axis=0)  # (M,)
print(best_iou)

matched_preds = pred_boxes[best_pred_idx]
matched_classes = pred_class_ids[best_pred_idx]

In [21]:
# ✅ 시각화
image = np.transpose(image.squeeze(0).numpy(), (1, 2, 0))  # (H, W, C)
fig, ax = plt.subplots(1, figsize=(8, 4))
ax.imshow(image)

# 🔴 1️⃣ Ground Truth 박스 (빨간색)
for lbl in labels:
    class_id, x_center, y_center, width, height = lbl.numpy()
    x_min = (x_center - width / 2) * W
    y_min = (y_center - height / 2) * H
    box_w = width * W
    box_h = height * H

    rect = patches.Rectangle(
        (x_min, y_min), box_w, box_h, linewidth=2, edgecolor="red", facecolor="none"
    )
    ax.add_patch(rect)
    ax.text(x_min, y_min - 5, f"GT {int(class_id)}", color="red", fontsize=10, weight="bold")

# 🔵 2️⃣ IoU 높은 예측 박스 (파란색)
for i in range(len(matched_preds)):
    x_center, y_center, width, height = matched_preds[i]
    x_min = x_center - width / 2
    y_min = y_center - height / 2
    box_w = width
    box_h = height

    rect = patches.Rectangle(
        (x_min, y_min), box_w, box_h, linewidth=2, edgecolor="blue", facecolor="none"
    )
    ax.add_patch(rect)
    ax.text(x_min, y_min - 5, f"Pred {int(matched_classes[i])}", color="blue", fontsize=10, weight="bold")

plt.title("GT vs IoU-Matched Predictions")
plt.show()