In [3]:
import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import Food101
from torch.utils.data import DataLoader
from tqdm import tqdm

# NVIDIA DALI imports 
try:
    from nvidia.dali.pipeline import Pipeline
    import nvidia.dali.ops as ops
    import nvidia.dali.types as types
    from nvidia.dali.plugin.pytorch import DALIClassificationIterator, LastBatchPolicy
except ImportError:
    print("="*80)
    print("!!! NVIDIA DALI not installed; skipping DALI parts. !!!")
    print("="*80)
    DALIClassificationIterator = None

# Configuration -
BATCH_SIZE   = 64
MAX_EPOCHS   = 5
NUM_CLASSES  = 101
DATA_DIR     = "./data"
DEVICE       = "cuda" if torch.cuda.is_available() else "cpu"

# Disable DALI on CPU
if DEVICE == "cpu":
    print("Warning: Running on CPU; DALI will be disabled.")
    DALIClassificationIterator = None

# Download Food-101 if needed
def download_food101(root_dir):
    marker = os.path.join(root_dir, "food-101")
    if os.path.isdir(marker):
        return
    print("Downloading Food-101...")
    Food101(root=root_dir, split='train', download=True)
    Food101(root=root_dir, split='test',  download=True)

# Create DALI file list of ABSOLUTE paths 
def create_dali_file_list(root_dir, output_file="train.txt"):
    ds = Food101(root=root_dir, split='train', download=False)
    out_path = os.path.join(root_dir, "food-101", output_file)
    print("Writing DALI file list to", out_path)
    with open(out_path, 'w') as f:
        for img_path, label in zip(ds._image_files, ds._labels):
            f.write(f"{os.path.abspath(img_path)} {label}\n")
    return out_path

# PyTorch DataLoader 
def get_pytorch_dataloader(root_dir, batch_size, split='train'):
    is_train = (split=='train')
    transforms_list = []
    if is_train:
        transforms_list += [
            transforms.Resize(256),
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
        ]
    else:
        transforms_list += [
            transforms.Resize(256),
            transforms.CenterCrop(224),
        ]
    transforms_list += [
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
    ]
    ds = Food101(root=root_dir, split=split, transform=transforms.Compose(transforms_list), download=False)
    return DataLoader(ds, batch_size=batch_size, shuffle=is_train,
                      num_workers=(4 if DEVICE=='cuda' else 0), pin_memory=(DEVICE=='cuda'))

# dALI DataLoader (with resizing
if DALIClassificationIterator:
    class DaliTrainPipeline(Pipeline):
        def __init__(self, batch_size, num_threads, device_id, file_list_path):
            super().__init__(batch_size, num_threads, device_id, seed=12)
            self.input  = ops.readers.File(file_root="", file_list=file_list_path,
                                           random_shuffle=True, name="Reader")
            self.decode = ops.decoders.Image(device="mixed")
            self.resize = ops.Resize(device="gpu", resize_shorter=256)
            self.cmn    = ops.CropMirrorNormalize(
                device="gpu", output_dtype=types.FLOAT, output_layout=types.NCHW,
                crop=(224,224), mean=[0.485*255,0.456*255,0.406*255],
                std=[0.229*255,0.224*255,0.225*255], image_type=types.RGB)
            self.coin = ops.random.CoinFlip(probability=0.5)

        def define_graph(self):
            rng = self.coin()
            jpegs, labels = self.input()
            images = self.decode(jpegs).gpu()
            images = self.resize(images)
            out    = self.cmn(images, mirror=rng)
            return out, labels.gpu()

    def get_dali_dataloader(data_dir, batch_size):
        file_list = os.path.join(data_dir, "food-101", "train.txt")
        pipeline = DaliTrainPipeline(batch_size=batch_size, num_threads=4,
                                     device_id=0, file_list_path=file_list)
        return DALIClassificationIterator(
            pipelines=[pipeline], last_batch_policy=LastBatchPolicy.PARTIAL,
            reader_name="Reader")

# Model & Evaluatio
def get_model():
    model = torchvision.models.resnet50(weights='IMAGENET1K_V1')
    model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)
    return model.to(DEVICE)

