# Lab 3: Datasets

### Overview

This notebook introduces you to the MONAI dataset APIs:
- Recap the base dataset API
- Understanding the caching mechanism
- Dataset utilities

## Install MONAI and import dependecies
This section installs the latest version of MONAI and validates the install by printing out the configuration.

We'll then import our dependencies and MONAI.  

In [2]:
!pip install -qU "monai[nibabel]==0.3.0rc2"

import time
import torch

import monai
monai.config.print_config()

MONAI version: 0.3.0rc2
Python version: 3.8.3 (default, Jul  2 2020, 16:21:59)  [GCC 7.3.0]
Numpy version: 1.18.5
Pytorch version: 1.6.0

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 3.1.1
scikit-image version: 0.16.2
Pillow version: 7.2.0
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies



## MONAI Dataset 

A MONAI [Dataset](https://docs.monai.io/en/latest/data.html?highlight=dataset#dataset) is a generic dataset with a `__len__` property, `__getitem__` property, and an optional callable data transform when fetching a data sample.

We'll start by initializing some generic data, calling the Dataset class with the generic data, and specifying `None` for our transforms.

In [4]:
items = [{"data": 4}, 
         {"data": 9}, 
         {"data": 3}, 
         {"data": 7}, 
         {"data": 1},
         {"data": 2},
         {"data": 5}]
dataset = monai.data.Dataset(items, transform=None)

print(f"Length of dataset is {len(dataset)}")
for item in dataset:
  print(item)

Length of dataset is 7
{'data': 4}
{'data': 9}
{'data': 3}
{'data': 7}
{'data': 1}
{'data': 2}
{'data': 5}


### Compatible with the PyTorch DataLoader

MONAI functionality should be compatible with the PyTorch DataLoader, although free to subclass from it if there is additional functionality that we consider key, which cannot be realized with the standard DataLoader class.

In [5]:
for item in torch.utils.data.DataLoader(dataset, batch_size=2):
  print(item)

{'data': tensor([4, 9])}
{'data': tensor([3, 7])}
{'data': tensor([1, 2])}
{'data': tensor([5])}


### Load items with a customized transform

We'll create a custom transform called `SquareIt`, which will replace the corresponding value of the input's `keys` with a squared value. In our case, `SquareIt(keys='data')` will apply the square transform to the value of `x['data']`.

In [6]:
class SquareIt(monai.transforms.MapTransform):
  """a simple transform to return a squared number"""

  def __init__(self, keys):
    monai.transforms.MapTransform.__init__(self, keys)
    print(f"keys to square it: {self.keys}")

  def __call__(self, x):
    key = self.keys[0]
    data = x[key]
    output = {key: data ** 2}
    return output

square_dataset = monai.data.Dataset(items, transform=SquareIt(keys='data'))
for item in square_dataset:
  print(item)

keys to square it: ('data',)
{'data': 16}
{'data': 81}
{'data': 9}
{'data': 49}
{'data': 1}
{'data': 4}
{'data': 25}


Keep in mind
- `SquareIt` is implemented as creating a new dictionary `output` instead of overwriting the content of dict `x` directly. So that we can repeatedly apply the transforms, for example, in multiple epochs of training
- `SquareIt.__call__` read the key information from `self.keys` but does not write any properties to `self`. Because writing properties will not work with a multi-processing data loader.
- In most of the MONAI preprocessing transforms, we assume `x[key]` has the shape: `(num_channels, spatial_dim_1, spatial_dim_2, ...)`. The channel dimension is not omitted even if `num_channels` equals to 1, but the spatial dimensions could be omitted.

## MONAI dataset caching

To demonstrate the benefit dataset caching, we're going to construct a dataset with a slow transform.  To do that, we're going to call the sleep function during each of the `__call__` functions.

In [7]:
class SlowSquare(monai.transforms.MapTransform):
  """a simple transform to slowly return a squared number"""
  
  def __init__(self, keys):
    monai.transforms.MapTransform.__init__(self, keys)
    print(f"keys to square it: {self.keys}")

  def __call__(self, x):
    time.sleep(1.0)
    output = {key: x[key] ** 2 for key in self.keys}
    return output

square_dataset = monai.data.Dataset(items, transform=SlowSquare(keys='data'))

keys to square it: ('data',)


As expected, it's going to take about 7 seconds to go through all the items.

In [8]:
%time for item in square_dataset: print(item)

{'data': 16}
{'data': 81}
{'data': 9}
{'data': 49}
{'data': 1}
{'data': 4}
{'data': 25}
CPU times: user 0 ns, sys: 15.6 ms, total: 15.6 ms
Wall time: 7.01 s


### Cache Dataset

When using [CacheDataset](https://docs.monai.io/en/latest/data.html?highlight=dataset#cachedataset) the caching is done when the object is initialized for the first time, so the initialization is slower than a regular dataset.

By caching the results of non-random preprocessing transforms, it accelerates the training data pipeline. If the requested data is not in the cache, all transforms will run normally.

In [9]:
square_cached = monai.data.CacheDataset(items, transform=SlowSquare(keys='data'))

keys to square it: ('data',)


However, repeatedly fetching the items from an initialised CacheDataset is fast.

In [10]:
%timeit list(item for item in square_cached)

17.9 µs ± 2.2 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)


To improve the caching efficiency, always put as many as possible non-random transforms before the randomized ones when composing the chain of transforms.

### Persistent Caching

[PersistantDataset](https://docs.monai.io/en/latest/data.html?highlight=dataset#persistentdataset) allows for persistent storage of pre-computed values to efficiently manage larger than memory dictionary format data.

The non-random transform components are computed when first used and stored in the cache_dir for rapid retrieval on subsequent uses.

In [11]:
square_persist = monai.data.PersistentDataset(items, transform=SlowSquare(keys='data'), cache_dir="my_cache")

keys to square it: ('data',)


The caching happens at the first epoch of loading the dataset, so calling the dataset the first time should take about 7 seconds.

In [12]:
%time for item in square_persist: print(item)

{'data': 16, 'cached': True}
{'data': 81, 'cached': True}
{'data': 9, 'cached': True}
{'data': 49, 'cached': True}
{'data': 1, 'cached': True}
{'data': 4, 'cached': True}
{'data': 25, 'cached': True}
CPU times: user 31.2 ms, sys: 46.9 ms, total: 78.1 ms
Wall time: 7.04 s


During the initialization of the `PersistentDataset` we passed in the parameter "my_cache" for the location to store the intermediate data.  We'll look at that directory below.

In [13]:
!ls my_cache

4778b171cb1049abbcf1032d03ff0afa.pt  b4f755104d6a0dbcb613830c6843e20a.pt
4c9197730c3e18666577f071056e22aa.pt  c21e0cfa7480c1552432f9970c278b2f.pt
98de00671e255e94c2f34ce3bee56982.pt  cf2dbfadfc25b1d7be23db09c39200ef.pt
aa9229f61411705e25ed1d31ed0b7f98.pt


When calling out to the dataset on the following epochs, it will not call the slow transform but used the cached data.

In [14]:
%timeit [item for item in square_persist]

5.55 ms ± 277 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


Fresh dataset instances can make use of the caching data:

In [15]:
square_persist_1 = monai.data.PersistentDataset(items, transform=SlowSquare(keys='data'), cache_dir="my_cache")
%timeit [item for item in square_persist_1]

keys to square it: ('data',)
5.23 ms ± 192 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


### Caching in action
- There's also a [SmartCacheDataset](https://docs.monai.io/en/latest/data.html#monai.data.SmartCacheDataset) to hide the transforms latency with less memory consumption.
- The dataset tutorial notebook has a working example and a comparison of different caching mechanism in MONAI: https://github.com/Project-MONAI/tutorials/blob/master/acceleration/dataset_type_performance.ipynb


## Other dataset utilities

### ZipDataset

[ZipDataset](https://docs.monai.io/en/latest/data.html?highlight=dataset#zipdataset) will zip several PyTorch datasets and output data(with the same index) together in a tuple. If a single dataset's output is already a tuple, flatten it and extend to the result. It supports applying some transforms on the associated new element.

In [None]:
items = [4, 9, 3]
dataset_1 = monai.data.Dataset(items)

items = [7, 1, 2, 5]
dataset_2 = monai.data.Dataset(items)

def concat(data):
  # data[0] is an element from dataset_1
  # data[1] is an element from dataset_2
  return (f"{data[0]} + {data[1]} = {data[0] + data[1]}",)

zipped_data = monai.data.ZipDataset([dataset_1, dataset_2], transform=concat)
for item in zipped_data:
  print(item)

### Common  Datasets

MONAI provides access to some commonly used medical imaging datasets through [DecathlonDataset](https://docs.monai.io/en/latest/data.html?highlight=dataset#decathlon-datalist). This function leverages the features described throughout this notebook.

In [None]:
dataset = monai.apps.DecathlonDataset(root_dir="./", task="Task04_Hippocampus", section="training", download=True)

In [None]:
print(dataset.get_properties("numTraining"))
print(dataset.get_properties("description"))

In [None]:
print(dataset[0]['image'].shape)
print(dataset[0]['label'].shape)

These datasets are an extension of CacheDataset.
More details of this API are covered in the other labs.

## Summary

In this notebook, we recapped datasets and learned more about their caching mechanisms, including:
- Cache Dataset and Persistent Dataset
- How to use DecathlonData

For full API documentation, please visit https://docs.monai.io.