In [None]:
import torch
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from RadarCameraYOLO import RadarCameraYoloDataset, RadarCameraYOLO

model = RadarCameraYOLO(num_classes=7)

In [None]:
model.load_state_dict(torch.load("./trained_model.pth"))
model.eval()

In [None]:
from WaterScenes.radar_map_generate import RESOLUTION

# ✅ 데이터셋 불러오기
data_root = "/workspaces/Radar-Camera-Fusion-Detection/WaterScenes/sample_dataset"  # data path
input_shape = (RESOLUTION, RESOLUTION)

dataset = RadarCameraYoloDataset(data_root=data_root, input_shape=input_shape)

In [None]:
from torchvision import ops
import matplotlib.pyplot as plt
import numpy as np
import torch
import cv2

# YOLO style color
def color_list():
    hex = (
        '7FFF00',  # yellow green,
        '808080',  # grey  
        'FF4C4C',  # red      
        '32CD32',  # green  
        '1E90FF',  # blue 
        '9370DB',  # purple  
        'FF69B4',  # pink  
    )


    return [tuple(int(h[i:i + 2], 16) for i in (0, 2, 4)) for h in hex]

COLORS = color_list()

# YOLO style bbox drawing
def plot_one_box(x, img, color=None, label=None, line_thickness=None):
    # copy
    img_copy = img.copy()

    # thickness adjusted by image size
    tl = max(1, int(min(img_copy.shape[:2]) / 400))
    
    # set color
    color = color or COLORS[int(label) % len(COLORS)]
    color = tuple(map(int, color))


    # RGB → BGR
    if img_copy.shape[-1] == 3:
        img_copy = cv2.cvtColor(img_copy, cv2.COLOR_RGB2BGR)

    # coord
    x1, y1, x2, y2 = map(int, x)
    # draw bbox
    cv2.rectangle(img_copy, (x1, y1), (x2, y2), color, thickness=tl, lineType=cv2.LINE_AA)

    if label:
        # font size
        tf = max(tl - 1, 1)  # font thickness
        font_scale = tl / 5
        t_size = cv2.getTextSize(label, 0, fontScale=font_scale, thickness=tf)[0]
        c2 = x1 + t_size[0], y1 - t_size[1] - 3
        cv2.rectangle(img_copy, (x1, y1), c2, color, -1, cv2.LINE_AA)  # Filled background

        # add text
        cv2.putText(img_copy, label, (x1, y1 - 2), 0, font_scale, (255, 255, 255), thickness=tf, lineType=cv2.LINE_AA)

    return cv2.cvtColor(img_copy, cv2.COLOR_BGR2RGB) 


# xywh → xyxy
def xywh2xyxy(x):
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
    y[:, 0] = x[:, 0] - x[:, 2] / 2  # top left x
    y[:, 1] = x[:, 1] - x[:, 3] / 2  # top left y
    y[:, 2] = x[:, 0] + x[:, 2] / 2  # bottom right x
    y[:, 3] = x[:, 1] + x[:, 3] / 2  # bottom right y
    return y


In [None]:
# Visualization
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
def visualize_predictions(model, dataset, num_samples=5, confidence_threshold=0.5, iou_threshold=0.4):
    model.eval()
    
    for i in range(num_samples):
        # ========= Tensor → numpy =========
        image, radar, labels = dataset[i]
        if isinstance(image, torch.Tensor):
            image_np = image.permute(1,2,0).cpu().numpy() # image_np: (3, 160, 160)
            image_np = (image_np * np.array(STD, dtype=np.float32)) + np.array(MEAN, dtype=np.float32)
            image_np = (image_np * 255).astype(np.uint8)  # float → uint8
        else:
            image_np = image.copy()

        if image_np.shape[-1] == 3:
            image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
        #print(f"image_np: {image_np.shape}")
 
        # ========= Prediction =========
        with torch.no_grad():
            class_output, bbox_output, _ = model(image.unsqueeze(0), radar.unsqueeze(0))

        class_prob = torch.softmax(class_output, dim=1)
        confidence, pred_classes = torch.max(class_prob, dim=1)

        # Check tensor
        if isinstance(bbox_output, torch.Tensor):
            pred_boxes = bbox_output.squeeze(0).cpu().numpy().reshape(-1, 4)
            conf_scores = confidence.squeeze(0).cpu().numpy().flatten()
            pred_classes = pred_classes.squeeze(0).cpu().numpy().flatten()
            pred_boxes = xywh2xyxy(torch.tensor(pred_boxes)).numpy()

        else:
            pred_boxes = bbox_output.reshape(-1, 4)
            conf_scores = confidence.flatten()
            pred_classes = pred_classes.flatten()


        # ✅ Debug Info
        print(f"pred_boxes[:5]: {pred_boxes[:5]}")
        print(f"conf_scores[:5]: {conf_scores[:5]}")
        print(f"pred_classes[:5]: {pred_classes[:5]}")

        # ========= NMS =========
        if len(pred_boxes) > 0:
            pred_boxes_tensor = torch.tensor(pred_boxes, dtype=torch.float32)
            conf_scores_tensor = torch.tensor(conf_scores, dtype=torch.float32)
            nms_indices = ops.nms(pred_boxes_tensor, conf_scores_tensor, iou_threshold)

            #pred_boxes = xywh2xyxy(pred_boxes_tensor[nms_indices]).numpy()
            pred_boxes = pred_boxes_tensor[nms_indices].numpy()
            print(f"pred_boxes: {pred_boxes}")
            conf_scores = conf_scores_tensor[nms_indices]
            pred_classes = pred_classes[nms_indices]

        # ========= Ground Truth =========
        gt_img = image_np.copy()
        pred_img = image_np.copy()
        
        if labels.size(0) > 0:
            gt_boxes = xywh2xyxy(labels[:, 1:5])
            for j, box in enumerate(gt_boxes):
                box = box.numpy()
                #print(f"box: {box}")
                class_id = int(labels[j][0])
                label = f"GT: {class_id}"
                gt_box = box * [image_np.shape[1], image_np.shape[0], image_np.shape[1], image_np.shape[0]]
                gt_img = plot_one_box(gt_box, gt_img, color=COLORS[int(class_id % len(COLORS))], label=label)

        if len(pred_boxes) > 0:
            for j in range(pred_boxes.shape[0]):
                if conf_scores[j] > confidence_threshold:
                    #print(f"pred_boxes[j]:{pred_boxes[j]}")
                    pred_box = pred_boxes[j] * np.array([image_np.shape[1], image_np.shape[0], image_np.shape[1], image_np.shape[0]])  # adjust size
                    class_id = int(pred_classes[j])
                    label = f"Pred: {class_id} ({conf_scores[j]:.2f})"
                    pred_img = plot_one_box(pred_box, pred_img, color=COLORS[class_id % len(COLORS)], label=label)

        # ========= Visualization =========
        fig, axs = plt.subplots(1, 2, figsize=(6, 3))

        # GT Plot
        gt_img = cv2.cvtColor(gt_img, cv2.COLOR_BGR2RGB) 
        axs[0].imshow(gt_img)
        axs[0].set_title("Ground Truth")
        axs[0].axis('off')

        # Prediction Plot
        pred_img = cv2.cvtColor(pred_img, cv2.COLOR_BGR2RGB) 
        axs[1].imshow(pred_img)
        axs[1].set_title("Prediction")
        axs[1].axis('off')

        plt.tight_layout()
        plt.show()

In [None]:
# ✅ 실행 예시
visualize_predictions(model, dataset, num_samples=5, confidence_threshold=0.7, iou_threshold=0.8)