<a href="https://colab.research.google.com/github/Deng-Xian-Sheng/Real-technology/blob/main/Yolo_%E4%B8%BB%E5%8A%A8%E5%AD%A6%E4%B9%A0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Yolo 主动学习

In [None]:
# @title 依赖安装
%pip install -q git+https://github.com/sunsmarterjie/yolov12.git supervision flash-attn ipywidgets huggingface_hub datasets

In [None]:
# @title 设置 初始数据集、待标注图片 的路径
# 项目名
PROJECT_NAME = ""
DATASET = ""
# 推荐在此之前对图片进行：压缩大小、调整分辨率640x640、使用哈希去重，因为后续的代码中不进行这些操作，不进行这些操作的原因是这些操作面对大规模数据集，需要多进程并行处理才能提高处理速度，甚至还需要进度保存，这些代码不在这个ipynb中实现，如果你需要，你可以联系作者。不用联系了，这里是：https://github.com/Deng-Xian-Sheng/Real-technology/blob/main/yolo大规模数据集预处理.md
UNLABELED_IMG = ""

In [None]:
# @title 选取清晰度前30%的图片（可选）
args_workers = 0 # 此处一般不改，0则使用cpu核心数
src_folder = UNLABELED_IMG
dst_folder = "填写你的文件夹"

import os
import argparse
import shutil
import torch
import piq
from skimage.io import imread
from multiprocessing import Pool, cpu_count
from tqdm import tqdm

def compute_brisque(img_path: str) -> dict:
    """
    读取图像并计算其 BRISQUE 指标。返回包含图像路径和 BRISQUE 指标的字典。
    BRISQUE 值越小，表示图像质量越好（更清晰）。
    """
    try:
        # 读取图像（skimage.io.imread 返回的是 [H, W, C] 维度）
        img = imread(img_path)

        # 转换为 [N, C, H, W] 维度，并归一化到 [0,1]
        # 这里把 N=1，当作 batch size=1
        img_tensor = torch.tensor(img).permute(2, 0, 1).unsqueeze(0).float() / 255.0

        # 如果可用，则移动到 GPU
        if torch.cuda.is_available():
            img_tensor = img_tensor.cuda()

        # 使用 piq 计算 BRISQUE
        brisque_score = piq.brisque(img_tensor, data_range=1., reduction='none')

        # 因为 reduction='none'，返回的是一个长度为 batch size 的向量，这里 batch=1
        score_value = brisque_score.item()

        return {
            "path": img_path,
            "brisque": score_value
        }
    except Exception as e:
        # 如果图片损坏或处理出错，可返回一个较大的值防止影响排序
        # 或者您也可以根据需求做其他处理
        return {
            "path": img_path,
            "brisque": float('inf')
        }

def get_image_files_recursive(folder: str):
    """
    递归遍历 folder，返回所有 .jpg / .jpeg / .png 文件的绝对路径列表。
    """
    valid_exts = {'.jpg', '.jpeg', '.png'}
    result_files = []
    for root, dirs, files in os.walk(folder):
        for f in files:
            ext = os.path.splitext(f)[1].lower()
            if ext in valid_exts:
                result_files.append(os.path.join(root, f))
    return result_files

def copy_top_images(image_info_list, target_folder):
    """
    将结果列表按 BRISQUE 升序(分数越小越好)排序，取前 30% 复制到目标目录。
    """
    sorted_list = sorted(image_info_list, key=lambda x: x["brisque"])
    top_count = max(1, int(len(sorted_list) * 0.3))  # 取前30%
    top_images = sorted_list[:top_count]

    if not os.path.exists(target_folder):
        os.makedirs(target_folder, exist_ok=True)

    for item in tqdm(top_images, desc="复制图片"):
        src_path = item["path"]
        filename = os.path.basename(src_path)
        dst_path = os.path.join(target_folder, filename)
        shutil.copy2(src_path, dst_path)

