Skip to content

Commit 2e6e254

Browse files
quick start docs (Lightning-AI#2731)
* added tests * added tests * added tests * added tests * added tests * added tests * added tests * added tests * added tests * added tests * added tests * added tests * added tests * added tests * added tests * added tests * added tests * added tests * added tests * added tests * added tests * added tests * added tests * added tests * added tests * added tests * added tests * added tests
1 parent 0fe933e commit 2e6e254

File tree

2 files changed

+170
-41
lines changed

2 files changed

+170
-41
lines changed

docs/source/introduction_guide.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ you've organized it into a LightningModule, it automates most of the training fo
1111

1212
To illustrate, here's the typical PyTorch project structure organized in a LightningModule.
1313

14-
.. figure:: /_images/mnist_imgs/pt_to_pl.jpg
14+
.. figure:: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pt_animation_gif.gif
1515
:alt: Convert from PyTorch to Lightning
1616

1717
As your project grows in complexity with things like 16-bit precision, distributed training, etc... the part in blue

docs/source/new-project.rst

+169-40
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Once you've organized it into a LightningModule, it automates most of the traini
1616

1717
To illustrate, here's the typical PyTorch project structure organized in a LightningModule.
1818

19-
.. figure:: /_images/mnist_imgs/pt_to_pl.jpg
19+
.. figure:: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pt_animation_gif.gif
2020
:alt: Convert from PyTorch to Lightning
2121

2222
----------
@@ -51,9 +51,7 @@ A lightningModule defines
5151
x, y = batch
5252
y_hat = self(x)
5353
loss = F.cross_entropy(y_hat, y)
54-
result = pl.TrainResult(minimize=loss, checkpoint_on=loss)
55-
result.log('train_loss', loss, prog_bar=True)
56-
return result
54+
return loss
5755

5856
def configure_optimizers(self):
5957
return torch.optim.Adam(self.parameters(), lr=0.0005)
@@ -68,20 +66,33 @@ well across any accelerator.
6866
.. code-block:: python
6967
7068
# dataloader
71-
train_loader = DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), shuffle=True)
69+
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
70+
train_loader = DataLoader(dataset)
7271
7372
# init model
7473
model = LitModel()
7574
7675
# most basic trainer, uses good defaults
77-
trainer = pl.Trainer(gpus=8, num_nodes=1)
78-
trainer.fit(
79-
model,
80-
train_loader,
81-
)
76+
trainer = pl.Trainer()
77+
trainer.fit(model, train_loader)
8278
83-
# to use advanced features such as GPUs/TPUs/16 bit you have to change NO CODE
84-
trainer = pl.Trainer(tpu_cores=8, precision=16)
79+
Using GPUs/TPUs
80+
^^^^^^^^^^^^^^^
81+
It's trivial to use GPUs or TPUs in Lightning. There's NO NEED to change your code, simply change the Trainer options.
82+
83+
.. code-block:: python
84+
85+
# train on 1, 2, 4, n GPUs
86+
Trainer(gpus=1)
87+
Trainer(gpus=2)
88+
Trainer(gpus=8, num_nodes=n)
89+
90+
# train on TPUs
91+
Trainer(tpu_cores=8)
92+
Trainer(tpu_cores=128)
93+
94+
# even half precision
95+
Trainer(gpus=2, precision=16)
8596
8697
The code above gives you the following for free:
8798

@@ -123,6 +134,14 @@ Under the hood, lightning does (in high-level pseudocode):
123134
optimizer.step()
124135
optimizer.zero_grad()
125136
137+
Main take-aways:
138+
139+
- Lightning sets .train() and enables gradients when entering the training loop.
140+
- Lightning iterates over the epochs automatically.
141+
- Lightning iterates the dataloaders automatically.
142+
- Training_step gives you full control of the main loop.
143+
- .backward(), .step(), .zero_grad() are called for you. BUT, you can override this if you need manual control.
144+
126145
----------
127146

128147
Adding a Validation loop
@@ -137,10 +156,7 @@ To add an (optional) validation loop add the following function
137156
x, y = batch
138157
y_hat = self(x)
139158
loss = F.cross_entropy(y_hat, y)
140-
result = EvalResult(early_stop_on=loss, checkpoint_on=loss)
141-
result.log('val_ce', loss)
142-
result.log('val_acc', accuracy(y_hat, y))
143-
return result
159+
return {'val_loss': loss, 'log': {'val_loss': loss}}
144160

145161
And now the trainer will call the validation loop automatically
146162

