Skip to content

Commit 54601bd

Browse files
* updated docs
1 parent afd7f67 commit 54601bd

File tree

5 files changed

+720
-313
lines changed

5 files changed

+720
-313
lines changed

docs/source/datamodules.rst

+197
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
LightningDataModule
2+
===================
3+
Data preparation in PyTorch follows 5 steps:
4+
5+
1. Download / tokenize / process.
6+
2. Clean and (maybe) save to disk.
7+
3. Load inside :class:`~torch.utils.data.Dataset`.
8+
4. Apply transforms (rotate, tokenize, etc...).
9+
5. Wrap inside a :class:`~torch.utils.data.DataLoader`.
10+
11+
A DataModule is simply a collection of a train_dataloader, val_dataloader(s), test_dataloader(s) along with the
12+
matching transforms and data processing/downloads steps required.
13+
14+
15+
>>> import pytorch_lightning as pl
16+
>>> class MNISTDataModule(pl.LightningDataModule):
17+
... def prepare_data(self):
18+
... # download
19+
... MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
20+
... MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())
21+
...
22+
... def setup(self, stage):
23+
... mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transforms.ToTensor())
24+
... mnist_test = MNIST(os.getcwd(), train=False, download=False, transform=transforms.ToTensor())
25+
... # train/val split
26+
... mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])
27+
...
28+
... # assign to use in dataloaders
29+
... self.train_dataset = mnist_train
30+
... self.val_dataset = mnist_val
31+
... self.test_dataset = mnist_test
32+
...
33+
... def train_dataloader(self):
34+
... return DataLoader(self.train_dataset, batch_size=64)
35+
...
36+
... def val_dataloader(self):
37+
... return DataLoader(self.mnist_val, batch_size=64)
38+
...
39+
... def test_dataloader(self):
40+
... return DataLoader(self.mnist_test, batch_size=64)
41+
42+
---------------
43+
44+
Methods
45+
-------
46+
To define a DataModule define 5 methods:
47+
48+
- prepare_data (how to download(), tokenize, etc...)
49+
- setup (how to split, etc...)
50+
- train_dataloader
51+
- val_dataloader(s)
52+
- test_dataloader(s)
53+
54+
prepare_data
55+
^^^^^^^^^^^^
56+
Use this method to do things that might write to disk or that need to be done only from a single GPU in distributed
57+
settings.
58+
59+
- download
60+
- tokenize
61+
- etc...
62+
63+
>>> class MNISTDataModule(pl.LightningDataModule):
64+
... def prepare_data(self):
65+
... # download
66+
... MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
67+
... MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())
68+
69+
.. warning:: `prepare_data` is called from a single GPU. Do not use it to assign state (`self.x = y`).
70+
71+
setup
72+
^^^^^
73+
There are also data operations you might want to perform on every GPU. Use setup to do things like:
74+
75+
- count number of classes
76+
- build vocabulary
77+
- perform train/val/test splits
78+
- etc...
79+
80+
>>> import pytorch_lightning as pl
81+
>>> class MNISTDataModule(pl.LightningDataModule):
82+
... def setup(self, stage):
83+
... mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transforms.ToTensor())
84+
... mnist_test = MNIST(os.getcwd(), train=False, download=False, transform=transforms.ToTensor())
85+
... # train/val split
86+
... mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])
87+
...
88+
... # assign to use in dataloaders
89+
... self.train_dataset = mnist_train
90+
... self.val_dataset = mnist_val
91+
... self.test_dataset = mnist_test
92+
93+
.. warning:: `setup` is called from every GPU. Setting state here is okay.
94+
95+
train_dataloader
96+
^^^^^^^^^^^^^^^^
97+
Use this method to generate the train dataloader. This is also a good place to place default transformations.
98+
99+
>>> import pytorch_lightning as pl
100+
>>> class MNISTDataModule(pl.LightningDataModule):
101+
... def train_dataloader(self):
102+
... transforms = transform_lib.Compose([
103+
... transform_lib.ToTensor(),
104+
... transform_lib.Normalize(mean=(0.5,), std=(0.5,)),
105+
... ])
106+
... return DataLoader(self.train_dataset, transform=transforms, batch_size=64)
107+
108+
However, to decouple your data from transforms you can parametrize them via `__init__`.
109+
110+
.. code-block:: python
111+
112+
class MNISTDataModule(pl.LightningDataModule):
113+
def __init__(self, train_transforms, val_transforms, test_transforms):
114+
self.train_transforms = train_transforms
115+
self.val_transforms = val_transforms
116+
self.test_transforms = test_transforms
117+
118+
val_dataloader
119+
^^^^^^^^^^^^^^
120+
Use this method to generate the val dataloader. This is also a good place to place default transformations.
121+
122+
>>> import pytorch_lightning as pl
123+
>>> class MNISTDataModule(pl.LightningDataModule):
124+
... def val_dataloader(self):
125+
... transforms = transform_lib.Compose([
126+
... transform_lib.ToTensor(),
127+
... transform_lib.Normalize(mean=(0.5,), std=(0.5,)),
128+
... ])
129+
... return DataLoader(self.val_dataset, transform=transforms, batch_size=64)
130+
131+
test_dataloader
132+
^^^^^^^^^^^^^^^
133+
Use this method to generate the test dataloader. This is also a good place to place default transformations.
134+
135+
>>> import pytorch_lightning as pl
136+
>>> class MNISTDataModule(pl.LightningDataModule):
137+
... def test_dataloader(self):
138+
... transforms = transform_lib.Compose([
139+
... transform_lib.ToTensor(),
140+
... transform_lib.Normalize(mean=(0.5,), std=(0.5,)),
141+
... ])
142+
... return DataLoader(self.test_dataset, transform=transforms, batch_size=64)
143+
144+
------------------
145+
146+
Using a DataModule
147+
------------------
148+
The recommended way to use a DataModule is simply:
149+
150+
.. code-block:: python
151+
152+
dm = MNISTDataModule()
153+
model = Model()
154+
trainer.fit(model, dm)
155+
156+
trainer.test(datamodule=dm)
157+
158+
If you need information from the dataset to build your model, then run `prepare_data` and `setup` manually (Lightning
159+
still ensures the method runs on the correct devices)
160+
161+
.. code-block:: python
162+
163+
dm = MNISTDataModule()
164+
dm.prepare_data()
165+
dm.setup()
166+
167+
model = Model(num_classes=dm.num_classes, width=dm.width, vocab=dm.vocab)
168+
trainer.fit(model, dm)
169+
170+
trainer.test(model, datamodule=dm)
171+
172+
----------------
173+
174+
Why use datamodules?
175+
--------------------
176+
DataModules have a few key advantages:
177+
178+
- It decouples the data from the model.
179+
- It has all the necessary details for anyone to use the exact same data setup.
180+
- Datamodules can be shared across models.
181+
- Datamodules can also be used without Lightning by calling the methods directly
182+
183+
.. code-block:: python
184+
185+
dm = MNISTDataModule()
186+
dm.prepare_data()
187+
dm.setup()
188+
189+
for batch in dm.train_dataloader():
190+
...
191+
for batch in dm.val_dataloader():
192+
...
193+
for batch in dm.test_dataloader():
194+
...
195+
196+
But overall, DataModules encourage reproducibility by allowing all details of a dataset to be specified in a unified
197+
structure.

docs/source/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ PyTorch Lightning Documentation
2222

2323
callbacks
2424
lightning-module
25+
datamodules
2526
loggers
2627
metrics
2728
hooks

0 commit comments

Comments
 (0)