In [4]:
import os
os.getcwd()

'C:\\Users\\wjbok\\Desktop\\BigTech'

In [5]:
from collections import Counter
import pandas as pd

BASE = r"./object_detection_"
LBL_SRC = os.path.join(BASE,"labels","train")


def analyze_class_distribution(label_directory):
    """
    지정된 디렉토리의 라벨 파일을 분석하여 클래스 분포를 표로 출력합니다.
    """
    print(f"분석 대상 폴더: {os.path.abspath(label_directory)}")

    # 1. 경로 존재 여부 확인
    if not os.path.isdir(label_directory):
        print(f"❌ 오류: '{label_directory}' 폴더를 찾을 수 없습니다. 경로를 확인해주세요.")
        return

    # 2. 모든 클래스 ID 수집
    class_ids = []
    try:
        label_files = [f for f in os.listdir(label_directory) if f.endswith(".txt")]
        if not label_files:
            print(f"⚠️ 경고: '{label_directory}' 폴더에 라벨 파일(.txt)이 없습니다.")
            return

        for filename in label_files:
            filepath = os.path.join(label_directory, filename)
            with open(filepath, 'r') as f:
                for line in f:
                    parts = line.split()
                    if parts:
                        try:
                            class_id = int(parts[0])
                            class_ids.append(class_id)
                        except (ValueError, IndexError):
                            continue # 숫자가 아니거나 빈 줄인 경우 무시
    except Exception as e:
        print(f"❌ 오류: 파일을 읽는 중 문제가 발생했습니다: {e}")
        return

    if not class_ids:
        print("분석할 클래스 데이터를 찾지 못했습니다.")
        return

    # 3. 클래스별 개수 계산 및 표로 출력
    class_counts = Counter(class_ids)
    df = pd.DataFrame(class_counts.items(), columns=['Class', 'Count'])
    df_sorted = df.sort_values(by='Count', ascending=False).reset_index(drop=True)

    print("\n✅ 클래스별 객체 수 (내림차순)")
    print("-" * 30)
    print(df_sorted.to_string())
    print("-" * 30)


# 스크립트 실행
if __name__ == '__main__':
    analyze_class_distribution(LBL_SRC)


분석 대상 폴더: C:\Users\wjbok\Desktop\BigTech\object_detection_\labels\train

✅ 클래스별 객체 수 (내림차순)
------------------------------
   Class  Count
0      1    958
1      0    112
2      5     59
3      3     49
4      4     36
5      6     28
6      2     22
------------------------------


In [6]:
!pip install albumentations opencv-python

Collecting albumentations
  Downloading albumentations-2.0.8-py3-none-any.whl.metadata (43 kB)
Collecting pydantic>=2.9.2 (from albumentations)
  Downloading pydantic-2.11.7-py3-none-any.whl.metadata (67 kB)
Collecting albucore==0.0.24 (from albumentations)
  Downloading albucore-0.0.24-py3-none-any.whl.metadata (5.3 kB)
Collecting opencv-python-headless>=4.9.0.80 (from albumentations)
  Downloading opencv_python_headless-4.12.0.88-cp37-abi3-win_amd64.whl.metadata (20 kB)
Collecting stringzilla>=3.10.4 (from albucore==0.0.24->albumentations)
  Downloading stringzilla-3.12.6-cp311-cp311-win_amd64.whl.metadata (81 kB)
Collecting simsimd>=5.9.2 (from albucore==0.0.24->albumentations)
  Downloading simsimd-6.5.1-cp311-cp311-win_amd64.whl.metadata (72 kB)
Collecting numpy>=1.24.4 (from albumentations)
  Using cached numpy-2.2.6-cp311-cp311-win_amd64.whl.metadata (60 kB)
Collecting annotated-types>=0.6.0 (from pydantic>=2.9.2->albumentations)
  Using cached annotated_types-0.7.0-py3-none-any

  You can safely remove it manually.
  You can safely remove it manually.
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
datumaro 1.10.0 requires numpy<2,>=1.23.4, but you have numpy 2.2.6 which is incompatible.
mediapipe 0.10.14 requires protobuf<5,>=4.25.3, but you have protobuf 5.29.5 which is incompatible.
numba 0.61.0 requires numpy<2.2,>=1.24, but you have numpy 2.2.6 which is incompatible.
tensorflow-intel 2.14.0 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3, but you have protobuf 5.29.5 which is incompatible.
tensorflow-intel 2.14.0 requires wrapt<1.15,>=1.11.0, but you have wrapt 1.17.2 which is incompatible.


