# YOLOv8 牙齿检测演示

本notebook用于演示单张图片的牙齿检测效果，包括真实标签和预测结果的对比可视化。

In [None]:
# 导入必要的库
import os
import sys
import cv2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib import rcParams
from ultralytics import YOLO
import yaml
from pathlib import Path

# 将项目根目录添加到路径
sys.path.append('..')

# 设置中文字体
rcParams['font.sans-serif'] = ['SimHei', 'Arial Unicode MS', 'DejaVu Sans']
rcParams['axes.unicode_minus'] = False

print("✅ 库导入成功！")

In [None]:
# 配置参数
MODEL_PATH = "../outputs/dentalai/train_yolov8n_5ep_2025_08_14_07_52_45/weights/best.pt"  # 模型路径
DATA_YAML = "../preprocessed_datasets/dentalai/data.yaml"  # 数据配置文件
IMAGE_PATH = "../preprocessed_datasets/dentalai/test/images/10_jpg.rf.34526d096400eda6a09986228293a587.jpg"  # 示例图片路径

# 检查文件是否存在
if not os.path.exists(MODEL_PATH):
    print(f"❌ 模型文件不存在: {MODEL_PATH}")
    print("请先训练模型或修改 MODEL_PATH")
else:
    print(f"✅ 模型文件存在: {MODEL_PATH}")

if not os.path.exists(DATA_YAML):
    print(f"❌ 数据配置文件不存在: {DATA_YAML}")
else:
    print(f"✅ 数据配置文件存在: {DATA_YAML}")

if not os.path.exists(IMAGE_PATH):
    print(f"❌ 示例图片不存在: {IMAGE_PATH}")
    print("请修改 IMAGE_PATH 为有效的图片路径")
else:
    print(f"✅ 示例图片存在: {IMAGE_PATH}")

In [None]:
# 加载模型和数据配置
print("🔄 正在加载模型...")
model = YOLO(MODEL_PATH)
print("✅ 模型加载成功!")

# 读取类别名称
with open(DATA_YAML, 'r', encoding='utf-8') as f:
    data_config = yaml.safe_load(f)
    class_names = data_config.get('names', ['Caries', 'Cavity', 'Crack', 'Tooth'])

print(f"📋 类别名称: {class_names}")

In [None]:
def load_ground_truth_labels(label_path):
    """
    加载真实标签
    
    Args:
        label_path (str): 标签文件路径
        
    Returns:
        list: 标签列表，每个元素为[class_id, x_center, y_center, width, height]
    """
    labels = []
    if os.path.exists(label_path):
        with open(label_path, 'r') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) >= 5:
                    labels.append([float(x) for x in parts])
    return labels

def denormalize_bbox(bbox, img_width, img_height):
    """
    将YOLO格式的归一化边界框转换为像素坐标
    
    Args:
        bbox (list): [x_center, y_center, width, height] (归一化)
        img_width (int): 图像宽度
        img_height (int): 图像高度
        
    Returns:
        tuple: (x_min, y_min, x_max, y_max) 像素坐标
    """
    x_center, y_center, width, height = bbox
    x_center *= img_width
    y_center *= img_height
    width *= img_width
    height *= img_height
    
    x_min = x_center - width / 2
    y_min = y_center - height / 2
    x_max = x_center + width / 2
    y_max = y_center + height / 2
    
    return int(x_min), int(y_min), int(x_max), int(y_max)

print("✅ 辅助函数定义完成!")

## 修改图片路径

在下面的代码框中修改 `IMAGE_PATH` 变量来指定你想要检测的图片：

In [None]:
# 修改这里的图片路径来检测不同的图片
IMAGE_PATH = "../preprocessed_datasets/dentalai/test/images/10_jpg.rf.34526d096400eda6a09986228293a587.jpg"

# 或者选择其他图片，例如:
# IMAGE_PATH = "../preprocessed_datasets/dentalai/test/images/103_jpg.rf.8ce6d1e50ce178132fc20124134d4afd.jpg"
# IMAGE_PATH = "your_custom_image_path.jpg"  # 使用自定义图片

