In [1]:
import os
import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from PIL import Image

- 이미지 이름과 이미지의 경로, csv파일의 경로를 불러온다.

In [2]:
img_name = ['normal', 'mask1', 'mask2', 
            'mask3', 'mask4', 'mask5', 'incorrect_mask']

csv_path = '/opt/ml/input/data/train/train.csv'
data_path = '/opt/ml/input/data/train/images'
mask_image_frame = pd.read_csv(csv_path)

test_csv_path = '/opt/ml/input/data/eval/info.csv'
test_data_path = '/opt/ml/input/data/eval/images'

eval_image_frame = pd.read_csv(test_csv_path)

In [3]:
mask_image_frame.head()

Unnamed: 0,id,gender,race,age,path
0,1,female,Asian,45,000001_female_Asian_45
1,2,female,Asian,52,000002_female_Asian_52
2,4,male,Asian,54,000004_male_Asian_54
3,5,female,Asian,58,000005_female_Asian_58
4,6,female,Asian,59,000006_female_Asian_59


In [4]:
eval_image_frame.head()

Unnamed: 0,ImageID,ans
0,cbc5c6e168e63498590db46022617123f1fe1268.jpg,0
1,0e72482bf56b3581c081f7da2a6180b8792c7089.jpg,0
2,b549040c49190cedc41327748aeb197c1670f14d.jpg,0
3,4f9cb2a045c6d5b9e50ad3459ea7b791eb6e18bc.jpg,0
4,248428d9a4a5b6229a7081c32851b90cb8d38d0c.jpg,0


- 데이터셋의 정의

In [5]:
class MaskImageDataset(Dataset):
    
    def __init__(self, csv_file, data_path, transform=None):
        """
        Args:
            csv_file (string): csv_file 경로
            data_path (string): data_path 경로
            transform (string): 샘플에 적용될 transform(전처리)
        """
        self.mask_image_frame = pd.read_csv(csv_file)
        self.data_path = data_path
        self.transform = transform
    
    
    def __len__(self):
        return len(self.mask_image_frame) * 7
    
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        idx //= 7
        img_path = os.path.join(self.data_path, 
                                self.mask_image_frame.loc[idx,'path'])
        
        img_type_list = ['.png', '.jpg', '.jpeg']
        for img_type in img_type_list:
            if os.path.isfile(os.path.join(img_path, img_name[idx%7] + img_type)):
                image = Image.open(os.path.join(img_path, img_name[idx%7] + img_type))
                break
        
        if self.transform:
            image = self.transform(image)
        
        label = { 'status': idx%7,
                  'gender': self.mask_image_frame.loc[idx,'gender'],
                  'age': self.mask_image_frame.loc[idx,'age']
                 }
            
        return image, label

In [6]:
class ValidationSet(Dataset):
    
    def __init__(self, csv_file, data_path, transform=None):
        self.mask_image_frame = pd.read_csv(csv_file)
        self.data_path = data_path
        self.transform = transform


    def __len__(self):
        return len(self.mask_image_frame)


    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        img_path = os.path.join(self.data_path, 
                               self.mask_image_frame.loc[idx, 'ImageID'])
        image = Image.open(img_path)
        
        if self.transform:
            image = self.transform(image)
            
        return image, idx

- torchvision의 transforms 기능으로 전처리 및 augmentation, 그리고 data_loader

In [9]:
data_transform = transforms.Compose([
    transforms.Resize((384,384)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])
    ])

dataset = MaskImageDataset(csv_path, data_path, data_transform)
data_loader = DataLoader(dataset,
                            shuffle=True,
                            batch_size=10, 
                            num_workers=16,)

for data in tqdm.notebook.tqdm(data_loader):
    pass

dataset = ValidationSet(test_csv_path, test_data_path, data_transform)
data_loader = DataLoader(dataset,
                            shuffle=True,
                            batch_size=10, 
                            num_workers=16,)

for data in tqdm.notebook.tqdm(data_loader):
    pass


(tensor([[[ 1.2899,  1.2899,  1.2899,  ...,  1.3755,  1.3755,  1.3927],
         [ 1.2899,  1.2899,  1.2899,  ...,  1.3755,  1.3755,  1.3927],
         [ 1.2899,  1.2899,  1.2899,  ...,  1.3755,  1.3755,  1.3927],
         ...,
         [ 0.2111,  0.1768,  0.1254,  ...,  0.3138,  0.1939, -0.0629],
         [ 0.2111,  0.1768,  0.1597,  ...,  0.2967,  0.3138,  0.3138],
         [ 0.2282,  0.1939,  0.1597,  ...,  0.2453,  0.2967,  0.3994]],

        [[ 1.4307,  1.4307,  1.4307,  ...,  1.5182,  1.5182,  1.5357],
         [ 1.4307,  1.4307,  1.4307,  ...,  1.5182,  1.5182,  1.5357],
         [ 1.4307,  1.4307,  1.4307,  ...,  1.5182,  1.5182,  1.5357],
         ...,
         [ 0.2752,  0.2402,  0.2402,  ..., -0.4426, -0.5826, -0.8102],
         [ 0.2927,  0.2402,  0.2752,  ..., -0.5126, -0.4776, -0.4776],
         [ 0.3102,  0.2577,  0.2752,  ..., -0.6001, -0.5301, -0.4251]],

        [[ 1.5768,  1.5768,  1.5768,  ...,  1.6988,  1.6988,  1.7163],
         [ 1.5768,  1.5768,  1.5768,  ...,  

  0%|          | 0/1890 [00:00<?, ?it/s]

  0%|          | 0/1260 [00:00<?, ?it/s]