## Notebook for augmenting images used for first stage YOLO model of 2-stage model pipline

In [6]:
# Check if all image have label or vice versa
import os

image_dir = "TwoStageYOLODataset/original/images"
label_dir = "TwoStageYOLODataset/original/labels"

# 获取不含扩展名的文件名集合
image_files = {os.path.splitext(f)[0] for f in os.listdir(image_dir) if f.lower().endswith(('.jpg', '.png'))}
label_files = {os.path.splitext(f)[0] for f in os.listdir(label_dir) if f.lower().endswith('.txt')}

# 差集
images_without_labels = image_files - label_files
labels_without_images = label_files - image_files

# 打印结果
print(f"🔍 图像总数: {len(image_files)}")
print(f"📝 标签总数: {len(label_files)}\n")

print(f"📸 有图像但没有标签的数量: {len(images_without_labels)}")
for name in sorted(images_without_labels):
    print(f" - {name}.jpg")

print(f"\n🗂️ 有标签但没有图像的数量: {len(labels_without_images)}")
for name in sorted(labels_without_images):
    print(f" - {name}.txt")


🔍 图像总数: 138
📝 标签总数: 138

📸 有图像但没有标签的数量: 0

🗂️ 有标签但没有图像的数量: 0


In [7]:
import os
import itertools
from tqdm.notebook import tqdm
import augment_utils as aug
from concurrent.futures import ThreadPoolExecutor

aug.INPUT_DIR = 'TwoStageYOLODataset/original'
aug.OUTPUT_DIR = 'TwoStageYOLODataset/augmented'

AUG_FUNCTIONS = [
    aug.add_gaussian_noise,
    aug.adjust_random_brightness,
    aug.add_black_rect,
    aug.horizontal_flip,
    aug.random_rotate,
    aug.random_scale_with_padding,
]


# 获取所有图片名（不带扩展名）
input_images = [
    os.path.splitext(f)[0]
    for f in os.listdir(os.path.join(aug.INPUT_DIR, "images"))
    if f.endswith(".jpg")
]

# # 对每张图进行增强
# for filename in tqdm(input_images):
#     # 1. 原图保存
#     image, boxes, class_labels, _, _ = aug.load_image_and_boxes(filename)
#     aug.save_augmented(image, boxes, class_labels, filename)

#     # 2. 所有6选4组合增强
#     for combo in itertools.combinations(AUG_FUNCTIONS, 4):
#         # 从原图开始依次增强
#         temp_filename = filename
#         image, boxes, class_labels, _, _ = aug.load_image_and_boxes(temp_filename)

#         for func in combo:
#             # 修改函数以支持传 image/boxes/class_labels 是更复杂的工作，这里重新从文件加载（确保函数独立）
#             func(temp_filename)  # 每个函数会自动保存带 hash 的增强版本


# 每张图像的增强逻辑，供线程调用
def process_one_image(filename):
    try:
        # 保存原图
        image, boxes, class_labels, _, _ = aug.load_image_and_boxes(filename)
        aug.save_augmented(image, boxes, class_labels, filename)

        # 生成所有 6 选 5 的组合
        combos = list(itertools.combinations(AUG_FUNCTIONS, 5))
        for combo in combos:
            for func in combo:
                func(filename)  # 每个函数内部自己保存带 hash+timestamp 的结果
    except Exception as e:
        print(f"[ERROR] {filename} failed: {e}")

# 使用 ThreadPoolExecutor 加速
with ThreadPoolExecutor(max_workers=20) as executor:
    list(tqdm(executor.map(process_one_image, input_images), total=len(input_images)))

  0%|          | 0/138 [00:00<?, ?it/s]

In [None]:
import os, itertools, copy
from concurrent.futures import ThreadPoolExecutor
from tqdm.notebook import tqdm
import augment_utils_chain as aug

# 设置数据目录
aug.INPUT_DIR  = "TwoStageYOLODataset/original"
aug.OUTPUT_DIR = "TwoStageYOLODataset/augmented"

# 6 个可链式增强函数
AUG_FUNCS = [
    aug.add_gaussian_noise,
    aug.adjust_random_brightness,
    aug.add_black_rect,
    aug.horizontal_flip,
    aug.random_rotate,
    aug.random_scale_with_padding,
]

# Dont want these two operations to be done together as they might ruin the picture
bad_pair = {aug.add_gaussian_noise, aug.random_scale_with_padding}


# 图片列表（无扩展名）
imgs = [
    os.path.splitext(f)[0]
    for f in os.listdir(os.path.join(aug.INPUT_DIR, "images"))
    if f.lower().endswith(".jpg")
]

# 处理单张图片 → 7 份
def process_one(name):
    try:
        # 1. 保存原图
        img, boxes, cls, _, _ = aug.load_image_and_boxes(name)
        aug.save_augmented(img, boxes, cls, name)

        # 2. 链式增强：6 选 4 = 15 种
        for combo in itertools.combinations(AUG_FUNCS, 4):
            if bad_pair.issubset(combo):
                continue

            img_aug, boxes_aug, cls_aug = copy.deepcopy(img), copy.deepcopy(boxes), cls
            for fn in combo:                               # 链式执行
                img_aug, boxes_aug, cls_aug = fn(img_aug, boxes_aug, cls_aug)
            aug.save_augmented(img_aug, boxes_aug, cls_aug, name)
    except Exception as e:
        print(f"[ERROR] {name}: {e}")

# 20 线程并发
with ThreadPoolExecutor(max_workers=20) as ex:
    list(tqdm(ex.map(process_one, imgs), total=len(imgs)))

  0%|          | 0/138 [00:00<?, ?it/s]