In [1]:
# tranfer_learning.ipynbからコピペ


import glob
import os
import random
import numpy as np
import json
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision
from torchvision import models, transforms


torch.manual_seed(1)
np.random.seed(1)
random.seed(1)


class ImageTransform:
    def __init__(self, resize, mean, std):
        self.data_transform = {
            'train':transforms.Compose([
                transforms.RandomResizedCrop(resize, scale=(0.5, 1.0)),  # random引き延ばし && 切り出し
                transforms.RandomHorizontalFlip(),              # 1/2で画像を反転
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ]),
            'val':transforms.Compose([
                transforms.Resize(resize),
                transforms.CenterCrop(resize),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ])
        }
        
    def __call__(self, img, phase='train'):
        return self.data_transform[phase](img)
    
    
    
def make_datapath_list(phase):
    rootpath = 'pytorch_advanced/1_image_classification/data/hymenoptera_data/'
    target_path = os.path.join(rootpath + phase + '/**/*.jpg')
    return glob.glob(target_path)


class HymenopteraDataset(data.Dataset):
    def __init__(self, file_list, phase, transformer=None):
        self.file_list = file_list
        self.transformer = transformer
        self.phase = phase
    
    def __len__(self):
        return len(self.file_list)
        
    def __getitem__(self, index):
        img_path = self.file_list[index]
        img = Image.open(img_path)
        img_transformed = self.transformer(img, self.phase)
        
        if self.phase == "train":
            label = img_path[68:72]
        elif self.phase == "val":
            label = img_path[66:70]
        else:
            raise Exception("unknown phase")
        
        if label == "ants":
            label = 0
        elif label == "bees":
            label = 1
        
        return img_transformed, label

In [2]:
train_list = make_datapath_list(phase='train')
val_list = make_datapath_list(phase='val')