In [1]:
import glob
import random
import os
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.data import DataLoader

In [3]:
def make_directory(train_dir):
    # create directory for each dataset
    if not os.path.exists(os.path.join(train_dir, 'train')):
        os.mkdir(os.path.join(train_dir, 'train'))
    
    if not os.path.exists(os.path.join(train_dir, 'valid')):
        os.mkdir(os.path.join(train_dir, 'train', 'valid'))
    
    labels = ['cat', 'dog']
    
    # create directory for each label
    for label in labels:
        if not os.path.exists(os.path.join(train_dir, 'train', label)):
            os.mkdir(os.path.join(train_dir, 'train', label))
            
    for label in labels:
        if not os.path.exists(os.path.join(train_dir, 'valid', label)):
            os.mkdir(os.path.join(train_dir, 'valid', label))
            

def build_file_structure(train_dir, train_ratio):
    # ratio of train/valid dataset
    files = glob.glob(os.path.join(train_dir, '*.jpg'))
    # shuffle the whole training data
    random.shuffle(files)
    
    boundary = (int)(len(files) * train_ratio)

    make_directory(train_dir)
    
    
    # process train dataset
    for file in files[:boundary]:
        filenames = file.split('\\')[-1].split('.')
        os.rename(file, os.path.join(train_dir, 'train', filenames[0], filenames[1]+'.'+filenames[2]))
        
        
    # process valid dataset
    for file in files[boundary:]:
        filenames = file.split('\\')[-1].split('.')
        os.rename(file, os.path.join(train_dir, 'valid', filenames[0], filenames[1]+'.'+filenames[2]))


def dataset_load(dataset_dir, batch_size, shuffled, num_workers):
    images = ImageFolder(dataset_dir, 
                         transforms.Compose([
                             transforms.Resize((224, 224)),
                             transforms.ToTensor(),
                             transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                  std=[0.229, 0.224, 0.225])
                         ]))
    return DataLoader(images,
                      batch_size=batch_size,
                      shuffle=shuffled,
                      num_workers=num_workers)            

In [21]:
train_data_gen = dataset_load('./dogs-vs-cats-redux-kernels-edition/train/train', 64, True, 3)
valid_data_gen = dataset_load('./dogs-vs-cats-redux-kernels-edition/train/valid', 64, True, 3)