# 分割原始数据为 --> train-val-test

In [10]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import shutil
import random
from pathlib import Path
from typing import List, Tuple

# ================= 配置 =================
# 源数据：假设图片与同名 .txt 标签位于同一目录或其子目录中
SOURCE_DIR = Path("cropped_objects")

# 图片/标签扩展名
IMG_EXTS = {".jpg", ".jpeg", ".png"}
LABEL_EXT = ".txt"

# 其他选项
SEED = 42
USE_SYMLINK = False        # True=软链接，False=复制
FORCE_CLEAN = False        # True=若目标已存在则清空
ALLOW_NEGATIVE = False     # True=允许没有任何目标的"空txt"样本；False=跳过无label样本

# 可选：自动写出一个最小 YAML（方便后续训练）
WRITE_YAML = True
YAML_NAME = "swd_detection.yaml"
NAMES = ["insect"]  # 类别名称列表

# ================ 工具函数 ================
def _list_image_files(root: Path) -> List[Path]:
    return [p for p in root.rglob("*") if p.is_file() and p.suffix.lower() in IMG_EXTS]

def _has_label(img_path: Path) -> bool:
    return img_path.with_suffix(LABEL_EXT).exists()

def _copy_or_link(src: Path, dst: Path, symlink: bool = False):
    dst.parent.mkdir(parents=True, exist_ok=True)
    if symlink:
        if dst.exists() or dst.is_symlink():
            dst.unlink()
        os.symlink(src.resolve(), dst)
    else:
        shutil.copy2(src, dst)

def _ensure_empty_dir(d: Path):
    if d.exists():
        if FORCE_CLEAN:
            shutil.rmtree(d)
            d.mkdir(parents=True, exist_ok=True)
        else:
            # 若不强制清空，要求目录为空；避免误覆盖
            if any(d.rglob("*")):
                raise SystemExit(f"[ABORT] 目标目录已存在且非空：{d}\n"
                                 f" - 修改 DEST_DIR 或设置 FORCE_CLEAN=True 再运行。")
    else:
        d.mkdir(parents=True, exist_ok=True)

def _write_yaml(dest_root: Path):
    yaml_path = dest_root / YAML_NAME
    content = [
        f"path: {dest_root.resolve()}",
        "train: images/train",
        "val: images/val",
        "test: images/test",
        "",
        "names:"
    ]
    for i, n in enumerate(NAMES):
        content.append(f"  {i}: {n}")
    yaml_path.write_text("\n".join(content), encoding="utf-8")
    print(f"[OK] 写出 YAML: {yaml_path}")