In [9]:
import os
import cv2
from collections import Counter
import pandas as pd

# --- 설정 ---
BASE = r"./object_detection_"
# --- 설정 ---

SPLITS_TO_CHECK = ["train"]
all_resolutions = []

print("이미지 해상도를 분석합니다...")
for split in SPLITS_TO_CHECK:
    img_dir = os.path.join(BASE, "images", split)
    if not os.path.isdir(img_dir):
        continue

    for filename in os.listdir(img_dir):
        if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
            try:
                img_path = os.path.join(img_dir, filename)
                # OpenCV로 이미지 크기(너비, 높이) 읽기
                image = cv2.imread(img_path)
                height, width, _ = image.shape
                all_resolutions.append(f"{width}x{height}")
            except Exception as e:
                print(f"'{filename}' 파일 처리 중 오류 발생: {e}")

if not all_resolutions:
    print("분석할 이미지가 없습니다.")
else:
    resolution_counts = Counter(all_resolutions)
    df = pd.DataFrame(resolution_counts.items(), columns=['Resolution', 'Count'])
    df_sorted = df.sort_values(by='Count', ascending=False).reset_index(drop=True)

    print("\n✅ 전체 데이터셋의 해상도 분포")
    print("-" * 35)
    print(df_sorted.to_string())
    print("-" * 35)

이미지 해상도를 분석합니다...

✅ 전체 데이터셋의 해상도 분포
-----------------------------------
  Resolution  Count
0    640x360    650
-----------------------------------


In [11]:
import cv2
import albumentations as A
from collections import defaultdict
import random

TARGET_IMAGE_COUNT=100

IMG_DIR = os.path.join(BASE,"images","train")
LBL_DIR = os.path.join(BASE,"labels","train")

transform = A.Compose([
    A.RandomBrightnessContrast(p=0.3),
    A.ShiftScaleRotate(
        shift_limit=0.0625,
        scale_limit=0.1,
        rotate_limit=15,
        p=0.5,
        border_mode=cv2.BORDER_CONSTANT,
        value=0
    ),
    A.RandomResizedCrop(
        size=(352, 640),
        scale=(0.8, 1.0),
        p=0.3
    ),
], bbox_params=A.BboxParams(
        format='yolo',
        label_fields=['class_labels'],
        min_visibility=0.3
))

def augment_minority_classes():
    """소수 클래스를 분석하고 목표치에 도달할 때까지 증강합니다."""
    if not os.path.isdir(LBL_DIR):
        print(f"❌ 오류: 라벨 폴더 '{LBL_DIR}'를 찾을 수 없습니다.")
        return

    class_to_images = defaultdict(list)
    for label_file in os.listdir(LBL_DIR):
        if not label_file.endswith(".txt"): continue
        image_name = os.path.splitext(label_file)[0]
        with open(os.path.join(LBL_DIR, label_file), 'r') as f:
            classes_in_image = set(int(line.split()[0]) for line in f if line.strip())
            for class_id in classes_in_image:
                class_to_images[class_id].append(image_name)

    class_image_counts = {k: len(v) for k, v in class_to_images.items()}
    print("--- 증강 전 클래스별 이미지 수 ---")
    for cid, count in sorted(class_image_counts.items()):
        print(f"  - 클래스 {cid}: {count}개 이미지")
    print("-" * 35)

    for class_id, image_count in class_image_counts.items():
        if image_count < TARGET_IMAGE_COUNT:
            num_to_generate = TARGET_IMAGE_COUNT - image_count
            source_images = class_to_images[class_id]
            print(f"🔧 클래스 {class_id} 증강 시작 (목표: {TARGET_IMAGE_COUNT}개, 생성: {num_to_generate}개)")

            for i in range(num_to_generate):
                base_name = random.choice(source_images)
                img_path, lbl_path = None, os.path.join(LBL_DIR, base_name + ".txt")
                for ext in ['.jpg', '.jpeg', '.png']:
                    if os.path.exists(os.path.join(IMG_DIR, base_name + ext)):
                        img_path = os.path.join(IMG_DIR, base_name + ext)
                        break
                if not img_path: continue

                image = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
                bboxes, class_labels = [], []
                with open(lbl_path, 'r') as f:
                    for line in f:
                        parts = line.strip().split()
                        class_labels.append(int(parts[0]))
                        bboxes.append([float(p) for p in parts[1:]])

                try:
                    transformed = transform(image=image, bboxes=bboxes, class_labels=class_labels)
                    if transformed['bboxes']:
                        new_name = f"{base_name}_aug_{class_id}_{i+1}"
                        new_img_path = os.path.join(IMG_DIR, new_name + os.path.splitext(img_path)[1])
                        new_lbl_path = os.path.join(LBL_DIR, new_name + ".txt")
                        # transformed 딕셔너리에서 'image' 키로 값을 가져와야 합니다.
                        cv2.imwrite(new_img_path, cv2.cvtColor(transformed['image'], cv2.COLOR_RGB2BGR))
                        with open(new_lbl_path, 'w') as f:
                            for j, bbox in enumerate(transformed['bboxes']):
                                f.write(f"{transformed['class_labels'][j]} {' '.join(map(str, bbox))}\n")
                except Exception as e:
                    print(f"  - 오류: {base_name} 증강 중 문제 발생 ({e})")

    print("\n🎉 모든 소수 클래스 증강 작업이 완료되었습니다!")

