# Pytorch Lightning Training Pipeline

![lightning-logo](https://upload.wikimedia.org/wikipedia/commons/e/e6/Lightning_Logo_v2.png)

In this section, we will create a deep learning model using [PyTorch Lightning](https://www.pytorchlightning.ai/index.html), which is a tool that helps us train neural networks more easily. It is a lightweight and flexible Python library that simplifies the process of training and organizing deep learning models using PyTorch. It provides a high-level interface and a set of abstractions that make it easier to structure, debug, and scale complex deep learning projects.

For this project, we will develop the popular [ResNet](https://arxiv.org/abs/1512.03385) variants: ResNet-18 and ResNet-50. Our goal is to compare these models and determine which one is most effective for solving the challenge at hand. We learned about the [Food101-tiny](https://www.kaggle.com/datasets/msarmi9/food101tiny) dataset in the previous section, and now we will develop a simple dataset parser that can read and apply enhancements to the data.

Configurations:
- Model: ResNet-18, ResNet-50
- Dataset: Food101-tiny
- Input Size: 384

In [1]:
!pip install -q lightning

You should consider upgrading via the '/usr/bin/python -m pip install --upgrade pip' command.[0m


In [2]:
import os

ROOT_DIR = os.path.dirname(os.path.abspath(''))
DATA_DIR = os.path.join(ROOT_DIR, 'data/food-101-tiny')

TRAIN_DATA_PATH = os.path.join(DATA_DIR, 'train')
VAL_DATA_PATH = os.path.join(DATA_DIR, 'valid')

## Dataset Pipeline

To build the dataset pipeline, we must understand about the dataset itself. Every researchers, developers, or organizations have different style of how to store and read the data. It is important to understand the datatypes, data structures, and storing methods.

In this example, the data has been organized such that the images are inside a subfolder that represent the class names. For each splits, there will be 10 subfolders that containes the images for the classes.

```
data/
----food-101-tiny/
    ----train/
        ----apple_pie/
            bibimbap/
            cannoli/
            ...
            ...
            tiramisu/
    ----valid/
```

There are 2 ways to load the dataset, the easy way is to use `torchvision.datasets.ImageFolder` or the hard way to search for the images and read it manually or simply just reimplement the `ImageFolder`

### Augmentation Policy

For training, we're using the following augmentation steps:
- random rotation
- random flipping
- center crop
- normalize the data

For validation, we only need to resize and normalize the dataset.
It's important to only apply the augmentation policy to the training split and just resizing operation for the validation split.

In [29]:
from torchvision import transforms

RGB_MEAN = [0.51442681, 0.43435301, 0.33421855]
RGB_STD = [0.24099932, 0.246478, 0.23652802]
INPUT_SIZE = (384, 384)

TRAIN_TRANSFORMATION = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomRotation(45),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.CenterCrop(INPUT_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=RGB_MEAN, std=RGB_STD)
])

VAL_TRANSFORMATION = transforms.Compose([
    transforms.ToPILImage(),
    transforms.CenterCrop(INPUT_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=RGB_MEAN, std=RGB_STD)
])

### Dataset Parser

Build a custom dataset parser to read a structured folder from before. This function will read the images from the subsequent folders from the *root_dir*.

In [19]:
import glob
import os
from typing import Optional, Tuple

import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
import torchvision

class ImageFolderDataset(Dataset):
    """Strutured Image Folder Dataset

    Parameters
    ----------
    root_dir : str
        Path to image root directory
    transform : Optional[torchvision.transforms.Compose], optional
        Data augmentation pipeline, by default None
    """

    def __init__(self, root_dir: str, stage: Optional[str] = None, transform: Optional[torchvision.transforms.Compose] = None):
        if stage is not None:
            self.root_dir = os.path.join(root_dir, stage)
        else:
            self.root_dir = root_dir
        if not os.path.exists(root_dir):
            raise RuntimeError(f'Path to dataset is not valid')
        self.labels_name = os.listdir(self.root_dir)
        self.labels_name.sort()
        self.list_images =  glob.glob(f'{self.root_dir}/**/*.jpg')
        self.transform = transform

    def __len__(self) -> int:
        """Get number of images."""
        return len(self.list_images)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get sample for the current idx."""
        img_path = self.list_images[idx]

        # NOTE: cv2 read image as BGR.
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Convert class name to label 0-10
        # NOTE: root_path/class_name/image_file.jpg
        class_name = img_path.split('/')[-2]
        label_index = self.labels_name.index(class_name)
        label_index = [label_index]

        # Apply transformations
        if self.transform:
            image = self.transform(image)

        # If no transformations, convert to C,H,W
        # Normalize to [0,1]
        if isinstance(image, np.ndarray):
            image = np.transpose(image, (1, 2, 0))
            image = torch.from_numpy()
            image /= 255.0

        return image, torch.tensor(label_index, dtype=torch.int64)

#### Test and Validate

To ensure the function works properly, we must test and validate the parser will return some expected values.

Note: **Don't forget to cleanup the variables** since it will consume the memory.

In [20]:
train_dataset = ImageFolderDataset(root_dir=DATA_DIR, stage = 'train', transform=TRAIN_TRANSFORMATION)
image, label = train_dataset[10]

assert len(train_dataset) == 1500
assert isinstance(image, torch.Tensor)
assert isinstance(label, torch.Tensor)

val_dataset = ImageFolderDataset(root_dir=VAL_DATA_PATH, transform=VAL_TRANSFORMATION)
image, label = train_dataset[10]

assert len(val_dataset) == 500
assert isinstance(image, torch.Tensor)
assert isinstance(label, torch.Tensor)

# Cleanup to reduce memory
train_dataset = None
val_dataset = None
image, label = None, None

### LightningDataModule Pipeline

To use the dataset with PyTorch Lightning, we have to build the custom datamodule that is required by the pipeline to load and parse the data using [LightningDataModule](https://lightning.ai/docs/pytorch/stable/data/datamodule.html)

In [30]:
from typing import List, Optional

from lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import transforms


class Food101LitDatamodule(LightningDataModule):
    """LightningDataModule for Food101 Data Pipeline.

    Read the docs:
        https://lightning.ai/docs/pytorch/latest/data/datamodule.html

    Parameters
    ----------
    data_dir : str, optional
        FiftyOne dataset directory, by default 'data/'
    input_size : List[int], optional
        Input model size, by default [600, 500]
    batch_size : int, optional
        Number of training batch size, by default 64
    num_workers : int, optional
        Number of worksers to process data, by default 0
    pin_memory : bool, optional
        Enable memory pinning, by default False
    """

    def __init__(
        self,
        data_dir: str = 'data/',
        input_size: Tuple[int, int] = (600, 500),
        batch_size: int = 64,
        num_workers: int = 0,
        pin_memory: bool = False,
    ):
        super().__init__()

        # this line allows to access init params with 'self.hparams' attribute
        # also ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False)

        self.data_train: Optional[Dataset] = None
        self.data_val: Optional[Dataset] = None
        self.data_test: Optional[Dataset] = None

    @property
    def num_classes(self):
        """Get number of classes."""
        return 10

    def setup(self, stage: Optional[str] = None):
        """Load the data with specified stage."""
        if stage in ['train', 'fit', None] and self.data_train is None:
            self.data_train = ImageFolderDataset(
                root_dir=self.hparams.data_dir, stage='train', transform=TRAIN_TRANSFORMATION)
            if len(self.data_train) == 0:
                raise ValueError('Train dataset is empty.')
        if stage in ['validation', 'test', 'fit', None]:
            if self.data_val is None:
                self.data_val = ImageFolderDataset(
                    root_dir=self.hparams.data_dir, stage='valid', transform=VAL_TRANSFORMATION)
                if len(self.data_val) == 0:
                    raise ValueError('Validation dataset is empty.')
            if self.data_test is None:
                self.data_test = ImageFolderDataset(
                    root_dir=self.hparams.data_dir, stage='valid', transform=VAL_TRANSFORMATION)
                if len(self.data_test) == 0:
                    raise ValueError('Test dataset is empty.')
        if stage == 'predict':
            if self.data_test is None:
                self.data_predict = ImageFolderDataset(
                    root_dir=self.hparams.data_dir, transform=VAL_TRANSFORMATION)
                if len(self.data_predict) == 0:
                    raise ValueError('Predict dataset is empty.')

    def train_dataloader(self):
        """Get train dataloader."""
        return DataLoader(
            dataset=self.data_train,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            shuffle=True,
        )

    def val_dataloader(self):
        """Get validation dataloader."""
        return DataLoader(
            dataset=self.data_val,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            shuffle=False,
        )

    def test_dataloader(self):
        """Get test dataloader."""
        return DataLoader(
            dataset=self.data_test,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            shuffle=False,
        )

#### Test and Validate

To ensure the function works properly, we must test and validate the parser will return some expected values.

Note: **Don't forget to cleanup the variables** since it will consume the memory.

In [22]:
dm = Food101LitDatamodule(data_dir=DATA_DIR, batch_size=4)

assert not dm.data_train and not dm.data_val and not dm.data_test

dm.setup()
assert dm.data_train and dm.data_val and dm.data_test
train_dataloader = dm.train_dataloader()
val_dataloader = dm.val_dataloader()
test_dataloader = dm.test_dataloader()
assert train_dataloader and val_dataloader and test_dataloader

assert len(dm.data_train) + len(dm.data_val) + len(dm.data_test) == 2500

batch = next(iter(train_dataloader))
x, y = batch
assert len(x) == 4
assert len(y) == 4
assert x.dtype == torch.float32
assert y.dtype == torch.int64

# Cleanup
dm = None
batch = None
train_dataloader, val_dataloader, test_dataloader = None, None, None

## Model

The most important part of any ResNet architecture is its basic block. It contains a stacking of a few convolutional, batch normalization, and ReLU activation layers which are common for all the ResNet models.

![resnet-model](https://debuggercafe.com/wp-content/uploads/2022/08/resnet18-basic-blocks-1.png)

Define a basic ResNet18 Building Block. Each block consists of two convolutional layers and supports skip connections.

![resnet-block](https://neurohive.io/wp-content/uploads/2019/01/resnet-e1548261477164.png)

In [23]:
from typing import Optional

import torch
import torch.nn as nn

class BasicBlock(nn.Module):
    """ResNet Basic Block.

    Parameters
    ----------
    in_channels : int
        Number of input channels
    out_channels : int
        Number of output channels
    stride : int, optional
        Convolution stride size, by default 1
    identity_downsample : Optional[torch.nn.Module], optional
        Downsampling layer, by default None
    """

    def __init__(self,
                in_channels: int,
                out_channels: int,
                stride: int = 1,
                identity_downsample: Optional[torch.nn.Module] = None):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels,
                              out_channels, 
                              kernel_size = 3,
                              stride = stride,
                              padding = 1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(out_channels,
                              out_channels,
                              kernel_size = 3,
                              stride = 1,
                              padding = 1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.identity_downsample = identity_downsample
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply forward computation."""
        identity = x
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)

        # Apply an operation to the identity output.
        # Useful to reduce the layer size and match from conv2 output
        if self.identity_downsample is not None:
            identity = self.identity_downsample(identity)
        x += identity
        x = self.relu(x)
        return x

In [24]:
class ResNet18(nn.Module):
    """Construct ResNet-18 Model.

    Parameters
    ----------
    input_channels : int
        Number of input channels
    num_classes : int
        Number of class outputs
    """

    def __init__(self, input_channels, num_classes):
        
        super(ResNet18, self).__init__()
        self.conv1 = nn.Conv2d(input_channels,
                               64, kernel_size = 7,
                              stride = 2, padding=3)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size = 3,
                                   stride = 2,
                                   padding = 1)
        
        self.layer1 = self._make_layer(64, 64, stride = 1)
        self.layer2 = self._make_layer(64, 128, stride = 2)
        self.layer3 = self._make_layer(128, 256, stride = 2)
        self.layer4 = self._make_layer(256, 512, stride = 2)
        
        # Last layers
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def identity_downsample(self, in_channels: int, out_channels: int) -> nn.Module:
        """Downsampling block to reduce the feature sizes."""
        return nn.Sequential(
             nn.Conv2d(in_channels, 
                       out_channels, 
                       kernel_size = 3, 
                       stride = 2,
                       padding = 1),
            nn.BatchNorm2d(out_channels)
        )

    def _make_layer(self, in_channels: int, out_channels: int, stride: int) -> nn.Module:
        """Create sequential basic block."""
        identity_downsample = None

        # Add downsampling function
        if stride != 1:
            identity_downsample = self.identity_downsample(in_channels, out_channels)
            
        return nn.Sequential(
                    BasicBlock(in_channels, out_channels, identity_downsample=identity_downsample, stride=stride),
                    BasicBlock(out_channels, out_channels)
                    )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        return x


