### The code runs in a [NGC](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) container nvcr.io/nvidia/pytorch:21.02-py3
 - Python 3.8
 - NVIDIA CUDA 11.2.0
 - DALI 0.29.0
 - PyTorch 1.8.0a0+52ea372
 - Catalyst 21.9

### To get MNIST data need to use [DALI extra](https://github.com/NVIDIA/DALI_extra).

In [1]:
from nvidia.dali.pipeline import Pipeline
import nvidia.dali as dali
import nvidia.dali.ops as ops
import nvidia.dali.types as types
from torch.utils.data import DataLoader

from nvidia.dali.plugin.pytorch import DALIGenericIterator, LastBatchPolicy

In [2]:
dali.__version__

'0.29.0'

### To get test data you need to use [DALI extra](https://github.com/NVIDIA/DALI_extra).

In [3]:
# define pipeline
data_paths = {
    'train': 'DALI_extra/db/MNIST/training/',
    'valid': 'DALI_extra/db/MNIST/testing/',
}

class MNISTPipeline(Pipeline):
    def __init__(
        self,
        mode: str = 'train',
        batch_size: int = 16,
        num_threads: int = 4,
        device_id: int = 0,
    ):
        super().__init__(
            batch_size=batch_size,
            num_threads=num_threads,
            device_id=device_id
        )
        self.mode = mode
        
        self.input = ops.Caffe2Reader(path=data_paths[mode], random_shuffle=True, name='Reader')
        self.decode = ops.ImageDecoder(device = 'mixed', output_type = types.GRAY)
        self.cmn = ops.CropMirrorNormalize(
            device="gpu",
            dtype=types.FLOAT,
            std=[0.3081 * 255],
            mean=[0.1307 * 255],
            output_layout=types.NCHW,
        )
    
    def define_graph(self):
        jpegs, labels = self.input()
        images = self.decode(jpegs)
        images = self.cmn(images)
        return images, labels.gpu()
    
    def __len__(self):
        return 60000 if self.mode == 'train' else 10000

In [4]:
# Customizing DALI loader for using in catalyst.
class DALILoader(DataLoader):
    def __init__(
        self,
        mode: str = 'train',
        batch_size: int = 32,
        num_workers: int = 4,
    ):
        self.batch_size = batch_size
        
        self.pipeline = MNISTPipeline(mode=mode, batch_size=batch_size, num_threads=num_workers)
        self.pipeline.build()
        
        self.loader = DALIGenericIterator(
            pipelines=self.pipeline,
            output_map=['features', 'targets'],
            size=len(self.pipeline),
            auto_reset=True,
            last_batch_policy=LastBatchPolicy.PARTIAL,
        )
        
    def __len__(self):
        return len(self.loader)
    
    def __iter__(self):
        return ({'features': batch[0]["features"], 'targets': batch[0]["targets"].squeeze().long()} for batch in self.loader)
    
    def sampler(self):
        return None
    
    def batch_sampler(self):
        return None

In [5]:
import os
from torch import nn, optim
from torch.utils.data import DataLoader
from catalyst import dl

In [6]:
BATCH_SIZE = 32
NUM_WORKERS = 8

model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.02)

loaders = {
    'train': DALILoader(mode='train', batch_size=BATCH_SIZE, num_workers=NUM_WORKERS),
    'valid': DALILoader(mode='valid', batch_size=BATCH_SIZE, num_workers=NUM_WORKERS),
}



In [7]:
runner = dl.SupervisedRunner()

runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    loaders=loaders,
    num_epochs=1,
    logdir="./logs",
    valid_loader="valid",
    valid_metric="loss",
    minimize_valid_metric=True,
    verbose=True,
    callbacks=[
        dl.AccuracyCallback(input_key="logits", target_key="targets", num_classes=10),
#         dl.PrecisionRecallF1SupportCallback(
#             input_key="logits", target_key="targets", num_classes=10
#         ),
#         dl.AUCCallback(input_key="logits", target_key="targets"),
#         # catalyst[ml] required ``pip install catalyst[ml]``
#         dl.ConfusionMatrixCallback(
#             input_key="logits", target_key="targets", num_classes=10
#         ),
    ]
)

HBox(children=(HTML(value='1/1 * Epoch (train)'), FloatProgress(value=0.0, max=1875.0), HTML(value='')))


train (1/1) accuracy: 0.8583166666666664 | accuracy/std: 0.07269646003630853 | accuracy01: 0.8583166666666664 | accuracy01/std: 0.07269646003630853 | accuracy03: 0.9703833333333339 | accuracy03/std: 0.03967923633655473 | accuracy05: 0.9909833333333333 | accuracy05/std: 0.023284034401133934 | loss: 1.2825655004053673 | loss/mean: 1.2825655004053673 | loss/std: 0.8768621207409731 | lr: 0.02 | momentum: 0.9


HBox(children=(HTML(value='1/1 * Epoch (valid)'), FloatProgress(value=0.0, max=313.0), HTML(value='')))


valid (1/1) accuracy: 0.8774 | accuracy/std: 0.058236192744843304 | accuracy01: 0.8774 | accuracy01/std: 0.058236192744843304 | accuracy03: 0.9776999999999999 | accuracy03/std: 0.024867586124993495 | accuracy05: 0.9943999999999998 | accuracy05/std: 0.012243569352890922 | loss: 1.3684241206318133 | loss/mean: 1.3684241206318133 | loss/std: 0.8965518830115483 | lr: 0.02 | momentum: 0.9
* Epoch (1/1) 
Top best models:
logs/checkpoints/train.1.pth	1.3684
