Pytorch Lightning is a lightweight wrapper for organizing your PyTorch code and easily adding advanced features such as distributed training, 16-bit precision or gradient accumulation.

Coupled with the [Weights & Biases integration](https://docs.wandb.com/library/integrations/lightning), you can quickly train and monitor models for full traceability and reproducibility with only 2 extra lines of code:

In [1]:
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer

wandb_logger = WandbLogger()
trainer = Trainer(logger=wandb_logger)

[34m[1mwandb[0m: Currently logged in as: [33mchristopher-marais[0m. Use [1m`wandb login --relogin`[0m to force relogin


GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(


W&B integration with Pytorch-Lightning can automatically:
* log your configuration parameters
* log your losses and metrics
* log your model
* keep track of your code
* log your system metrics (GPU, CPU, memory, temperature, etc)

### 📚 Docs
You can find the PyTorch Lightning WandbLogger docs [here](https://pytorch-lightning.readthedocs.io/en/latest/extensions/generated/pytorch_lightning.loggers.WandbLogger.html?highlight=wandblogger) and the Weights & Biases docs [here](https://docs.wandb.com/library/integrations/lightning)

### 🛠️ Installation and set-up

In [2]:
# !pip install -q pytorch-lightning wandb

We make sure we're logged into W&B so that our experiments can be associated with our account.

In [3]:
import wandb
wandb.login()



True

## 📊 Setting up the dataloader

For the context of this tutorial we use vanilla pytorch dataloaders on the MNIST dataset

In [4]:
from torchvision.datasets import MNIST
from torchvision import transforms, models
from torch.utils.data import DataLoader, random_split

transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))])

dataset = MNIST(root="./MNIST", download=True, transform=transform)
training_set, validation_set = random_split(dataset, [55000, 5000])

training_loader = DataLoader(training_set, batch_size=64, shuffle=True)
validation_loader = DataLoader(validation_set, batch_size=64)

In [5]:
class MNISTDataModule(LightningDataModule):

    def __init__(self, data_dir='./', batch_size=256):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transforms.ToTensor()

    def prepare_data(self):
        '''called only once and on 1 GPU'''
        # download data
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        '''called on each GPU separately - stage defines if we are at fit or test step'''
        # we set up only relevant datasets when stage is specified (automatically set by Pytorch-Lightning)
        if stage == 'fit' or stage is None:
            mnist_train = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        '''returns training dataloader'''
        mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle = True)#, num_workers = cpu_count)
        return mnist_train

    def val_dataloader(self):
        '''returns validation dataloader'''
        mnist_val = DataLoader(self.mnist_val, batch_size=self.batch_size, shuffle = True)#, num_workers = cpu_count)
        return mnist_val

    def test_dataloader(self):
        '''returns test dataloader'''
        mnist_test = DataLoader(self.mnist_test, batch_size=self.batch_size, shuffle = True)#, num_workers = cpu_count)
        return mnist_test

NameError: name 'LightningDataModule' is not defined

In [None]:
# declaring the path of the train and test folders
train_path = "DATASET/TRAIN"
test_path = "DATASET/TEST"
classes_dir_data = os.listdir(train_path)
num_of_classes = len(classes_dir_data)
print("Total Number of Classes :" , num_of_classes)
num = 0
classes_dict = {}
classes_lst = []
num_dict = {}
for c in  classes_dir_data:
    classes_dict[c] = num
    num_dict[num] = c
    classes_lst.append(c)
    num = num +1
"""
num_dict contains a dictionary of the classes numerically and it's corresponding classes.
classes_dict contains a dictionary of the classes and the coresponding values numerically.
"""
num_of_classes = len(classes_dir_data)

classes_dict

In [None]:
#creating the dataset

#dataset

class Image_Dataset(Dataset):

    def __init__(self,classes,image_base_dir,transform = None, target_transform = None):

        """

        classes:The classes in the dataset

        image_base_dir:The directory of the folders containing the images

        transform:The trasformations for the Images

        Target_transform:The trasformations for the target

        """

        self.img_labels = classes

        self.imge_base_dir = image_base_dir

        self.transform = transform

        self.target_transform = target_transform

    def __len__(self):

        return len(self.img_labels)

    def __getitem__(self,idx):

        img_dir_list = os.listdir(os.path.join(self.imge_base_dir,self.img_labels[idx]))

        image_path = img_dir_list[randint(0,len(img_dir_list)-1)]

        #print(image_path)

        image_path = os.path.join(self.imge_base_dir,self.img_labels[idx],image_path)

        image = skimage.io.imread(image_path)

        if self.transform:

            image = self.transform(image)

        if self.transform:

            label = self.target_transform(self.img_labels[idx])

        return image,label

In [None]:
size = 50 # need to be the same as what is used in layer_5/ input layer ot the cnn

basic_transformations = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((size,size)),
        transforms.Grayscale(1),
    transforms.ToTensor()])
training_transformations = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((size,size)),
    transforms.RandomRotation(degrees = 45),
    transforms.RandomHorizontalFlip(p = 0.005),
        transforms.Grayscale(1),
    transforms.ToTensor()
])

def target_transformations(x):
    return torch.tensor(classes_dict.get(x))

## 🤓 Defining the Model

**Tips**:
* Call `self.save_hyperparameters()` in `__init__` to automatically log your hyperparameters to **W&B**
* Call self.log in `training_step` and `validation_step` to log the metrics

In [6]:
import torch
from torch.nn import Linear, CrossEntropyLoss, functional as F
from torch.optim import Adam
from torchmetrics.functional import accuracy
from pytorch_lightning import LightningModule

In [7]:
class MNIST_LitModule(LightningModule): 

    def __init__(self, n_classes=10, n_layer_1=128, n_layer_2=256, lr=1e-3): #############################################################################
        '''method used to define our model parameters'''
        super().__init__()

        # mnist images are (1, 28, 28) (channels, width, height)#############################################################################
        self.layer_1 = Linear(28 * 28, n_layer_1)#############################################################################
        self.layer_2 = Linear(n_layer_1, n_layer_2)#############################################################################
        self.layer_3 = Linear(n_layer_2, n_classes)#############################################################################

        # loss
        self.loss = CrossEntropyLoss()#############################################################################

        # optimizer parameters
        self.lr = lr#############################################################################

        # save hyper-parameters to self.hparams (auto-logged by W&B)
        self.save_hyperparameters() 

    def forward(self, x):
        '''method used for inference input -> output'''

        batch_size, channels, width, height = x.size()

        # (b, 1, 28, 28) -> (b, 1*28*28)
        x = x.view(batch_size, -1)

        # let's do 3 x (linear + relu)
        x = self.layer_1(x)
        x = F.relu(x)
        x = self.layer_2(x)
        x = F.relu(x)
        x = self.layer_3(x)

        return x

    def training_step(self, batch, batch_idx):
        '''needs to return a loss from a single batch'''
        _, loss, acc = self._get_preds_loss_accuracy(batch)

        # Log loss and metric
        self.log('train_loss', loss)
        self.log('train_accuracy', acc)

        return loss

    def validation_step(self, batch, batch_idx):
        '''used for logging metrics'''
        preds, loss, acc = self._get_preds_loss_accuracy(batch)

        # Log loss and metric
        self.log('val_loss', loss)
        self.log('val_accuracy', acc)

        # Let's return preds to use it in a custom callback
        return preds

    def test_step(self, batch, batch_idx):
        '''used for logging metrics'''
        _, loss, acc = self._get_preds_loss_accuracy(batch)

        # Log loss and metric
        self.log('test_loss', loss)
        self.log('test_accuracy', acc)
    
    def configure_optimizers(self):
        '''defines model optimizer'''
        return Adam(self.parameters(), lr=self.lr)
    
    def _get_preds_loss_accuracy(self, batch):
        '''convenience function since train/valid/test steps are similar'''
        x, y = batch
        logits = self(x)
        preds = torch.argmax(logits, dim=1)
        loss = self.loss(logits, y)
        acc = accuracy(preds, y, 'multiclass', num_classes=10)
        return preds, loss, acc

The model is now ready!

In [8]:
model = MNIST_LitModule(n_layer_1=128, n_layer_2=128)

## 💾 Save Model Checkpoints

The `ModelCheckpoint` callback is required along with the `WandbLogger` argument to log model checkpoints to W&B.

In [9]:
from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(monitor='val_accuracy', mode='max')

## 💡 Tracking Experiments with WandbLogger

PyTorch Lightning has a `WandbLogger` to easily log your experiments with Wights & Biases. Just pass it to your `Trainer` to log to W&B. See the [WandbLogger docs](https://pytorch-lightning.readthedocs.io/en/stable/extensions/generated/pytorch_lightning.loggers.WandbLogger.html#pytorch_lightning.loggers.WandbLogger) for all parameters. Note, to log the metrics to a specific W&B Team, pass your Team name to the `entity` argument in `WandbLogger`

#### `pytorch_lightning.loggers.WandbLogger()`

| Functionality | Argument/Function | PS |
| ------ | ------ | ------ |
| Logging models | `WandbLogger(... ,log_model='all')` or `WandbLogger(... ,log_model=True`) | Log all models if `log_model="all"` and at end of training if `log_model=True`
| Set custom run names | `WandbLogger(... ,name='my_run_name'`) | |
| Organize runs by project | `WandbLogger(... ,project='my_project')` | |
| Log histograms of gradients and parameters | `WandbLogger.watch(model)`  | `WandbLogger.watch(model, log='all')` to log parameter histograms  |
| Log hyperparameters | Call `self.save_hyperparameters()` within `LightningModule.__init__()` |
| Log custom objects (images, audio, video, molecules…) | Use `WandbLogger.log_text`, `WandbLogger.log_image` and `WandbLogger.log_table` |