#### Test and Validate

To ensure the function works properly, we must test and validate the parser will return some expected values.

Note: **Don't forget to cleanup the variables** since it will consume the memory.

In [25]:
model = ResNet18(3, 10)

input = torch.rand(1, 3, 256, 256)
output = model(input)

assert output.shape == torch.Size([1, 10])

## Training Pipeline

In PyTorch Lightning, **LightningModule** is a key component that serves as the core building block for defining and organizing deep learning models. It is an abstract class provided by the PyTorch Lightning library that extends PyTorch's nn.Module.

A LightningModule **encapsulates all the necessary components of a deep learning model**, including the model architecture, forward pass logic, loss functions, and optimization methods. It provides a standardized interface and a set of predefined hooks to handle various aspects of the training process, such as data loading, training/validation loops, and testing.

For more details, read the full documentation [LightningDataModule](https://lightning.ai/docs/pytorch/stable/data/datamodule.html)

In [33]:
from typing import Any, Optional

import torch
import torch.nn.functional as F
import torchmetrics as tm
from lightning import LightningModule


class ClassificationLightningModule(LightningModule):
    """Model training pipeline for Food101 classification.

    Parameters
    ----------
    net : torch.nn.Module
        The model module or configuration
    num_classes : int, optional
        Number of output classes, by default 10
    lr : float, optional
        Optimizer learning rate, by default 0.00001
    """

    def __init__(
        self,
        net: torch.nn.Module,
        num_classes: int = 10,
        lr: float = 0.00001
    ):
        super().__init__()

        # this line allows to access init params with 'self.hparams' attribute
        # also ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False)

        self.net = net

        # loss function
        self.criterion = torch.nn.CrossEntropyLoss()

        # metric objects for calculating and averaging accuracy across batches
        self.train_acc = tm.Accuracy(task='multiclass', num_classes=num_classes)

        self.val_metrics = tm.MetricCollection({
            'acc': tm.Accuracy(task='multiclass', num_classes=num_classes),
            'prec': tm.Precision(task='multiclass', num_classes=num_classes),
            'rec': tm.Recall(task='multiclass', num_classes=num_classes),
            'auroc': tm.AUROC(task='multiclass', num_classes=num_classes),
            'f1': tm.F1Score(task='multiclass', num_classes=num_classes),
        })  # type: ignore

        self.test_metrics = self.val_metrics.clone()

    def reset_metrics(self):
        self.train_acc.reset()
        self.val_metrics.reset()

    def forward(self, x: torch.Tensor):
        return self.net(x)

    def on_train_start(self):
        # by default lightning executes validation step sanity checks before training starts,
        # so it's worth to make sure validation metrics don't store results from these checks
        self.reset_metrics()

    def model_step(self, batch: Any):
        images, targets = batch
        targets = targets.squeeze().long() # convert to 1D
        logits = self.forward(images)
        loss = self.criterion(logits, targets)
        preds = F.softmax(logits, dim=1)
        return loss, preds, targets

    def training_step(self, batch: Any, batch_idx: int):
        loss, preds, targets = self.model_step(batch)

        # update and log metrics
        acc = self.train_acc(preds, targets)
        self.log('train/loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('train/acc', acc, on_step=False, on_epoch=True, prog_bar=True)

        # return loss or backpropagation will fail
        return loss

    def on_train_epoch_start(self):
        self.reset_metrics()

    def validation_step(self, batch: Any, batch_idx: int) -> None:
        loss, preds, targets = self.model_step(batch)

        # update and log metrics
        metrics = self.val_metrics(preds, targets)
        self.log('val/loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('val/acc', metrics['acc'], on_step=False, on_epoch=True, prog_bar=True)
        self.log('val/prec', metrics['prec'], on_step=False, on_epoch=True, prog_bar=False)
        self.log('val/rec', metrics['rec'], on_step=False, on_epoch=True, prog_bar=False)
        self.log('val/auroc', metrics['auroc'], on_step=False, on_epoch=True, prog_bar=False)
        self.log('val/f1', metrics['f1'], on_step=False, on_epoch=True, prog_bar=True)

    def test_step(self, batch: Any, batch_idx: int):
        loss, preds, targets = self.model_step(batch)

        # update and log metrics
        metrics = self.test_metrics(preds, targets)
        self.log('test/loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('test/acc', metrics['acc'], on_step=False, on_epoch=True, prog_bar=True)
        self.log('test/prec', metrics['prec'], on_step=False, on_epoch=True, prog_bar=False)
        self.log('test/rec', metrics['rec'], on_step=False, on_epoch=True, prog_bar=False)
        self.log('test/auroc', metrics['auroc'], on_step=False, on_epoch=True, prog_bar=False)
        self.log('test/f1', metrics['f1'], on_step=False, on_epoch=True, prog_bar=True)

    def configure_optimizers(self):
        """Configure the optimizer and scheduler to use.

        Examples:
            https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
        """
        optimizer = torch.optim.Adam(
            params=self.parameters(),
            lr=self.hparams.lr,
        )

        return optimizer

#### Test and Validate

To ensure the function works properly, we must test and validate the parser will return some expected values.

Note: **Don't forget to cleanup the variables** since it will consume the memory.

In [34]:
from lightning.pytorch.trainer import Trainer

# Construct the model
model = ResNet18(3, 10)

# Construct training pipeline
lit_model = ClassificationLightningModule(
    net = model,
    num_classes = 10,
    lr = 0.001,
)

# Construct the datamodule
datamodule = Food101LitDatamodule(
    data_dir=DATA_DIR,
    batch_size=4,
    num_workers=4
)

training_config = {
    "fast_dev_run": True,
    "max_epochs": 1,
}

trainer = Trainer(**training_config)
trainer.fit(model=lit_model, datamodule=datamodule)

trainer = None
datamodule = None
lit_model = None
model = None

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type               | Params
----------------------------------------------------
0 | net          | ResNet18           | 12.6 M
1 | criterion    | CrossEntropyLoss   | 0     
2 | train_acc    | MulticlassAccuracy | 0     
3 | val_metrics  | MetricCollection   | 0     
4 | test_metrics | MetricCollection   | 0     
----------------------------------------------------
12.6 M    Trainable params
0         Non-trainable params
12.6 M    Total params
50.251    Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_steps=1` reached.


## Training Process

The **Trainer** function in PyTorch Lightning is a key component that provides a high-level interface for training and managing the training process of deep learning models. It acts as a central orchestrator, handling crucial aspects such as training loops, validation loops, logging, checkpointing, and distributed training.

Additionally, the **Trainer** function integrates with other PyTorch Lightning components, such as LightningModule and DataLoader, streamlining the training pipeline. It handles the training and validation loops, automatically applies the specified callbacks, and provides options for logging and visualization of training progress and metrics.

Please read the full [documentation](https://lightning.ai/docs/pytorch/stable/common/trainer.html) to understand the full functionality.

In [35]:
from lightning.pytorch.trainer import Trainer

# Global Variables
NUM_CLASSES = 10
LEARNING_RATE = 0.0001
BATCH_SIZE = 16

# Construct the model
model = ResNet18(3, NUM_CLASSES)

# Construct training pipeline
lit_model = ClassificationLightningModule(
    net = model,
    num_classes = NUM_CLASSES,
    lr = LEARNING_RATE,
)

# Construct the datamodule
datamodule = Food101LitDatamodule(
    data_dir=DATA_DIR,
    batch_size=BATCH_SIZE,
    num_workers=4
)

training_config = {
    "accelerator": 'auto',
    "devices": 'auto',
    "precision": 32,
    "max_epochs": 100,
}

trainer = Trainer(**training_config)
trainer.fit(model=lit_model, datamodule=datamodule)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type               | Params
----------------------------------------------------
0 | net          | ResNet18           | 12.6 M
1 | criterion    | CrossEntropyLoss   | 0     
2 | train_acc    | MulticlassAccuracy | 0     
3 | val_metrics  | MetricCollection   | 0     
4 | test_metrics | MetricCollection   | 0     
----------------------------------------------------
12.6 M    Trainable params
0         Non-trainable params
12.6 M    Total params
50.251    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


## Testing

Note that this is just an example, you have to improve the performance by applying different augmentations, models, hyperparameter tuning, etc.

In [36]:
# disable randomness, dropout, etc...
lit_model.eval()

trainer.test(model=lit_model, datamodule=datamodule)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

[{'test/loss': 1.44989013671875,
  'test/acc': 0.6100000143051147,
  'test/prec': 0.6100000143051147,
  'test/rec': 0.6100000143051147,
  'test/auroc': 0.04938095062971115,
  'test/f1': 0.6100000143051147}]

This notebook shows you a simple use of PyTorch Lightning to develop classification models. In the next section we will integrate a helpful functionalites to our pipeline in order to help us understand better about the training process and the results using an experiment tracking.

PyTorch Lightning has a lot of practical [hands-on](https://lightning.ai/docs/pytorch/stable/tutorials.html) to develop deep learning models.