In [None]:

# %%
# Import necessary libraries
import os
import glob
import numpy as np
import pandas as pd
import cv2
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, confusion_matrix
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout, Input, Conv2D, Reshape
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from datetime import datetime
import json
import xml.etree.ElementTree as ET

# --- Configuration & Constants ---
MODEL_NAME = 'EfficientNetB0_ObjectDetection'

BASE_DATASET_PATH = './k_CBIS-DDSM/'
CALC_METADATA_CSV_PATH = os.path.join(BASE_DATASET_PATH, 'calc_case(with_jpg_img).csv')
MASS_METADATA_CSV_PATH = os.path.join(BASE_DATASET_PATH, 'mass_case(with_jpg_img).csv')

IMAGE_ROOT_DIR = BASE_DATASET_PATH
ACTUAL_IMAGE_FILES_BASE_DIR = os.path.join(IMAGE_ROOT_DIR, 'jpg_img')

# Annotation directory (you may need to create this or adapt to your annotation format)
ANNOTATIONS_DIR = os.path.join(BASE_DATASET_PATH, 'annotations')

# Column names
CONCEPTUAL_ROI_COLUMN_NAME = 'jpg_ROI_img_path'
PATHOLOGY_COLUMN_NAME = 'pathology'
CASE_TYPE_COLUMN_NAME = 'case_type'

# Model & Training Parameters
IMG_WIDTH, IMG_HEIGHT = 416, 416  # Common size for object detection
BATCH_SIZE = 16  # Smaller batch size for object detection
EPOCHS = 300
FINE_TUNE_EPOCHS = 100
LEARNING_RATE = 1e-4
RANDOM_STATE = 42

# Object detection specific parameters
MAX_OBJECTS = 5  # Maximum number of objects to detect per image
NUM_CLASSES = 2  # BENIGN, MALIGNANT (excluding background)
CONFIDENCE_THRESHOLD = 0.5
NMS_THRESHOLD = 0.4

PATIENCE_EARLY_STOPPING = 20
PATIENCE_REDUCE_LR = 8

PATIENCE_EARLY_STOPPING_FT = 15
PATIENCE_REDUCE_LR_FT = 5

