A ~LightningModule
organizes your PyTorch code into 5 sections
- Computations (init).
- Train loop (training_step)
- Validation loop (validation_step)
- Test loop (test_step)
- Optimizers (configure_optimizers)
Notice a few things.
- It's the SAME code.
- The PyTorch code IS NOT abstracted - just organized.
- All the other code that's not in the
~LightningModule
has been automated for you by the trainer.
net = Net() trainer = Trainer() trainer.fit(net)
- There are no .cuda() or .to() calls... Lightning does these for you.
# don't do in lightning x = torch.Tensor(2, 3) x = x.cuda() x = x.to(device) # do this instead x = x # leave it alone! # or to init a new tensor new_x = torch.Tensor(2, 3) new_x = new_x.type_as(x)
- There are no samplers for distributed, Lightning also does this for you.
# Don't do in Lightning... data = MNIST(...) sampler = DistributedSampler(data) DataLoader(data, sampler=sampler) # do this instead data = MNIST(...) DataLoader(data)
- A
~LightningModule
is atorch.nn.Module
but with added functionality. Use it as such!
net = Net.load_from_checkpoint(PATH) net.freeze() out = net(x)
Thus, to use Lightning, you just need to organize your code which takes about 30 minutes, (and let's be real, you probably should do anyhow).
Here are the only required methods.
>>> import pytorch_lightning as pl
>>> class LitModel(pl.LightningModule):
...
... def __init__(self):
... super().__init__()
... self.l1 = nn.Linear(28 * 28, 10)
...
... def forward(self, x):
... return torch.relu(self.l1(x.view(x.size(0), -1)))
...
... def training_step(self, batch, batch_idx):
... x, y = batch
... y_hat = self(x)
... loss = F.cross_entropy(y_hat, y)
... return loss
...
... def configure_optimizers(self):
... return torch.optim.Adam(self.parameters(), lr=0.02)
Which you can train by doing:
train_loader = DataLoader(MNIST(os.getcwd(), download=True, transform=transforms.ToTensor()))
trainer = pl.Trainer()
model = LitModel()
trainer.fit(model, train_loader)
The LightningModule has many convenience methods, but the core ones you need to know about are:
Name | Description |
---|---|
init | Define computations here |
forward | Use for inference only (separate from training_step) |
training_step | the full training loop |
validation_step | the full validation loop |
test_step | the full test loop |
configure_optimizers | define optimizers and LR schedulers |
To add a training loop use the training_step method
class LitClassifier(pl.LightningModule):
def __init__(self, model):
super().__init__()
self.model = model
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
return loss
Under the hood, Lightning does the following (pseudocode):
# put model in train mode
model.train()
torch.set_grad_enabled(True)
losses = []
for batch in train_dataloader:
# forward
loss = training_step(batch)
losses.append(loss.detach())
# clear gradients
optimizer.zero_grad()
# backward
loss.backward()
# update parameters
optimizer.step()
If you want to calculate epoch-level metrics and log them, use the .log method
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
# logs metrics for each training_step,
# and the average across the epoch, to the progress bar and logger
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss
The .log object automatically reduces the requested metrics across the full epoch. Here's the pseudocode of what it does under the hood:
outs = []
for batch in train_dataloader:
# forward
out = training_step(val_batch)
# clear gradients
optimizer.zero_grad()
# backward
loss.backward()
# update parameters
optimizer.step()
epoch_metric = torch.mean(torch.stack([x['train_loss'] for x in outs]))
If you need to do something with all the outputs of each training_step, override training_epoch_end yourself.
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
preds = ...
return {'loss': loss, 'other_stuff': preds}
def training_epoch_end(self, training_step_outputs):
for pred in training_step_outputs:
# do something
The matching pseudocode is:
outs = []
for batch in train_dataloader:
# forward
out = training_step(val_batch)
# clear gradients
optimizer.zero_grad()
# backward
loss.backward()
# update parameters
optimizer.step()
training_epoch_end(outs)
When training using a accelerator that splits data from each batch across GPUs, sometimes you might need to aggregate them on the master GPU for processing (dp, or ddp2).
In this case, implement the training_step_end method
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
pred = ...
return {'loss': loss, 'pred': pred}
def training_step_end(self, batch_parts):
gpu_0_prediction = batch_parts[0]['pred']
gpu_1_prediction = batch_parts[1]['pred']
# do something with both outputs
return (batch_parts[0]['loss'] + batch_parts[1]['loss']) / 2
def training_epoch_end(self, training_step_outputs):
for out in training_step_outputs:
# do something with preds
The full pseudocode that lighting does under the hood is:
outs = []
for train_batch in train_dataloader:
batches = split_batch(train_batch)
dp_outs = []
for sub_batch in batches:
# 1
dp_out = training_step(sub_batch)
dp_outs.append(dp_out)
# 2
out = training_step_end(dp_outs)
outs.append(out)
# do something with the outputs for all batches
# 3
training_epoch_end(outs)
To add a validation loop, override the validation_step method of the ~LightningModule
:
class LitModel(pl.LightningModule):
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
self.log('val_loss', loss)
Under the hood, Lightning does the following:
# ...
for batch in train_dataloader:
loss = model.training_step()
loss.backward()
# ...
if validate_at_some_point:
# disable grads + batchnorm + dropout
torch.set_grad_enabled(False)
model.eval()
# ----------------- VAL LOOP ---------------
for val_batch in model.val_dataloader:
val_out = model.validation_step(val_batch)
# ----------------- VAL LOOP ---------------
# enable grads + batchnorm + dropout
torch.set_grad_enabled(True)
model.train()
If you need to do something with all the outputs of each validation_step, override validation_epoch_end.
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
pred = ...
return pred
def validation_epoch_end(self, validation_step_outputs):
for pred in validation_step_outputs:
# do something with a pred
When training using a accelerator that splits data from each batch across GPUs, sometimes you might need to aggregate them on the master GPU for processing (dp, or ddp2).
In this case, implement the validation_step_end method
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
pred = ...
return {'loss': loss, 'pred': pred}
def validation_step_end(self, batch_parts):
gpu_0_prediction = batch_parts.pred[0]['pred']
gpu_1_prediction = batch_parts.pred[1]['pred']
# do something with both outputs
return (batch_parts[0]['loss'] + batch_parts[1]['loss']) / 2
def validation_epoch_end(self, validation_step_outputs):
for out in validation_step_outputs:
# do something with preds
The full pseudocode that lighting does under the hood is:
outs = []
for batch in dataloader:
batches = split_batch(batch)
dp_outs = []
for sub_batch in batches:
# 1
dp_out = validation_step(sub_batch)
dp_outs.append(dp_out)
# 2
out = validation_step_end(dp_outs)
outs.append(out)
# do something with the outputs for all batches
# 3
validation_epoch_end(outs)
The process for adding a test loop is the same as the process for adding a validation loop. Please refer to the section above for details.
The only difference is that the test loop is only called when .test() is used:
model = Model()
trainer = Trainer()
trainer.fit()
# automatically loads the best weights for you
trainer.test(model)
There are two ways to call `test()`:
# call after training
trainer = Trainer()
trainer.fit(model)
# automatically auto-loads the best weights
trainer.test(test_dataloaders=test_dataloader)
# or call with pretrained model
model = MyLightningModule.load_from_checkpoint(PATH)
trainer = Trainer()
trainer.test(model, test_dataloaders=test_dataloader)
For research, LightningModules are best structured as systems.
import pytorch_lightning as pl
import torch
from torch import nn
class Autoencoder(pl.LightningModule):
def __init__(self, latent_dim=2):
super().__init__()
self.encoder = nn.Sequential(nn.Linear(28 * 28, 256), nn.ReLU(), nn.Linear(256, latent_dim))
self.decoder = nn.Sequential(nn.Linear(latent_dim, 256), nn.ReLU(), nn.Linear(256, 28 * 28))
def training_step(self, batch, batch_idx):
x, _ = batch
# encode
x = x.view(x.size(0), -1)
z = self.encoder(x)
# decode
recons = self.decoder(z)
# reconstruction
reconstruction_loss = nn.functional.mse_loss(recons, x)
return reconstruction_loss
def validation_step(self, batch, batch_idx):
x, _ = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
recons = self.decoder(z)
reconstruction_loss = nn.functional.mse_loss(recons, x)
self.log('val_reconstruction', reconstruction_loss)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.0002)
Which can be trained like this:
autoencoder = Autoencoder()
trainer = pl.Trainer(gpus=1)
trainer.fit(autoencoder, train_dataloader, val_dataloader)
This simple model generates examples that look like this (the encoders and decoders are too weak)
The methods above are part of the lightning interface:
- training_step
- validation_step
- test_step
- configure_optimizers
Note that in this case, the train loop and val loop are exactly the same. We can of course reuse this code.
class Autoencoder(pl.LightningModule):
def __init__(self, latent_dim=2):
super().__init__()
self.encoder = nn.Sequential(nn.Linear(28 * 28, 256), nn.ReLU(), nn.Linear(256, latent_dim))
self.decoder = nn.Sequential(nn.Linear(latent_dim, 256), nn.ReLU(), nn.Linear(256, 28 * 28))
def training_step(self, batch, batch_idx):
loss = self.shared_step(batch)
return loss
def validation_step(self, batch, batch_idx):
loss = self.shared_step(batch)
self.log('val_loss', loss)
def shared_step(self, batch):
x, _ = batch
# encode
x = x.view(x.size(0), -1)
z = self.encoder(x)
# decode
recons = self.decoder(z)
# loss
return nn.functional.mse_loss(recons, x)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.0002)
We create a new method called shared_step that all loops can use. This method name is arbitrary and NOT reserved.
In the case where we want to perform inference with the system we can add a forward method to the LightningModule.
class Autoencoder(pl.LightningModule):
def forward(self, x):
return self.decoder(x)
The advantage of adding a forward is that in complex systems, you can do a much more involved inference procedure, such as text generation:
class Seq2Seq(pl.LightningModule):
def forward(self, x):
embeddings = self(x)
hidden_states = self.encoder(embeddings)
for h in hidden_states:
# decode
...
return decoded
For cases like production, you might want to iterate different models inside a LightningModule.
import pytorch_lightning as pl
from pytorch_lightning.metrics import functional as FM
class ClassificationTask(pl.LightningModule):
def __init__(self, model):
super().__init__()
self.model = model
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
acc = FM.accuracy(y_hat, y)
metrics = {'val_acc': acc, 'val_loss': loss}
self.log_dict(metrics)
return metrics
def test_step(self, batch, batch_idx):
metrics = self.validation_step(batch, batch_idx)
metrics = {'test_acc': metrics['val_acc'], 'test_loss': metrics['val_loss']}
self.log_dict(metrics)
def configure_optimizers(self):
return torch.optim.Adam(self.model.parameters(), lr=0.02)
Then pass in any arbitrary model to be fit with this task
for model in [resnet50(), vgg16(), BidirectionalRNN()]:
task = ClassificationTask(model)
trainer = Trainer(gpus=2)
trainer.fit(task, train_dataloader, val_dataloader)
Tasks can be arbitrarily complex such as implementing GAN training, self-supervised or even RL.
class GANTask(pl.LightningModule):
def __init__(self, generator, discriminator):
super().__init__()
self.generator = generator
self.discriminator = discriminator
...
When used like this, the model can be separated from the Task and thus used in production without needing to keep it in a LightningModule.
- You can export to onnx.
- Or trace using Jit.
- or run in the python runtime.
task = ClassificationTask(model)
trainer = Trainer(gpus=2)
trainer.fit(task, train_dataloader, val_dataloader)
# use model after training or load weights and drop into the production system
model.eval()
y_hat = model(x)
pytorch_lightning.core.lightning.LightningModule.configure_callbacks
pytorch_lightning.core.lightning.LightningModule.configure_optimizers
pytorch_lightning.core.lightning.LightningModule.forward
pytorch_lightning.core.lightning.LightningModule.freeze
pytorch_lightning.core.lightning.LightningModule.log
pytorch_lightning.core.lightning.LightningModule.log_dict
pytorch_lightning.core.lightning.LightningModule.manual_backward
pytorch_lightning.core.lightning.LightningModule.print
pytorch_lightning.core.lightning.LightningModule.predict_step
pytorch_lightning.core.lightning.LightningModule.save_hyperparameters
pytorch_lightning.core.lightning.LightningModule.test_step
pytorch_lightning.core.lightning.LightningModule.test_step_end
pytorch_lightning.core.lightning.LightningModule.test_epoch_end
pytorch_lightning.core.lightning.LightningModule.to_onnx
pytorch_lightning.core.lightning.LightningModule.to_torchscript
pytorch_lightning.core.lightning.LightningModule.training_step
pytorch_lightning.core.lightning.LightningModule.training_step_end
pytorch_lightning.core.lightning.LightningModule.training_epoch_end
pytorch_lightning.core.lightning.LightningModule.unfreeze
pytorch_lightning.core.lightning.LightningModule.validation_step
pytorch_lightning.core.lightning.LightningModule.validation_step_end
pytorch_lightning.core.lightning.LightningModule.validation_epoch_end
pytorch_lightning.core.lightning.LightningModule.write_prediction
pytorch_lightning.core.lightning.LightningModule.write_prediction_dict
These are properties available in a LightningModule.
The current epoch
def training_step(...):
if self.current_epoch == 0:
The device the module is on. Use it to keep your code device agnostic
def training_step(...):
z = torch.rand(2, 3, device=self.device)
The global_rank of this LightningModule. Lightning saves logs, weights etc only from global_rank = 0. You normally do not need to use this property
Global rank refers to the index of that GPU across ALL GPUs. For example, if using 10 machines, each with 4 GPUs, the 4th GPU on the 10th machine has global_rank = 39
The current step (does not reset each epoch)
def training_step(...):
self.logger.experiment.log_image(..., step=self.global_step)
- The arguments saved by calling
save_hyperparameters
passed through__init__()
could be accessed by the
hparams
attribute.
def __init__(self, learning_rate):
self.save_hyperparameters()
def configure_optimizers(self):
return Adam(self.parameters(), lr=self.hparams.learning_rate)
The current logger being used (tensorboard or other supported logger)
def training_step(...):
# the generic logger (same no matter if tensorboard or other supported logger)
self.logger
# the particular logger
tensorboard_logger = self.logger.experiment
The local_rank of this LightningModule. Lightning saves logs, weights etc only from global_rank = 0. You normally do not need to use this property
Local rank refers to the rank on that machine. For example, if using 10 machines, the GPU at index 0 on each machine has local_rank = 0.
The type of precision used:
def training_step(...):
if self.precision == 16:
Pointer to the trainer
def training_step(...):
max_steps = self.trainer.max_steps
any_flag = self.trainer.any_flag
True if using Automatic Mixed Precision (AMP)
When set to False
, Lightning does not automate the optimization process. This means you are responsible for handling your optimizers. However, we do take care of precision and any accelerators used.
See manual optimization<common/optimizers:Manual optimization>
for details.
def __init__(self):
self.automatic_optimization = False
def training_step(self, batch, batch_idx):
opt = self.optimizers(use_pl_optimizer=True)
loss = ...
opt.zero_grad()
self.manual_backward(loss)
opt.step()
This is recommended only if using 2+ optimizers AND if you know how to perform the optimization procedure properly. Note that automatic optimization can still be used with multiple optimizers by relying on the optimizer_idx
parameter. Manual optimization is most useful for research topics like reinforcement learning, sparse coding, and GAN research.
def __init__(self):
self.automatic_optimization = False
def training_step(self, batch, batch_idx):
# access your optimizers with use_pl_optimizer=False. Default is True
opt_a, opt_b = self.optimizers(use_pl_optimizer=True)
gen_loss = ...
opt_a.zero_grad()
self.manual_backward(gen_loss)
opt_a.step()
disc_loss = ...
opt_b.zero_grad()
self.manual_backward(disc_loss)
opt_b.step()
Set and access example_input_array which is basically a single batch.
def __init__(self):
self.example_input_array = ...
self.generator = ...
def on_train_epoch_end(...):
# generate some images using the example_input_array
gen_images = self.generator(self.example_input_array)
Set or access your datamodule.
def configure_optimizers(self):
num_training_samples = len(self.trainer.datamodule.train_dataloader())
...
Get the model file size (in megabytes) using self.model_size
inside LightningModule.
Truncated back prop breaks performs backprop every k steps of a much longer sequence.
If this is enabled, your batches will automatically get truncated and the trainer will apply Truncated Backprop to it.
python
from pytorch_lightning import LightningModule
class MyModel(LightningModule):
- def __init__(self):
super().__init__() # Important: This property activates truncated backpropagation through time # Setting this value to 2 splits the batch into sequences of size 2 self.truncated_bptt_steps = 2
# Truncated back-propagation through time def training_step(self, batch, batch_idx, hiddens): # the training step must be updated to accept a
hiddens
argument # hiddens are the hiddens from the previous truncated backprop step out, hiddens = self.lstm(data, hiddens) return { "loss": ..., "hiddens": hiddens }
Lightning takes care to split your batch along the time-dimension.
# we use the second as the time dimension
# (batch, time, ...)
sub_batch = batch[0, 0:t, ...]
To modify how the batch is split, override pytorch_lightning.core.LightningModule.tbptt_split_batch
:
python
- class LitMNIST(LightningModule):
- def tbptt_split_batch(self, batch, split_size):
# do your own splitting on the batch return splits
This is the pseudocode to describe how all the hooks are called during a call to .fit()
.
def fit(...):
if global_rank == 0:
# prepare data is called on GLOBAL_ZERO only
prepare_data()
configure_callbacks()
on_fit_start()
for gpu/tpu in gpu/tpus:
train_on_device(model.copy())
on_fit_end()
def train_on_device(model):
# setup is called PER DEVICE
setup()
configure_optimizers()
on_pretrain_routine_start()
for epoch in epochs:
train_loop()
teardown()
def train_loop():
on_epoch_start()
on_train_epoch_start()
train_outs = []
for train_batch in train_dataloader():
on_train_batch_start()
# ----- train_step methods -------
out = training_step(batch)
train_outs.append(out)
loss = out.loss
on_before_zero_grad()
optimizer_zero_grad()
backward()
on_after_backward()
optimizer_step()
on_train_batch_end(out)
if should_check_val:
val_loop()
# end training epoch
training_epoch_end(outs)
on_train_epoch_end(outs)
on_epoch_end()
def val_loop():
model.eval()
torch.set_grad_enabled(False)
on_epoch_start()
on_validation_epoch_start()
val_outs = []
for val_batch in val_dataloader():
on_validation_batch_start()
# -------- val step methods -------
out = validation_step(val_batch)
val_outs.append(out)
on_validation_batch_end(out)
validation_epoch_end(val_outs)
on_validation_epoch_end()
on_epoch_end()
# set up for train
model.train()
torch.set_grad_enabled(True)
pytorch_lightning.core.lightning.LightningModule.backward
pytorch_lightning.core.lightning.LightningModule.get_progress_bar_dict
pytorch_lightning.core.hooks.ModelHooks.on_after_backward
pytorch_lightning.core.hooks.ModelHooks.on_before_zero_grad
pytorch_lightning.core.hooks.ModelHooks.on_fit_start
pytorch_lightning.core.hooks.ModelHooks.on_fit_end
pytorch_lightning.core.hooks.CheckpointHooks.on_load_checkpoint
pytorch_lightning.core.hooks.CheckpointHooks.on_save_checkpoint
pytorch_lightning.core.hooks.ModelHooks.on_train_start
pytorch_lightning.core.hooks.ModelHooks.on_train_end
pytorch_lightning.core.hooks.ModelHooks.on_validation_start
pytorch_lightning.core.hooks.ModelHooks.on_validation_end
pytorch_lightning.core.hooks.ModelHooks.on_pretrain_routine_start
pytorch_lightning.core.hooks.ModelHooks.on_pretrain_routine_end
pytorch_lightning.core.hooks.ModelHooks.on_test_batch_start
pytorch_lightning.core.hooks.ModelHooks.on_test_batch_end
pytorch_lightning.core.hooks.ModelHooks.on_test_epoch_start
pytorch_lightning.core.hooks.ModelHooks.on_test_epoch_end
pytorch_lightning.core.hooks.ModelHooks.on_test_end
pytorch_lightning.core.hooks.ModelHooks.on_train_batch_start
pytorch_lightning.core.hooks.ModelHooks.on_train_batch_end
pytorch_lightning.core.hooks.ModelHooks.on_epoch_start
pytorch_lightning.core.hooks.ModelHooks.on_epoch_end
pytorch_lightning.core.hooks.ModelHooks.on_train_epoch_start
pytorch_lightning.core.hooks.ModelHooks.on_train_epoch_end
pytorch_lightning.core.hooks.ModelHooks.on_validation_batch_start
pytorch_lightning.core.hooks.ModelHooks.on_validation_batch_end
pytorch_lightning.core.hooks.ModelHooks.on_validation_epoch_start
pytorch_lightning.core.hooks.ModelHooks.on_validation_epoch_end
pytorch_lightning.core.hooks.ModelHooks.on_post_move_to_device
pytorch_lightning.core.hooks.ModelHooks.on_validation_model_eval
pytorch_lightning.core.hooks.ModelHooks.on_validation_model_train
pytorch_lightning.core.hooks.ModelHooks.on_test_model_eval
pytorch_lightning.core.hooks.ModelHooks.on_test_model_train
pytorch_lightning.core.lightning.LightningModule.optimizer_step
pytorch_lightning.core.lightning.LightningModule.optimizer_zero_grad
pytorch_lightning.core.lightning.LightningModule.prepare_data
pytorch_lightning.core.hooks.DataHooks.setup
pytorch_lightning.core.lightning.LightningModule.tbptt_split_batch
pytorch_lightning.core.hooks.DataHooks.teardown
pytorch_lightning.core.hooks.DataHooks.train_dataloader
pytorch_lightning.core.hooks.DataHooks.val_dataloader
pytorch_lightning.core.hooks.DataHooks.test_dataloader
pytorch_lightning.core.hooks.DataHooks.transfer_batch_to_device
pytorch_lightning.core.hooks.DataHooks.on_before_batch_transfer
pytorch_lightning.core.hooks.DataHooks.on_after_batch_transfer