See the [WandbLogger docs](https://pytorch-lightning.readthedocs.io/en/stable/extensions/generated/pytorch_lightning.loggers.WandbLogger.html#pytorch_lightning.loggers.WandbLogger) here for all parameters. 

In [10]:
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer

wandb_logger = WandbLogger(project='MNIST', # group runs in "MNIST" project
                           log_model='all') # log all new checkpoints during training

  rank_zero_warn(


## ⚙️ Using WandbLogger to log Images, Text and More
Pytorch Lightning is extensible through its callback system. We can create a custom callback to automatically log sample predictions during validation. `WandbLogger` provides convenient media logging functions:
* `WandbLogger.log_text` for text data
* `WandbLogger.log_image` for images
* `WandbLogger.log_table` for [W&B Tables](https://docs.wandb.ai/guides/data-vis).

An alternate to `self.log` in the Model class is directly using `wandb.log({dict})` or `trainer.logger.experiment.log({dict})`

In this case we log the first 20 images in the first batch of the validation dataset along with the predicted and ground truth labels.

In [11]:
from pytorch_lightning.callbacks import Callback
 
class LogPredictionsCallback(Callback):
    
    def on_validation_batch_end(
        self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        """Called when the validation batch ends."""
 
        # `outputs` comes from `LightningModule.validation_step`
        # which corresponds to our model predictions in this case
        
        # Let's log 20 sample image predictions from first batch
        if batch_idx == 0:
            n = 20
            x, y = batch
            images = [img for img in x[:n]]
            captions = [f'Ground Truth: {y_i} - Prediction: {y_pred}' for y_i, y_pred in zip(y[:n], outputs[:n])]
            
            # Option 1: log images with `WandbLogger.log_image`
            wandb_logger.log_image(key='sample_images', images=images, caption=captions)

            # Option 2: log predictions as a Table
            columns = ['image', 'ground truth', 'prediction']
            data = [[wandb.Image(x_i), y_i, y_pred] for x_i, y_i, y_pred in list(zip(x[:n], y[:n], outputs[:n]))]
            wandb_logger.log_table(key='sample_table', columns=columns, data=data)

log_predictions_callback = LogPredictionsCallback()

## 🏋️‍ Train Your Model

In [12]:
trainer = Trainer(
    logger=wandb_logger,                    # W&B integration
    callbacks=[log_predictions_callback,    # logging of sample predictions
               checkpoint_callback],        # our model checkpoint callback
    accelerator="gpu",                      # use GPU
    max_epochs=5)                           # number of epochs

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [13]:
trainer.fit(model, training_loader, validation_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type             | Params
---------------------------------------------
0 | layer_1 | Linear           | 100 K 
1 | layer_2 | Linear           | 16.5 K
2 | layer_3 | Linear           | 1.3 K 
3 | loss    | CrossEntropyLoss | 0     
---------------------------------------------
118 K     Trainable params
0         Non-trainable params
118 K     Total params
0.473     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


When we want to close our W&B run, we call `wandb.finish()` (mainly useful in notebooks, called automatically in scripts).

In [None]:
wandb.finish()

## 📚 Resources

* [Pytorch Lightning and W&B integration documentation](https://docs.wandb.ai/integrations/lightning) contains a few tips for taking most advantage of W&B
* [Pytorch Lightning documentation](https://pytorch-lightning.readthedocs.io/en/stable/common/loggers.html#weights-and-biases) is extremely thorough and full of examples