OUTPUT_DIR = os.path.join('./', f"run_{MODEL_NAME}_{IMG_WIDTH}_{BATCH_SIZE}_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"All output will be saved to: {os.path.abspath(OUTPUT_DIR)}")

# Create annotations directory if it doesn't exist
os.makedirs(ANNOTATIONS_DIR, exist_ok=True)

# --- End of Configuration & Constants ---

# %%
# --- Data Loading and Path Finding ---
print("--- Initial Path Configuration Debug ---")
print(f"Current working directory (CWD): {os.getcwd()}")
print(f"BASE_DATASET_PATH: {BASE_DATASET_PATH}")
print(f"ANNOTATIONS_DIR: {ANNOTATIONS_DIR}")

# Load and combine datasets
loaded_dfs = []

# Load Calc cases
if os.path.exists(CALC_METADATA_CSV_PATH):
    try:
        calc_df = pd.read_csv(CALC_METADATA_CSV_PATH)
        calc_df[CASE_TYPE_COLUMN_NAME] = 'calc'
        loaded_dfs.append(calc_df)
        print(f"Successfully loaded {len(calc_df)} calc cases")
    except Exception as e:
        print(f"Error loading calc CSV: {e}")

# Load Mass cases
if os.path.exists(MASS_METADATA_CSV_PATH):
    try:
        mass_df = pd.read_csv(MASS_METADATA_CSV_PATH)
        mass_df[CASE_TYPE_COLUMN_NAME] = 'mass'
        loaded_dfs.append(mass_df)
        print(f"Successfully loaded {len(mass_df)} mass cases")
    except Exception as e:
        print(f"Error loading mass CSV: {e}")

if not loaded_dfs:
    raise FileNotFoundError("No CSV files could be loaded")

source_df = pd.concat(loaded_dfs, ignore_index=True)
print(f"Combined DataFrame: {len(source_df)} total rows")

# Clean dataframe
source_df.dropna(subset=[CONCEPTUAL_ROI_COLUMN_NAME, PATHOLOGY_COLUMN_NAME], inplace=True)
source_df = source_df[source_df[PATHOLOGY_COLUMN_NAME].isin(['MALIGNANT', 'BENIGN'])]
print(f"After cleaning: {len(source_df)} rows")

# %%
# --- Object Detection Data Preparation Functions ---

def create_synthetic_bounding_boxes(roi_image_path, full_image_path):
    """
    Create synthetic bounding boxes from ROI masks.
    In a real scenario, you would have actual bounding box annotations.
    """
    try:
        # Load ROI mask to extract bounding box coordinates
        roi_img = cv2.imread(roi_image_path, cv2.IMREAD_GRAYSCALE)
        if roi_img is None:
            return None
        
        # Find contours in the ROI mask
        contours, _ = cv2.findContours(roi_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        
        if not contours:
            return None
        
        # Get the largest contour (assuming it's the main abnormality)
        largest_contour = max(contours, key=cv2.contourArea)
        x, y, w, h = cv2.boundingRect(largest_contour)
        
        # Normalize coordinates to [0, 1] range
        img_h, img_w = roi_img.shape
        x_norm = x / img_w
        y_norm = y / img_h
        w_norm = w / img_w
        h_norm = h / img_h
        
        return {
            'x_min': x_norm,
            'y_min': y_norm,
            'x_max': x_norm + w_norm,
            'y_max': y_norm + h_norm,
            'width': w_norm,
            'height': h_norm
        }
    except Exception as e:
        print(f"Error creating bounding box for {roi_image_path}: {e}")
        return None

def parse_annotation_file(annotation_path):
    """
    Parse annotation file (supports PASCAL VOC XML format).
    Adapt this function based on your annotation format.
    """
    try:
        if annotation_path.endswith('.xml'):
            # Parse PASCAL VOC XML format
            tree = ET.parse(annotation_path)
            root = tree.getroot()
            
            annotations = []
            for obj in root.findall('object'):
                name = obj.find('name').text
                bbox = obj.find('bndbox')
                x_min = float(bbox.find('xmin').text)
                y_min = float(bbox.find('ymin').text)
                x_max = float(bbox.find('xmax').text)
                y_max = float(bbox.find('ymax').text)
                
                # Get image dimensions for normalization
                size = root.find('size')
                img_width = float(size.find('width').text)
                img_height = float(size.find('height').text)
                
                annotations.append({
                    'class': name,
                    'x_min': x_min / img_width,
                    'y_min': y_min / img_height,
                    'x_max': x_max / img_width,
                    'y_max': y_max / img_height,
                    'width': (x_max - x_min) / img_width,
                    'height': (y_max - y_min) / img_height
                })
            return annotations
            
        elif annotation_path.endswith('.json'):
            # Parse COCO JSON format
            with open(annotation_path, 'r') as f:
                data = json.load(f)
            # Add COCO parsing logic here if needed
            return []
            
    except Exception as e:
        print(f"Error parsing annotation file {annotation_path}: {e}")
        return []

def heuristic_find_image_path(row, actual_images_root_dir_abs):
    """Enhanced version of the original function"""
    try:
        patient_id = row['patient_id']
        breast_side = row['left or right breast']
        image_view = row['image view']
        abnormality_id = str(row['abnormality id'])

        csv_conceptual_roi_path = str(row.get(CONCEPTUAL_ROI_COLUMN_NAME, "")).strip()

        case_type_folder_prefix = ""
        if csv_conceptual_roi_path.startswith("jpg_img/"):
            path_part = csv_conceptual_roi_path.split('/')[1]
            if path_part.startswith("Calc_Training_"): case_type_folder_prefix = "Calc_Training"
            elif path_part.startswith("Calc_Test_"): case_type_folder_prefix = "Calc_Test"
            elif path_part.startswith("Mass_Training_"): case_type_folder_prefix = "Mass_Training"
            elif path_part.startswith("Mass_Test_"): case_type_folder_prefix = "Mass_Test"

        if not case_type_folder_prefix:
            return None, None

        # Search for series directory
        dir_search_prefix = f"{case_type_folder_prefix}_{patient_id}_{breast_side}_{image_view}_{abnormality_id}"
        full_dir_search_pattern = os.path.join(actual_images_root_dir_abs, f"{dir_search_prefix}-*")
        potential_series_dirs = glob.glob(full_dir_search_pattern)

        if not potential_series_dirs:
            return None, None

        # Look for both ROI and full images
        roi_filename_patterns = [
            "ROI-mask-images-img_0-*.jpg", 
            "ROI-mask-images-img_1-*.jpg", 
            "ROI-mask-images-img_*-*.jpg"
        ]
        
        full_image_patterns = [
            "full-mammogram-images-img_0-*.jpg",
            "full-mammogram-images-img_1-*.jpg", 
            "full-mammogram-images-img_*-*.jpg"
        ]

        for series_dir_on_disk in sorted(potential_series_dirs):
            if os.path.isdir(series_dir_on_disk):
                roi_path = None
                full_path = None
                
                # Find ROI image
                for pattern in roi_filename_patterns:
                    roi_files = glob.glob(os.path.join(series_dir_on_disk, pattern))
                    if roi_files:
                        roi_path = sorted(roi_files)[0]
                        break
                
                # Find full mammogram image
                for pattern in full_image_patterns:
                    full_files = glob.glob(os.path.join(series_dir_on_disk, pattern))
                    if full_files:
                        full_path = sorted(full_files)[0]
                        break
                
                # If we have both ROI and full image, return them
                if roi_path and full_path:
                    return roi_path, full_path
                # If we only have ROI, use it as both
                elif roi_path:
                    return roi_path, roi_path
                    
        return None, None
    except Exception as e:
        return None, None

# %%
# --- Enhanced Data Processing ---
print("Searching for image paths...")
source_df[['roi_path', 'full_image_path']] = source_df.apply(
    lambda r: pd.Series(heuristic_find_image_path(r, os.path.abspath(ACTUAL_IMAGE_FILES_BASE_DIR))), 
    axis=1
)

# Filter rows with valid image paths
metadata_df = source_df.dropna(subset=['roi_path', 'full_image_path']).copy()
print(f"Found {len(metadata_df)} valid image pairs")

if len(metadata_df) == 0:
    raise ValueError("No valid image pairs found")

# Create bounding box annotations
print("Creating bounding box annotations...")
bounding_boxes = []
valid_indices = []

for idx, row in metadata_df.iterrows():
    bbox = create_synthetic_bounding_boxes(row['roi_path'], row['full_image_path'])
    if bbox is not None:
        bbox['class'] = row[PATHOLOGY_COLUMN_NAME]
        bounding_boxes.append(bbox)
        valid_indices.append(idx)

metadata_df = metadata_df.loc[valid_indices].copy()
metadata_df['bounding_boxes'] = bounding_boxes

print(f"Created bounding boxes for {len(metadata_df)} images")

# Encode labels
label_encoder = LabelEncoder()
metadata_df['class_id'] = label_encoder.fit_transform(metadata_df[PATHOLOGY_COLUMN_NAME])
class_names = list(label_encoder.classes_)
print(f"Class names: {class_names}")

# %%
# --- Object Detection Data Pipeline ---

def load_and_preprocess_image(image_path, target_size):
    """Load and preprocess image for object detection"""
    try:
        img = cv2.imread(image_path)
        if img is None:
            return np.zeros((target_size[0], target_size[1], 3), dtype=np.float32)
        
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, target_size)
        img = img.astype(np.float32) / 255.0
        return img
    except Exception as e:
        return np.zeros((target_size[0], target_size[1], 3), dtype=np.float32)

def create_target_tensor(bboxes_list, num_classes, max_objects):
    """
    Create target tensor for object detection
    Format: [x_center, y_center, width, height, confidence, class_probabilities...]
    """
    target_size = max_objects * (5 + num_classes)  # 5 = x,y,w,h,conf + class_probs
    targets = np.zeros(target_size, dtype=np.float32)
    
    for i, bbox in enumerate(bboxes_list[:max_objects]):
        base_idx = i * (5 + num_classes)
        
        # Convert to center coordinates
        x_center = (bbox['x_min'] + bbox['x_max']) / 2
        y_center = (bbox['y_min'] + bbox['y_max']) / 2
        width = bbox['width']
        height = bbox['height']
        
        targets[base_idx] = x_center
        targets[base_idx + 1] = y_center
        targets[base_idx + 2] = width
        targets[base_idx + 3] = height
        targets[base_idx + 4] = 1.0  # confidence
        
        # One-hot encode class
        if bbox['class'] == 'MALIGNANT':
            targets[base_idx + 5] = 1.0  # malignant
        else:
            targets[base_idx + 6] = 1.0  # benign
    
    return targets

def data_generator(df, batch_size, augment=False):
    """Data generator for object detection"""
    indices = np.arange(len(df))
    
    while True:
        if augment:
            np.random.shuffle(indices)
        
        for start_idx in range(0, len(indices), batch_size):
            batch_indices = indices[start_idx:start_idx + batch_size]
            
            images = []
            targets = []
            
            for idx in batch_indices:
                row = df.iloc[idx]
                
                # Load image
                img = load_and_preprocess_image(row['full_image_path'], (IMG_WIDTH, IMG_HEIGHT))
                
                # Create target tensor
                bbox_data = [{
                    'x_min': row['bounding_boxes']['x_min'],
                    'y_min': row['bounding_boxes']['y_min'],
                    'x_max': row['bounding_boxes']['x_max'],
                    'y_max': row['bounding_boxes']['y_max'],
                    'width': row['bounding_boxes']['width'],
                    'height': row['bounding_boxes']['height'],
                    'class': row[PATHOLOGY_COLUMN_NAME]
                }]
                
                target = create_target_tensor(bbox_data, NUM_CLASSES, MAX_OBJECTS)
                
                images.append(img)
                targets.append(target)
            
            yield np.array(images), np.array(targets)

# %%
# --- Train/Validation/Test Split ---
X = metadata_df[['full_image_path', 'bounding_boxes', PATHOLOGY_COLUMN_NAME]]
y = metadata_df['class_id']

X_train_val, X_test, y_train_val, y_test = train_test_split(
    X, y, test_size=0.15, random_state=RANDOM_STATE, stratify=y
)
X_train, X_val, y_train, y_val = train_test_split(
    X_train_val, y_train_val, test_size=0.15, random_state=RANDOM_STATE, stratify=y_train_val
)

print(f"Training samples: {len(X_train)}")
print(f"Validation samples: {len(X_val)}")
print(f"Test samples: {len(X_test)}")

# %%
# --- Object Detection Model Architecture ---

def create_object_detection_model(input_shape, num_classes, max_objects):
    """
    Create object detection model based on EfficientNet backbone
    """
    # Input layer
    inputs = Input(shape=input_shape)
    
    # Backbone (EfficientNet)
    backbone = EfficientNetB0(include_top=False, weights='imagenet', input_shape=input_shape)
    backbone.trainable = False
    
    # Feature extraction
    features = backbone(inputs)
    
    # Detection head
    x = GlobalAveragePooling2D()(features)
    x = Dense(512, activation='relu')(x)
    x = Dropout(0.3)(x)
    x = Dense(256, activation='relu')(x)
    x = Dropout(0.3)(x)
    
    # Output layer: [x_center, y_center, width, height, confidence, class_probs] * max_objects
    output_size = max_objects * (5 + num_classes)
    outputs = Dense(output_size, activation='sigmoid')(x)
    
    model = Model(inputs, outputs)
    return model, backbone

def custom_object_detection_loss(y_true, y_pred, num_classes=NUM_CLASSES, max_objects=MAX_OBJECTS):
    """
    Custom loss function for object detection
    Combines coordinate regression loss, confidence loss, and classification loss
    """
    obj_size = 5 + num_classes
    
    total_loss = 0.0
    
    for obj_idx in range(max_objects):
        base_idx = obj_idx * obj_size
        
        # Extract predictions and ground truth for this object
        pred_coords = y_pred[:, base_idx:base_idx+4]  # x, y, w, h
        true_coords = y_true[:, base_idx:base_idx+4]
        
        pred_conf = y_pred[:, base_idx+4:base_idx+5]  # confidence
        true_conf = y_true[:, base_idx+4:base_idx+5]
        
        pred_class = y_pred[:, base_idx+5:base_idx+5+num_classes]  # class probabilities
        true_class = y_true[:, base_idx+5:base_idx+5+num_classes]
        
        # Coordinate loss (only for objects that exist)
        coord_loss = tf.reduce_sum(true_conf * tf.square(pred_coords - true_coords))
        
        # Confidence loss
        conf_loss = tf.reduce_sum(tf.square(pred_conf - true_conf))
        
        # Classification loss (only for objects that exist)
        class_loss = tf.reduce_sum(true_conf * tf.reduce_sum(tf.square(pred_class - true_class), axis=1, keepdims=True))
        
        total_loss += coord_loss + conf_loss + class_loss
    
    return total_loss

# %%
# --- Model Creation and Compilation ---
print("\nCreating object detection model...")
model, backbone = create_object_detection_model((IMG_HEIGHT, IMG_WIDTH, 3), NUM_CLASSES, MAX_OBJECTS)

# Compile model
optimizer = Adam(learning_rate=LEARNING_RATE)
model.compile(
    optimizer=optimizer,
    loss=custom_object_detection_loss,
    metrics=['mse']  # Mean squared error as additional metric
)

model.summary()

# %%
# --- Training ---
print("\nStarting training...")

# Create data generators
train_gen = data_generator(X_train, BATCH_SIZE, augment=True)
val_gen = data_generator(X_val, BATCH_SIZE, augment=False)

# Calculate steps per epoch
train_steps = len(X_train) // BATCH_SIZE
val_steps = len(X_val) // BATCH_SIZE

# Callbacks
checkpoint_filepath = os.path.join(OUTPUT_DIR, 'best_object_detection_model.keras')
callbacks = [
    ModelCheckpoint(
        filepath=checkpoint_filepath, 
        save_weights_only=False, 
        monitor='val_loss', 
        mode='min', 
        save_best_only=True
    ),
    EarlyStopping(
        monitor='val_loss', 
        patience=PATIENCE_EARLY_STOPPING, 
        mode='min', 
        restore_best_weights=True
    ),
    ReduceLROnPlateau(
        monitor='val_loss', 
        factor=0.2, 
        patience=PATIENCE_REDUCE_LR, 
        min_lr=1e-7, 
        mode='min'
    )
]

# Train the model
history = model.fit(
    train_gen,
    steps_per_epoch=train_steps,
    epochs=EPOCHS,
    validation_data=val_gen,
    validation_steps=val_steps,
    callbacks=callbacks,
    verbose=1
)

# %%
# --- Fine-tuning ---
print("\nStarting fine-tuning...")

# Unfreeze backbone
backbone.trainable = True

# Freeze BatchNormalization layers
for layer in backbone.layers:
    if isinstance(layer, tf.keras.layers.BatchNormalization):
        layer.trainable = False

# Recompile with lower learning rate
model.compile(
    optimizer=Adam(learning_rate=LEARNING_RATE/10),
    loss=custom_object_detection_loss,
    metrics=['mse']
)

# Fine-tuning callbacks
finetune_checkpoint = os.path.join(OUTPUT_DIR, 'best_finetuned_model.keras')
finetune_callbacks = [
    ModelCheckpoint(
        filepath=finetune_checkpoint,
        save_weights_only=False,
        monitor='val_loss',
        mode='min',
        save_best_only=True
    ),
    EarlyStopping(
        monitor='val_loss',
        patience=PATIENCE_EARLY_STOPPING_FT,
        mode='min',
        restore_best_weights=True
    ),
    ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.2,
        patience=PATIENCE_REDUCE_LR_FT,
        min_lr=1e-8,
        mode='min'
    )
]