print(f"📸 当前图片路径: {IMAGE_PATH}")
if os.path.exists(IMAGE_PATH):
    print("✅ 图片文件存在")
else:
    print("❌ 图片文件不存在，请修改路径")

In [None]:
# 读取图像
image = cv2.imread(IMAGE_PATH)
if image is None:
    print(f"❌ 无法读取图像: {IMAGE_PATH}")
else:
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    img_height, img_width = image.shape[:2]
    print(f"✅ 图像读取成功! 尺寸: {img_width}x{img_height}")
    
    # 显示原图
    plt.figure(figsize=(12, 8))
    plt.imshow(image)
    plt.title(f'原始图像: {os.path.basename(IMAGE_PATH)}', fontsize=14, fontweight='bold')
    plt.axis('off')
    plt.show()

In [None]:
# 加载真实标签
image_name = os.path.splitext(os.path.basename(IMAGE_PATH))[0]
label_dir = os.path.dirname(IMAGE_PATH).replace('images', 'labels')
label_path = os.path.join(label_dir, f"{image_name}.txt")

print(f"📋 标签文件路径: {label_path}")

gt_labels = load_ground_truth_labels(label_path)
if gt_labels:
    print(f"✅ 找到 {len(gt_labels)} 个真实标签")
    for i, label in enumerate(gt_labels):
        class_id = int(label[0])
        class_name = class_names[class_id] if class_id < len(class_names) else f'Class{class_id}'
        print(f"   标签 {i+1}: {class_name} (类别ID: {class_id})")
else:
    print("⚠️ 未找到真实标签文件")

In [None]:
# 模型预测
print("🔄 正在进行预测...")
results = model.predict(IMAGE_PATH, verbose=False)
print("✅ 预测完成!")

# 解析预测结果
predictions = []
if results and len(results) > 0 and results[0].boxes is not None:
    boxes = results[0].boxes
    for box in boxes:
        x_min, y_min, x_max, y_max = box.xyxy[0].cpu().numpy()
        confidence = box.conf[0].cpu().numpy()
        class_id = int(box.cls[0].cpu().numpy())
        
        predictions.append({
            'bbox': [x_min, y_min, x_max, y_max],
            'confidence': confidence,
            'class_id': class_id,
            'class_name': class_names[class_id] if class_id < len(class_names) else f'Class{class_id}'
        })
    
    print(f"📊 预测结果: {len(predictions)} 个检测框")
    for i, pred in enumerate(predictions):
        print(f"   预测 {i+1}: {pred['class_name']} (置信度: {pred['confidence']:.3f})")
else:
    print("⚠️ 未检测到任何对象")

In [None]:
# 可视化对比结果
fig, axes = plt.subplots(1, 3, figsize=(20, 8))

# 1. 原图
axes[0].imshow(image)
axes[0].set_title('原始图像', fontsize=14, fontweight='bold')
axes[0].axis('off')

# 2. 真实标签
axes[1].imshow(image)
axes[1].set_title('真实标签 (绿色)', fontsize=14, fontweight='bold')
axes[1].axis('off')

# 绘制真实标签
for label in gt_labels:
    class_id, x_center, y_center, width, height = label
    x_min, y_min, x_max, y_max = denormalize_bbox(
        [x_center, y_center, width, height], img_width, img_height
    )
    
    # 绘制边界框
    rect = patches.Rectangle(
        (x_min, y_min), x_max - x_min, y_max - y_min,
        linewidth=3, edgecolor='green', facecolor='none'
    )
    axes[1].add_patch(rect)
    
    # 添加类别标签
    class_name = class_names[int(class_id)] if int(class_id) < len(class_names) else f'Class{int(class_id)}'
    axes[1].text(x_min, y_min - 10, f'GT: {class_name}', 
               fontsize=12, color='green', fontweight='bold',
               bbox=dict(boxstyle="round,pad=0.3", facecolor='white', alpha=0.8))

