-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a notebook example to reach a quick baseline of ~94% accuracy on …
…CIFAR (#4818) * Add a notebook example to reach a quick baseline of ~94% accuracy on CIFAR10 using Resnet in Lightning * Remove outputs * PR Feedback * some changes * some more changes Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: rohitgr7 <rohitgr1998@gmail.com>
- Loading branch information
1 parent
4ebce38
commit 820d5c7
Showing
2 changed files
with
402 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,394 @@ | ||
{ | ||
"nbformat": 4, | ||
"nbformat_minor": 0, | ||
"metadata": { | ||
"accelerator": "GPU", | ||
"colab": { | ||
"name": "06_cifar10_baseline.ipynb", | ||
"provenance": [], | ||
"collapsed_sections": [] | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"name": "python3" | ||
} | ||
}, | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "qMDj0BYNECU8" | ||
}, | ||
"source": [ | ||
"<a href=\"https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/06-cifar10-pytorch-lightning.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "ECu0zDh8UXU8" | ||
}, | ||
"source": [ | ||
"# PyTorch Lightning CIFAR10 ~94% Baseline Tutorial ⚡\n", | ||
"\n", | ||
"Train a Resnet to 94% accuracy on Cifar10!\n", | ||
"\n", | ||
"Main takeaways:\n", | ||
"1. Experiment with different Learning Rate schedules and frequencies in the configure_optimizers method in pl.LightningModule\n", | ||
"2. Use an existing Resnet architecture with modifications directly with Lightning\n", | ||
"\n", | ||
"---\n", | ||
"\n", | ||
" - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", | ||
" - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n", | ||
" - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "HYpMlx7apuHq" | ||
}, | ||
"source": [ | ||
"### Setup\n", | ||
"Lightning is easy to install. Simply `pip install pytorch-lightning`.\n", | ||
"Also check out [bolts](https://github.com/PyTorchLightning/pytorch-lightning-bolts/) for pre-existing data modules and models." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "ziAQCrE-TYWG" | ||
}, | ||
"source": [ | ||
"! pip install pytorch-lightning pytorch-lightning-bolts -qU" | ||
], | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "L-W_Gq2FORoU" | ||
}, | ||
"source": [ | ||
"# Run this if you intend to use TPUs\n", | ||
"# !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py\n", | ||
"# !python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev" | ||
], | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "wjov-2N_TgeS" | ||
}, | ||
"source": [ | ||
"import torch\n", | ||
"import torch.nn as nn\n", | ||
"import torch.nn.functional as F\n", | ||
"from torch.optim.lr_scheduler import OneCycleLR\n", | ||
"from torch.optim.swa_utils import AveragedModel, update_bn\n", | ||
"import torchvision\n", | ||
"\n", | ||
"import pytorch_lightning as pl\n", | ||
"from pytorch_lightning.callbacks import LearningRateMonitor\n", | ||
"from pytorch_lightning.metrics.functional import accuracy\n", | ||
"from pl_bolts.datamodules import CIFAR10DataModule\n", | ||
"from pl_bolts.transforms.dataset_normalizations import cifar10_normalization" | ||
], | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "54JMU1N-0y0g" | ||
}, | ||
"source": [ | ||
"pl.seed_everything(7);" | ||
], | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "FA90qwFcqIXR" | ||
}, | ||
"source": [ | ||
"### CIFAR10 Data Module\n", | ||
"\n", | ||
"Import the existing data module from `bolts` and modify the train and test transforms." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "S9e-W8CSa8nH" | ||
}, | ||
"source": [ | ||
"batch_size = 32\n", | ||
"\n", | ||
"train_transforms = torchvision.transforms.Compose([\n", | ||
" torchvision.transforms.RandomCrop(32, padding=4),\n", | ||
" torchvision.transforms.RandomHorizontalFlip(),\n", | ||
" torchvision.transforms.ToTensor(),\n", | ||
" cifar10_normalization(),\n", | ||
"])\n", | ||
"\n", | ||
"test_transforms = torchvision.transforms.Compose([\n", | ||
" torchvision.transforms.ToTensor(),\n", | ||
" cifar10_normalization(),\n", | ||
"])\n", | ||
"\n", | ||
"cifar10_dm = CIFAR10DataModule(\n", | ||
" batch_size=batch_size,\n", | ||
" train_transforms=train_transforms,\n", | ||
" test_transforms=test_transforms,\n", | ||
" val_transforms=test_transforms,\n", | ||
")" | ||
], | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "SfCsutp3qUMc" | ||
}, | ||
"source": [ | ||
"### Resnet\n", | ||
"Modify the pre-existing Resnet architecture from TorchVision. The pre-existing architecture is based on ImageNet images (224x224) as input. So we need to modify it for CIFAR10 images (32x32)." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "GNSeJgwvhHp-" | ||
}, | ||
"source": [ | ||
"def create_model():\n", | ||
" model = torchvision.models.resnet18(pretrained=False, num_classes=10)\n", | ||
" model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | ||
" model.maxpool = nn.Identity()\n", | ||
" return model" | ||
], | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "HUCj5TKsqty1" | ||
}, | ||
"source": [ | ||
"### Lightning Module\n", | ||
"Check out the [`configure_optimizers`](https://pytorch-lightning.readthedocs.io/en/stable/lightning_module.html#configure-optimizers) method to use custom Learning Rate schedulers. The OneCycleLR with SGD will get you to around 92-93% accuracy in 20-30 epochs and 93-94% accuracy in 40-50 epochs. Feel free to experiment with different LR schedules from https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "03OMrBa5iGtT" | ||
}, | ||
"source": [ | ||
"class LitResnet(pl.LightningModule):\n", | ||
" def __init__(self, lr=0.05):\n", | ||
" super().__init__()\n", | ||
"\n", | ||
" self.save_hyperparameters()\n", | ||
" self.model = create_model()\n", | ||
"\n", | ||
" def forward(self, x):\n", | ||
" out = self.model(x)\n", | ||
" return F.log_softmax(out, dim=1)\n", | ||
"\n", | ||
" def training_step(self, batch, batch_idx):\n", | ||
" x, y = batch\n", | ||
" logits = F.log_softmax(self.model(x), dim=1)\n", | ||
" loss = F.nll_loss(logits, y)\n", | ||
" self.log('train_loss', loss)\n", | ||
" return loss\n", | ||
"\n", | ||
" def evaluate(self, batch, stage=None):\n", | ||
" x, y = batch\n", | ||
" logits = self(x)\n", | ||
" loss = F.nll_loss(logits, y)\n", | ||
" preds = torch.argmax(logits, dim=1)\n", | ||
" acc = accuracy(preds, y)\n", | ||
"\n", | ||
" if stage:\n", | ||
" self.log(f'{stage}_loss', loss, prog_bar=True)\n", | ||
" self.log(f'{stage}_acc', acc, prog_bar=True)\n", | ||
"\n", | ||
" def validation_step(self, batch, batch_idx):\n", | ||
" self.evaluate(batch, 'val')\n", | ||
"\n", | ||
" def test_step(self, batch, batch_idx):\n", | ||
" self.evaluate(batch, 'test')\n", | ||
"\n", | ||
" def configure_optimizers(self):\n", | ||
" optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4)\n", | ||
" steps_per_epoch = 45000 // batch_size\n", | ||
" scheduler_dict = {\n", | ||
" 'scheduler': OneCycleLR(optimizer, 0.1, epochs=self.trainer.max_epochs, steps_per_epoch=steps_per_epoch),\n", | ||
" 'interval': 'step',\n", | ||
" }\n", | ||
" return {'optimizer': optimizer, 'lr_scheduler': scheduler_dict}" | ||
], | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "3FFPgpAFi9KU" | ||
}, | ||
"source": [ | ||
"model = LitResnet(lr=0.05)\n", | ||
"model.datamodule = cifar10_dm\n", | ||
"\n", | ||
"trainer = pl.Trainer(\n", | ||
" progress_bar_refresh_rate=20,\n", | ||
" max_epochs=40,\n", | ||
" gpus=1,\n", | ||
" logger=pl.loggers.TensorBoardLogger('lightning_logs/', name='resnet'),\n", | ||
" callbacks=[LearningRateMonitor(logging_interval='step')],\n", | ||
")\n", | ||
"\n", | ||
"trainer.fit(model, cifar10_dm)\n", | ||
"trainer.test(model, datamodule=cifar10_dm);" | ||
], | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "lWL_WpeVIXWQ" | ||
}, | ||
"source": [ | ||
"### Bonus: Use [Stochastic Weight Averaging](https://arxiv.org/abs/1803.05407) to get a boost on performance\n", | ||
"\n", | ||
"Use SWA from torch.optim to get a quick performance boost. Also shows a couple of cool features from Lightning:\n", | ||
"- Use `training_epoch_end` to run code after the end of every epoch\n", | ||
"- Use a pretrained model directly with this wrapper for SWA" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "bsSwqKv0t9uY" | ||
}, | ||
"source": [ | ||
"class SWAResnet(LitResnet):\n", | ||
" def __init__(self, trained_model, lr=0.01):\n", | ||
" super().__init__()\n", | ||
"\n", | ||
" self.save_hyperparameters('lr')\n", | ||
" self.model = trained_model\n", | ||
" self.swa_model = AveragedModel(self.model)\n", | ||
"\n", | ||
" def forward(self, x):\n", | ||
" out = self.swa_model(x)\n", | ||
" return F.log_softmax(out, dim=1)\n", | ||
"\n", | ||
" def training_epoch_end(self, training_step_outputs):\n", | ||
" self.swa_model.update_parameters(self.model)\n", | ||
"\n", | ||
" def validation_step(self, batch, batch_idx, stage=None):\n", | ||
" x, y = batch\n", | ||
" logits = F.log_softmax(self.model(x), dim=1)\n", | ||
" loss = F.nll_loss(logits, y)\n", | ||
" preds = torch.argmax(logits, dim=1)\n", | ||
" acc = accuracy(preds, y)\n", | ||
"\n", | ||
" self.log(f'val_loss', loss, prog_bar=True)\n", | ||
" self.log(f'val_acc', acc, prog_bar=True)\n", | ||
"\n", | ||
" def configure_optimizers(self):\n", | ||
" optimizer = torch.optim.SGD(self.model.parameters(), lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4)\n", | ||
" return optimizer\n", | ||
"\n", | ||
" def on_train_end(self):\n", | ||
" update_bn(self.datamodule.train_dataloader(), self.swa_model, device=self.device)" | ||
], | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "cA6ZG7C74rjL" | ||
}, | ||
"source": [ | ||
"swa_model = SWAResnet(model.model, lr=0.01)\n", | ||
"swa_model.datamodule = cifar10_dm\n", | ||
"\n", | ||
"swa_trainer = pl.Trainer(\n", | ||
" progress_bar_refresh_rate=20,\n", | ||
" max_epochs=20,\n", | ||
" gpus=1,\n", | ||
" logger=pl.loggers.TensorBoardLogger('lightning_logs/', name='swa_resnet'),\n", | ||
")\n", | ||
"\n", | ||
"swa_trainer.fit(swa_model, cifar10_dm)\n", | ||
"swa_trainer.test(swa_model, datamodule=cifar10_dm);" | ||
], | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "RRHMfGiDpZ2M" | ||
}, | ||
"source": [ | ||
"# Start tensorboard.\n", | ||
"%reload_ext tensorboard\n", | ||
"%tensorboard --logdir lightning_logs/" | ||
], | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "RltpFGS-s0M1" | ||
}, | ||
"source": [ | ||
"<code style=\"color:#792ee5;\">\n", | ||
" <h1> <strong> Congratulations - Time to Join the Community! </strong> </h1>\n", | ||
"</code>\n", | ||
"\n", | ||
"Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!\n", | ||
"\n", | ||
"### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n", | ||
"The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n", | ||
"\n", | ||
"* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)\n", | ||
"\n", | ||
"### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)!\n", | ||
"The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel\n", | ||
"\n", | ||
"### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", | ||
"Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.\n", | ||
"\n", | ||
"* Please, star [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", | ||
"\n", | ||
"### Contributions !\n", | ||
"The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts) GitHub Issues page and filter for \"good first issue\". \n", | ||
"\n", | ||
"* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", | ||
"* [Bolt good first issue](https://github.com/PyTorchLightning/pytorch-lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", | ||
"* You can also contribute your own notebooks with useful examples !\n", | ||
"\n", | ||
"### Great thanks from the entire Pytorch Lightning Team for your interest !\n", | ||
"\n", | ||
"<img src=\"https://github.com/PyTorchLightning/pytorch-lightning/blob/master/docs/source/_images/logos/lightning_logo-name.png?raw=true\" width=\"800\" height=\"200\" />" | ||
] | ||
} | ||
] | ||
} |
Oops, something went wrong.