def main():
    # 获取所有图片文件路径
    image_files = get_image_files_recursive(src_folder)
    if not image_files:
        print("在指定的源文件夹中未找到任何 jpg/jpeg/png 图片。")
        return

    # 如果用户指定了workers
    if args_workers <= 0:
        workers = cpu_count()
    else:
        workers = args_workers

    print(f"共找到 {len(image_files)} 张图片，开始计算 BRISQUE 分数...")

    # 使用多进程计算
    with Pool(processes=workers) as pool:
        image_info_list = list(tqdm(
            pool.imap(compute_brisque, image_files),
            total=len(image_files),
            desc="计算BRISQUE"
        ))

    print("BRISQUE 计算完毕，开始筛选并复制前 30% 的图片...")

    # 将分数较好的（更清晰）图片复制到目标文件夹
    copy_top_images(image_info_list, dst_folder)

    print(f"处理完成！前 30% 的清晰图片已复制到: {dst_folder}")

main()

UNLABELED_IMG = dst_folder

In [None]:
# @title 导入库
import os
from pathlib import Path
from typing import List, Dict, Any
from ultralytics.engine.results import Results
from ultralytics import YOLO
import ipywidgets as widgets
from IPython.display import display, Javascript, clear_output
import uuid
from PIL import Image, ImageDraw, ImageFont
import io
import shutil
from google.colab import output
output.enable_custom_widget_manager()
from huggingface_hub import HfApi, login
import re
import hashlib
import colorsys

In [None]:
# @title 工具函数

def find_best_model_path(root_dir=f"{PROJECT_NAME}"):
    """
    查找最新训练的 YOLO 模型文件路径。

    Args:
        root_dir: 训练结果保存的根目录。

    Returns: 最新训练的模型文件路径。
    """
    # 检查目录是否存在
    if not os.path.exists(root_dir):
        raise FileNotFoundError(f"目录 {root_dir} 不存在")

    train_dirs = [f for f in os.listdir(root_dir) if f.startswith("train")]
    if not train_dirs:
        raise FileNotFoundError(f"在 {root_dir} 中未找到以 'train' 开头的目录")

    # 提取数字并设置默认值
    def extract_number(dirname):
        match = re.search(r'train(\d+)', dirname)
        # 如果是纯"train"，则返回-1，确保它排在所有数字版本之前
        return int(match.group(1)) if match else -1

    # 按照数字排序，选取最大的
    latest_train_dir = sorted(train_dirs, key=extract_number)[-1]
    latest_model_path = Path(root_dir) / latest_train_dir / "weights" / "best.pt"

    # 检查模型文件是否存在
    if not latest_model_path.exists():
        raise FileNotFoundError(f"模型文件 {latest_model_path} 不存在")

    return str(latest_model_path)

def annotate_single_image(
    model: YOLO,
    image: any,
    conf=0.25,
    iou=0.45,
) -> List[Dict[str, Any]]:
    """
    使用训练好的YOLO模型标注单张图片，返回结构化预测结果

    Args:
        model: 训练好的模型
        image: 输入源，可以是图像路径、文件夹、视频、URL 或摄像头设备。
        conf: 设置检测的最低置信度阈值。置信度低于此阈值的检测对象将被忽略。调整此值有助于减少误报。
        iou: 非最大抑制 (NMS) 的交并比 (IoU) 阈值。较低的值会通过消除重叠框而减少检测次数，有助于减少重复。

    Returns:
        结构化预测结果列表，每个元素包含：
        - bbox: 边界框坐标 [xmin, ymin, xmax, ymax]
        - confidence: 置信度
        - object_class_id: 类别ID
        - object_class_name: 类别名称
    """

    batch_results: Results = model.predict(
        image,
        verbose=False,
        conf=conf,
        iou=iou
    )

    # 解析结果
    batch_detections = []
    for results in batch_results:
      # 单个的结果
      detections = []
      for result in results:
          boxes = result.boxes.xyxy
          confidences = result.boxes.conf
          class_ids = result.boxes.cls

          # 构建结构化数据
          for i in range(len(boxes)):
              detections.append({
                  "bbox": boxes[i],
                  "confidence": confidences[i],
                  "object_class_id": class_ids[i],
                  "object_class_name": model.names[int(class_ids[i])]
              })
      batch_detections.append(detections)

    return batch_detections