# 3. 预测结果
axes[2].imshow(image)
axes[2].set_title('预测结果 (红色)', fontsize=14, fontweight='bold')
axes[2].axis('off')

# 绘制预测结果
for pred in predictions:
    if pred['confidence'] > 0.1:  # 只显示置信度大于0.1的预测
        x_min, y_min, x_max, y_max = pred['bbox']
        
        # 绘制边界框
        rect = patches.Rectangle(
            (x_min, y_min), x_max - x_min, y_max - y_min,
            linewidth=3, edgecolor='red', facecolor='none'
        )
        axes[2].add_patch(rect)
        
        # 添加类别标签和置信度
        axes[2].text(x_min, y_min - 10, f"Pred: {pred['class_name']} ({pred['confidence']:.2f})", 
                   fontsize=12, color='red', fontweight='bold',
                   bbox=dict(boxstyle="round,pad=0.3", facecolor='white', alpha=0.8))

plt.tight_layout()
plt.show()

In [None]:
# 叠加对比
plt.figure(figsize=(15, 10))
plt.imshow(image)
plt.title('预测结果对比 (绿色: 真实标签, 红色: 预测结果)', fontsize=16, fontweight='bold')

# 绘制真实标签 (绿色)
for label in gt_labels:
    class_id, x_center, y_center, width, height = label
    x_min, y_min, x_max, y_max = denormalize_bbox(
        [x_center, y_center, width, height], img_width, img_height
    )
    
    rect = patches.Rectangle(
        (x_min, y_min), x_max - x_min, y_max - y_min,
        linewidth=3, edgecolor='green', facecolor='none', alpha=0.8
    )
    plt.gca().add_patch(rect)
    
    class_name = class_names[int(class_id)] if int(class_id) < len(class_names) else f'Class{int(class_id)}'
    plt.text(x_min, y_min - 10, f'GT: {class_name}', 
           fontsize=12, color='green', fontweight='bold',
           bbox=dict(boxstyle="round,pad=0.3", facecolor='white', alpha=0.8))

# 绘制预测结果 (红色)
for pred in predictions:
    if pred['confidence'] > 0.1:
        x_min, y_min, x_max, y_max = pred['bbox']
        
        rect = patches.Rectangle(
            (x_min, y_min), x_max - x_min, y_max - y_min,
            linewidth=3, edgecolor='red', facecolor='none', alpha=0.8
        )
        plt.gca().add_patch(rect)
        
        plt.text(x_min, y_max + 25, f"Pred: {pred['class_name']} ({pred['confidence']:.2f})", 
               fontsize=12, color='red', fontweight='bold',
               bbox=dict(boxstyle="round,pad=0.3", facecolor='white', alpha=0.8))

plt.axis('off')
plt.tight_layout()
plt.show()

In [None]:
# 统计信息
print("📊 检测统计:")
print(f"   真实标签数量: {len(gt_labels)}")
print(f"   预测结果数量: {len(predictions)}")

if gt_labels:
    print("\n🏷️  真实标签详情:")
    for i, label in enumerate(gt_labels):
        class_id = int(label[0])
        class_name = class_names[class_id] if class_id < len(class_names) else f'Class{class_id}'
        x_center, y_center, width, height = label[1:]
        print(f"   {i+1}. {class_name} - 中心:({x_center:.3f}, {y_center:.3f}), 大小:({width:.3f}, {height:.3f})")

if predictions:
    print("\n🎯 预测结果详情:")
    for i, pred in enumerate(predictions):
        x_min, y_min, x_max, y_max = pred['bbox']
        print(f"   {i+1}. {pred['class_name']} - 置信度:{pred['confidence']:.3f}, 位置:({x_min:.1f}, {y_min:.1f}, {x_max:.1f}, {y_max:.1f})")

print("\n✅ 演示完成!")
print("\n💡 提示: 修改上面的 IMAGE_PATH 变量来检测不同的图片")