@@ -166,14 +182,17 @@ Under the hood in pseudocode, lightning does the following:
166182
# ...
167183
168184
if validate_at_some_point:
185+
# disable grads + batchnorm + dropout
169186
torch.set_grad_enabled(False)
170187
model.eval()
188+
171189
val_outs = []
172190
for val_batch in model.val_dataloader:
173191
val_out = model.validation_step(val_batch)
174192
val_outs.append(val_out)
175-
176193
model.validation_epoch_end(val_outs)
194+
195+
# enable grads + batchnorm + dropout
177196
torch.set_grad_enabled(True)
178197
model.train()
179198
@@ -197,10 +216,8 @@ You might also need an optional test loop
197216
x, y = batch
198217
y_hat = self(x)
199218
loss = F.cross_entropy(y_hat, y)
200-
result = pl.EvalResult(early_stop_on=loss, checkpoint_on=loss)
201-
result.log('test_ce', loss)
202-
result.log('test_acc', accuracy(y_hat, y), prog_bar=True)
203-
return result
219+
return {'test_loss': loss, 'log': {'test_loss': loss}}
220+
204221

205222
However, this time you need to specifically call test (this is done so you don't use the test set by mistake)
206223

@@ -214,7 +231,7 @@ However, this time you need to specifically call test (this is done so you don't
214231
# OPTION 2:
215232
# test after loading weights
216233
model = LitModel.load_from_checkpoint(PATH)
217-
trainer = Trainer(tpu_cores=1)
234+
trainer = Trainer()
218235
trainer.test(test_dataloaders=test_dataloader)
219236
220237
Test loop under the hood
@@ -223,15 +240,21 @@ Under the hood, lightning does the following in (pseudocode):
223240

224241
.. code-block:: python
225242
243+
# disable grads + batchnorm + dropout
226244
torch.set_grad_enabled(False)
227245
model.eval()
246+
228247
test_outs = []
229248
for test_batch in model.test_dataloader:
230249
test_out = model.test_step(val_batch)
231250
test_outs.append(test_out)
232251
233252
model.test_epoch_end(test_outs)
234253
254+
# enable grads + batchnorm + dropout
255+
torch.set_grad_enabled(True)
256+
model.train()
257+
235258
---------------
236259

237260
Data
@@ -380,7 +403,7 @@ Next, materialize the data and build your model
380403
dm.setup()
381404
382405
# pass in the properties you want
383-
model = LitModel(image_width=dm.train_dims[0], image_height=dm.train_dims[1], vocab_length=dm.vocab_size)
406+
model = LitModel(image_width=dm.train_dims[0], vocab_length=dm.vocab_size)
384407
385408
# train
386409
trainer.fit(model, dm)
@@ -389,26 +412,53 @@ Next, materialize the data and build your model
389412

390413
Logging/progress bar
391414
--------------------
415+
416+
|
417+
418+
.. image:: /_images/mnist_imgs/mnist_tb.png
419+
:width: 300
420+
:align: center
421+
:alt: Example TB logs
422+
423+
|
424+
392425
Lightning has built-in logging to any of the supported loggers or progress bar.
393426

394427
Log in train loop
395428
^^^^^^^^^^^^^^^^^
396-
To log from the training loop use the `TrainResult` object
429+
To log from the training loop use the `log` reserved key.
397430

398431
.. code-block:: python
399432
400433
def training_step(self, batch, batch_idx):
401434
loss = ...
402-
acc = ...
435+
return {'loss': loss, 'log': {'train_loss': loss}}
403436
404-
# pick what to minimize
405-
result = pl.TrainResult(minimize=loss)
406437
407-
# logs metric at the end of every training step (batch) to the tensorboard or user-specified logger
438+
However, for more fine-grain control use the `TrainResult` object.
439+
These are equivalent:
440+
441+
.. code-block:: python
442+
443+
def training_step(self, batch, batch_idx):
444+
loss = ...
445+
return {'loss': loss, 'log': {'train_loss': loss}}
446+
447+
# equivalent
448+
def training_step(self, batch, batch_idx):
449+
loss = ...
450+
451+
result = pl.TrainResult(minimize=loss)
408452
result.log('train_loss', loss)
453+
return result
409454
410-
# log to the progress bar only
411-
result.log('train_acc', acc, prog_bar=True, logger=False)
455+
But the TrainResult gives you error-checking and greater flexibility:
456+
457+
.. code-block:: python
458+
459+
# equivalent
460+
result.log('train_loss', loss)
461+
result.log('train_loss', loss, prog_bar=False, logger=True, on_step=True, on_epoch=False)
412462
413463
Then boot up your logger or tensorboard instance to view training logs
414464

@@ -417,8 +467,7 @@ Then boot up your logger or tensorboard instance to view training logs
417467
tensorboard --logdir ./lightning_logs
418468
419469
.. warning:: Refreshing the progress bar too frequently in Jupyter notebooks or Colab may freeze your UI.
420-
421-
.. note:: TrainResult defaults to logging on every step, set `on_epoch` to also log the metric for the full epoch
470+
We recommend you set `Trainer(progress_bar_refresh_rate=10)`
422471

423472
Log in Val/Test loop
424473
^^^^^^^^^^^^^^^^^^^^
@@ -429,25 +478,105 @@ To log from the validation or test loop use a similar approach
429478
def validation_step(self, batch, batch_idx):
430479
loss = ...
431480
acc = ...
481+
val_output = {'loss': loss, 'acc': acc}
482+
return val_output
432483
433-
# pick what to minimize
434-
result = pl.EvalResult(checkpoint_on=acc, early_stop_on=loss)
484+
def validation_epoch_end(self, validation_step_outputs):
485+
# this step allows you to aggregate whatever you passed in from every val step
486+
val_epoch_loss = torch.stack([x['loss'] for x in val_output]).mean()
487+
val_epoch_acc = torch.stack([x['acc'] for x in val_output]).mean()
488+
return {
489+
'val_loss': val_epoch_loss,
490+
'log': {'avg_val_loss': val_epoch_loss, 'avg_val_acc': val_epoch_acc}
491+
}
435492
436-
# log the val loss averaged across the full epoch
493+
The recommended equivalent version in case you don't need to do anything special
494+
with all the outputs of the validation step:
495+
496+
.. code-block:: python
497+
498+
def validation_step(self, batch, batch_idx):
499+
loss = ...
500+
acc = ...
501+
502+
result = pl.EvalResult(checkpoint_on=loss)
437503
result.log('val_loss', loss)
504+
result.log('val_acc', acc)
505+
return result
506+
507+
.. note:: Only use `validation_epoch_end` if you need fine-grain control over aggreating all step outputs
508+
509+
510+
Log to the progress bar
511+
^^^^^^^^^^^^^^^^^^^^^^^
512+
|
513+
514+
.. image:: /_images/mnist_imgs/mnist_cpu_bar.png
515+
:width: 500
516+
:align: center
517+
:alt: Example CPU bar logging
438518

439-
# log the val acc at each step AND for the full epoch (mean)
440-
result.log('val_acc', acc, prog_bar=True, logger=True, on_epoch=True, on_step=True)
519+
|
520+
521+
In addition to visual logging, you can log to the progress bar by using the keyword `progress_bar`:
522+
523+
.. code-block:: python
524+
525+
def training_step(self, batch, batch_idx):
526+
loss = ...
527+
return {'loss': loss, 'progress_bar': {'train_loss': loss}}
528+
529+
Or simply set `prog_bar=True` in either of the `EvalResult` or `TrainResult`
530+
531+
.. code-block:: python
532+
533+
def training_step(self, batch, batch_idx):
534+
result = TrainResult(loss)
535+
result.log('train_loss', loss, prog_bar=True)
536+
return result
441537
442-
.. note:: EvalResult defaults to logging for the full epoch, use `reduce_fx=torch.mean` to specify a different function.
443538
444539
-----------------
445540

446541
Why do you need Lightning?
447542
--------------------------
448-
Notice the code above has nothing about .cuda() or 16-bit or early stopping or logging, etc...
449-
This is where Lightning adds a ton of value.
543+
The MAIN teakeaway points are:
544+
545+
- Lightning is for professional AI researchers/production teams.
546+
- Lightning is organized PyTorch. It is not an abstraction.
450547

548+
Lightning is for you if
549+
^^^^^^^^^^^^^^^^^^^^^^^
550+
551+
- You're a professional researcher/ml engineer working on non-trivial deep learning.
552+
- You already know PyTorch and are not a beginner.
553+
- You want to put models into production much faster.
554+
- You need full control of all the details but don't need the boilerplate.
555+
- You want to leverage code written by hundreds of AI researchers, research engs and PhDs from the world's top AI labs.
556+
- You need GPUs, multi-node training, half-precision and TPUs.
557+
- You want research code that is rigorously tested (500+ tests) across CPUs/multi-GPUs/multi-TPUs on every pull-request.
558+
559+
Some more cool features
560+
^^^^^^^^^^^^^^^^^^^^^^^
561+
Here are (some) of the other things you can do with lightning:
562+
563+
- Automatic checkpointing.
564+
- Automatic early stopping.
565+
- Automatically overfit your model for a sanity test.
566+
- Automatic truncated-back-propagation-through-time.
567+
- Automatically scale your batch size.
568+
- Automatically attempt to find a good learning rate.
569+
- Add arbitrary callbacks
570+
- Hit every line of your code once to see if you have bugs (instead of waiting hours to crash on validation ;)
571+
- Load checkpoints directly from S3.
572+
- Move from CPUs to GPUs or TPUs without code changes.
573+
- Profile your code for speed/memory bottlenecks.
574+
- Scale to massive compute clusters.
575+
- Use multiple dataloaders per train/val/test loop.
576+
- Use multiple optimizers to do Reinforcement learning or even GANs.
577+
578+
Example:
579+
^^^^^^^^
451580
Without changing a SINGLE line of your code, you can now do the following with the above code
452581

453582
.. code-block:: python

0 commit comments

Comments
 (0)