In [6]:
import os
import shutil
from PIL import Image
from sklearn.model_selection import train_test_split

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

In [7]:
class CASIAWebFaceDataset(Dataset):
    def __init__(self, file_list, transform=None):
        self.file_list = file_list
        self.transform = transform

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        img_path, label = self.file_list[idx]
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, label

In [8]:
def split_data(dataset_path, test_split=0.1, val_split=0.1, min_samples_per_class=10):
    classes = os.listdir(dataset_path)
    train_files, val_files, test_files = [], [], []

    for cls in classes:
        cls_folder = os.path.join(dataset_path, cls)
        cls_files = [os.path.join(cls_folder, file) for file in os.listdir(cls_folder)]

        if len(cls_files) < min_samples_per_class:
            train_files += [(file, cls) for file in cls_files]
        else:
            cls_train_val, cls_test = train_test_split(cls_files, test_size=test_split)
            cls_train, cls_val = train_test_split(cls_train_val, test_size=val_split)

            train_files += [(file, cls) for file in cls_train]
            val_files += [(file, cls) for file in cls_val]
            test_files += [(file, cls) for file in cls_test]

    return train_files, val_files, test_files

In [9]:
dataset_path = 'dataset/CasiaAligned' # 데이터셋 경로 지정

# 데이터 분할
train_files, val_files, test_files = split_data(dataset_path)

In [10]:
print(len(train_files))
print(len(val_files))
print(len(test_files))

389295
48567
53774


In [11]:
def save_split_data(train_files, val_files, test_files, dataset_path):
    # Train 데이터 저장
    for file_path, cls in train_files:
        dest_folder = os.path.join(dataset_path, 'train', cls)
        os.makedirs(dest_folder, exist_ok=True)
        shutil.copy(file_path, dest_folder)

    # Val 데이터 저장
    val_dest_folder = os.path.join(dataset_path, 'val', 'images')
    os.makedirs(val_dest_folder, exist_ok=True)
    for file_path, _ in val_files:
        shutil.copy(file_path, val_dest_folder)

    # Test 데이터 저장
    test_dest_folder = os.path.join(dataset_path, 'test', 'images')
    os.makedirs(test_dest_folder, exist_ok=True)
    for file_path, _ in test_files:
        shutil.copy(file_path, test_dest_folder)

In [12]:
save_split_data(train_files, val_files, test_files, 'dataset/CASIAWebFace')