In [None]:
import os
import pandas as pd
import shutil
from sklearn.model_selection import train_test_split

# 경로 수정해야함!!!
csv_path = '/data/youlee/data_library/oxford_102_flowers/image_caption.csv'
images_src_dir = '/data/youlee/data_library/oxford_102_flowers/images'
train_img_dir = '/data/youlee/data_library/oxford_102_flowers/image/train'
test_img_dir = '/data/youlee/data_library/oxford_102_flowers/image/test'
train_caption_dir = '/data/youlee/data_library/oxford_102_flowers/caption/train'
test_caption_dir = '/data/youlee/data_library/oxford_102_flowers/caption/test'

# 6개 그룹룹
groups = {
    "group1": ["class_00051", "class_00072", "class_00082", "class_00060"],
    "group2": ["class_00077", "class_00065", "class_00058", "class_00056"],
    "group3": ["class_00046", "class_00074", "class_00037", "class_00080"],
    "group4": ["class_00073", "class_00041", "class_00083", "class_00095"],
    "group5": ["class_00089", "class_00081", "class_00075", "class_00076"],
    "group6": ["class_00094", "class_00088", "class_00078", "class_00043"]
}


df = pd.read_csv(csv_path)

seed = 42

for group, classes in groups.items():
    for cls in classes:
        df_class = df[df['label'] == cls]

        train_df, test_df = train_test_split(df_class, test_size=0.2, random_state=seed)
        train_img_class_dir = os.path.join(train_img_dir, group, cls)
        test_img_class_dir = os.path.join(test_img_dir, group, cls)
        train_caption_class_dir = os.path.join(train_caption_dir, group, cls)
        test_caption_class_dir = os.path.join(test_caption_dir, group, cls)
        
        os.makedirs(train_img_class_dir, exist_ok=True)
        os.makedirs(test_img_class_dir, exist_ok=True)
        os.makedirs(train_caption_class_dir, exist_ok=True)
        os.makedirs(test_caption_class_dir, exist_ok=True)

        for idx, row in train_df.iterrows():
            img_file = row['image']
            caption = row['caption']
            
            base, ext = os.path.splitext(img_file)
            if ext == "":
                img_file = base + ".jpg"
            
            src_img_path = os.path.join(images_src_dir, img_file)
            dest_img_path = os.path.join(train_img_class_dir, img_file)
            shutil.copy(src_img_path, dest_img_path)

            caption_filename = base + '.txt'
            caption_file_path = os.path.join(train_caption_class_dir, caption_filename)
            with open(caption_file_path, 'w', encoding='utf-8') as f:
                f.write(caption)

        for idx, row in test_df.iterrows():
            img_file = row['image']
            caption = row['caption']
            
            base_name, ext = os.path.splitext(img_file)
            if ext == "":
                img_file = base_name + ".jpg"
            
            src_img_path = os.path.join(images_src_dir, img_file)
            dest_img_path = os.path.join(test_img_class_dir, img_file)
            shutil.copy(src_img_path, dest_img_path)
            
            caption_filename = base_name + '.txt'
            caption_file_path = os.path.join(test_caption_class_dir, caption_filename)
            with open(caption_file_path, 'w', encoding='utf-8') as f:
                f.write(caption)


In [None]:
import os

train_img_dir = '/data/youlee/data_library/oxford_102_flowers/image/train'
test_img_dir = '/data/youlee/data_library/oxford_102_flowers/image/test'
train_caption_dir = '/data/youlee/data_library/oxford_102_flowers/caption/train'
test_caption_dir = '/data/youlee/data_library/oxford_102_flowers/caption/test'

dirs = {
    "Train Images": train_img_dir,
    "Test Images": test_img_dir,
    "Train Captions": train_caption_dir,
    "Test Captions": test_caption_dir,
}

for dir_label, dir_path in dirs.items():
    print(f"Directory: {dir_label}")
    
    if not os.path.exists(dir_path):
        print(f"  {dir_path} does not exist!")
        continue
    
    groups = sorted([d for d in os.listdir(dir_path) if os.path.isdir(os.path.join(dir_path, d))])
    for group in groups:
        group_path = os.path.join(dir_path, group)
        file_count = 0
        
        for root, _, files in os.walk(group_path):
            file_count += len(files)
        print(f"  {group}: {file_count} files")
    print()


Directory: Train Images
  group1: 458 files
  group2: 459 files
  group3: 462 files
  group4: 462 files
  group5: 460 files
  group6: 465 files

Directory: Test Images
  group1: 117 files
  group2: 117 files
  group3: 118 files
  group4: 118 files
  group5: 117 files
  group6: 118 files

Directory: Train Captions
  group1: 458 files
  group2: 459 files
  group3: 462 files
  group4: 462 files
  group5: 460 files
  group6: 465 files

Directory: Test Captions
  group1: 117 files
  group2: 117 files
  group3: 118 files
  group4: 118 files
  group5: 117 files
  group6: 118 files

