In [1]:
import os
import shutil
from PIL import Image
from sklearn.model_selection import train_test_split

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

In [2]:
class CASIAWebFaceDataset(Dataset):
    def __init__(self, file_list, transform=None):
        self.file_list = file_list
        self.transform = transform

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        img_path, label = self.file_list[idx]
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, label

In [3]:
def get_all_data(dataset_path):
    classes = os.listdir(dataset_path)
    all_files = []

    for cls in classes:
        cls_folder = os.path.join(dataset_path, cls)
        cls_files = [os.path.join(cls_folder, file) for file in os.listdir(cls_folder)]
        all_files += [(file, cls) for file in cls_files]

    return all_files

In [4]:
dataset_path = 'dataset/CasiaAligned'  # 데이터셋 경로 지정

# 데이터 가져오기
all_files = get_all_data(dataset_path)

In [5]:
print(len(all_files))

491636


In [6]:
def save_all_data(all_files, dataset_path):
    # 전체 데이터를 Train 데이터 폴더에 저장
    for file_path, cls in all_files:
        dest_folder = os.path.join(dataset_path, 'train', cls)
        os.makedirs(dest_folder, exist_ok=True)
        shutil.copy(file_path, dest_folder)

In [7]:
save_all_data(all_files, 'dataset/face/')