In [1]:
import os
import shutil
import pandas as pd
from sklearn.model_selection import train_test_split

def split_dataset(meta_file, img_dir, output_dir, test_size=200):
    # Read the metadata
    df = pd.read_csv(meta_file)
    
    # Create train and test directories
    train_dir = os.path.join(output_dir, 'train_coco1600')
    test_dir = os.path.join(output_dir, 'test_coco1600')
    os.makedirs(train_dir, exist_ok=True)
    os.makedirs(test_dir, exist_ok=True)

    # Create image directories
    train_img_dir = os.path.join(train_dir, 'images_cocogray')
    test_img_dir = os.path.join(test_dir, 'images_cocogray')
    os.makedirs(train_img_dir, exist_ok=True)
    os.makedirs(test_img_dir, exist_ok=True)

    # Split the dataset
    classes = df['obj'].unique()
    train_data = []
    test_data = []

    for cls in classes:
        class_data = df[df['obj'] == cls]
        class_test = class_data.sample(n=20)
        class_train = class_data.drop(class_test.index)
        
        train_data.append(class_train)
        test_data.append(class_test)

    train_df = pd.concat(train_data)
    test_df = pd.concat(test_data)

    # Copy images and create new metadata files
    def process_set(data, img_src_dir, img_dest_dir, meta_file_name):
        for _, row in data.iterrows():
            img_name = row['image_names']
            src_path = os.path.join(img_src_dir, img_name)
            dest_path = os.path.join(img_dest_dir, img_name)
            shutil.copy2(src_path, dest_path)
        
        meta_path = os.path.join(os.path.dirname(img_dest_dir), meta_file_name)
        data.to_csv(meta_path, index=False)

    process_set(train_df, img_dir, train_img_dir, 'coco1400_meta.csv')
    process_set(test_df, img_dir, test_img_dir, 'coco200_meta.csv')

    print(f"Dataset split complete. Train set: {len(train_df)}, Test set: {len(test_df)}")

# Usage
meta_file = 'data/coco1600/coco1600_meta.csv'
img_dir = 'data/coco1600/images_cocogray'
output_dir = 'data/split_dataset'

split_dataset(meta_file, img_dir, output_dir)

Dataset split complete. Train set: 1400, Test set: 200