def display_interface(
    image: Image.Image,
    title: str = "xxx",
    button_labels: List[str] = ["Accept", "Correct", "Skip"]
) -> str:
    """
    显示交互界面并获取用户选择（数字键1-3）

    Args:
        image: PIL Image对象
        title: 界面标题
        button_labels: 三个按钮的标签列表

    Returns:
        用户选择的标签（对应的数字键1-3）
    """
    # 将PIL Image转换为字节数据
    with io.BytesIO() as buffer:
      if image.mode in ('RGBA', 'LA', 'P'):
          image.save(buffer, format='PNG')
          img_format = 'png'
      else:
          image.save(buffer, format='JPEG', quality=90)
          img_format = 'jpeg'
      image_data = buffer.getvalue()  # 在 with 块内获取数据

    # 创建新的界面组件
    title_widget = widgets.HTML(value=f"<h2>{title}</h2>")
    image_widget = widgets.Image(value=image_data, format=img_format)
    vbox = widgets.VBox([title_widget, image_widget])

    display(vbox)

    result = input(f"请输入({','.join([f'{i}:{label}' for i, label in enumerate(button_labels, 1)])}):")

    if not result or int(result) > len(button_labels):
      raise RuntimeError(f"期望为'{','.join([f'{i}:{label}' for i, label in enumerate(button_labels, 1)])}'中的一个，得到'{result}'")

    result = button_labels[int(result) - 1]

    return result


def merge_datasets(new_dataset_path: str):
    """
    将新数据集合并到现有的DATASET目录中。

    参数:
        new_dataset_path (str): 新数据集的路径，符合YOLOv12格式，包含train、valid、test等子目录，
                                每个子目录下包含images和labels子目录。
    """
    import shutil
    import yaml
    from pathlib import Path

    new_dataset = Path(new_dataset_path)
    existing_dataset = Path(DATASET)

    # 检查路径有效性
    if not new_dataset.exists():
        raise FileNotFoundError(f"新数据集路径不存在: {new_dataset}")
    if not existing_dataset.exists():
        raise FileNotFoundError(f"现有数据集路径不存在: {existing_dataset}")

    # 验证data.yaml文件
    new_data_yaml = new_dataset / "data.yaml"
    if not new_data_yaml.exists():
        raise FileNotFoundError(f"新数据集缺少data.yaml文件: {new_data_yaml}")

    # 加载数据集配置
    with open(existing_dataset / "data.yaml", "r") as f:
        existing_data = yaml.safe_load(f)
    with open(new_data_yaml, "r") as f:
        new_data = yaml.safe_load(f)

    # 检查类别一致性
    # 暂时注释，有可能扩展数据集类别
    # existing_names = existing_data.get("names", [])
    # new_names = new_data.get("names", [])
    # if existing_names != new_names:
    #     raise ValueError(
    #         f"类别不匹配。现有数据集: {existing_names}，新数据集: {new_names}"
    #     )

    # 遍历所有子集目录
    for subset in ["train", "valid", "test"]:
        subset_dir = new_dataset / subset

        if not subset_dir.exists():
            continue

        # 处理图像和标注文件
        for data_type in ["images", "labels"]:
            src_dir = subset_dir / data_type
            dst_dir = existing_dataset / subset / data_type

            if not src_dir.exists():
                continue

            # 创建目标目录
            dst_dir.mkdir(parents=True, exist_ok=True)

            # 复制所有文件（保留元数据）
            for file in src_dir.glob("*"):
                if file.is_file():
                    shutil.copy2(file, dst_dir / file.name)

    print(f"✅ 成功合并数据集到 {existing_dataset}")
    print(f"   合并来源: {new_dataset}")

def get_contrast_color(class_name):
    """根据类别名称生成高对比度颜色"""
    hash_obj = hashlib.md5(class_name.encode('utf-8')).hexdigest()
    hash_num = int(hash_obj, 16)

    # 使用HSV颜色空间确保色相差异最大化
    hue = (hash_num % 360)        # 0-359度色相
    saturation = 0.9              # 90%饱和度
    value = 0.9                   # 90%亮度

    # 转换HSV到RGB
    r, g, b = colorsys.hsv_to_rgb(hue/360, saturation, value)
    return (int(r*255), int(g*255), int(b*255))

