In [None]:
import json
from pathlib import Path
import random
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

# --- Configuration ---
# Ensure these paths are consistent with your project structure
ANNOTATIONS_DIR = Path('../data/processed/splits')
IMAGES_DIR = Path('../data/raw/images')

# Annotation file paths
TRAIN_ANNOTATIONS_FILE = ANNOTATIONS_DIR / 'train_annotations.json'
VAL_ANNOTATIONS_FILE = ANNOTATIONS_DIR / 'val_annotations.json'

# Mapping from category ID to name
# Adjust according to your specific categories. Here, 1, 2, 3 are example IDs.
CATEGORY_MAPPING = {
    1: 'door',
    2: 'window',
    3: 'room',
}

# Colors for drawing bounding boxes
COLORS = {
    'door': (255, 0, 0),    # Red
    'window': (0, 255, 0),  # Green
    'room': (0, 0, 255),    # Blue
}

# --- Data Loading and Visualization Functions ---

def load_annotations(json_path: Path):
    """Loads COCO-format annotations from a JSON file."""
    if not json_path.exists():
        raise FileNotFoundError(f"Annotations file not found at: {json_path}")
    with open(json_path, 'r') as f:
        data = json.load(f)
    return data

def draw_bbox_on_image(image: np.ndarray, bboxes: list, categories: list):
    """Draws bounding boxes and category labels on an image."""
    vis_image = image.copy()
    font = cv2.FONT_HERSHEY_SIMPLEX
    
    for bbox, category_id in zip(bboxes, categories):
        # COCO format: [x, y, width, height]
        x, y, w, h = [int(val) for val in bbox]
        
        # Get color and category name
        category_name = CATEGORY_MAPPING.get(category_id, 'unknown')
        color = COLORS.get(category_name, (255, 255, 255))
        
        # Draw bounding box
        cv2.rectangle(vis_image, (x, y), (x + w, y + h), color, 2)
        
        # Draw category label
        text = f'{category_name}'
        cv2.putText(vis_image, text, (x, y - 5), font, 0.5, color, 1, cv2.LINE_AA)
        
    return vis_image

def visualize_random_samples(data: dict, num_samples: int = 5):
    """Visualizes a random selection of samples."""
    images = data['images']
    annotations = data['annotations']
    
    image_ids = [img['id'] for img in images]
    random_image_ids = random.sample(image_ids, num_samples)
    
    for img_id in random_image_ids:
        # Get image information
        img_info = next((img for img in images if img['id'] == img_id), None)
        if img_info is None:
            continue
        
        image_path = IMAGES_DIR / img_info['file_name']
        if not image_path.exists():
            print(f"Warning: Image file not found for ID {img_id}: {image_path}")
            continue
            
        # Read the image
        img = Image.open(image_path).convert('RGB')
        img_np = np.array(img)
        
        # Get all annotations for this image
        img_annotations = [ann for ann in annotations if ann['image_id'] == img_id]
        
        bboxes = [ann['bbox'] for ann in img_annotations]
        categories = [ann['category_id'] for ann in img_annotations]
        
        # Draw annotations on the image
        vis_img = draw_bbox_on_image(img_np, bboxes, categories)
        
        # Display the image
        plt.figure(figsize=(10, 10))
        plt.imshow(vis_img)
        plt.title(f"Image ID: {img_id}, Filename: {img_info['file_name']}")
        plt.axis('off')
        plt.show()

# --- Script Execution ---

if __name__ == '__main__':
    print("--- Visualizing Training Set Samples ---")
    try:
        train_data = load_annotations(TRAIN_ANNOTATIONS_FILE)
        visualize_random_samples(train_data, num_samples=3)
    except FileNotFoundError as e:
        print(e)
    
    print("\n--- Visualizing Validation Set Samples ---")
    try:
        val_data = load_annotations(VAL_ANNOTATIONS_FILE)
        visualize_random_samples(val_data, num_samples=3)
    except FileNotFoundError as e:
        print(e)

--- 可视化训练集样本 ---
Annotations file not found at: data\processed\splits\train_annotations.json

--- 可视化验证集样本 ---
Annotations file not found at: data\processed\splits\val_annotations.json
