In [46]:
import numpy as np
import pandas as pd
import os.path as osp
import glob
import matplotlib.pyplot as plt
import random

from PIL import Image
import torch
import torch.utils.data as data
import torchvision
from torchvision import models, transforms

In [25]:
# set random seed
torch.manual_seed(334)
np.random.seed(334)
random.seed(334)

In [14]:
data_root = osp.join('..','input','cassava-leaf-disease-classification')
im = Image.open(osp.join(data_root,'train_images','1000015157.jpg'))

In [42]:
def make_datapath_list(phase='train'):
    rootpath = osp.join('..', 'input', 'cassava-leaf-disease-classification')
    target_path = osp.join(rootpath, phase+'_images', '*.jpg')

    path_list = []
    for path in glob.glob(target_path):
        path_list.append(path)
    return path_list

In [43]:
train_list = make_datapath_list('train')
# test_list = make_datapath_list('test')

In [27]:
class ImageTransform():
    def __init__(self, resize, mean, std):
        self.data_transform = {
            'train': transforms.Compose([
                transforms.RandomResizedCrop(resize, scale=(0.5, 1.0)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ]),
            'val': transforms.Compose([
                transforms.Resize(resize),
                transforms.CenterCrop(resize),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ])
        }
        
    def __call__(self, img, phase='train'):
        return self.data_transform[phase](img)


In [None]:
size = 224
mean = [0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]

transform = ImageTransform(size, mean, std)
img_transformed = transform(img, 'train')

In [None]:
class CassavaDataset(data.Dataset):
    def __init__(self, file_list, transform=None, file_to_label_map, phase='train'):
        self.file_list = file_list
        self.transform = transform
        self.file_to_label_map = file_to_label_map
        self.phase = phase
        
    def __len__(self):
        return len(self.file_list)
    
    def __getitem__(self, index):
        img_path = self.file_list[index]
        img = Image.open(img_path)
        img_transformed = self.transform(img, self.phase)
        label = self.file_to_label_map[osp.basename(img_path)]
        
        return img_transformed, label
        

In [47]:
x = pd.read_csv(osp.join('..', 'input', 'cassava-leaf-disease-classification', 'train.csv'))

In [50]:
d = dict(zip(x.image_id, x.label))
d['1000015157.jpg']

0