# 代码

In [None]:
# ====================核心模块=====================
from pathlib import Path
from typing import List, Dict, Any, Tuple
import logging
import cv2
import numpy as np
import orjson
import torch
import yaml
from ultralytics import YOLO
from uuid import uuid4
from datetime import datetime

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

# ============================================================
# 配置加载
# ============================================================

def load_config(config_path: Path) -> dict:
    """加载YAML配置文件"""
    with config_path.open("r", encoding="utf-8") as file:
        return yaml.safe_load(file)

# ============================================================
# 工具函数
# ============================================================

def find_raw_data_directories(root_directory: Path, search_depth: int) -> List[Path]:
    """查找包含raw_data的目录"""
    directories = []
    to_visit = [(root_directory, 0)]
    
    while to_visit:
        current_dir, current_depth = to_visit.pop()
        raw_data_dir = current_dir / "raw_data"
        
        if raw_data_dir.exists() and raw_data_dir.is_dir():
            directories.append(current_dir)
        
        if current_depth < search_depth:
            for child in current_dir.iterdir():
                if child.is_dir():
                    to_visit.append((child, current_depth + 1))
    
    return directories

def collect_images(raw_data_directory: Path) -> List[Path]:
    """收集目录中的所有图片"""
    extensions = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp")
    image_paths = []
    for ext in extensions:
        image_paths.extend(raw_data_directory.glob(f"*{ext}"))
    return sorted(image_paths)

def generate_tile_windows(image_height: int, image_width: int, tile_height: int, tile_width: int, overlap_ratio: float) -> List[Tuple[int, int, int, int]]:
    """生成切片窗口"""
    overlap_ratio = max(0.0, min(0.99, overlap_ratio))
    step_height = max(1, int(tile_height * (1.0 - overlap_ratio)))
    step_width = max(1, int(tile_width * (1.0 - overlap_ratio)))
    
    y_starts = list(range(0, max(1, image_height - tile_height + 1), step_height))
    x_starts = list(range(0, max(1, image_width - tile_width + 1), step_width))
    
    # 确保覆盖边缘
    last_y = max(0, image_height - tile_height)
    last_x = max(0, image_width - tile_width)
    if y_starts[-1] != last_y: y_starts.append(last_y)
    if x_starts[-1] != last_x: x_starts.append(last_x)
    
    windows = []
    for y in y_starts:
        y_end = min(y + tile_height, image_height)
        for x in x_starts:
            x_end = min(x + tile_width, image_width)
            windows.append((x, y, x_end, y_end))
    
    return windows

def run_yolo_on_tiles(model, tiles, offsets, yolo_config, image_stem):
    """在切片上运行YOLO推理"""
    if not tiles:
        return []
    
    results = model.predict(
        tiles,
        imgsz=yolo_config["image_size"],
        conf=yolo_config["confidence_threshold"],
        iou=yolo_config["iou_threshold"],
        device=str(yolo_config["device"]),
        batch=yolo_config["batch_size"],
        workers=yolo_config["workers"],
        verbose=yolo_config["verbose"],
        save=yolo_config["save_results"],
        retina_masks=yolo_config["retina_masks"],
    )
    
    detections = []
    for i, result in enumerate(results):
        offset_x, offset_y = offsets[i]
        
        try:
            records = orjson.loads(result.to_json())
        except Exception:
            continue
        
        for record in records:
            segments = record.get("segments", {})
            x_vals = segments.get("x", [])
            y_vals = segments.get("y", [])
            
            if not x_vals or not y_vals:
                continue
            
            label = record.get("name", "")
            score = record.get("confidence", 0.0)
            
            # 转换到原图坐标系
            polygon = []
            for local_x, local_y in zip(x_vals, y_vals):
                polygon.append([local_x + offset_x, local_y + offset_y])
            
            if len(polygon) >= 3:  # 至少3个点
                detections.append({
                    "image_stem": image_stem,
                    "label_name": label,
                    "score": score,
                    "polygon": polygon,
                    "object_id": str(uuid4())
                })
    
    torch.cuda.empty_cache()
    return detections

def process_single_image(image_path, model, processing_config, yolo_config):
    """处理单张图片"""
    # 读取图片
    buffer = np.fromfile(str(image_path), dtype=np.uint8)
    if buffer.size == 0:
        logger.warning(f"Empty file: {image_path}")
        return []
    
    image = cv2.imdecode(buffer, cv2.IMREAD_COLOR)
    if image is None:
        logger.warning(f"Failed to decode: {image_path}")
        return []
    
    height, width = image.shape[:2]
    
    # 生成切片
    windows = generate_tile_windows(
        height, width,
        processing_config["tile_height"],
        processing_config["tile_width"],
        processing_config["overlap_ratio"]
    )
    
    # 分批处理切片
    all_detections = []
    tiles, offsets = [], []
    batch_size = 64  # 每批最大切片数
    
    for x_start, y_start, x_end, y_end in windows:
        tile = image[y_start:y_end, x_start:x_end]
        tiles.append(tile)
        offsets.append((x_start, y_start))
        
        if len(tiles) >= batch_size:
            batch_detections = run_yolo_on_tiles(model, tiles, offsets, yolo_config, image_path.stem)
            all_detections.extend(batch_detections)
            tiles, offsets = [], []
    
    # 处理剩余切片
    if tiles:
        batch_detections = run_yolo_on_tiles(model, tiles, offsets, yolo_config, image_path.stem)
        all_detections.extend(batch_detections)
    
    return all_detections