In [None]:
# @title 使用数据集训练模型

model = YOLO('yolov12x.pt')
results = model.train(
    data=f"{DATASET}/data.yaml",
    epochs=100,
    batch=-1,
    imgsz=640,
    project=PROJECT_NAME,
)

In [None]:
# @title 标注数据集

# 本次标注，每个对象的样本数量（默认10）(设置此变量时，尽可能的确保，该数量可以被拆分为70% 20% 10%)
LABEL_NUM = 10
# 已标注图片的路径
LABELED_IMG = "./labeled_img/"
# 需要人工标注的图片路径
PEOPLE_LABEL = "./people_label/"
# 是否进行人工标注
IS_PEOPLE_LABEL = False
# 设置检测的最低置信度阈值。置信度低于此阈值的检测对象将被忽略。调整此值有助于减少误报。
CONF=0.25
# 非最大抑制 (NMS) 的交并比 (IoU) 阈值。较低的值会通过消除重叠框而减少检测次数，有助于减少重复。
IOU=0.45
# 批量大小
BATCH = 32
# 预测全部图片，并按置信度排序，置信度高的在上面，然后取前30%进行人工检查。（人工检查是必要的，设为True则优化主动学习的实现方式，代价是更费性能）
IS_SORT = True

# 本次标注，每个对象的样本数量管理
# 这个变量用于约束每个对象的样本数量，以及训练集、验证集、测试集的占比（占比应为70% 20% 10%）。
label_num_manage = []
# 数据结构：
# [
#     {
#         "object_class_id":1,
#         "object_class_name":"xxx",
#         "train_num":1,
#         "valid_num":1,
#         "test_num":1,
#     }
# ]

# 创建已标注图片的路径
if not os.path.exists(LABELED_IMG):
    os.makedirs(LABELED_IMG)

# 创建需要人工标注的图片路径
if not os.path.exists(PEOPLE_LABEL):
    os.makedirs(PEOPLE_LABEL)

# 递归遍历待标注图片的路径，获取所有图片
unlabeled_imgs = []
for root, dirs, files in os.walk(UNLABELED_IMG):
    for file in files:
        if file.lower().endswith(('.jpg', '.jpeg', '.png')):
            # 检查文件名是否已经有UUID前缀
            # UUID格式: 8-4-4-4-12 个字符，总共36个字符加下划线
            uuid_pattern = r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}_'
            if re.match(uuid_pattern, file):
                # 已有UUID前缀，不重命名
                unlabeled_imgs.append(os.path.join(root, file))
            else:
                # 没有UUID前缀，添加前缀并重命名
                new_file_name = f"{uuid.uuid4()}_{file}"
                original_path = os.path.join(root, file)
                new_path = os.path.join(root, new_file_name)

                # 重命名文件
                os.rename(original_path, new_path)

                # 将新路径添加到列表中
                unlabeled_imgs.append(new_path)

# 加载模型以获取所有类别名称
model_path = find_best_model_path()
model = YOLO(model_path)
class_id_to_name = model.names.copy()
all_class_names = list(model.names.values())
all_classes_sorted = sorted(all_class_names)

for class_id, class_name in model.names.items():
    label_num_manage.append({
        "object_class_id": class_id,
        "object_class_name": class_name,
        "train_num": 0,
        "valid_num": 0,
        "test_num": 0
    })

# 遍历所有待标注图片，按批量大小处理
if IS_SORT:
  BATCH = len(unlabeled_imgs)
