A ~LightningModule
organizes your PyTorch code into 5 sections
- Computations (init).
- Train loop (training_step)
- Validation loop (validation_step)
- Test loop (test_step)
- Optimizers (configure_optimizers)
Notice a few things.
- It's the SAME code.
- The PyTorch code IS NOT abstracted - just organized.
- All the other code that's not in the
~LightningModule
has been automated for you by the trainer.
net = Net() trainer = Trainer() trainer.fit(net)
- There are no .cuda() or .to() calls... Lightning does these for you.
# don't do in lightning x = torch.Tensor(2, 3) x = x.cuda() x = x.to(device) # do this instead x = x # leave it alone! # or to init a new tensor new_x = torch.Tensor(2, 3) new_x = new_x.type_as(x)
- There are no samplers for distributed, Lightning also does this for you.
# Don't do in Lightning... data = MNIST(...) sampler = DistributedSampler(data) DataLoader(data, sampler=sampler) # do this instead data = MNIST(...) DataLoader(data)
- A
~LightningModule
is atorch.nn.Module
but with added functionality. Use it as such!
net = Net.load_from_checkpoint(PATH) net.freeze() out = net(x)
Thus, to use Lightning, you just need to organize your code which takes about 30 minutes, (and let's be real, you probably should do anyhow).
Here are the only required methods.
>>> import pytorch_lightning as pl
>>> class LitModel(pl.LightningModule):
...
... def __init__(self):
... super().__init__()
... self.l1 = torch.nn.Linear(28 * 28, 10)
...
... def forward(self, x):
... return torch.relu(self.l1(x.view(x.size(0), -1)))
...
... def training_step(self, batch, batch_idx):
... x, y = batch
... y_hat = self(x)
... loss = F.cross_entropy(y_hat, y)
... return loss
...
... def configure_optimizers(self):
... return torch.optim.Adam(self.parameters(), lr=0.02)
Which you can train by doing:
train_loader = DataLoader(MNIST(os.getcwd(), download=True, transform=transforms.ToTensor()))
trainer = pl.Trainer()
model = LitModel()
trainer.fit(model, train_loader)
The LightningModule has many convenience methods, but the core ones you need to know about are:
Name | Description |
---|---|
init | Define computations here |
forward | Use for inference only (separate from training_step) |
training_step | the full training loop |
validation_step | the full validation loop |
test_step | the full test loop |
configure_optimizers | define optimizers and LR schedulers |
To add a training loop use the training_step method
class LitClassifier(pl.LightningModule):
def __init__(self, model):
super().__init__()
self.model = model
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
return loss
Under the hood, Lightning does the following (pseudocode):
# put model in train mode
model.train()
torch.set_grad_enabled(True)
losses = []
for batch in train_dataloader:
# forward
loss = training_step(batch)
losses.append(loss.detach())
# backward
loss.backward()
# apply and clear grads
optimizer.step()
optimizer.zero_grad()
If you want to calculate epoch-level metrics and log them, use the .log method
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
# logs metrics for each training_step,
# and the average across the epoch, to the progress bar and logger
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss
The .log object automatically reduces the requested metrics across the full epoch. Here's the pseudocode of what it does under the hood:
outs = []
for batch in train_dataloader:
# forward
out = training_step(val_batch)
# backward
loss.backward()
# apply and clear grads
optimizer.step()
optimizer.zero_grad()
epoch_metric = torch.mean(torch.stack([x['train_loss'] for x in outs]))
If you need to do something with all the outputs of each training_step, override training_epoch_end yourself.
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
preds = ...
return {'loss': loss, 'other_stuff': preds}
def training_epoch_end(self, training_step_outputs):
for pred in training_step_outputs:
# do something
The matching pseudocode is:
outs = []
for batch in train_dataloader:
# forward
out = training_step(val_batch)
# backward
loss.backward()
# apply and clear grads
optimizer.step()
optimizer.zero_grad()
training_epoch_end(outs)
When training using a accelerator that splits data from each batch across GPUs, sometimes you might need to aggregate them on the master GPU for processing (dp, or ddp2).
In this case, implement the training_step_end method
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
pred = ...
return {'loss': loss, 'pred': pred}
def training_step_end(self, batch_parts):
gpu_0_prediction = batch_parts[0]['pred']
gpu_1_prediction = batch_parts[1]['pred']
# do something with both outputs
return (batch_parts[0]['loss'] + batch_parts[1]['loss']) / 2
def training_epoch_end(self, training_step_outputs):
for out in training_step_outputs:
# do something with preds
The full pseudocode that lighting does under the hood is:
outs = []
for train_batch in train_dataloader:
batches = split_batch(train_batch)
dp_outs = []
for sub_batch in batches:
# 1
dp_out = training_step(sub_batch)
dp_outs.append(dp_out)
# 2
out = training_step_end(dp_outs)
outs.append(out)
# do something with the outputs for all batches
# 3
training_epoch_end(outs)
To add a validation loop, override the validation_step method of the ~LightningModule
:
class LitModel(pl.LightningModule):
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
self.log('val_loss', loss)
Under the hood, Lightning does the following:
# ...
for batch in train_dataloader:
loss = model.training_step()
loss.backward()
# ...
if validate_at_some_point:
# disable grads + batchnorm + dropout
torch.set_grad_enabled(False)
model.eval()
# ----------------- VAL LOOP ---------------
for val_batch in model.val_dataloader:
val_out = model.validation_step(val_batch)
# ----------------- VAL LOOP ---------------
# enable grads + batchnorm + dropout
torch.set_grad_enabled(True)
model.train()
If you need to do something with all the outputs of each validation_step, override validation_epoch_end.
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
pred = ...
return pred
def validation_epoch_end(self, validation_step_outputs):
for pred in validation_step_outputs:
# do something with a pred
When training using a accelerator that splits data from each batch across GPUs, sometimes you might need to aggregate them on the master GPU for processing (dp, or ddp2).
In this case, implement the validation_step_end method
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
pred = ...
return {'loss': loss, 'pred': pred}
def validation_step_end(self, batch_parts):
gpu_0_prediction = batch_parts.pred[0]['pred']
gpu_1_prediction = batch_parts.pred[1]['pred']
# do something with both outputs
return (batch_parts[0]['loss'] + batch_parts[1]['loss']) / 2
def validation_epoch_end(self, validation_step_outputs):
for out in validation_step_outputs:
# do something with preds
The full pseudocode that lighting does under the hood is:
outs = []
for batch in dataloader:
batches = split_batch(batch)
dp_outs = []
for sub_batch in batches:
# 1
dp_out = validation_step(sub_batch)
dp_outs.append(dp_out)
# 2
out = validation_step_end(dp_outs)
outs.append(out)
# do something with the outputs for all batches
# 3
validation_epoch_end(outs)
The process for adding a test loop is the same as the process for adding a validation loop. Please refer to the section above for details.
The only difference is that the test loop is only called when .test() is used:
model = Model()
trainer = Trainer()
trainer.fit()
# automatically loads the best weights for you
trainer.test(model)
There are two ways to call `test()`:
# call after training
trainer = Trainer()
trainer.fit(model)
# automatically auto-loads the best weights
trainer.test(test_dataloaders=test_dataloader)
# or call with pretrained model
model = MyLightningModule.load_from_checkpoint(PATH)
trainer = Trainer()
trainer.test(model, test_dataloaders=test_dataloader)
For research, LightningModules are best structured as systems.
import pytorch_lightning as pl
import torch
from torch import nn
class Autoencoder(pl.LightningModule):
def __init__(self, latent_dim=2):
super().__init__()
self.encoder = nn.Sequential(nn.Linear(28 * 28, 256), nn.ReLU(), nn.Linear(256, latent_dim))
self.decoder = nn.Sequential(nn.Linear(latent_dim, 256), nn.ReLU(), nn.Linear(256, 28 * 28))
def training_step(self, batch, batch_idx):
x, _ = batch
# encode
x = x.view(x.size(0), -1)
z = self.encoder(x)
# decode
recons = self.decoder(z)
# reconstruction
reconstruction_loss = nn.functional.mse_loss(recons, x)
return reconstruction_loss
def validation_step(self, batch, batch_idx):
x, _ = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
recons = self.decoder(z)
reconstruction_loss = nn.functional.mse_loss(recons, x)
self.log('val_reconstruction', reconstruction_loss)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.0002)
Which can be trained like this:
autoencoder = Autoencoder()
trainer = pl.Trainer(gpus=1)
trainer.fit(autoencoder, train_dataloader, val_dataloader)
This simple model generates examples that look like this (the encoders and decoders are too weak)
The methods above are part of the lightning interface:
- training_step
- validation_step
- test_step
- configure_optimizers
Note that in this case, the train loop and val loop are exactly the same. We can of course reuse this code.
class Autoencoder(pl.LightningModule):
def __init__(self, latent_dim=2):
super().__init__()
self.encoder = nn.Sequential(nn.Linear(28 * 28, 256), nn.ReLU(), nn.Linear(256, latent_dim))
self.decoder = nn.Sequential(nn.Linear(latent_dim, 256), nn.ReLU(), nn.Linear(256, 28 * 28))
def training_step(self, batch, batch_idx):
loss = self.shared_step(batch)
return loss
def validation_step(self, batch, batch_idx):
loss = self.shared_step(batch)
self.log('val_loss', loss)
def shared_step(self, batch):
x, _ = batch
# encode
x = x.view(x.size(0), -1)
z = self.encoder(x)
# decode
recons = self.decoder(z)
# loss
return nn.functional.mse_loss(recons, x)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.0002)
We create a new method called shared_step that all loops can use. This method name is arbitrary and NOT reserved.
In the case where we want to perform inference with the system we can add a forward method to the LightningModule.
class Autoencoder(pl.LightningModule):
def forward(self, x):
return self.decoder(x)
The advantage of adding a forward is that in complex systems, you can do a much more involved inference procedure, such as text generation:
class Seq2Seq(pl.LightningModule):
def forward(self, x):
embeddings = self(x)
hidden_states = self.encoder(embeddings)
for h in hidden_states:
# decode
...
return decoded
For cases like production, you might want to iterate different models inside a LightningModule.
import pytorch_lightning as pl
from pytorch_lightning.metrics import functional as FM
class ClassificationTask(pl.LightningModule):
def __init__(self, model):
super().__init__()
self.model = model
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
acc = FM.accuracy(y_hat, y)
# loss is tensor. The Checkpoint Callback is monitoring 'checkpoint_on'
metrics = {'val_acc': acc, 'val_loss': loss}
self.log_dict(metrics)
return metrics
def test_step(self, batch, batch_idx):
metrics = self.validation_step(batch, batch_idx)
metrics = {'test_acc': metrics['val_acc'], 'test_loss': metrics['val_loss']}
self.log_dict(metrics)
def configure_optimizers(self):
return torch.optim.Adam(self.model.parameters(), lr=0.02)
Then pass in any arbitrary model to be fit with this task
for model in [resnet50(), vgg16(), BidirectionalRNN()]:
task = ClassificationTask(model)
trainer = Trainer(gpus=2)
trainer.fit(task, train_dataloader, val_dataloader)
Tasks can be arbitrarily complex such as implementing GAN training, self-supervised or even RL.
class GANTask(pl.LightningModule):
def __init__(self, generator, discriminator):
super().__init__()
self.generator = generator
self.discriminator = discriminator
...
When used like this, the model can be separated from the Task and thus used in production without needing to keep it in a LightningModule.
- You can export to onnx.
- Or trace using Jit.
- or run in the python runtime.
task = ClassificationTask(model)
trainer = Trainer(gpus=2)
trainer.fit(task, train_dataloader, val_dataloader)
# use model after training or load weights and drop into the production system
model.eval()
y_hat = model(x)
pytorch_lightning.core.lightning.LightningModule.configure_optimizers
pytorch_lightning.core.lightning.LightningModule.forward
pytorch_lightning.core.lightning.LightningModule.freeze
pytorch_lightning.core.lightning.LightningModule.log
pytorch_lightning.core.lightning.LightningModule.log_dict
pytorch_lightning.core.lightning.LightningModule.print
pytorch_lightning.core.lightning.LightningModule.save_hyperparameters
pytorch_lightning.core.lightning.LightningModule.test_step
pytorch_lightning.core.lightning.LightningModule.test_step_end
pytorch_lightning.core.lightning.LightningModule.test_epoch_end
pytorch_lightning.core.lightning.LightningModule.to_onnx
pytorch_lightning.core.lightning.LightningModule.to_torchscript
pytorch_lightning.core.lightning.LightningModule.training_step
pytorch_lightning.core.lightning.LightningModule.training_step_end
pytorch_lightning.core.lightning.LightningModule.training_epoch_end
pytorch_lightning.core.lightning.LightningModule.unfreeze
pytorch_lightning.core.lightning.LightningModule.validation_step
pytorch_lightning.core.lightning.LightningModule.validation_step_end
pytorch_lightning.core.lightning.LightningModule.validation_epoch_end
These are properties available in a LightningModule.
The current epoch
def training_step(...):
if self.current_epoch == 0:
The device the module is on. Use it to keep your code device agnostic
def training_step(...):
z = torch.rand(2, 3, device=self.device)
The global_rank of this LightningModule. Lightning saves logs, weights etc only from global_rank = 0. You normally do not need to use this property
Global rank refers to the index of that GPU across ALL GPUs. For example, if using 10 machines, each with 4 GPUs, the 4th GPU on the 10th machine has global_rank = 39
The current step (does not reset each epoch)
def training_step(...):
self.logger.experiment.log_image(..., step=self.global_step)
After calling save_hyperparameters anything passed to init() is available via hparams.
def __init__(self, learning_rate):
self.save_hyperparameters()
def configure_optimizers(self):
return Adam(self.parameters(), lr=self.hparams.learning_rate)
The current logger being used (tensorboard or other supported logger)
def training_step(...):
# the generic logger (same no matter if tensorboard or other supported logger)
self.logger
# the particular logger
tensorboard_logger = self.logger.experiment
The local_rank of this LightningModule. Lightning saves logs, weights etc only from global_rank = 0. You normally do not need to use this property
Local rank refers to the rank on that machine. For example, if using 10 machines, the GPU at index 0 on each machine has local_rank = 0.
The type of precision used:
def training_step(...):
if self.precision == 16:
Pointer to the trainer
def training_step(...):
max_steps = self.trainer.max_steps
any_flag = self.trainer.any_flag
True if using Automatic Mixed Precision (AMP)
True if using ddp
True if using ddp2
True if using dp
True if using TPUs
This is the pseudocode to describe how all the hooks are called during a call to .fit()
def fit(...):
on_fit_start()
if global_rank == 0:
# prepare data is called on GLOBAL_ZERO only
prepare_data()
for gpu/tpu in gpu/tpus:
train_on_device(model.copy())
on_fit_end()
def train_on_device(model):
# setup is called PER DEVICE
setup()
configure_optimizers()
on_pretrain_routine_start()
for epoch in epochs:
train_loop()
teardown()
def train_loop():
on_train_epoch_start()
train_outs = []
for train_batch in train_dataloader():
on_train_batch_start()
# ----- train_step methods -------
out = training_step(batch)
train_outs.append(out)
loss = out.loss
backward()
on_after_backward()
optimizer_step()
on_before_zero_grad()
optimizer_zero_grad()
on_train_batch_end(out)
if should_check_val:
val_loop()
# end training epoch
logs = training_epoch_end(outs)
def val_loop():
model.eval()
torch.set_grad_enabled(False)
on_validation_epoch_start()
val_outs = []
for val_batch in val_dataloader():
on_validation_batch_start()
# -------- val step methods -------
out = validation_step(val_batch)
val_outs.append(out)
on_validation_batch_end(out)
validation_epoch_end(val_outs)
on_validation_epoch_end()
# set up for train
model.train()
torch.set_grad_enabled(True)
pytorch_lightning.core.lightning.LightningModule.backward
pytorch_lightning.core.lightning.LightningModule.get_progress_bar_dict
pytorch_lightning.core.lightning.LightningModule.manual_backward
pytorch_lightning.core.lightning.LightningModule.manual_optimizer_step
pytorch_lightning.core.lightning.LightningModule.on_after_backward
pytorch_lightning.core.lightning.LightningModule.on_before_zero_grad
pytorch_lightning.core.hooks.ModelHooks.on_fit_start
pytorch_lightning.core.hooks.ModelHooks.on_fit_end
pytorch_lightning.core.lightning.LightningModule.on_load_checkpoint
pytorch_lightning.core.lightning.LightningModule.on_save_checkpoint
pytorch_lightning.core.hooks.ModelHooks.on_pretrain_routine_start
pytorch_lightning.core.hooks.ModelHooks.on_pretrain_routine_end
pytorch_lightning.core.hooks.ModelHooks.on_test_batch_start
pytorch_lightning.core.hooks.ModelHooks.on_test_batch_end
pytorch_lightning.core.hooks.ModelHooks.on_test_epoch_start
pytorch_lightning.core.hooks.ModelHooks.on_test_epoch_end
pytorch_lightning.core.hooks.ModelHooks.on_train_batch_start
pytorch_lightning.core.hooks.ModelHooks.on_train_batch_end
pytorch_lightning.core.hooks.ModelHooks.on_train_epoch_start
pytorch_lightning.core.hooks.ModelHooks.on_train_epoch_end
pytorch_lightning.core.hooks.ModelHooks.on_validation_batch_start
pytorch_lightning.core.hooks.ModelHooks.on_validation_batch_end
pytorch_lightning.core.hooks.ModelHooks.on_validation_epoch_start
pytorch_lightning.core.hooks.ModelHooks.on_validation_epoch_end
pytorch_lightning.core.lightning.LightningModule.optimizer_step
pytorch_lightning.core.lightning.LightningModule.optimizer_zero_grad
pytorch_lightning.core.lightning.LightningModule.prepare_data
pytorch_lightning.core.hooks.ModelHooks.setup
pytorch_lightning.core.lightning.LightningModule.tbptt_split_batch
pytorch_lightning.core.hooks.ModelHooks.teardown
pytorch_lightning.core.lightning.LightningModule.train_dataloader
pytorch_lightning.core.lightning.LightningModule.val_dataloader
pytorch_lightning.core.lightning.LightningModule.test_dataloader
pytorch_lightning.core.hooks.DataHooks.transfer_batch_to_device