Following [this](https://pytorch-lightning.readthedocs.io/en/stable/data/datamodule.html)

### LightningDataModule

A datamodule is a shareable, reusable class that encapsulates all the steps needed to process data.

<img src="./assets/datamodule_overview.png" width=1000/>

<br/>

<img src="./assets/datamodule_use.png" width=200/>

A datamodule encapsulates the five steps involved in data processing in PyTorch:

* Download / tokenize / process.
* Clean and (maybe) save to disk.
* Load inside `Dataset`.
* Apply transforms (rotate, tokenize, etc…).
* Wrap inside a `DataLoader`.

This class can then be shared and used anywhere:
```python
from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule

model = LitClassifier()
trainer = Trainer()

imagenet = ImagenetDataModule()
trainer.fit(model, datamodule=imagenet)

cifar10 = CIFAR10DataModule()
trainer.fit(model, datamodule=cifar10)
```

###  Why do I need a DataModule?

In normal PyTorch code, the data cleaning/preparation is usually scattered across many files. This makes sharing and reusing the exact splits and transforms across projects impossible.

Datamodules are for you if you ever asked the questions:
* what splits did you use?
* what transforms did you use?
* what normalization did you use?
* how did you prepare/tokenize the data?


### What is a DataModule?

[❗] A DataModule is simply a collection of:
* a train_dataloader(s), 
* val_dataloader(s),
* test_dataloader(s) and
* predict_dataloader(s)
* along with the matching transforms
* and data processing/downloads steps required.

Here’s a simple PyTorch example:

```python
# regular PyTorch
test_data = MNIST(my_path, train=False, download=True)
predict_data = MNIST(my_path, train=False, download=True)
train_data = MNIST(my_path, train=True, download=True)
train_data, val_data = random_split(train_data, [55000, 5000])

train_loader = DataLoader(train_data, batch_size=32)
val_loader = DataLoader(val_data, batch_size=32)
test_loader = DataLoader(test_data, batch_size=32)
predict_loader = DataLoader(predict_data, batch_size=32)
```

The equivalent DataModule just organizes the same exact code, but makes it reusable across projects.

```python
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = "path/to/dir", batch_size: int = 32):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

    def setup(self, stage: str):
        self.mnist_test = MNIST(self.data_dir, train=False)
        self.mnist_predict = MNIST(self.data_dir, train=False)
        mnist_full = MNIST(self.data_dir, train=True)
        self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size)

    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=self.batch_size)

    def teardown(self, stage: str):
        # Used to clean-up when the run is finished
        ...
```

But now, as the complexity of your processing grows (transforms, multiple-GPU training),
you can let Lightning handle those details for you while making this dataset reusable
so you can share with colleagues or use in different projects.

```python
mnist = MNISTDataModule(my_path)
model = LitClassifier()

trainer = Trainer()
trainer.fit(model, mnist)
```

Here’s a more realistic, complex DataModule that shows how much more reusable the datamodule is.

In [1]:
import pytorch_lightning as pl
from torch.utils.data import random_split, DataLoader

# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST
from torchvision import transforms


class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = "./"):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage: str):

        # Assign train/val datasets for use in dataloaders
        if stage == "fit":
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test":
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

        if stage == "predict":
            self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=32)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=32)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=32)

    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=32)

dm = MNISTDataModule()
dm

<__main__.MNISTDataModule at 0x7f20241be880>

### LightningDataModule API

See full details of the API online:

https://pytorch-lightning.readthedocs.io/en/stable/data/datamodule.html#lightningdatamodule-api

To define a DataModule the following methods are used to create train/val/test/predict dataloaders:

* prepare_data (how to download, tokenize, etc…)
* setup (how to split, define dataset, etc…)
* train_dataloader
* val_dataloader
* test_dataloader
* predict_dataloader


### Using a DataModule

The recommended way to use a DataModule is simply:

```python
dm = MNISTDataModule()
model = Model()
trainer.fit(model, datamodule=dm)
trainer.test(datamodule=dm)
trainer.validate(datamodule=dm)
trainer.predict(datamodule=dm)
```

⚠️ If you need information from the dataset to build your model, then run `prepare_data` and `setup` manually 
(Lightning ensures the method runs on the correct devices).

```python
dm = MNISTDataModule()
dm.prepare_data()
dm.setup(stage="fit")

model = Model(num_classes=dm.num_classes, width=dm.width, vocab=dm.vocab)
trainer.fit(model, dm)

dm.setup(stage="test")
trainer.test(datamodule=dm)
```

You can access:
* the current used datamodule of a trainer via `trainer.datamodule`
* and the current used dataloaders via `trainer.train_dataloader`, `trainer.val_dataloaders` and `trainer.test_dataloaders`.

### DataModules without Lightning

You can of course use `DataModule`s in plain PyTorch code as well.

```python
# download, etc...
dm = MNISTDataModule()
dm.prepare_data()

# splits/transforms
dm.setup(stage="fit")

# use data
for batch in dm.train_dataloader():
    ...

for batch in dm.val_dataloader():
    ...

dm.teardown(stage="fit")

# lazy load test data
dm.setup(stage="test")
for batch in dm.test_dataloader():
    ...

dm.teardown(stage="test")
```

But overall, `DataModule`s encourage reproducibility by allowing all details of a dataset to be specified in a unified structure.

### Hyperparameters in DataModules

Like `LightningModule`s, `DataModule`s support hyperparameters with the same API.

```python
import pytorch_lightning as pl


class CustomDataModule(pl.LightningDataModule):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.save_hyperparameters()

    def configure_optimizers(self):
        # access the saved hyperparameters
        opt = optim.Adam(self.parameters(), lr=self.hparams.lr)
```

Refer to `save_hyperparameters` in lightning module for more details.

### Save DataModule state

When a checkpoint is created, it asks every DataModule for their state.

[❗] If your `DataModule` defines the `state_dict` and `load_state_dict` methods,
the checkpoint will automatically track and restore your `DataModules`.

```python
class LitDataModule(pl.DataModule):
    def state_dict(self):  # NOTE.
        # track whatever you want here
        state = {"current_train_batch_index": self.current_train_batch_index}
        return state

    def load_state_dict(self, state_dict):  # NOTE.
        # restore the state based on what you tracked in (def state_dict)
        self.current_train_batch_index = state_dict["current_train_batch_index"]
```