@@ -16,7 +16,7 @@ Once you've organized it into a LightningModule, it automates most of the traini
16
16
17
17
To illustrate, here's the typical PyTorch project structure organized in a LightningModule.
18
18
19
- .. figure :: /_images/mnist_imgs/pt_to_pl.jpg
19
+ .. figure :: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pt_animation_gif.gif
20
20
:alt: Convert from PyTorch to Lightning
21
21
22
22
----------
@@ -51,9 +51,7 @@ A lightningModule defines
51
51
x, y = batch
52
52
y_hat = self(x)
53
53
loss = F.cross_entropy(y_hat, y)
54
- result = pl.TrainResult(minimize=loss, checkpoint_on=loss)
55
- result.log('train_loss', loss, prog_bar=True)
56
- return result
54
+ return loss
57
55
58
56
def configure_optimizers(self):
59
57
return torch.optim.Adam(self.parameters(), lr=0.0005)
@@ -68,20 +66,33 @@ well across any accelerator.
68
66
.. code-block :: python
69
67
70
68
# dataloader
71
- train_loader = DataLoader(MNIST(os.getcwd(), train = True , download = True , transform = transforms.ToTensor()), shuffle = True )
69
+ dataset = MNIST(os.getcwd(), download = True , transform = transforms.ToTensor())
70
+ train_loader = DataLoader(dataset)
72
71
73
72
# init model
74
73
model = LitModel()
75
74
76
75
# most basic trainer, uses good defaults
77
- trainer = pl.Trainer(gpus = 8 , num_nodes = 1 )
78
- trainer.fit(
79
- model,
80
- train_loader,
81
- )
76
+ trainer = pl.Trainer()
77
+ trainer.fit(model, train_loader)
82
78
83
- # to use advanced features such as GPUs/TPUs/16 bit you have to change NO CODE
84
- trainer = pl.Trainer(tpu_cores = 8 , precision = 16 )
79
+ Using GPUs/TPUs
80
+ ^^^^^^^^^^^^^^^
81
+ It's trivial to use GPUs or TPUs in Lightning. There's NO NEED to change your code, simply change the Trainer options.
82
+
83
+ .. code-block :: python
84
+
85
+ # train on 1, 2, 4, n GPUs
86
+ Trainer(gpus = 1 )
87
+ Trainer(gpus = 2 )
88
+ Trainer(gpus = 8 , num_nodes = n)
89
+
90
+ # train on TPUs
91
+ Trainer(tpu_cores = 8 )
92
+ Trainer(tpu_cores = 128 )
93
+
94
+ # even half precision
95
+ Trainer(gpus = 2 , precision = 16 )
85
96
86
97
The code above gives you the following for free:
87
98
@@ -123,6 +134,14 @@ Under the hood, lightning does (in high-level pseudocode):
123
134
optimizer.step()
124
135
optimizer.zero_grad()
125
136
137
+ Main take-aways:
138
+
139
+ - Lightning sets .train() and enables gradients when entering the training loop.
140
+ - Lightning iterates over the epochs automatically.
141
+ - Lightning iterates the dataloaders automatically.
142
+ - Training_step gives you full control of the main loop.
143
+ - .backward(), .step(), .zero_grad() are called for you. BUT, you can override this if you need manual control.
144
+
126
145
----------
127
146
128
147
Adding a Validation loop
@@ -137,10 +156,7 @@ To add an (optional) validation loop add the following function
137
156
x, y = batch
138
157
y_hat = self(x)
139
158
loss = F.cross_entropy(y_hat, y)
140
- result = EvalResult(early_stop_on=loss, checkpoint_on=loss)
141
- result.log('val_ce', loss)
142
- result.log('val_acc', accuracy(y_hat, y))
143
- return result
159
+ return {'val_loss': loss, 'log': {'val_loss': loss}}
144
160
145
161
And now the trainer will call the validation loop automatically
146
162
@@ -166,14 +182,17 @@ Under the hood in pseudocode, lightning does the following:
166
182
# ...
167
183
168
184
if validate_at_some_point:
185
+ # disable grads + batchnorm + dropout
169
186
torch.set_grad_enabled(False )
170
187
model.eval()
188
+
171
189
val_outs = []
172
190
for val_batch in model.val_dataloader:
173
191
val_out = model.validation_step(val_batch)
174
192
val_outs.append(val_out)
175
-
176
193
model.validation_epoch_end(val_outs)
194
+
195
+ # enable grads + batchnorm + dropout
177
196
torch.set_grad_enabled(True )
178
197
model.train()
179
198
@@ -197,10 +216,8 @@ You might also need an optional test loop
197
216
x, y = batch
198
217
y_hat = self(x)
199
218
loss = F.cross_entropy(y_hat, y)
200
- result = pl.EvalResult(early_stop_on=loss, checkpoint_on=loss)
201
- result.log('test_ce', loss)
202
- result.log('test_acc', accuracy(y_hat, y), prog_bar=True)
203
- return result
219
+ return {'test_loss': loss, 'log': {'test_loss': loss}}
220
+
204
221
205
222
However, this time you need to specifically call test (this is done so you don't use the test set by mistake)
206
223
@@ -214,7 +231,7 @@ However, this time you need to specifically call test (this is done so you don't
214
231
# OPTION 2:
215
232
# test after loading weights
216
233
model = LitModel.load_from_checkpoint(PATH )
217
- trainer = Trainer(tpu_cores = 1 )
234
+ trainer = Trainer()
218
235
trainer.test(test_dataloaders = test_dataloader)
219
236
220
237
Test loop under the hood
@@ -223,15 +240,21 @@ Under the hood, lightning does the following in (pseudocode):
223
240
224
241
.. code-block :: python
225
242
243
+ # disable grads + batchnorm + dropout
226
244
torch.set_grad_enabled(False )
227
245
model.eval()
246
+
228
247
test_outs = []
229
248
for test_batch in model.test_dataloader:
230
249
test_out = model.test_step(val_batch)
231
250
test_outs.append(test_out)
232
251
233
252
model.test_epoch_end(test_outs)
234
253
254
+ # enable grads + batchnorm + dropout
255
+ torch.set_grad_enabled(True )
256
+ model.train()
257
+
235
258
---------------
236
259
237
260
Data
@@ -380,7 +403,7 @@ Next, materialize the data and build your model
380
403
dm.setup()
381
404
382
405
# pass in the properties you want
383
- model = LitModel(image_width = dm.train_dims[0 ], image_height = dm.train_dims[ 1 ], vocab_length = dm.vocab_size)
406
+ model = LitModel(image_width = dm.train_dims[0 ], vocab_length = dm.vocab_size)
384
407
385
408
# train
386
409
trainer.fit(model, dm)
@@ -389,26 +412,53 @@ Next, materialize the data and build your model
389
412
390
413
Logging/progress bar
391
414
--------------------
415
+
416
+ |
417
+
418
+ .. image :: /_images/mnist_imgs/mnist_tb.png
419
+ :width: 300
420
+ :align: center
421
+ :alt: Example TB logs
422
+
423
+ |
424
+
392
425
Lightning has built-in logging to any of the supported loggers or progress bar.
393
426
394
427
Log in train loop
395
428
^^^^^^^^^^^^^^^^^
396
- To log from the training loop use the `TrainResult ` object
429
+ To log from the training loop use the `log ` reserved key.
397
430
398
431
.. code-block :: python
399
432
400
433
def training_step (self , batch , batch_idx ):
401
434
loss = ...
402
- acc = ...
435
+ return { ' loss ' : loss, ' log ' : { ' train_loss ' : loss}}
403
436
404
- # pick what to minimize
405
- result = pl.TrainResult(minimize = loss)
406
437
407
- # logs metric at the end of every training step (batch) to the tensorboard or user-specified logger
438
+ However, for more fine-grain control use the `TrainResult ` object.
439
+ These are equivalent:
440
+
441
+ .. code-block :: python
442
+
443
+ def training_step (self , batch , batch_idx ):
444
+ loss = ...
445
+ return {' loss' : loss, ' log' : {' train_loss' : loss}}
446
+
447
+ # equivalent
448
+ def training_step (self , batch , batch_idx ):
449
+ loss = ...
450
+
451
+ result = pl.TrainResult(minimize = loss)
408
452
result.log(' train_loss' , loss)
453
+ return result
409
454
410
- # log to the progress bar only
411
- result.log(' train_acc' , acc, prog_bar = True , logger = False )
455
+ But the TrainResult gives you error-checking and greater flexibility:
456
+
457
+ .. code-block :: python
458
+
459
+ # equivalent
460
+ result.log(' train_loss' , loss)
461
+ result.log(' train_loss' , loss, prog_bar = False , logger = True , on_step = True , on_epoch = False )
412
462
413
463
Then boot up your logger or tensorboard instance to view training logs
414
464
@@ -417,8 +467,7 @@ Then boot up your logger or tensorboard instance to view training logs
417
467
tensorboard --logdir ./lightning_logs
418
468
419
469
.. warning :: Refreshing the progress bar too frequently in Jupyter notebooks or Colab may freeze your UI.
420
-
421
- .. note :: TrainResult defaults to logging on every step, set `on_epoch` to also log the metric for the full epoch
470
+ We recommend you set `Trainer(progress_bar_refresh_rate=10) `
422
471
423
472
Log in Val/Test loop
424
473
^^^^^^^^^^^^^^^^^^^^
@@ -429,25 +478,105 @@ To log from the validation or test loop use a similar approach
429
478
def validation_step (self , batch , batch_idx ):
430
479
loss = ...
431
480
acc = ...
481
+ val_output = {' loss' : loss, ' acc' : acc}
482
+ return val_output
432
483
433
- # pick what to minimize
434
- result = pl.EvalResult(checkpoint_on = acc, early_stop_on = loss)
484
+ def validation_epoch_end (self , validation_step_outputs ):
485
+ # this step allows you to aggregate whatever you passed in from every val step
486
+ val_epoch_loss = torch.stack([x[' loss' ] for x in val_output]).mean()
487
+ val_epoch_acc = torch.stack([x[' acc' ] for x in val_output]).mean()
488
+ return {
489
+ ' val_loss' : val_epoch_loss,
490
+ ' log' : {' avg_val_loss' : val_epoch_loss, ' avg_val_acc' : val_epoch_acc}
491
+ }
435
492
436
- # log the val loss averaged across the full epoch
493
+ The recommended equivalent version in case you don't need to do anything special
494
+ with all the outputs of the validation step:
495
+
496
+ .. code-block :: python
497
+
498
+ def validation_step (self , batch , batch_idx ):
499
+ loss = ...
500
+ acc = ...
501
+
502
+ result = pl.EvalResult(checkpoint_on = loss)
437
503
result.log(' val_loss' , loss)
504
+ result.log(' val_acc' , acc)
505
+ return result
506
+
507
+ .. note :: Only use `validation_epoch_end` if you need fine-grain control over aggreating all step outputs
508
+
509
+
510
+ Log to the progress bar
511
+ ^^^^^^^^^^^^^^^^^^^^^^^
512
+ |
513
+
514
+ .. image :: /_images/mnist_imgs/mnist_cpu_bar.png
515
+ :width: 500
516
+ :align: center
517
+ :alt: Example CPU bar logging
438
518
439
- # log the val acc at each step AND for the full epoch (mean)
440
- result.log(' val_acc' , acc, prog_bar = True , logger = True , on_epoch = True , on_step = True )
519
+ |
520
+
521
+ In addition to visual logging, you can log to the progress bar by using the keyword `progress_bar `:
522
+
523
+ .. code-block :: python
524
+
525
+ def training_step (self , batch , batch_idx ):
526
+ loss = ...
527
+ return {' loss' : loss, ' progress_bar' : {' train_loss' : loss}}
528
+
529
+ Or simply set `prog_bar=True ` in either of the `EvalResult ` or `TrainResult `
530
+
531
+ .. code-block :: python
532
+
533
+ def training_step (self , batch , batch_idx ):
534
+ result = TrainResult(loss)
535
+ result.log(' train_loss' , loss, prog_bar = True )
536
+ return result
441
537
442
- .. note :: EvalResult defaults to logging for the full epoch, use `reduce_fx=torch.mean` to specify a different function.
443
538
444
539
-----------------
445
540
446
541
Why do you need Lightning?
447
542
--------------------------
448
- Notice the code above has nothing about .cuda() or 16-bit or early stopping or logging, etc...
449
- This is where Lightning adds a ton of value.
543
+ The MAIN teakeaway points are:
544
+
545
+ - Lightning is for professional AI researchers/production teams.
546
+ - Lightning is organized PyTorch. It is not an abstraction.
450
547
548
+ Lightning is for you if
549
+ ^^^^^^^^^^^^^^^^^^^^^^^
550
+
551
+ - You're a professional researcher/ml engineer working on non-trivial deep learning.
552
+ - You already know PyTorch and are not a beginner.
553
+ - You want to put models into production much faster.
554
+ - You need full control of all the details but don't need the boilerplate.
555
+ - You want to leverage code written by hundreds of AI researchers, research engs and PhDs from the world's top AI labs.
556
+ - You need GPUs, multi-node training, half-precision and TPUs.
557
+ - You want research code that is rigorously tested (500+ tests) across CPUs/multi-GPUs/multi-TPUs on every pull-request.
558
+
559
+ Some more cool features
560
+ ^^^^^^^^^^^^^^^^^^^^^^^
561
+ Here are (some) of the other things you can do with lightning:
562
+
563
+ - Automatic checkpointing.
564
+ - Automatic early stopping.
565
+ - Automatically overfit your model for a sanity test.
566
+ - Automatic truncated-back-propagation-through-time.
567
+ - Automatically scale your batch size.
568
+ - Automatically attempt to find a good learning rate.
569
+ - Add arbitrary callbacks
570
+ - Hit every line of your code once to see if you have bugs (instead of waiting hours to crash on validation ;)
571
+ - Load checkpoints directly from S3.
572
+ - Move from CPUs to GPUs or TPUs without code changes.
573
+ - Profile your code for speed/memory bottlenecks.
574
+ - Scale to massive compute clusters.
575
+ - Use multiple dataloaders per train/val/test loop.
576
+ - Use multiple optimizers to do Reinforcement learning or even GANs.
577
+
578
+ Example:
579
+ ^^^^^^^^
451
580
Without changing a SINGLE line of your code, you can now do the following with the above code
452
581
453
582
.. code-block :: python
0 commit comments