# ================ 主流程 ================
def main(SPLITS: dict, DEST_DIR: Path):
    assert 0 < SPLITS["train"] < 1 and 0 < SPLITS["val"] < 1 and 0 < SPLITS["test"] < 1, "SPLITS 每项需在 (0,1) 内"
    if abs(sum(SPLITS.values()) - 1.0) > 1e-6:
        print(f"[WARN] 划分比例之和={sum(SPLITS.values()):.6f} ≠ 1.0，将以 test = N - train - val 兜底。")

    random.seed(SEED)

    # 1) 收集样本（图片+同名txt）
    imgs = _list_image_files(SOURCE_DIR)
    pairs: List[Tuple[Path, Path]] = []
    neg_imgs: List[Path] = []  # 没有 label 的图片

    for img in imgs:
        lab = img.with_suffix(LABEL_EXT)
        if lab.exists():
            if lab.stat().st_size == 0 and not ALLOW_NEGATIVE:
                # 空txt当作负样本，默认跳过
                continue
            pairs.append((img, lab))
        else:
            neg_imgs.append(img)

    if not pairs:
        raise SystemExit("[ABORT] 没有找到成对的 图片+同名.txt；请检查 SOURCE_DIR 或标签生成。")

    print(f"[INFO] 收集到样本对：{len(pairs)}（跳过负样本：{len(neg_imgs)}）")

    # 2) 打乱并划分
    random.shuffle(pairs)
    n = len(pairs)
    n_train = int(SPLITS["train"] * n)
    n_val = int(SPLITS["val"] * n)
    n_test = n - n_train - n_val
    assert n_train >= 0 and n_val >= 0 and n_test >= 0

    splits = {
        "train": pairs[:n_train],
        "val": pairs[n_train:n_train + n_val],
        "test": pairs[n_train + n_val:]
    }

    print("[INFO] 划分结果：",
          f"train={len(splits['train'])}, val={len(splits['val'])}, test={len(splits['test'])}")

    # 3) 创建目标结构
    _ensure_empty_dir(DEST_DIR)
    for split in ("train", "val", "test"):
        (DEST_DIR / "images" / split).mkdir(parents=True, exist_ok=True)
        (DEST_DIR / "labels" / split).mkdir(parents=True, exist_ok=True)

    # 4) 复制/链接文件
    for split, items in splits.items():
        for img, lab in items:
            dst_img = DEST_DIR / "images" / split / img.name
            dst_lab = DEST_DIR / "labels" / split / (img.stem + LABEL_EXT)
            _copy_or_link(img, dst_img, symlink=USE_SYMLINK)
            _copy_or_link(lab, dst_lab, symlink=USE_SYMLINK)

    # 5) 统计与 YAML
    for split in ("train", "val", "test"):
        ni = len(list((DEST_DIR / "images" / split).glob("*")))
        nl = len(list((DEST_DIR / "labels" / split).glob("*")))
        print(f"[OK] {split:<5} images={ni}  labels={nl}")

    if WRITE_YAML:
        _write_yaml(DEST_DIR)

    print("✅ Done. 数据已整理为 YOLO Detection 标准结构。")
    print(f"   训练命令示例：\n"
          f"   yolo mode=checks data={DEST_DIR / YAML_NAME}\n"
          f"   yolo detect train model=yolo11s.pt data={DEST_DIR / YAML_NAME} "
          f"imgsz=640 epochs=150 batch=16")


In [11]:
if __name__ == "__main__":
    # 划分比例：合计≈1.0；最终以 test = N - train - val 兜底
    SPLITS = {"train": 0.7, "val": 0.2, "test": 0.1}
    # 目标数据根目录（将生成 images/{split} 与 labels/{split}）
    DEST_DIR = Path(f"datasets/insect_split_{SPLITS['train']}_{SPLITS['val']}_{SPLITS['test']}")
    main(SPLITS, DEST_DIR)

    print("✅ Upscale done.")

[INFO] 收集到样本对：74（跳过负样本：0）
[INFO] 划分结果： train=51, val=14, test=9
[OK] train images=51  labels=51
[OK] val   images=14  labels=14
[OK] test  images=9  labels=9
[OK] 写出 YAML: datasets/insect_split_0.7_0.2_0.1/swd_detection.yaml
✅ Done. 数据已整理为 YOLO Detection 标准结构。
   训练命令示例：
   yolo mode=checks data=datasets/insect_split_0.7_0.2_0.1/swd_detection.yaml
   yolo detect train model=yolo11s.pt data=datasets/insect_split_0.7_0.2_0.1/swd_detection.yaml imgsz=640 epochs=150 batch=16
✅ Upscale done.


In [12]:
if __name__ == "__main__":
    # 划分比例：合计≈1.0；最终以 test = N - train - val 兜底
    SPLITS = {"train": 0.4, "val": 0.3, "test": 0.3}
    DEST_DIR = Path(f"datasets/insect_split_{SPLITS['train']}_{SPLITS['val']}_{SPLITS['test']}")
    main(SPLITS, DEST_DIR)

    print("✅ Upscale done.")

