[Blog post](https://www.basicml.com)

In [18]:
!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2019 NVIDIA Corporation
Built on Sun_Jul_28_19:07:16_PDT_2019
Cuda compilation tools, release 10.1, V10.1.243


In [19]:
!pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/cuda/10.0 nvidia-dali

Looking in indexes: https://pypi.org/simple, https://developer.download.nvidia.com/compute/redist/cuda/10.0


In [20]:
import types
import numpy as np
import collections
import pandas as pd
from torch.utils import data
from random import shuffle

import nvidia.dali.ops as ops
import nvidia.dali.types as types
from nvidia.dali.pipeline import Pipeline

In [21]:
class ExternalInputIterator(object):
    def __init__(self, batch_size, data_file, image_dir, shuffle_files=True):
        self.images_dir = image_dir
        self.batch_size = batch_size
        self.data_file = data_file
        self.shuffle_files = shuffle_files
        with open(self.data_file, 'r') as f:
            self.files = [line.rstrip() for line in f if line is not '']
        self.idxs = []

    def __iter__(self):
        self.n = len(self.files)
        return self

    def __next__(self):
        batch = []
        labels = []
        for _ in range(self.batch_size):
            jpeg_filename, text, *label = self.files[self.get_idx()].split(' ')
            f = open(image_dir + jpeg_filename, 'rb')
            batch.append(np.frombuffer(f.read(), dtype = np.uint8))
            labels.append(np.array(label, dtype = np.uint8))
        return (batch, labels)
    
    def get_idx(self):
        if len(self.idxs) == 0:
            print("Shuffling")
            self.idxs = list(range(self.n))
            if self.shuffle_files:
                shuffle(self.idxs)
        return self.idxs.pop()
      
    next = __next__

In [22]:
class ExternalInputDataset(data.Dataset):
    def __init__(self, batch_size, data_file, image_dir, shuffle_files=True):
        self.images_dir = image_dir
        self.batch_size = batch_size
        self.data_file = data_file
        self.shuffle_files = shuffle_files
        with open(self.data_file, 'r') as f:
            self.files = [line.rstrip() for line in f if line is not '']
        self.idxs = []

    def __len__(self):
        return len(self.files)
        
    def __getitem__(self, index):
        jpeg_filename, *label = self.files[index].split(' ')
        f = open(image_dir + jpeg_filename, 'rb')
        image = np.frombuffer(f.read(), dtype = np.uint8)
        label = np.array(label, dtype = np.uint8)
        return image, label

In [None]:
!wget -cq https://s3.amazonaws.com/content.udacity-data.com/courses/nd188/flower_data.zip
!unzip -qq flower_data.zip
!mkdir -p ./flower_data/flower_data_flat
!find ./flower_data/train -mindepth 2 -type f -exec mv -t ./flower_data/flower_data_flat -i '{}' +

replace flower_data/valid/61/image_06296.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: 

In [None]:
image_dir = "./flower_data/flower_data_flat/"

In [None]:
from os import listdir
from os.path import isfile, join
image_files = [f for f in listdir(image_dir) if isfile(join(image_dir, f))]

In [None]:
image_files

In [None]:
data_frame = pd.DataFrame(list(zip(image_files, 
                               list(range(len(image_files))), 
                               list(range(len(image_files))))), 
                          columns=['image_filename', 'label_1', 'label_2'])

In [None]:
data_frame.head()

In [None]:
import PIL.Image

In [None]:
for file in data_frame['image_filename']:
  PIL.Image.open(image_dir + file).convert('RGB')

In [None]:
data_frame.to_csv('dummy_data.csv', index=False, header=False, sep=" ")

In [None]:
dataset = ExternalInputDataset(batch_size=16, 
                            data_file='dummy_data.csv', 
                            image_dir=image_dir)

In [None]:
dataset[0]

In [None]:
def collate_fn(batch):
  images = [item[0] for item in batch]
  labels = [item[1] for item in batch]
  return images, labels

In [None]:
dataset_loader = data.DataLoader(dataset, batch_size=16, shuffle=True, num_workers=1, pin_memory=False, collate_fn=collate_fn)

In [None]:
iter(dataset_loader).next()

In [None]:
eii = ExternalInputIterator(batch_size=16, 
                            data_file='dummy_data.csv', 
                            image_dir=image_dir)
iterator = iter(eii)

In [None]:
im, lab = next(iterator)

In [None]:
im

In [None]:
lab

In [None]:
class ExternalSourcePipeline(Pipeline):
    def __init__(self, source, batch_size, num_threads, device_id):
        super(ExternalSourcePipeline, self).__init__(batch_size,
                                      num_threads,
                                      device_id,
                                      seed=12)
        self.source = source
        self.source_iter = iter(source)
        self.input = ops.ExternalSource()
        self.input_label = ops.ExternalSource()
        self.decode = ops.ImageDecoder(device = "mixed", output_type = types.RGB)
        self.res = ops.Resize(device="gpu", resize_x=224, resize_y=224, interp_type=types.INTERP_TRIANGULAR)
        self.cast = ops.Cast(device = "gpu", dtype = types.INT32)

    def define_graph(self):
        self.jpegs = self.input()
        self.labels = self.input_label()
        images = self.decode(self.jpegs)
        output = self.res(images)
        return (output, self.labels)

    def iter_setup(self):
        try:
          p = self.source_iter.next()
        except:
          print("Exception occured")
          self.source_iter = iter(self.source)
          p = self.source_iter.next()
        images, labels = p
        self.feed_input(self.jpegs, images)
        self.feed_input(self.labels, labels)

In [None]:
pipe = ExternalSourcePipeline(source=dataset_loader, batch_size=16, num_threads=4, device_id = 0)
pipe.build()

In [None]:
from nvidia.dali.plugin.pytorch import DALIGenericIterator

In [None]:
len(dataset)

In [None]:
len(dataset_loader)

In [None]:
dali_iter = DALIGenericIterator([pipe], ['images', 'labels'], 400*16)

In [None]:
dali_iter

In [None]:
#%%time
import time
start = time.time()
for epoch in range(16):
  for i, it in enumerate(dali_iter, 5):
    batch_data = it[0]
    images, labels = batch_data["images"], batch_data["labels"]
    #print(len(images))
  dali_iter.reset()
print(time.time() - start)