In [51]:
import os
from PIL import Image
from torchvision import transforms
import json
import re

In [None]:
emotions = ['anger', 'happy', 'sadness', 'panic']  # 예시 감정
base_dir = os.path.dirname(os.path.abspath(__file__))
input_img_root = os.path.join(base_dir, 'CropData2','img','train')
input_json_root = os.path.join(base_dir, 'CropData2', 'label','train')
output_img_root = os.path.join(base_dir,'augment')

In [53]:
augment_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.04),
    transforms.RandomRotation(10),
    transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 2.0))
])

In [54]:
#혹시나 하니까...
def age_to_group(age):
    if 10 <= age < 20:
        return '10대'
    elif 20 <= age < 30:
        return '20대'
    elif 30 <= age < 40:
        return '30대'
    elif 40 <= age < 50:
        return '40대'
    elif 50 <= age < 60:
        return '50대'
    elif 60 <= age < 70:
        return '60대'
    else:
        return '기타'

In [None]:
augment_counts = {
    '20대_여': 0,
    '20대_남': 0,
    '30대_여': 1,
    '30대_남': 1,
    '40대_여': 2,
    '40대_남': 2,
    '50대_여': 3,
    '50대_남': 5,
    '10대_여': 4,
    '10대_남': 7,
    '60대_여': 0,
}

In [56]:
def strip_ext(filename):
    return os.path.splitext(filename)[0]

In [57]:
def load_metadata_from_json(json_path):
    with open(json_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    return {strip_ext(item['filename']): item for item in data}


In [58]:
for emotion in emotions:
    print(f'Processing emotion: {emotion}')
    img_dir = os.path.join(input_img_root, emotion)
    json_path = os.path.join(input_json_root, f'train_crop_{emotion}.json')
    save_dir = os.path.join(output_img_root, emotion)
    os.makedirs(save_dir, exist_ok=True)

    metadata = load_metadata_from_json(json_path)

    for fname in os.listdir(img_dir):
        if not fname.lower().endswith(('.jpg', '.jpeg', '.png')):
            continue

        base_fname = strip_ext(fname)
        info = metadata.get(base_fname)

        img_path = os.path.join(img_dir, fname)
        img = Image.open(img_path)

        if info is None:
            print(f'Warning: {fname} 메타데이터 없음 → 원본만 저장.')
            img.save(os.path.join(save_dir, fname))
            continue

        age = info.get('age')
        if age is not None:
            age_group = age_to_group(int(age))  # int로 변환해주기
        else:
            age_group = None

        gender = info.get('gender')
        key = f'{age_group}_{gender}'

        count = augment_counts.get(key, 0)

        # 원본 저장
        img.save(os.path.join(save_dir, fname))

        # 증강 이미지 생성 및 저장
        for i in range(count):
            aug_img = augment_transform(img)
            save_name = f'{base_fname}_aug{i+1}.jpg'
            aug_img.save(os.path.join(save_dir, save_name))


Processing emotion: anger
Processing emotion: happy
Processing emotion: sadness
Processing emotion: panic


In [59]:
metadata = load_metadata_from_json(json_path)