In [None]:
#@title Config
!pip install -q datasets pillow

import os
import gc
import shutil
from itertools import islice
from PIL import Image as PILImage
from datasets import Dataset, Features, Value, Sequence, Image, concatenate_datasets, load_from_disk
from IPython.display import display

# Google Drive mount
if not os.path.exists('/content/drive'):
    from google.colab import drive
    drive.mount('/content/drive')

DATASET_ROOT = '/content/drive/MyDrive/project/AEGIS/Dataset'
CATEGORIES = ['swoon', 'vandalism', 'assault', 'burglary', 'dump']
CLASS1_MAP = {'이상': 'abnormal', '의심': 'suspicious', '정상': 'normal'}
SAVE_PATH = os.path.join(DATASET_ROOT, 'hf_dataset')

MAX_ENTRY = 0  #@param {type:"integer"}
# 0 = 제한 없음 (전체), 양수 = 해당 개수까지만 생성

BATCH_SIZE = 120  #@param {type:"integer"}
# 배치 단위로 disk 저장 후 메모리 해제

PREVIEW_COUNT = 3  #@param {type:"integer"}
# 미리보기할 entry 개수 (각 entry의 8장 이미지를 모두 출력)

In [None]:
#@title Run: Listing → Batch Save → Merge → Preview

features = Features({
    'image_names': Sequence(Value('string')),
    'images': Sequence(Image()),
    'class1': Value('string'),
    'class2': Value('string'),
    'summary': Value('string'),
})

def entry_generator():
    """폴더 구조를 순회하며 PIL 이미지 포함 entry를 yield하는 generator"""
    exclude_list = ['전체', 'ipynb_checkpoints', '.ipynb_checkpoints']

    for category in CATEGORIES:
        image_base = os.path.join(DATASET_ROOT, category, 'image')
        if not os.path.exists(image_base):
            print(f'[SKIP] Path not found: {image_base}')
            continue

        folders = sorted([
            f for f in os.listdir(image_base)
            if os.path.isdir(os.path.join(image_base, f)) and f not in exclude_list
        ])

        for folder_name in folders:
            folder_path = os.path.join(image_base, folder_name)

            for sub_name, class1_value in CLASS1_MAP.items():
                sub_path = os.path.join(folder_path, sub_name)
                if not os.path.exists(sub_path):
                    continue

                image_files = sorted([
                    f for f in os.listdir(sub_path)
                    if f.lower().endswith(('.jpg', '.jpeg', '.png'))
                ])

                if len(image_files) != 8:
                    if len(image_files) > 0:
                        print(f'[SKIP] {category}/{folder_name}/{sub_name} ({len(image_files)}장)')
                    continue

                pil_images = [
                    PILImage.open(os.path.join(sub_path, name)).convert('RGB')
                    for name in image_files
                ]

                yield {
                    'image_names': image_files,
                    'images': pil_images,
                    'class1': class1_value,
                    'class2': category,
                    'summary': ''
                }

# --- 배치 저장 디렉토리 초기화 ---
batch_dir = os.path.join(SAVE_PATH, 'batches')
if os.path.exists(batch_dir):
    shutil.rmtree(batch_dir)
os.makedirs(batch_dir, exist_ok=True)

# --- 배치 단위로 저장 ---
gen = entry_generator()
total_count = 0
batch_idx = 0
limit = MAX_ENTRY if MAX_ENTRY > 0 else float('inf')

while total_count < limit:
    remaining = int(min(BATCH_SIZE, limit - total_count))
    batch = list(islice(gen, remaining))

    if not batch:
        break

    total_count += len(batch)
    batch_path = os.path.join(batch_dir, f'batch_{batch_idx}')

    ds_batch = Dataset.from_list(batch, features=features)
    ds_batch.save_to_disk(batch_path)

    print(f'  Batch {batch_idx}: {len(batch)}건 저장 → {batch_path}')

    # 메모리 해제
    del ds_batch, batch
    gc.collect()

    batch_idx += 1

print(f'\n총 {total_count}건, {batch_idx}개 배치 저장 완료')

# --- 배치 병합 ---
print('\n배치 병합 중...')
batch_datasets = []
for i in range(batch_idx):
    batch_path = os.path.join(batch_dir, f'batch_{i}')
    batch_datasets.append(load_from_disk(batch_path))

ds = concatenate_datasets(batch_datasets)
ds.save_to_disk(os.path.join(SAVE_PATH, 'merged'))

# 배치 폴더 정리
shutil.rmtree(batch_dir)

print(f'병합 완료: {ds}')
print(f'저장 경로: {os.path.join(SAVE_PATH, "merged")}')

# --- Report ---
print(f'\n  - abnormal: {sum(1 for i in range(len(ds)) if ds[i]["class1"] == "abnormal")}')
print(f'  - suspicious: {sum(1 for i in range(len(ds)) if ds[i]["class1"] == "suspicious")}')
print(f'  - normal: {sum(1 for i in range(len(ds)) if ds[i]["class1"] == "normal")}')

# --- Preview ---
preview_n = min(PREVIEW_COUNT, len(ds))
print(f'\n--- Preview ({preview_n} entries) ---')

for i in range(preview_n):
    sample = ds[i]
    print(f'\n[Entry {i}] class1={sample["class1"]}, class2={sample["class2"]}')
    print(f'  image_names: {sample["image_names"]}')

    imgs = sample['images']
    widths = [img.size[0] for img in imgs]
    heights = [img.size[1] for img in imgs]
    total_w = sum(widths)
    max_h = max(heights)

    grid = PILImage.new('RGB', (total_w, max_h))
    x_offset = 0
    for img in imgs:
        grid.paste(img, (x_offset, 0))
        x_offset += img.size[0]

    display(grid)