-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Luca Antiga <luca.antiga@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
f2caa01
commit 24c0cd7
Showing
4 changed files
with
632 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
## Build Your Own Trainer (BYOT) | ||
|
||
This example demonstrates how easy it is to build a fully customizable trainer for your `LightningModule` using `Fabric`. | ||
It is built upon `lightning.fabric` for hardware and training orchestration and consists of two files: | ||
|
||
- trainer.py contains the actual `MyCustomTrainer` implementation | ||
- run.py contains a script utilizing this trainer for training a very simple MNIST module. | ||
|
||
### Run | ||
|
||
To run this example, call `python run.py` | ||
|
||
### Requirements | ||
|
||
This example has the following requirements which need to be installed on your python environment: | ||
|
||
- `lightning` | ||
- `torchmetrics` | ||
- `torch` | ||
- `torchvision` | ||
- `tqdm` | ||
|
||
to install them with the appropriate versions run: | ||
|
||
```bash | ||
pip install "lightning>=2.0" "torchmetrics>=0.11" "torchvision>=0.14" "torch>=1.13" tqdm | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import torch | ||
from torchmetrics.functional.classification.accuracy import accuracy | ||
from trainer import MyCustomTrainer | ||
|
||
import lightning as L | ||
|
||
|
||
class MNISTModule(L.LightningModule): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
self.model = torch.nn.Sequential( | ||
torch.nn.Conv2d( | ||
in_channels=1, | ||
out_channels=16, | ||
kernel_size=5, | ||
stride=1, | ||
padding=2, | ||
), | ||
torch.nn.ReLU(), | ||
torch.nn.MaxPool2d(kernel_size=2), | ||
torch.nn.Conv2d(16, 32, 5, 1, 2), | ||
torch.nn.ReLU(), | ||
torch.nn.MaxPool2d(2), | ||
torch.nn.Flatten(), | ||
# fully connected layer, output 10 classes | ||
torch.nn.Linear(32 * 7 * 7, 10), | ||
) | ||
self.loss_fn = torch.nn.CrossEntropyLoss() | ||
|
||
def forward(self, x: torch.Tensor): | ||
return self.model(x) | ||
|
||
def training_step(self, batch, batch_idx: int): | ||
x, y = batch | ||
|
||
logits = self(x) | ||
|
||
loss = self.loss_fn(logits, y) | ||
accuracy_train = accuracy(logits.argmax(-1), y, num_classes=10, task="multiclass", top_k=1) | ||
|
||
return {"loss": loss, "accuracy": accuracy_train} | ||
|
||
def configure_optimizers(self): | ||
|
||
optim = torch.optim.Adam(self.parameters(), lr=1e-4) | ||
return optim, { | ||
"scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(optim, mode="max", verbose=True), | ||
"monitor": "val_accuracy", | ||
"interval": "epoch", | ||
"frequency": 1, | ||
} | ||
|
||
def validation_step(self, *args, **kwargs): | ||
return self.training_step(*args, **kwargs) | ||
|
||
|
||
def train(model): | ||
from torchvision.datasets import MNIST | ||
from torchvision.transforms import ToTensor | ||
|
||
train_set = MNIST(root="/tmp/data/MNIST", train=True, transform=ToTensor(), download=True) | ||
val_set = MNIST(root="/tmp/data/MNIST", train=False, transform=ToTensor(), download=False) | ||
|
||
train_loader = torch.utils.data.DataLoader( | ||
train_set, batch_size=64, shuffle=True, pin_memory=torch.cuda.is_available(), num_workers=4 | ||
) | ||
val_loader = torch.utils.data.DataLoader( | ||
val_set, batch_size=64, shuffle=False, pin_memory=torch.cuda.is_available(), num_workers=4 | ||
) | ||
|
||
# MPS backend currently does not support all operations used in this example. | ||
# If you want to use MPS, set accelerator='auto' and also set PYTORCH_ENABLE_MPS_FALLBACK=1 | ||
accelerator = "cpu" if torch.backends.mps.is_available() else "auto" | ||
|
||
trainer = MyCustomTrainer( | ||
accelerator=accelerator, devices="auto", limit_train_batches=10, limit_val_batches=20, max_epochs=3 | ||
) | ||
trainer.fit(model, train_loader, val_loader) | ||
|
||
|
||
if __name__ == "__main__": | ||
train(MNISTModule()) |
Oops, something went wrong.