for i in range(0, len(unlabeled_imgs), BATCH):
  batch_imgs = unlabeled_imgs[i:i + BATCH]
  # 使用训练后的模型标注图片，返回结构化数据
  batch_detections = annotate_single_image(
    model,
    batch_imgs,
    conf=CONF,
    iou=IOU
  )

  if IS_SORT:
    # 按置信度排序
    batch_detections.sort(key=lambda sub_list: max(d["confidence"] for d in sub_list), reverse=True)

  # 处理每个图片的结果
  for img, detections in zip(batch_imgs, batch_detections):
      if not detections:
        print(f"图片 {img} 没有检测到任何对象，跳过")
        continue

      # 实现样本数量检查逻辑
      if detections:
          detected_classes = {d['object_class_id'] for d in detections}
          skip_image = False
          for cls_id in detected_classes:
              # 查找或创建管理条目
              entry = next((item for item in label_num_manage if item['object_class_id'] == cls_id), None)
              if entry is None:
                  # 创建新条目
                  entry = {
                      "object_class_id": cls_id,
                      "object_class_name": class_id_to_name[cls_id],
                      "train_num": 0,
                      "valid_num": 0,
                      "test_num": 0
                  }
                  label_num_manage.append(entry)
              total = entry['train_num'] + entry['valid_num'] + entry['test_num']
              if total >= LABEL_NUM:
                  skip_image = True
                  break
          if skip_image:
              continue  # 跳过该图片

      # 生成标题
      detected_classes = list({d['object_class_name'] for d in detections}) if detections else []
      detected_classes_sorted = sorted(detected_classes)
      title_all = '/'.join(all_classes_sorted)
      title_detected = '/'.join(detected_classes_sorted)
      title = f"{title_all} | {title_detected}"

      # 读取图片并绘制检测框
      with Image.open(img) as pil_image:
          if detections:
              draw = ImageDraw.Draw(pil_image)
              try:
                  font = ImageFont.truetype("arial.ttf", 20)
              except IOError:
                  font = ImageFont.load_default()

              # 创建颜色缓存字典
              color_cache = {}

              for d in detections:
                  xmin, ymin, xmax, ymax = map(int, d['bbox'])
                  class_name = d['object_class_name']
                  confidence = d['confidence']

                  # 获取或生成颜色
                  if class_name not in color_cache:
                      color_cache[class_name] = get_contrast_color(class_name)
                  color = color_cache[class_name]

                  # 绘制边框
                  draw.rectangle([xmin, ymin, xmax, ymax], outline=color, width=3)

                  # 构建标签文本
                  label = f"{class_name} {confidence:.2f}"
                  text_bbox = draw.textbbox((xmin, ymin), label, font=font)
                  text_width = text_bbox[2] - text_bbox[0]
                  text_height = text_bbox[3] - text_bbox[1]

                  # 绘制文字背景
                  draw.rectangle(
                      [xmin, ymin - text_height, xmin + text_width, ymin],
                      fill=color
                  )

                  # 绘制文字（保持高对比度）
                  draw.text(
                      (xmin, ymin - text_height),
                      label,
                      fill='white' if sum(color) < 400 else 'black',  # 自动选择文字颜色
                      font=font
                  )

          user_choice = display_interface(
              pil_image,
              title=title,
              button_labels=['allow', 'deny', 'deny but exist other object']
          )

      # 根据用户选择执行操作
      if user_choice == 'allow':
          # 确定子集（train/valid/test）
          subset = None
          for candidate_subset in ['train', 'valid', 'test']:
              can_allocate = True
              for d in detections:
                  cls_id = d['object_class_id']
                  entry = next((item for item in label_num_manage if item['object_class_id'] == cls_id), None)
                  if not entry:
                      continue  # 跳过未管理的类别（理论上不应出现）
                  # 计算当前子集配额
                  if candidate_subset == 'train':
                      quota = int(LABEL_NUM * 0.7)
                  elif candidate_subset == 'valid':
                      quota = int(LABEL_NUM * 0.2)
                  else:
                      quota = LABEL_NUM - int(LABEL_NUM*0.7) - int(LABEL_NUM*0.2)
                  # 检查配额
                  if entry[f"{candidate_subset}_num"] >= quota:
                      can_allocate = False
                      break
              if can_allocate:
                  subset = candidate_subset
                  break
          if not subset:
              print(f"无法为图片 {img} 分配子集，跳过")
              continue

          # 更新label_num_manage计数
          for d in detections:
              cls_id = d['object_class_id']
              entry = next(item for item in label_num_manage if item['object_class_id'] == cls_id)
              entry[f"{subset}_num"] += 1

          # 创建数据集目录
          img_name = os.path.basename(img)
          base_name = os.path.splitext(img_name)[0]
          images_dir = Path(DATASET) / subset / 'images'
          labels_dir = Path(DATASET) / subset / 'labels'
          images_dir.mkdir(parents=True, exist_ok=True)
          labels_dir.mkdir(parents=True, exist_ok=True)

          # 复制原图到images目录
          shutil.copy2(img, images_dir / img_name)

          # 生成YOLO格式标签
          with Image.open(img) as pil_img:
              img_width, img_height = pil_img.size
          label_path = labels_dir / f"{base_name}.txt"
          with open(label_path, 'w') as f:
              for d in detections:
                  xmin, ymin, xmax, ymax = d['bbox']
                  x_center = (xmin + xmax) / 2 / img_width
                  y_center = (ymin + ymax) / 2 / img_height
                  width = (xmax - xmin) / img_width
                  height = (ymax - ymin) / img_height
                  f.write(f"{d['object_class_id']} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n")

          # 移动原图到已标注目录
          shutil.move(img, Path(LABELED_IMG) / img_name)

      elif user_choice == 'deny':
          # 确定子集（逻辑与allow相同）
          subset = None
          for candidate_subset in ['train', 'valid', 'test']:
              can_allocate = True
              for d in detections:
                  cls_id = d['object_class_id']
                  entry = next((item for item in label_num_manage if item['object_class_id'] == cls_id), None)
                  if not entry:
                      continue
                  if candidate_subset == 'train':
                      quota = int(LABEL_NUM * 0.7)
                  elif candidate_subset == 'valid':
                      quota = int(LABEL_NUM * 0.2)
                  else:
                      quota = LABEL_NUM - int(LABEL_NUM*0.7) - int(LABEL_NUM*0.2)
                  if entry[f"{candidate_subset}_num"] >= quota:
                      can_allocate = False
                      break
              if can_allocate:
                  subset = candidate_subset
                  break
          if not subset:
              print(f"无法为图片 {img} 分配子集，跳过")
              continue

          # 创建数据集目录
          img_name = os.path.basename(img)
          base_name = os.path.splitext(img_name)[0]
          images_dir = Path(DATASET) / subset / 'images'
          labels_dir = Path(DATASET) / subset / 'labels'
          images_dir.mkdir(parents=True, exist_ok=True)
          labels_dir.mkdir(parents=True, exist_ok=True)

          # 复制原图并创建空标签
          shutil.copy2(img, images_dir / img_name)
          (labels_dir / f"{base_name}.txt").touch()

          # 移动原图到已标注目录
          shutil.move(img, Path(LABELED_IMG) / img_name)

      elif user_choice == 'deny but exist other object':
          if IS_PEOPLE_LABEL:
            # 将这些图片移动到PEOPLE_LABEL以进行人工标注
            img_name = os.path.basename(img)
            shutil.move(img, Path(PEOPLE_LABEL) / img_name)

  # 检查所有对象的数量是否满足LABEL_NUM
  all_met = True
  for entry in label_num_manage:
      total = entry['train_num'] + entry['valid_num'] + entry['test_num']
      if total < LABEL_NUM:
          all_met = False
          break
  if all_met:
      print("所有对象的数据集数量已满足要求，停止标注。")
      break

In [None]:
# @title 合并人工标注的数据集 -> 主动学习数据集
merge_datasets("this is new dataset path")

In [None]:
# @title 推送到Hugging Face

# ---- 配置部分 ----
HF_TOKEN = "your_token_here"
DATASET_PATH = DATASET
MODEL_PATH = find_best_model_path()
DATASET_REPO = "your_username/yolo_dataset"
MODEL_REPO = "your_username/yolo_model"
# ------------------

# 登录
login(token=HF_TOKEN)
api = HfApi()

# 上传数据集
api.create_repo(repo_id=DATASET_REPO, repo_type="dataset", exist_ok=True)
api.upload_folder(
    folder_path=DATASET_PATH,
    repo_id=DATASET_REPO,
    repo_type="dataset",
    path_in_repo=".",
)

# 上传模型
api.create_repo(repo_id=MODEL_REPO, exist_ok=True)
api.upload_file(
    path_or_fileobj=MODEL_PATH,
    repo_id=MODEL_REPO,
    path_in_repo="best.pt",
)

print("Upload completed!")