[INFO] 收集到样本对：74（跳过负样本：0）
[INFO] 划分结果： train=29, val=22, test=23
[OK] train images=29  labels=29
[OK] val   images=22  labels=22
[OK] test  images=23  labels=23
[OK] 写出 YAML: datasets/insect_split_0.4_0.3_0.3/swd_detection.yaml
✅ Done. 数据已整理为 YOLO Detection 标准结构。
   训练命令示例：
   yolo mode=checks data=datasets/insect_split_0.4_0.3_0.3/swd_detection.yaml
   yolo detect train model=yolo11s.pt data=datasets/insect_split_0.4_0.3_0.3/swd_detection.yaml imgsz=640 epochs=150 batch=16
✅ Upscale done.


In [13]:
if __name__ == "__main__":
    # 划分比例：合计≈1.0；最终以 test = N - train - val 兜底
    SPLITS = {"train": 0.5, "val": 0.3, "test": 0.2}
    DEST_DIR = Path(f"datasets/insect_split_{SPLITS['train']}_{SPLITS['val']}_{SPLITS['test']}")
    main(SPLITS, DEST_DIR)

    print("✅ Upscale done.")

[INFO] 收集到样本对：74（跳过负样本：0）
[INFO] 划分结果： train=37, val=22, test=15
[OK] train images=37  labels=37
[OK] val   images=22  labels=22
[OK] test  images=15  labels=15
[OK] 写出 YAML: datasets/insect_split_0.5_0.3_0.2/swd_detection.yaml
✅ Done. 数据已整理为 YOLO Detection 标准结构。
   训练命令示例：
   yolo mode=checks data=datasets/insect_split_0.5_0.3_0.2/swd_detection.yaml
   yolo detect train model=yolo11s.pt data=datasets/insect_split_0.5_0.3_0.2/swd_detection.yaml imgsz=640 epochs=150 batch=16
✅ Upscale done.


In [None]:
if __name__ == "__main__":
    # 划分比例：合计≈1.0；最终以 test = N - train - val 兜底
    SPLITS = {"train": 0.5, "val": 0.2, "test": 0.3}
    DEST_DIR = Path(f"datasets/insect_split_{SPLITS['train']}_{SPLITS['val']}_{SPLITS['test']}")
    main(SPLITS, DEST_DIR)

    print("✅ Upscale done.")

[INFO] 收集到样本对：849（跳过负样本：0）
[INFO] 划分结果： train=424, val=169, test=256
[OK] train images=424  labels=424
[OK] val   images=169  labels=169
[OK] test  images=256  labels=256
[OK] 写出 YAML: datasets/swd_pose_split_0.5_0.2_0.3/swd_detection.yaml
✅ Done. 数据已整理为 YOLO Detection 标准结构。
   训练命令示例：
   yolo mode=checks data=datasets/swd_pose_split_0.5_0.2_0.3/swd_detection.yaml
   yolo detect train model=yolo11s.pt data=datasets/swd_pose_split_0.5_0.2_0.3/swd_detection.yaml imgsz=640 epochs=150 batch=16
✅ Upscale done.


In [None]:
if __name__ == "__main__":
    # 划分比例：合计≈1.0；最终以 test = N - train - val 兜底
    SPLITS = {"train": 0.6, "val": 0.2, "test": 0.2}
    DEST_DIR = Path(f"datasets/insect_split_{SPLITS['train']}_{SPLITS['val']}_{SPLITS['test']}")
    main(SPLITS, DEST_DIR)

    print("✅ Upscale done.")

[INFO] 收集到样本对：849（跳过负样本：0）
[INFO] 划分结果： train=509, val=169, test=171
[OK] train images=509  labels=509
[OK] val   images=169  labels=169
[OK] test  images=171  labels=171
[OK] 写出 YAML: datasets/swd_pose_split_0.6_0.2_0.2/swd_detection.yaml
✅ Done. 数据已整理为 YOLO Detection 标准结构。
   训练命令示例：
   yolo mode=checks data=datasets/swd_pose_split_0.6_0.2_0.2/swd_detection.yaml
   yolo detect train model=yolo11s.pt data=datasets/swd_pose_split_0.6_0.2_0.2/swd_detection.yaml imgsz=640 epochs=150 batch=16
✅ Upscale done.