def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            preds = model(inputs).argmax(dim=1)
            correct += preds.eq(labels).sum().item()
            total   += labels.size(0)
    return correct / total

# Training Loop with Robust Batch Detection 
from torch.cuda.amp import GradScaler

def train(model, loader, optimizer, criterion, num_epochs,
          desc, grad_accum_steps=1, use_amp=False):
    scaler = GradScaler() if use_amp else None
    start_time = time.time()

    for epoch in range(num_epochs):
        model.train()
        loop = tqdm(loader, desc=f"{desc} Epoch {epoch+1}/{num_epochs}")
        for step, batch in enumerate(loop):
            # Detect DALI batch by element type
            if isinstance(batch, (list, tuple)) and len(batch) and isinstance(batch[0], dict):
                data_dict = batch[0]
                inputs = data_dict['data']
                labels = data_dict['label'].squeeze(-1).long()
            else:
                inputs, labels = batch
                inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)

            optimizer.zero_grad()
            if use_amp:
                with torch.cuda.amp.autocast():
                    loss = criterion(model(inputs), labels) / grad_accum_steps
                scaler.scale(loss).backward()
            else:
                loss = criterion(model(inputs), labels) / grad_accum_steps
                loss.backward()

            if (step + 1) % grad_accum_steps == 0:
                if use_amp:
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    optimizer.step()
                optimizer.zero_grad()

            loop.set_postfix(loss=(loss.item() * grad_accum_steps))

        # Reset DALI iterator if in use
        if isinstance(loader, DALIClassificationIterator):
            loader.reset()

    return time.time() - start_time

# Main Execution & 5-Experiment series
if __name__ == "__main__":
    download_food101(DATA_DIR)
    test_loader = get_pytorch_dataloader(DATA_DIR, BATCH_SIZE, split='test')
    criterion = nn.CrossEntropyLoss()
    results = []

    # 1) PyTorch Baseline
    pt_model = get_model()
    pt_loader= get_pytorch_dataloader(DATA_DIR, BATCH_SIZE, split='train')
    pt_opt   = optim.SGD(pt_model.parameters(), lr=1e-3, momentum=0.9)
    if DEVICE=='cuda': torch.cuda.reset_peak_memory_stats()
    t1 = train(pt_model, pt_loader, pt_opt, criterion, MAX_EPOCHS,
               "PyTorch Baseline", grad_accum_steps=1, use_amp=False)
    a1 = evaluate(pt_model, test_loader)
    m1 = torch.cuda.max_memory_allocated()/1e6 if DEVICE=='cuda' else None
    results.append(("PyTorch Baseline", t1, a1, m1))

    if DALIClassificationIterator:
        create_dali_file_list(DATA_DIR)

        # 2) DALI Baseline
        db_model = get_model()
        db_loader= get_dali_dataloader(DATA_DIR, BATCH_SIZE)
        db_opt   = optim.SGD(db_model.parameters(), lr=1e-3, momentum=0.9)
        if DEVICE=='cuda': torch.cuda.reset_peak_memory_stats()
        t2 = train(db_model, db_loader, db_opt, criterion, MAX_EPOCHS,
                   "DALI Baseline", grad_accum_steps=1, use_amp=False)
        a2 = evaluate(db_model, test_loader)
        m2 = torch.cuda.max_memory_allocated()/1e6 if DEVICE=='cuda' else None
        results.append(("DALI Baseline", t2, a2, m2))


        # 4) DALI +AMP
        daa_model= get_model()
        daa_loader= get_dali_dataloader(DATA_DIR, BATCH_SIZE)
        daa_opt   = optim.SGD(daa_model.parameters(), lr=1e-3, momentum=0.9)
        if DEVICE=='cuda': torch.cuda.reset_peak_memory_stats()
        t4 = train(daa_model, daa_loader, daa_opt, criterion, MAX_EPOCHS,
                   "DALI +  AMP", grad_accum_steps=1, use_amp=True)
        a4 = evaluate(daa_model, test_loader)
        m4 = torch.cuda.max_memory_allocated()/1e6 if DEVICE=='cuda' else None
        results.append(("DALI +  AMP", t4, a4, m4))

        # 5) DALI +  AMP + Compile
        dac_model= get_model()
        dac_loader= get_dali_dataloader(DATA_DIR, BATCH_SIZE)
        dac_opt   = optim.SGD(dac_model.parameters(), lr=1e-3, momentum=0.9)
        if DEVICE=='cuda': torch.cuda.reset_peak_memory_stats()
        dac_model= torch.compile(dac_model)
        t5 = train(dac_model, dac_loader, dac_opt, criterion, MAX_EPOCHS,
                   "DALI +  AMP + Compile", grad_accum_steps=1, use_amp=True)
        a5 = evaluate(dac_model, test_loader)
        m5 = torch.cuda.max_memory_allocated()/1e6 if DEVICE=='cuda' else None
        results.append(("DALI +  AMP + Compile", t5, a5, m5))