# Continue training
initial_epoch = len(history.history['loss'])
finetune_history = model.fit(
    train_gen,
    steps_per_epoch=train_steps,
    epochs=initial_epoch + FINE_TUNE_EPOCHS,
    initial_epoch=initial_epoch,
    validation_data=val_gen,
    validation_steps=val_steps,
    callbacks=finetune_callbacks,
    verbose=1
)

# %%
# --- Evaluation and Visualization ---

def decode_predictions(predictions, confidence_threshold=0.5, num_classes=NUM_CLASSES, max_objects=MAX_OBJECTS):
    """Decode model predictions into bounding boxes and classes"""
    obj_size = 5 + num_classes
    detections = []
    
    for pred in predictions:
        boxes = []
        for obj_idx in range(max_objects):
            base_idx = obj_idx * obj_size
            
            # Extract predictions
            x_center = pred[base_idx]
            y_center = pred[base_idx + 1]
            width = pred[base_idx + 2]
            height = pred[base_idx + 3]
            confidence = pred[base_idx + 4]
            class_probs = pred[base_idx + 5:base_idx + 5 + num_classes]
            
            if confidence > confidence_threshold:
                # Convert to corner coordinates
                x_min = x_center - width / 2
                y_min = y_center - height / 2
                x_max = x_center + width / 2
                y_max = y_center + height / 2
                
                class_id = np.argmax(class_probs)
                class_confidence = class_probs[class_id]
                
                boxes.append({
                    'x_min': x_min,
                    'y_min': y_min,
                    'x_max': x_max,
                    'y_max': y_max,
                    'confidence': confidence,
                    'class_id': class_id,
                    'class_confidence': class_confidence,
                    'class_name': class_names[class_id]
                })
        detections.append(boxes)
    
    return detections

