In [16]:
import os
import random
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class TripletDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform

        # 각 클래스에 해당하는 이미지 경로를 저장
        self.classes = [d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
        self.class_to_images = {cls: [os.path.join(root_dir, cls, img) for img in os.listdir(os.path.join(root_dir, cls))] for cls in self.classes}

    def __len__(self):
        total_images = sum(len(images) for images in self.class_to_images.values())
        return total_images

    def __getitem__(self, idx):
        anchor_class = random.choice(self.classes)
        anchor_img_path = random.choice(self.class_to_images[anchor_class])
        positive_img_path = random.choice(self.class_to_images[anchor_class])

        # 네거티브 클래스를 앵커 클래스와 다른 클래스로 선택
        negative_class = random.choice([cls for cls in self.classes if cls != anchor_class])
        negative_img_path = random.choice(self.class_to_images[negative_class])

        # 이미지 로드 및 변환
        anchor_img = Image.open(anchor_img_path)
        positive_img = Image.open(positive_img_path)
        negative_img = Image.open(negative_img_path)

        if self.transform:
            anchor_img = self.transform(anchor_img)
            positive_img = self.transform(positive_img)
            negative_img = self.transform(negative_img)

        return anchor_img, positive_img, negative_img

In [17]:
# 이미지 전처리
transform = transforms.Compose([
    transforms.Resize((160, 160)),  # 모델 입력 크기와 일치하도록 조정
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 정규화
])

# 데이터셋 경로
train_dataset_path = r"E:\data\train"
valid_dataset_path = r"E:\data\valid"
test_dataset_path = r"E:\data\test"

# 데이터셋 로드
train_dataset = TripletDataset(root_dir=train_dataset_path, transform=transform)
valid_dataset = TripletDataset(root_dir=valid_dataset_path, transform=transform)
test_dataset = TripletDataset(root_dir=test_dataset_path, transform=transform)

# 데이터 로더 설정
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=os.cpu_count())
valid_loader = DataLoader(valid_dataset, batch_size=64, shuffle=False, num_workers=os.cpu_count())
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=os.cpu_count())