In [None]:
'''
划分数据集
'''
import os
import shutil
import random

# 定义路径
images_path = "data/images"  # 图片路径
labels_path = "data/labels"      # 标签路径
output_path = "data"            # 输出文件夹路径

# 数据划分比例
train_ratio = 0.8
val_ratio = 0.1

# 创建目标文件夹
for folder in ["images/train", "images/val", "images/test",
               "labels/train", "labels/val", "labels/test"]:
    os.makedirs(os.path.join(output_path, folder), exist_ok=True)

# 获取所有图片文件
images = [f for f in os.listdir(images_path) if f.endswith(".jpg")]
random.shuffle(images)

# 划分数据集
train_cutoff = int(len(images) * train_ratio)
val_cutoff = train_cutoff + int(len(images) * val_ratio)

train_images = images[:train_cutoff]
val_images = images[train_cutoff:val_cutoff]
test_images = images[val_cutoff:]

# 复制图片和标签文件
for dataset, dataset_images in zip(["train", "val", "test"], [train_images, val_images, test_images]):
    for image in dataset_images:
        base_name = os.path.splitext(image)[0]
        
        # 图片路径
        src_image_path = os.path.join(images_path, image)
        dest_image_path = os.path.join(output_path, f"dividedJPEGImages/{dataset}", image)
        
        # 标签路径
        label_file = f"{base_name}.txt"
        src_label_path = os.path.join(labels_path, label_file)
        dest_label_path = os.path.join(output_path, f"dividedLabels/{dataset}", label_file)
        
        # 检查并复制图片
        if os.path.exists(src_image_path):
            shutil.copy(src_image_path, dest_image_path)
        else:
            print(f"Warning: Image not found - {src_image_path}")
        
        # 检查并复制标签
        if os.path.exists(src_label_path):
            shutil.copy(src_label_path, dest_label_path)
        else:
            print(f"Warning: Label not found - {src_label_path}")


In [None]:
'''
检查yolo文件与图片是否一一对应
'''
import os

image_folder = "/Users/sowingg/CNN-food-detection/data/images/train"
label_folder = "/Users/sowingg/CNN-food-detection/data/images/train"

image_files = [os.path.splitext(f)[0] for f in os.listdir(image_folder) if f.endswith(".jpg")]
label_files = [os.path.splitext(f)[0] for f in os.listdir(label_folder) if f.endswith(".txt")]

missing_labels = set(image_files) - set(label_files)
if missing_labels:
    print("Missing labels for images:", missing_labels)
else:
    print("All images have corresponding labels.")

All images have corresponding labels.


In [None]:
'''
检查不符合规范的yolo文件
'''
import os

labels_dir = "/Users/sowingg/CNN-food-detection/data/images/train"
invalid_labels = []

for label_file in os.listdir(labels_dir):
    if label_file.endswith(".txt"):
        path = os.path.join(labels_dir, label_file)
        with open(path, "r") as f:
            lines = f.readlines()
            for line in lines:
                parts = line.strip().split()
                if len(parts) != 5:  # 每行应该有5个值
                    invalid_labels.append(label_file)
                else:
                    class_id, x, y, w, h = parts
                    if not (0 <= float(x) <= 1 and 0 <= float(y) <= 1 and 0 <= float(w) <= 1 and 0 <= float(h) <= 1):
                        invalid_labels.append(label_file)

if invalid_labels:
    print("Invalid label files:", invalid_labels)
else:
    print("All label files are valid.")

Invalid label files: ['people(2086).txt']