def process_directory(raw_data_dir, model, processing_config, yolo_config):
    """处理整个目录"""
    image_paths = collect_images(raw_data_dir)
    if not image_paths:
        logger.info(f"No images found in {raw_data_dir}")
        return {}
    
    detections_by_image = {}
    total = len(image_paths)
    
    for i, image_path in enumerate(image_paths, 1):
        logger.info(f"[{i}/{total}] Processing: {image_path}")
        detections = process_single_image(image_path, model, processing_config, yolo_config)
        if detections:
            detections_by_image[image_path.stem] = detections
    
    return detections_by_image

def create_coco_dataset(raw_data_dir, detections_by_image, categories):
    """创建COCO格式数据集"""
    image_paths = collect_images(raw_data_dir)
    
    # 构建类别映射
    category_map = {cat["name"]: cat for cat in categories}
    
    images, annotations = [], []
    image_id, annotation_id = 1, 1
    
    for image_path in image_paths:
        stem = image_path.stem
        if stem not in detections_by_image:
            continue
        
        # 读取图片尺寸
        buffer = np.fromfile(str(image_path), dtype=np.uint8)
        image = cv2.imdecode(buffer, cv2.IMREAD_COLOR)
        if image is None:
            continue
        
        height, width = image.shape[:2]
        
        # 解析文件名中的信息
        parts = stem.split("_")
        timestamp = f"{parts[0]}_{parts[1]}" if len(parts) >= 2 else ""
        focal_length = int(parts[-1]) if parts and parts[-1].isdigit() else 0
        
        # 添加图片信息
        images.append({
            "id": image_id,
            "file_name": image_path.name,
            "width": width,
            "height": height,
            "timestamp": timestamp,
            "focal_length_parameter": focal_length
        })
        
        # 添加标注信息
        for detection in detections_by_image[stem]:
            category = category_map.get(detection["label_name"])
            if not category:
                continue
            
            polygon = detection["polygon"]
            # 计算边界框
            x_coords = [p[0] for p in polygon]
            y_coords = [p[1] for p in polygon]
            bbox = [min(x_coords), min(y_coords), max(x_coords)-min(x_coords), max(y_coords)-min(y_coords)]
            area = bbox[2] * bbox[3]
            
            # 展平多边形坐标
            segmentation = [coord for point in polygon for coord in point]
            
            annotations.append({
                "id": annotation_id,
                "image_id": image_id,
                "category_id": category["id"],
                "segmentation": [segmentation],
                "area": area,
                "bbox": bbox,
                "iscrowd": 0,
                "score": detection["score"],
                "object_id": detection["object_id"]
            })
            annotation_id += 1
        
        image_id += 1
    
    return {
        "images": images,
        "annotations": annotations,
        "categories": categories
    }

def save_coco_dataset(coco_data, output_path):
    """保存COCO数据集"""
    output_path.parent.mkdir(parents=True, exist_ok=True)
    output_path.write_bytes(orjson.dumps(coco_data, option=orjson.OPT_INDENT_2))
    logger.info(f"COCO annotations saved to: {output_path}")

# ============================================================
# 主流程
# ============================================================

def run_pipeline(config_path: Path):
    """运行完整流程"""
    # 加载配置
    config = load_config(config_path)
    processing_config = config["processing"]
    yolo_config = config["yolo"]
    pipeline_config = config["pipeline"]
    categories = config["categories"]
    
    root_dir = Path(pipeline_config["input"]["root_directory"])
    search_depth = pipeline_config["input"]["search_depth"]
    
    print(f"Root directory: {root_dir}")
    print(f"Search depth: {search_depth}")
    
    # 查找所有包含raw_data的目录
    parent_dirs = find_raw_data_directories(root_dir, search_depth)
    if not parent_dirs:
        print("No directories with 'raw_data' found.")
        return
    
    print("Found directories with 'raw_data':")
    for dir_path in parent_dirs:
        print(f"  - {dir_path}")
    
    # 加载模型
    model = YOLO(yolo_config["model_path"])
    
    # 处理每个目录
    for parent_dir in parent_dirs:
        raw_data_dir = parent_dir / "raw_data"
        
        if not collect_images(raw_data_dir):
            logger.info(f"Skipping {raw_data_dir}, no images found")
            continue
        
        logger.info(f"=== Processing: {parent_dir} ===")
        
        # 运行推理
        detections = process_directory(raw_data_dir, model, processing_config, yolo_config)
        
        if not detections:
            logger.info(f"No detections in {raw_data_dir}, skipping")
            continue
        
        # 创建COCO数据集
        coco_data = create_coco_dataset(raw_data_dir, detections, categories)
        
        # 保存结果
        output_dir_name = pipeline_config["output"]["directory_name"]
        output_file_name = pipeline_config["output"]["file_name"]
        output_path = parent_dir / output_dir_name / output_file_name
        
        save_coco_dataset(coco_data, output_path)

# 使用示例
if __name__ == "__main__":
    config_path = Path("config.yaml")  # 你的配置文件路径
    run_pipeline(config_path)