In [106]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from PIL import Image

- 이미지 이름과 이미지의 경로, csv파일의 경로를 불러온다.

In [94]:
img_name = ['normal.jpg', 'mask1.jpg', 'mask2.jpg', 
            'mask3.jpg', 'mask4.jpg', 'mask5.jpg', 'incorrect_mask.jpg']

csv_path = '/opt/ml/input/data/train/train.csv'
data_path = '/opt/ml/input/data/train/images'
mask_image_frame = pd.read_csv(csv_path)

In [17]:
mask_image_frame.head()

Unnamed: 0,id,gender,race,age,path
0,1,female,Asian,45,000001_female_Asian_45
1,2,female,Asian,52,000002_female_Asian_52
2,4,male,Asian,54,000004_male_Asian_54
3,5,female,Asian,58,000005_female_Asian_58
4,6,female,Asian,59,000006_female_Asian_59


- 데이터셋의 정의

In [153]:
class MaskImageDataset(Dataset):
    
    def __init__(self, csv_file, data_path, transform=None):
        """
        Args:
            csv_file (string): csv_file 경로
            data_path (string): data_path 경로
            transform (string): 샘플에 적용될 transform(전처리)
        """
        self.mask_image_frame = pd.read_csv(csv_file)
        self.data_path = data_path
        self.transform = transform
    
    
    def __len__(self):
        return len(self.mask_image_frame) * 7
    
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        idx //= 7
        img_path = os.path.join(self.data_path, 
                                self.mask_image_frame.loc[idx,'path'])
        image = Image.open(os.path.join(img_path, img_name[idx%7]))
        
        if self.transform:
            image = self.transform(image)
        
        label = { 'status': idx%7,
                  'gender': self.mask_image_frame.loc[idx,'gender'],
                  'age': self.mask_image_frame.loc[idx,'age']
                 }
            
        return image, label

In [159]:
data_transform = transforms.Compose([
    transforms.Resize((384,384)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])
    ])

dataset = MaskImageDataset(csv_path, data_path, data_transform)
dataset_loader = DataLoader(dataset,
                            shuffle=True,
                            batch_size=10, 
                            num_workers=8,)