In [None]:
# pip install --extra-index-url https://developer.download.nvidia.com/compute/redist nvidia-dali-cuda110

In [None]:
import nvidia.dali.fn as fn
import nvidia.dali.types as types
from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.pytorch import DALIClassificationIterator, LastBatchPolicy
from torchvision.models import resnet18
import torch
import torch.nn as nn
import torch.optim as optim

# DALI pipeline for CIFAR-10
class CIFAR10Pipeline(Pipeline):
    def __init__(self, batch_size, num_threads, device_id, data_dir, train=True):
        super(CIFAR10Pipeline, self).__init__(batch_size, num_threads, device_id)
        self.input = fn.readers.caffe2(path=data_dir, random_shuffle=train, name="Reader")
        
        # Define pipeline operations
        self.decode = fn.decoders.image(self.input, device="mixed", output_type=types.RGB)
        self.resize = fn.resize(self.decode, resize_x=224, resize_y=224)
        self.cmnp = fn.crop_mirror_normalize(self.resize, dtype=types.FLOAT,
                                             mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
                                             std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
        self.uniform = fn.random.uniform(range=(0.0, 1.0))
        self.mirror = fn.cast(self.uniform > 0.5, dtype=types.DALIDataType.BOOL)

    def define_graph(self):
        inputs, labels = self.input(name="Reader")
        images = self.cmnp(self.resize(self.decode(inputs)), mirror=self.mirror)
        return images, labels

# Initialize the DALI pipeline
batch_size = 32
cifar10_pipeline = CIFAR10Pipeline(batch_size=batch_size, num_threads=2, device_id=0,
                                   data_dir="/path/to/cifar10", train=True)
cifar10_pipeline.build()

# Create a DALI iterator for PyTorch
dali_iterator = DALIClassificationIterator(cifar10_pipeline, last_batch_policy=LastBatchPolicy.PARTIAL)

# Initialize ResNet18
net = resnet18(pretrained=False)
net.fc = nn.Linear(net.fc.in_features, 10)  # Adjust for CIFAR-10

# Training essentials
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

# Train the network
for epoch in range(2):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(dali_iterator):
        # get the inputs; data is a dict with 'data' and 'label'
        inputs, labels = data[0]["data"], data[0]["label"].squeeze().long()
        
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:  # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0

    dali_iterator.reset()

print('Finished Training')

# Cleanup
dali_iterator.shutdown()