PyTorch Baseline Epoch 1/5: 100%|██████████| 1184/1184 [02:33<00:00,  7.71it/s, loss=1.96]
PyTorch Baseline Epoch 2/5: 100%|██████████| 1184/1184 [02:01<00:00,  9.75it/s, loss=1.64]
PyTorch Baseline Epoch 3/5: 100%|██████████| 1184/1184 [02:22<00:00,  8.33it/s, loss=1.18] 
PyTorch Baseline Epoch 4/5: 100%|██████████| 1184/1184 [02:30<00:00,  7.86it/s, loss=1.4]  
PyTorch Baseline Epoch 5/5: 100%|██████████| 1184/1184 [02:32<00:00,  7.78it/s, loss=0.955]


Writing DALI file list to ./data/food-101/train.txt


  kwargs = _handle_arg_deprecations(schema, kwargs, operator_name)
  kwargs = _handle_arg_deprecations(schema, kwargs, operator_name)
DALI Baseline Epoch 1/5: 100%|██████████| 1184/1184 [01:59<00:00,  9.91it/s, loss=1.27]
DALI Baseline Epoch 2/5: 100%|██████████| 1184/1184 [01:59<00:00,  9.92it/s, loss=1.03] 
DALI Baseline Epoch 3/5: 100%|██████████| 1184/1184 [01:59<00:00,  9.93it/s, loss=1.03] 
DALI Baseline Epoch 4/5: 100%|██████████| 1184/1184 [01:59<00:00,  9.91it/s, loss=0.446]
DALI Baseline Epoch 5/5: 100%|██████████| 1184/1184 [01:59<00:00,  9.91it/s, loss=0.654]
  kwargs = _handle_arg_deprecations(schema, kwargs, operator_name)
  kwargs = _handle_arg_deprecations(schema, kwargs, operator_name)
  return F.conv2d(input, weight, bias, self.stride,
DALI +  AMP Epoch 1/5: 100%|██████████| 1184/1184 [01:01<00:00, 19.29it/s, loss=1.33]
DALI +  AMP Epoch 2/5: 100%|██████████| 1184/1184 [01:00<00:00, 19.52it/s, loss=1.05] 
DALI +  AMP Epoch 3/5: 100%|██████████| 1184/1184 [01:00<00:00,

In [4]:
# Summary
print("\n" + "="*60)
print(" Experiment                       Time(s)   Accuracy(%)   Peak GPU Mem(MB)")
print("="*60)
for name, t, acc, mem in results:
    mem_str = f"{mem:.1f}" if mem is not None else "N/A"
    print(f"{name:30s} {t:8.1f}   {acc*100:10.2f}   {mem_str:>15s}")
print("="*60)
if len(results) >= 2:
    speedup = results[0][1]/results[1][1]
    print(f"\nSpeedup (PyTorch / DALI Baseline): {speedup:.2f}×")


 Experiment                       Time(s)   Accuracy(%)   Peak GPU Mem(MB)
PyTorch Baseline                  720.2        80.28            6251.2
DALI Baseline                     597.0        81.12            6445.0
DALI +  AMP                       303.4        81.24            4056.0
DALI +  AMP + Compile             382.0        81.16            4342.4

Speedup (PyTorch / DALI Baseline): 1.21×


### 1:30 min from torch.compile should subtracted from the comparison , because of one time graph creation