def visualize_detections(images, detections, ground_truth=None, save_path=None):
    """Visualize object detection results"""
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.flatten()
    
    for i in range(min(6, len(images))):
        ax = axes[i]
        
        # Display image
        ax.imshow(images[i])
        ax.set_title(f'Image {i+1}')
        ax.axis('off')
        
        # Draw predicted bounding boxes
        for detection in detections[i]:
            x_min = detection['x_min'] * IMG_WIDTH
            y_min = detection['y_min'] * IMG_HEIGHT
            width = (detection['x_max'] - detection['x_min']) * IMG_WIDTH
            height = (detection['y_max'] - detection['y_min']) * IMG_HEIGHT
            
            rect = patches.Rectangle(
                (x_min, y_min), width, height,
                linewidth=2, edgecolor='red', facecolor='none'
            )
            ax.add_patch(rect)
            
            # Add label
            label = f"{detection['class_name']}: {detection['confidence']:.2f}"
            ax.text(x_min, y_min-5, label, color='red', fontsize=8, 
                   bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.7))
        
        # Draw ground truth if available
        if ground_truth and i < len(ground_truth):
            gt_boxes = ground_truth[i]
            for gt_box in gt_boxes:
                x_min = gt_box['x_min'] * IMG_WIDTH
                y_min = gt_box['y_min'] * IMG_HEIGHT
                width = (gt_box['x_max'] - gt_box['x_min']) * IMG_WIDTH
                height = (gt_box['y_max'] - gt_box['y_min']) * IMG_HEIGHT
                
                rect = patches.Rectangle(
                    (x_min, y_min), width, height,
                    linewidth=2, edgecolor='green', facecolor='none', linestyle='--'
                )
                ax.add_patch(rect)
    
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path)
        print(f"Detection visualization saved to {save_path}")
    plt.show()

