In [1]:
from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np
import torch

# based on https://gist.github.com/srikarplus/8bdb5bedf0ca25e894e39ea78fce2f39
def get_train_valid_loader(data_dir,
                           batch_size,
                           augment,
                           random_seed,
                           valid_size=0.1,
                           shuffle=True,
                           show_sample=False,
                           num_workers=4,
                           pin_memory=False):
    """
    
    Params
    ------
    - data_dir: path directory to the dataset.
    - batch_size: how many samples per batch to load.
    - augment: whether to apply the data augmentation scheme
      mentioned in the paper. Only applied on the train split.
    - random_seed: fix seed for reproducibility.
    - valid_size: percentage split of the training set used for
      the validation set. Should be a float in the range [0, 1].
    - shuffle: whether to shuffle the train/validation indices.
    - show_sample: plot 9x9 sample grid of the dataset.
    - num_workers: number of subprocesses to use when loading the dataset.
    - pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
      True if using GPU.
    Returns
    -------
    - train_loader: training set iterator.
    - valid_loader: validation set iterator.
    """
    error_msg = "[!] valid_size should be in the range [0, 1]."
    assert ((valid_size >= 0) and (valid_size <= 1)), error_msg

    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    )

    # define transforms
    valid_transform = transforms.Compose([transforms.Resize(512),
                                      transforms.ToTensor(),
                                      normalize])

    train_transform = transforms.Compose([transforms.Resize(512),
                                      transforms.ToTensor(),
                                      normalize])

    # load the dataset
    train_dataset = datasets.ImageFolder(
        root=data_dir, transform=train_transform,
    )

    valid_dataset = datasets.ImageFolder(
        root=data_dir, transform=valid_transform,
    )

    num_train = len(train_dataset)
    indices = list(range(num_train))
    split = int(np.floor(valid_size * num_train))

    if shuffle:
        np.random.seed(random_seed)
        np.random.shuffle(indices)

    train_idx, valid_idx = indices[split:], indices[:split]
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, sampler=train_sampler,
        num_workers=num_workers, pin_memory=pin_memory,
    )
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=batch_size, sampler=valid_sampler,
        num_workers=num_workers, pin_memory=pin_memory,
    )


    return (train_loader, valid_loader, train_dataset.class_to_idx)



  from .autonotebook import tqdm as notebook_tqdm
  Referenced from: <870081F6-12FD-3CEA-BC5C-30F4764F2A98> /Users/juniverse/opt/anaconda3/envs/veda/lib/python3.8/site-packages/torchvision/image.so
  warn(


In [2]:
if __name__=='__main__':
    batch_size = 64
    PATH = dir = '/Users/juniverse/Desktop/pointcloud/VectorUniverse/Data/thumbnails/images/'
    print(batch_size,PATH)
    train_loader, test_loader, class_idx = get_train_valid_loader(PATH, batch_size, False, 42, 0.1)
    print(train_loader)
    
    for i in train_loader:
        print(i)
        break

64 /Users/juniverse/Desktop/pointcloud/VectorUniverse/Data/thumbnails/images/
<torch.utils.data.dataloader.DataLoader object at 0x7f8d26c6a4c0>


  Referenced from: <870081F6-12FD-3CEA-BC5C-30F4764F2A98> /Users/juniverse/opt/anaconda3/envs/veda/lib/python3.8/site-packages/torchvision/image.so
  warn(
  Referenced from: <870081F6-12FD-3CEA-BC5C-30F4764F2A98> /Users/juniverse/opt/anaconda3/envs/veda/lib/python3.8/site-packages/torchvision/image.so
  warn(
  Referenced from: <870081F6-12FD-3CEA-BC5C-30F4764F2A98> /Users/juniverse/opt/anaconda3/envs/veda/lib/python3.8/site-packages/torchvision/image.so
  warn(
  Referenced from: <870081F6-12FD-3CEA-BC5C-30F4764F2A98> /Users/juniverse/opt/anaconda3/envs/veda/lib/python3.8/site-packages/torchvision/image.so
  warn(


[tensor([[[[-0.0287, -0.0458, -0.0458,  ...,  2.1804,  2.1804,  2.1804],
          [-0.0287, -0.0287, -0.0287,  ...,  2.1975,  2.1975,  2.1975],
          [-0.0287, -0.0287, -0.0287,  ...,  2.1975,  2.1975,  2.1975],
          ...,
          [ 2.0434,  2.0605,  2.0605,  ..., -1.1932, -1.2274, -1.2274],
          [ 2.0263,  2.0434,  2.0605,  ..., -1.1760, -1.1760, -1.1760],
          [ 2.0263,  2.0263,  2.0263,  ..., -1.1589, -1.1589, -1.1589]],

         [[-1.1779, -1.1954, -1.1954,  ...,  1.8859,  1.8859,  1.8859],
          [-1.1779, -1.1779, -1.1779,  ...,  1.9034,  1.9034,  1.9034],
          [-1.1779, -1.1779, -1.1779,  ...,  1.9034,  1.9034,  1.9034],
          ...,
          [ 1.4482,  1.4657,  1.4657,  ..., -1.8606, -1.8431, -1.8431],
          [ 1.4832,  1.4832,  1.5007,  ..., -1.8431, -1.8431, -1.8431],
          [ 1.5007,  1.5007,  1.5007,  ..., -1.8606, -1.8606, -1.8606]],

         [[-1.0724, -1.0898, -1.0898,  ...,  0.7402,  0.7228,  0.7054],
          [-1.0724, -1.0724, 