### Installing necessary dependencies

In [15]:
# %pip install pytorch_lightning lightning_utilities torchmetrics tqdm pyyaml matplotlib
# %pip install nvidia-dali-cuda120

In [16]:
import glob 
import os
import numpy as np
import time

import nvidia.dali.fn as fn
import nvidia.dali.types as types
from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.pytorch import (
    DALIGenericIterator as PyTorchIterator,
)

import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from torchvision.io import decode_image, decode_jpeg
from torch.utils.data import DataLoader, Dataset

from PIL import Image
from random import shuffle

from matplotlib import pyplot as plt

In [17]:
from nvidia.dali.pipeline import Pipeline

import numpy as np
import nvidia.dali.fn as fn
import nvidia.dali.types as types
from nvidia.dali.plugin.pytorch import DALIGenericIterator

In [18]:
wds_data = "/home/a.filatov/workdir/dataset_loading/tar_repo/0.tar"
batch_size = 4

In [19]:
def sharded_pipeline(device_id, shard_id, num_shards):
    pipe = Pipeline(batch_size=batch_size, num_threads=batch_size, device_id=device_id)
    with pipe:
        img_raw = fn.readers.webdataset(
            paths=wds_data, ext=["jpeg;png;jpg"], missing_component_behavior="skip", dtypes=types.UINT8,             
            shard_id=shard_id,
            num_shards=num_shards
        )
        img = fn.decoders.image(img_raw, device="mixed", output_type=types.RGB)
        img = fn.resize(img, device="gpu", resize_x=1024, resize_y=1024)
        img = fn.crop_mirror_normalize(
            img,
            dtype=types.FLOAT16,
            mean=[0.0, 0.0, 0.0],
            std=[255.0, 255.0, 255.0],
            scale=2,
            shift=-1,
        )

        pipe.set_outputs(img,)

    return pipe

In [20]:
pipe = sharded_pipeline(0, 0, 1)

In [21]:
dali_iter = DALIGenericIterator([pipe], ["images"], size=1000)

[/opt/dali/dali/operators/reader/loader/webdataset_loader.cc:380] Index file not provided, it may take some time to infer it from the tar file


### Model test

In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        # First convolutional block: input channels 3, output channels 32
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1)
        # Second convolutional block: input channels 32, output channels 64
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        # Third convolutional block: input channels 64, output channels 128
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        # Max pooling layer
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # After three pooling operations, the image size reduces from 1024 to 1024/8 = 128
        # The output feature map size will be: 128 channels, 128x128 spatial dimensions.
        # Flattened feature size = 128 * 128 * 128
        self.fc1 = nn.Linear(128 * 128 * 128, 256)
        self.fc2 = nn.Linear(256, num_classes)
    
    def forward(self, x):
        # x shape: (batch_size, 3, 1024, 1024)
        x = F.relu(self.conv1(x))
        x = self.pool(x)  # -> shape: (batch_size, 32, 512, 512)
        
        x = F.relu(self.conv2(x))
        x = self.pool(x)  # -> shape: (batch_size, 64, 256, 256)
        
        x = F.relu(self.conv3(x))
        x = self.pool(x)  # -> shape: (batch_size, 128, 128, 128)
        
        # Flatten the tensor for the fully connected layers
        x = x.view(x.size(0), -1)  # -> shape: (batch_size, 128*128*128)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [23]:
# Create an instance of the network
model = SimpleCNN(num_classes=10).to('cuda:0', dtype=torch.float16)

In [24]:
for i, data in enumerate(dali_iter):
    res = model(data[0]['images'])
    print(f"ENDED {i}")

ENDED 0
ENDED 1
ENDED 2
ENDED 3
ENDED 4
ENDED 5
ENDED 6
ENDED 7
ENDED 8
ENDED 9
ENDED 10
ENDED 11
ENDED 12
ENDED 13
ENDED 14
ENDED 15
ENDED 16
ENDED 17
ENDED 18
ENDED 19
ENDED 20
ENDED 21
ENDED 22
ENDED 23
ENDED 24
ENDED 25
ENDED 26
ENDED 27
ENDED 28
ENDED 29
ENDED 30
ENDED 31
ENDED 32
ENDED 33
ENDED 34
ENDED 35
ENDED 36
ENDED 37
ENDED 38
ENDED 39
ENDED 40
ENDED 41
ENDED 42
ENDED 43
ENDED 44
ENDED 45
ENDED 46
ENDED 47
ENDED 48
ENDED 49
ENDED 50
ENDED 51
ENDED 52
ENDED 53
ENDED 54
ENDED 55
ENDED 56
ENDED 57
ENDED 58
ENDED 59
ENDED 60
ENDED 61
ENDED 62
ENDED 63
ENDED 64
ENDED 65
ENDED 66
ENDED 67
ENDED 68
ENDED 69
ENDED 70
ENDED 71
ENDED 72
ENDED 73
ENDED 74
ENDED 75
ENDED 76
ENDED 77
ENDED 78
ENDED 79
ENDED 80
ENDED 81
ENDED 82
ENDED 83
ENDED 84
ENDED 85
ENDED 86
ENDED 87
ENDED 88
ENDED 89
ENDED 90
ENDED 91
ENDED 92
ENDED 93
ENDED 94
ENDED 95
ENDED 96
ENDED 97
ENDED 98
ENDED 99
ENDED 100
ENDED 101
ENDED 102
ENDED 103
ENDED 104
ENDED 105
ENDED 106
ENDED 107
ENDED 108
ENDED 109
ENDED 110