if __name__ == '__main__':
    augment_minority_classes()

  original_init(self, **validated_kwargs)
  A.ShiftScaleRotate(


--- 증강 전 클래스별 이미지 수 ---
  - 클래스 0: 111개 이미지
  - 클래스 1: 366개 이미지
  - 클래스 2: 22개 이미지
  - 클래스 3: 49개 이미지
  - 클래스 4: 36개 이미지
  - 클래스 5: 59개 이미지
  - 클래스 6: 28개 이미지
-----------------------------------
🔧 클래스 3 증강 시작 (목표: 100개, 생성: 51개)
🔧 클래스 5 증강 시작 (목표: 100개, 생성: 41개)
🔧 클래스 6 증강 시작 (목표: 100개, 생성: 72개)
🔧 클래스 2 증강 시작 (목표: 100개, 생성: 78개)
🔧 클래스 4 증강 시작 (목표: 100개, 생성: 64개)

🎉 모든 소수 클래스 증강 작업이 완료되었습니다!


In [None]:
import os
import random
from collections import defaultdict

# --- 설정 ---
BASE = r"./object_detection_bok"
TARGET_IMAGE_COUNT = 100
# --- 설정 ---

IMG_DIR = os.path.join(BASE, "images", "train")
LBL_DIR = os.path.join(BASE, "labels", "train")

def undersample_majority_classes():
    """다수 클래스를 분석하고 목표치에 도달할 때까지 관련 파일을 삭제합니다."""
    print("⚠️ 경고: 파일이 영구적으로 삭제됩니다. 5초 후에 작업을 시작합니다...")
    # import time; time.sleep(5) # 필요시 주석 해제

    if not os.path.isdir(LBL_DIR):
        print(f"❌ 오류: 라벨 폴더 '{LBL_DIR}'를 찾을 수 없습니다.")
        return

    class_to_images = defaultdict(set)
    for label_file in os.listdir(LBL_DIR):
        if not label_file.endswith(".txt"): continue
        image_name = os.path.splitext(label_file)[0]
        with open(os.path.join(LBL_DIR, label_file), 'r') as f:
            # [수정된 부분] float을 거쳐 int로 변환하도록 수정
            classes_in_image = set(int(float(line.split()[0])) for line in f if line.strip())
            for class_id in classes_in_image:
                class_to_images[class_id].add(image_name)

    class_image_counts = {k: len(v) for k, v in class_to_images.items()}
    files_to_delete = set()

    for class_id, image_count in class_image_counts.items():
        if image_count > TARGET_IMAGE_COUNT:
            num_to_delete = image_count - TARGET_IMAGE_COUNT
            source_images = class_to_images[class_id]
            images_marked_for_deletion = random.sample(list(source_images), num_to_delete)
            files_to_delete.update(images_marked_for_deletion)
            print(f"🗑️ 클래스 {class_id}: {num_to_delete}개 이미지 삭제 예정...")

    print(f"\n총 {len(files_to_delete)}개의 이미지/라벨 쌍을 삭제합니다.")
    for base_name in files_to_delete:
        lbl_path = os.path.join(LBL_DIR, base_name + ".txt")
        if os.path.exists(lbl_path): os.remove(lbl_path)

        for ext in ['.jpg', '.jpeg', '.png']:
            img_path = os.path.join(IMG_DIR, base_name + ext)
            if os.path.exists(img_path):
                os.remove(img_path)
                break

    print("\n🎉 모든 다수 클래스 언더샘플링 작업이 완료되었습니다!")

if __name__ == '__main__':
    undersample_majority_classes()

In [12]:
import os
import random
from collections import defaultdict

# --- 설정 ---
BASE = r"./object_detection_"
TARGET_IMAGE_COUNT = 100
# --- 설정 ---

IMG_DIR = os.path.join(BASE, "images", "train")
LBL_DIR = os.path.join(BASE, "labels", "train")

def undersample_majority_classes():
    """다수 클래스를 분석하고 목표치에 도달할 때까지 관련 파일을 삭제합니다."""
    print("⚠️ 경고: 파일이 영구적으로 삭제됩니다. 5초 후에 작업을 시작합니다...")
    # import time; time.sleep(5) # 필요시 주석 해제

    if not os.path.isdir(LBL_DIR):
        print(f"❌ 오류: 라벨 폴더 '{LBL_DIR}'를 찾을 수 없습니다.")
        return

    class_to_images = defaultdict(set)
    for label_file in os.listdir(LBL_DIR):
        if not label_file.endswith(".txt"): continue
        image_name = os.path.splitext(label_file)[0]
        with open(os.path.join(LBL_DIR, label_file), 'r') as f:
            # [수정된 부분] float을 거쳐 int로 변환하도록 수정
            classes_in_image = set(int(float(line.split()[0])) for line in f if line.strip())
            for class_id in classes_in_image:
                class_to_images[class_id].add(image_name)

    class_image_counts = {k: len(v) for k, v in class_to_images.items()}
    files_to_delete = set()

    for class_id, image_count in class_image_counts.items():
        if image_count > TARGET_IMAGE_COUNT:
            num_to_delete = image_count - TARGET_IMAGE_COUNT
            source_images = class_to_images[class_id]
            images_marked_for_deletion = random.sample(list(source_images), num_to_delete)
            files_to_delete.update(images_marked_for_deletion)
            print(f"🗑️ 클래스 {class_id}: {num_to_delete}개 이미지 삭제 예정...")

    print(f"\n총 {len(files_to_delete)}개의 이미지/라벨 쌍을 삭제합니다.")
    for base_name in files_to_delete:
        lbl_path = os.path.join(LBL_DIR, base_name + ".txt")
        if os.path.exists(lbl_path): os.remove(lbl_path)

        for ext in ['.jpg', '.jpeg', '.png']:
            img_path = os.path.join(IMG_DIR, base_name + ext)
            if os.path.exists(img_path):
                os.remove(img_path)
                break

    print("\n🎉 모든 다수 클래스 언더샘플링 작업이 완료되었습니다!")

if __name__ == '__main__':
    undersample_majority_classes()

⚠️ 경고: 파일이 영구적으로 삭제됩니다. 5초 후에 작업을 시작합니다...
🗑️ 클래스 1: 266개 이미지 삭제 예정...
🗑️ 클래스 0: 11개 이미지 삭제 예정...

총 273개의 이미지/라벨 쌍을 삭제합니다.

🎉 모든 다수 클래스 언더샘플링 작업이 완료되었습니다!


In [13]:
import os
from collections import Counter
import pandas as pd

# --- 설정 ---
BASE = r"./object_detection_"
# --- 설정 ---

LBL_DIR = os.path.join(BASE, "labels", "train")

def final_check_balance():
    """최종 클래스 분포(객체 수 기준)를 확인합니다."""
    print(f"분석 대상 폴더: {os.path.abspath(LBL_DIR)}")

    if not os.path.isdir(LBL_DIR):
        print(f"❌ 오류: 라벨 폴더 '{LBL_DIR}'를 찾을 수 없습니다.")
        return

    class_ids = []
    try:
        for filename in os.listdir(LBL_DIR):
            if filename.endswith(".txt"):
                with open(os.path.join(LBL_DIR, filename), 'r') as f:
                    for line in f:
                        parts = line.strip().split()
                        if parts:
                            # '2.0' 같은 실수 형태의 클래스 ID도 처리 가능하도록 수정
                            class_ids.append(int(float(parts[0])))
    except Exception as e:
        print(f"❌ 오류: 파일을 읽는 중 문제가 발생했습니다: {e}")
        return

    if not class_ids:
        print("분석할 라벨 데이터가 없습니다.")
        return

    class_counts = Counter(class_ids)
    df = pd.DataFrame(class_counts.items(), columns=['Class', 'Count'])
    df_sorted = df.sort_values(by='Count', ascending=False).reset_index(drop=True)

    print("\n📊 최종 클래스별 객체 수 (내림차순)")
    print("-" * 30)
    print(df_sorted.to_string())
    print("-" * 30)

if __name__ == '__main__':
    final_check_balance()

분석 대상 폴더: C:\Users\wjbok\Desktop\BigTech\object_detection_\labels\train

📊 최종 클래스별 객체 수 (내림차순)
------------------------------
   Class  Count
0      1    255
1      5    100
2      6    100
3      2     98
4      3     97
5      4     96
6      0     92
------------------------------


In [14]:
!pip install scikit-multilearn

Collecting scikit-multilearn
  Downloading scikit_multilearn-0.2.0-py3-none-any.whl.metadata (6.0 kB)
Downloading scikit_multilearn-0.2.0-py3-none-any.whl (89 kB)
Installing collected packages: scikit-multilearn
Successfully installed scikit-multilearn-0.2.0


In [16]:
import os
import numpy as np
import shutil
from collections import defaultdict
from skmultilearn.model_selection import IterativeStratification

# --- 설정 ---
BASE = r"./object_detection_"
# --- 설정 ---

IMG_SRC_DIR = os.path.join(BASE, "images", "train")
LBL_SRC_DIR = os.path.join(BASE, "labels", "train")

# 1. 이미지와 클래스 라벨 매핑
image_to_classes = defaultdict(set)
all_labels = set()
image_files = [os.path.splitext(f)[0] for f in os.listdir(LBL_SRC_DIR) if f.endswith(".txt")]

for img_name in image_files:
    with open(os.path.join(LBL_SRC_DIR, img_name + ".txt"), 'r') as f:
        classes_in_image = set(int(float(line.split()[0])) for line in f if line.strip())
        image_to_classes[img_name] = classes_in_image
        all_labels.update(classes_in_image)

# 2. Iterative Stratification을 위한 데이터 형식 준비
sorted_labels = sorted(list(all_labels))
label_map = {label: i for i, label in enumerate(sorted_labels)}
X = np.array(image_files).reshape(-1, 1)
y = np.zeros((len(image_files), len(all_labels)), dtype=int)

for i, img_name in enumerate(image_files):
    for cls in image_to_classes[img_name]:
        y[i, label_map[cls]] = 1

# 3. 데이터 분할 (train: 70%, val: 20%, test: 10%)
stratifier = IterativeStratification(n_splits=2, order=1, sample_distribution_per_fold=[0.3, 0.7])
train_indices, temp_indices = next(stratifier.split(X, y))
X_train, y_train = X[train_indices], y[train_indices]
X_temp, y_temp = X[temp_indices], y[temp_indices]

stratifier_val_test = IterativeStratification(n_splits=2, order=1, sample_distribution_per_fold=[1/3, 2/3])
test_indices, val_indices = next(stratifier_val_test.split(X_temp, y_temp)) # 1/3을 test로, 2/3를 val로
X_val, y_val = X_temp[val_indices], y_temp[val_indices]
X_test, y_test = X_temp[test_indices], y_temp[test_indices]

train_files = set(X_train.flatten())
val_files = set(X_val.flatten())
test_files = set(X_test.flatten())

# 4. 폴더 생성 및 파일 이동
for split_name, file_set in [("val", val_files), ("test", test_files)]:
    os.makedirs(os.path.join(BASE, "images", split_name), exist_ok=True)
    os.makedirs(os.path.join(BASE, "labels", split_name), exist_ok=True)

    print(f"\n'{split_name}' 세트로 파일 이동 중... ({len(file_set)}개)")
    for base_name in file_set:
        # --- 여기에 파일 이동 로직 추가 ---
        lbl_src_path = os.path.join(LBL_SRC_DIR, base_name + ".txt")
        lbl_dest_path = os.path.join(BASE, "labels", split_name, base_name + ".txt")
        if os.path.exists(lbl_src_path):
            shutil.move(lbl_src_path, lbl_dest_path)

        for ext in ['.jpg', '.jpeg', '.png']:
            img_src_path = os.path.join(IMG_SRC_DIR, base_name + ext)
            if os.path.exists(img_src_path):
                img_dest_path = os.path.join(BASE, "images", split_name, base_name + ext)
                shutil.move(img_src_path, img_dest_path)
                break
        # --- 로직 추가 끝 ---

print("\n🎉 Iterative Stratification을 사용하여 파일 분할이 완료되었습니다!")
print("-" * 30)
# train 폴더에 남은 파일 수를 세도록 수정
print(f"  - 훈련(train) 세트: {len(os.listdir(IMG_SRC_DIR))}개")
print(f"  - 검증(val) 세트: {len(val_files)}개")
print(f"  - 테스트(test) 세트: {len(test_files)}개")
print("-" * 30)


'val' 세트로 파일 이동 중... (69개)

'test' 세트로 파일 이동 중... (134개)

🎉 Iterative Stratification을 사용하여 파일 분할이 완료되었습니다!
------------------------------
  - 훈련(train) 세트: 471개
  - 검증(val) 세트: 69개
  - 테스트(test) 세트: 134개
------------------------------


In [17]:
import os
from collections import Counter
import pandas as pd

# --- 설정 ---
BASE = r"./object_detection_"
# --- 설정 ---

SPLITS_TO_CHECK = ["train", "val", "test"]

def analyze_split_distribution(split_name):
    """지정된 세트(train, val, test)의 클래스 분포를 분석합니다."""

    lbl_dir = os.path.join(BASE, "labels", split_name)

    if not os.path.isdir(lbl_dir):
        print(f"\n--- [{split_name.upper()} 세트] ---")
        print(f"❌ 오류: '{lbl_dir}' 폴더를 찾을 수 없습니다.")
        return

    class_ids = []
    for filename in os.listdir(lbl_dir):
        if filename.endswith(".txt"):
            with open(os.path.join(lbl_dir, filename), 'r') as f:
                for line in f:
                    parts = line.strip().split()
                    if parts:
                        class_ids.append(int(float(parts[0])))

    print(f"\n--- [{split_name.upper()} 세트] ---")
    if not class_ids:
        print("분석할 라벨 데이터가 없습니다.")
        return

    class_counts = Counter(class_ids)
    df = pd.DataFrame(class_counts.items(), columns=['Class', 'Count'])
    df_sorted = df.sort_values(by='Class').reset_index(drop=True) # 클래스 ID 순으로 정렬

    print(df_sorted.to_string())


if __name__ == '__main__':
    print("각 세트의 클래스 분포를 확인합니다...")
    for split in SPLITS_TO_CHECK:
        analyze_split_distribution(split)

각 세트의 클래스 분포를 확인합니다...

--- [TRAIN 세트] ---
   Class  Count
0      0     65
1      1    192
2      2     69
3      3     68
4      4     67
5      5     70
6      6     70

--- [VAL 세트] ---
   Class  Count
0      0      9
1      1     19
2      2     10
3      3     10
4      4     10
5      5     10
6      6     10

--- [TEST 세트] ---
   Class  Count
0      0     18
1      1     44
2      2     19
3      3     19
4      4     19
5      5     20
6      6     20


In [None]:
# train_yolo.py
# -----------------------------
# Ultralytics YOLO 학습 및 PTQ 적용 스크립트
# -----------------------------

import os
import random
import shutil
from datetime import datetime
from ultralytics import YOLO

# ========= 사용자 설정 =========
DATASET_DIR = r"./object_detection_"
PROJECT_DIR = r"./object_detection_/runs_yolo_bok"
EXP_NAME    = "exp_car_yolov8n"

CLASS_NAMES = [
    "animal", "person", "traffic_red","traffic_yellow", "traffic_green", "right", "left"
]

MODEL_NAME = "./yolov8n.pt"
EPOCHS     = 100
IMGSZ      = 320
BATCH      = 16
LR0        = 0.005
PATIENCE   = 10
DEVICE     = "cpu"

DO_PREDICT_SAMPLES = True
DO_EXPORT_OPENVINO = True # PTQ를 적용할 것이므로 True로 유지
DO_EXPORT_ONNX     = False

PREDICT_SOURCE = os.path.join(DATASET_DIR, "images", "val")
PREDICT_CONF   = 0.25

# ========= 유틸 =========
def ensure_yaml(dataset_dir, class_names):
    """data.yaml 자동 생성 (이미 있으면 덮어쓰지 않음)"""
    yaml_path = os.path.join(dataset_dir, "data.yaml")
    if os.path.exists(yaml_path):
        print(f"[INFO] data.yaml 이미 존재: {yaml_path}")
        return yaml_path

    # 절대 경로로 변환하여 yaml 파일에 기록
    abs_dataset_dir = os.path.abspath(dataset_dir)
    content = [
        f"path: {abs_dataset_dir}", # 절대 경로 사용
        "train: images/train",
        "val: images/val",
        "test: images/test",
        "names:"
    ]
    for i, name in enumerate(class_names):
        content.append(f"  {i}: {name}")

    with open(yaml_path, "w", encoding="utf-8") as f:
        f.write("\n".join(content) + "\n")

    print(f"[OK] data.yaml 생성: {yaml_path}")
    return yaml_path


def sanity_check(dataset_dir):
    """간단 무결성 체크"""
    img_train_dir = os.path.join(dataset_dir, "images", "train")
    lbl_train_dir = os.path.join(dataset_dir, "labels", "train")
    if not os.path.isdir(img_train_dir) or not os.path.isdir(lbl_train_dir):
        print(f"[WARN] train 폴더를 찾을 수 없어 sanity check를 건너뜁니다.")
        return

    missing = []
    exts = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"}

    for f in os.listdir(img_train_dir):
        ext = os.path.splitext(f)[1].lower()
        if ext not in exts: continue
        stem = os.path.splitext(f)[0]
        if not os.path.exists(os.path.join(lbl_train_dir, stem + ".txt")):
            missing.append(f)

    if missing:
        print(f"[WARN] 라벨 누락 이미지 {len(missing)}개 예시: {missing[:10]}")
    else:
        print("[OK] 라벨 매칭 이상 없음")


def train():
    # 1) data.yaml 보장
    data_yaml = ensure_yaml(DATASET_DIR, CLASS_NAMES)

    # 2) 간단 체크
    sanity_check(DATASET_DIR)

    # 3) 모델 로드
    print(f"[INFO] Loading model: {MODEL_NAME}")
    model = YOLO(MODEL_NAME)

    # 4) 학습
    print("[INFO] Start training...")
    results = model.train(
        data=data_yaml,
        epochs=EPOCHS,
        imgsz=IMGSZ,
        batch=BATCH,
        device=DEVICE,
        project=PROJECT_DIR,
        name=EXP_NAME,
        lr0=LR0,
        patience=PATIENCE,
        optimizer="auto",
        hsv_h=0.015, hsv_s=0.7, hsv_v=0.4,
        fliplr=0.5,
        mosaic=1.0, mixup=0.15,
        degrees=5, translate=0.05, scale=0.1, shear=0.0, perspective=0.0,
    )

    run_dir = results.save_dir
    best_pt = os.path.join(run_dir, "weights", "best.pt")
    print(f"[OK] Training done. best: {best_pt}")

    # 5) 검증(mAP, PR커브)
    print("[INFO] Validate best weights...")
    model = YOLO(best_pt)
    model.val(data=data_yaml, project=PROJECT_DIR, name=f"{EXP_NAME}_val")

    # 6) 예측 샘플 저장
    if DO_PREDICT_SAMPLES and os.path.exists(PREDICT_SOURCE):
        print(f"[INFO] Predict & save samples from: {PREDICT_SOURCE}")
        model.predict(
            source=PREDICT_SOURCE,
            conf=PREDICT_CONF,
            save=True,
            project=PROJECT_DIR,
            name=f"{EXP_NAME}_pred_val"
        )

     # 7) 내보내기 (PTQ 적용)
    if DO_EXPORT_OPENVINO:
        print("[INFO] Export OpenVINO IR with Data-aware INT8 Quantization (PTQ)...")

        # ✅ 'data'와 'imgsz'를 명시하여 우리 데이터셋에 맞게 양자화를 수행합니다.
        model.export(
            format="openvino",    # OpenVINO 형식으로 내보내기
            int8=True,            # INT8 양자화 활성화
            data=data_yaml,       # 교정 데이터로 우리 val 세트를 사용하도록 지정
            imgsz=IMGSZ,          # 훈련 시와 동일한 이미지 크기로 교정
            half=False,           # FP16 대신 INT8을 목표로 하므로 False
            simplify=True         # ONNX 모델을 단순화하여 호환성 및 속도 향상
        )


    print("[DONE] All finished.")


if __name__ == "__main__":
    train()

[INFO] data.yaml 이미 존재: ./object_detection_\data.yaml
[OK] 라벨 매칭 이상 없음
[INFO] Loading model: ./yolov8n.pt
[INFO] Start training...
New https://pypi.org/project/ultralytics/8.3.186 available  Update with 'pip install -U ultralytics'
Ultralytics 8.3.176  Python-3.11.2 torch-2.7.1+cpu CPU (AMD Ryzen 5 5600U with Radeon Graphics)
[34m[1mengine\trainer: [0magnostic_nms=False, amp=True, augment=False, auto_augment=randaugment, batch=16, bgr=0.0, box=7.5, cache=False, cfg=None, classes=None, close_mosaic=10, cls=0.5, conf=None, copy_paste=0.0, copy_paste_mode=flip, cos_lr=False, cutmix=0.0, data=./object_detection_\data.yaml, degrees=5, deterministic=True, device=cpu, dfl=1.5, dnn=False, dropout=0.0, dynamic=False, embed=None, epochs=100, erasing=0.4, exist_ok=False, fliplr=0.5, flipud=0.0, format=torchscript, fraction=1.0, freeze=None, half=False, hsv_h=0.015, hsv_s=0.7, hsv_v=0.4, imgsz=640, int8=False, iou=0.7, keras=False, kobj=1.0, line_width=None, lr0=0.005, lrf=0.01, mask_ratio=4, m

[34m[1mtrain: [0mScanning C:\Users\wjbok\Desktop\BigTech\object_detection_\labels\train... 471 images, 0 backgrounds, 0 corrupt: 100%|██████████| 471/471 [00:00<00:00, 814.87it/s]


[34m[1mtrain: [0mNew cache created: C:\Users\wjbok\Desktop\BigTech\object_detection_\labels\train.cache
[34m[1malbumentations: [0mBlur(p=0.01, blur_limit=(3, 7)), MedianBlur(p=0.01, blur_limit=(3, 7)), ToGray(p=0.01, method='weighted_average', num_output_channels=3), CLAHE(p=0.01, clip_limit=(1.0, 4.0), tile_grid_size=(8, 8))
[34m[1mval: [0mFast image access  (ping: 0.10.0 ms, read: 791.6264.0 MB/s, size: 217.6 KB)


[34m[1mval: [0mScanning C:\Users\wjbok\Desktop\BigTech\object_detection_\labels\val... 69 images, 0 backgrounds, 0 corrupt: 100%|██████████| 69/69 [00:00<00:00, 1024.64it/s]

[34m[1mval: [0mNew cache created: C:\Users\wjbok\Desktop\BigTech\object_detection_\labels\val.cache





Plotting labels to object_detection_\runs_yolo_bok\exp_car_yolov8n\labels.jpg... 
[34m[1moptimizer:[0m 'optimizer=auto' found, ignoring 'lr0=0.005' and 'momentum=0.937' and determining best 'optimizer', 'lr0' and 'momentum' automatically... 
[34m[1moptimizer:[0m AdamW(lr=0.000909, momentum=0.9) with parameter groups 57 weight(decay=0.0), 64 weight(decay=0.0005), 63 bias(decay=0.0)
Image sizes 640 train, 640 val
Using 0 dataloader workers
Logging results to [1mobject_detection_\runs_yolo_bok\exp_car_yolov8n[0m
Starting training for 100 epochs...

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      1/100         0G      1.489      4.698      1.258         43        640:  57%|█████▋    | 17/30 [01:35<01:13,  5.67s/it]