# %%
# --- Test Set Evaluation ---
print("\nEvaluating on test set...")

# Create test generator
test_gen = data_generator(X_test, BATCH_SIZE, augment=False)
test_steps = len(X_test) // BATCH_SIZE

# Get test predictions
test_images = []
test_predictions = []
test_ground_truth = []

for i, (batch_images, batch_targets) in enumerate(test_gen):
    if i >= test_steps:
        break
    
    batch_preds = model.predict(batch_images, verbose=0)
    
    test_images.extend(batch_images)
    test_predictions.extend(batch_preds)
    
    # Convert targets back to readable format for visualization
    # This is simplified - you might want to implement proper target decoding
    for target in batch_targets:
        # Extract first object's ground truth (simplified)
        gt_box = {
            'x_min': target[0] - target[2]/2,
            'y_min': target[1] - target[3]/2,
            'x_max': target[0] + target[2]/2,
            'y_max': target[1] + target[3]/2,
            'class_name': class_names[1] if target[6] > target[5] else class_names[0]
        }
        test_ground_truth.append([gt_box] if target[4] > 0 else [])

# Decode predictions
decoded_predictions = decode_predictions(test_predictions, CONFIDENCE_THRESHOLD)

# Visualize results
viz_save_path = os.path.join(OUTPUT_DIR, "detection_results.png")
visualize_detections(
    test_images[:6], 
    decoded_predictions[:6], 
    test_ground_truth[:6],
    save_path=viz_save_path
)