In [1]:
# 드라이브 마운트
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os
import json
import shutil
from pathlib import Path
from collections import defaultdict
import random
import zipfile

# ==============================
# 1. 경로 설정
# ==============================
BASE_DIR = "/content/dataset"
FLATTEN_DIR = "/content/flattened"  # 모든 이미지/라벨을 합치는 곳
OUTPUT_DIR = "/content/split_dataset"  # 최종 분할 폴더
os.makedirs(FLATTEN_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)

CLASS_BELT = "01"   # 안전벨트 클래스
CLASS_HELMET = "07" # 안전모 클래스

zip_path = "/content/drive/MyDrive/AI/"

# 압축 해제 함수
def unzip_files(zip_path, target_dir):
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(target_dir)
    print(f"압축 해제 완료: {zip_path} -> {target_dir}")

# 압축 해제
file_name = "completed.zip"
unzip_files(zip_path + file_name, BASE_DIR)

KeyboardInterrupt: 

In [None]:
import os
import json
import shutil
from pathlib import Path
import random

# ==============================
# 1. 경로 설정
# ==============================
BASE_DIR = "/content/dataset"
FLATTEN_DIR = "/content/flattened"
OUTPUT_DIR = "/content/split_dataset"
os.makedirs(FLATTEN_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)

CLASS_BELT = "01"   # 안전벨트 클래스
CLASS_HELMET = "07" # 안전모 클래스

# ==============================
# 2. train, val 데이터를 모두 FLATTEN_DIR에 복사
# ==============================
def flatten_dataset():
    for split in ["train", "val"]:
        img_root = Path(BASE_DIR) / split / "images"
        label_root = Path(BASE_DIR) / split / "labels_json"

        # 서브폴더 없으므로 바로 glob
        image_map = {}
        for img_path in img_root.glob("*.jpg"):
            image_map[img_path.name.lower()] = img_path
        print(f"{split} 이미지 개수:", len(image_map))

        for label_file in label_root.glob("*.json"):
            with open(label_file, "r", encoding="utf-8") as f:
                data = json.load(f)

            pure_filename = Path(data["image"]["filename"]).name.lower()

            if pure_filename not in image_map:
                print(f"⚠️ 매칭 실패 (이미지 없음): {pure_filename}")
                continue

            src_img = image_map[pure_filename]
            dst_img = Path(FLATTEN_DIR) / pure_filename
            dst_label = Path(FLATTEN_DIR) / f"{pure_filename}.json"

            shutil.copy2(src_img, dst_img)
            shutil.copy2(label_file, dst_label)

flatten_dataset()

# ==============================
# 3. 각 이미지별 안전벨트/안전모 인스턴스 수 집계
# ==============================
image_stats = {}

for label_file in Path(FLATTEN_DIR).glob("*.json"):
    with open(label_file, "r", encoding="utf-8") as f:
        data = json.load(f)

    anns = data.get("annotations", [])
    belt_count = sum(1 for ann in anns if ann.get("class") == CLASS_BELT)
    helmet_count = sum(1 for ann in anns if ann.get("class") == CLASS_HELMET)
    pure_filename = Path(data["image"]["filename"]).name.lower()
    image_stats[pure_filename] = {"belt": belt_count, "helmet": helmet_count}

belt_images = [f for f, v in image_stats.items() if v["belt"] > 0]
helmet_images = [f for f, v in image_stats.items() if v["helmet"] > 0]
belt_only = list(set(belt_images) - set(helmet_images))
helmet_only = list(set(helmet_images) - set(belt_images))
both = list(set(helmet_images) & set(belt_images))

print("총 이미지 수:", len(image_stats))
print("안전벨트 이미지 수:", len(belt_images))
print("안전모 이미지 수:", len(helmet_images))
print("둘 다 포함된 이미지 수:", len(both))

# ==============================
# 5. 분할 함수
# ==============================
def split_dataset(belt_only, helmet_only, both, val_ratio=0.3):
    random.seed(42)
    random.shuffle(both)
    total_count = len(set(belt_only + helmet_only + both))
    val_count = int(total_count * val_ratio)

    val_images = set()
    while len(val_images) < val_count and both:
        val_images.add(both.pop())

    toggle = True
    while len(val_images) < val_count and (belt_only or helmet_only):
        if toggle and belt_only:
            val_images.add(belt_only.pop())
        elif (not toggle) and helmet_only:
            val_images.add(helmet_only.pop())
        toggle = not toggle

    train_images = set(belt_only + helmet_only + both) - val_images
    train_belt = sum(image_stats[f]["belt"] for f in train_images)
    train_helmet = sum(image_stats[f]["helmet"] for f in train_images)

    if train_belt > train_helmet:
        diff = train_belt - train_helmet
        sorted_belt = sorted(train_images, key=lambda f: image_stats[f]["belt"], reverse=True)
        removed = 0
        for f in sorted_belt:
            if removed >= diff:
                break
            if image_stats[f]["belt"] > 0 and image_stats[f]["helmet"] == 0:
                train_images.remove(f)
                removed += image_stats[f]["belt"]
    elif train_helmet > train_belt:
        diff = train_helmet - train_belt
        sorted_helmet = sorted(train_images, key=lambda f: image_stats[f]["helmet"], reverse=True)
        removed = 0
        for f in sorted_helmet:
            if removed >= diff:
                break
            if image_stats[f]["helmet"] > 0 and image_stats[f]["belt"] == 0:
                train_images.remove(f)
                removed += image_stats[f]["helmet"]

    return list(train_images), list(val_images)

train_list, val_list = split_dataset(belt_only, helmet_only, both)

print("최종 train 이미지 수:", len(train_list))
print("최종 val 이미지 수:", len(val_list))

# ==============================
# 6. 결과 저장
# ==============================
def save_split(images, split_name):
    img_out = Path(OUTPUT_DIR) / split_name / "images"
    label_out = Path(OUTPUT_DIR) / split_name / "labels_json"
    img_out.mkdir(parents=True, exist_ok=True)
    label_out.mkdir(parents=True, exist_ok=True)

    for fname in images:
        shutil.copy2(Path(FLATTEN_DIR) / fname, img_out / fname)
        shutil.copy2(Path(FLATTEN_DIR) / f"{fname}.json", label_out / f"{fname}.json")

save_split(train_list, "train")
save_split(val_list, "val")

print("데이터셋 분할 완료!")

# ==============================
# 7. 분할 후 객체 개수 확인
# ==============================
def count_instances(image_list):
    belt_total = sum(image_stats[f]["belt"] for f in image_list)
    helmet_total = sum(image_stats[f]["helmet"] for f in image_list)
    return belt_total, helmet_total

train_belt, train_helmet = count_instances(train_list)
val_belt, val_helmet = count_instances(val_list)

print("\n=== 인스턴스 개수 통계 ===")
print(f"Train - 안전벨트: {train_belt}, 안전모: {train_helmet}")
print(f"Val   - 안전벨트: {val_belt}, 안전모: {val_helmet}")


train 이미지 개수: 1540
val 이미지 개수: 180
총 이미지 수: 1720
안전벨트 이미지 수: 1369
안전모 이미지 수: 1504
둘 다 포함된 이미지 수: 1153
최종 train 이미지 수: 1135
최종 val 이미지 수: 516
데이터셋 분할 완료!
