There are a variety of free and paid resources available for interactively tracking training performance of deep learning models. Here we'll use [TensorBoard](https://github.com/tensorflow/tensorboard) which is free and open-source. Another popular option with a free tier is [Weights & Biases](https://wandb.ai/site), which has some additional features and integrations. 

We can actually visualise TensorBoard within the notebook with some cell magic:

In [11]:
import lightning as L
import torch
from torch import nn
from torch_geometric import nn as graph_nn
from src import dataloader
from lightning.pytorch import callbacks
from lightning.pytorch.loggers import TensorBoardLogger

In [5]:
from sklearn import metrics

In [1]:
%load_ext tensorboard

Some more cell magic runs TensorBoard, tells Jupyter to run the server in the background and display it in the notebook, by looking for files in the `lightning_logs` directory. As we run training, log files will be written to this directory by default and the TensorBoard display will update automatically (refreshes every 30 seconds).

In [3]:
%tensorboard --logdir lightning_logs/ --bind_all

Let's set up the LightningModule again and this time also log the loss value on our validation set. Note that the `validation_step` function automatically runs with gradients disabled, meaning that weights are frozen and the model is in evaluation mode, so we don't need to worry about that.

In [2]:


class GATModule(L.LightningModule):
    """
    LightningModule wrapping a GAT model.
    """
    def __init__(self):
        super().__init__()
        self.model = graph_nn.GAT(in_channels=20,
                         hidden_channels=32,
                         num_layers=2,
                         heads=2,
                         out_channels=1,
                         dropout=0.01,
                         jk="last", v2=True)
        self.loss_function = nn.BCEWithLogitsLoss()

    def forward(self, node_attributes, edge_index):
        return self.model(node_attributes, edge_index)

    def training_step(self, batch, batch_idx):
        out = self(batch.amino_acid_one_hot.float(), batch.edge_index)
        loss = self.loss_function(out, batch.y.view(-1, 1))
        self.log('train_loss', loss, on_step=True, on_epoch=True, sync_dist=True,
                 batch_size=batch.batch_size)
        return loss
    
    def validation_step(self, batch, batch_idx):
        out = self(batch.amino_acid_one_hot.float(), batch.edge_index)
        loss = self.loss_function(out, batch.y.view(-1, 1))
        self.log('val_loss', loss, on_step=True, on_epoch=True, sync_dist=True,
                 batch_size=batch.batch_size)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(params=self.model.parameters(), lr=0.001, weight_decay=0.0001)

We add the TensorBoard logger to our Trainer and train the model as before

In [2]:
model = GATModule()
datamodule = dataloader.ProteinGraphDataModule("./test_data", "dataset.txt")
logger = TensorBoardLogger("lightning_logs", name="gat")

trainer = L.Trainer(logger=logger, log_every_n_steps=1, max_epochs=10, accelerator='cpu')
trainer.fit(model=model, datamodule=datamodule)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/jay/mambaforge/envs/geometric-learning/lib/python3.8/site-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
Missing logger folder: lightning_logs/gat
Processing...
Done!
Processing...
Done!

  | Name          | Type              | Params
----------------------------------------------------
0 | model         | GAT               | 3.6 K 
1 | loss_function | BCEWithLogitsLoss | 0     
----------------------------------------------------
3.6 K     Trainable params
0         Non-trainable params
3.6 K     Total params
0.014     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/jay/mambaforge/envs/geometric-learning/lib/python3.8/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:436: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.
/Users/jay/mambaforge/envs/geometric-learning/lib/python3.8/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:436: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


Every new run you train will be added to the logger so you can keep track of the improvements made over time.

## Adding callbacks

PyTorch Lightning has a number of [built-in callbacks](https://lightning.ai/docs/pytorch/stable/extensions/callbacks.html) which are used to perform actions at various points during training. For example, the `ModelCheckpoint` callback saves the model after each epoch, and the `EarlyStopping` callback stops training if the validation loss has not improved for a certain number of epochs.

In [3]:
model = GATModule()
datamodule = dataloader.ProteinGraphDataModule("./test_data", "dataset.txt")
trainer = L.Trainer(
    max_epochs=20,
    logger=logger,
    log_every_n_steps=1,
    callbacks=[callbacks.EarlyStopping(monitor="val_loss", patience=2), # stop training if validation loss does not improve for 2 epochs
               callbacks.ModelCheckpoint(monitor="val_loss", save_top_k=1)], # save the best model based on validation loss
    accelerator="cpu",

)
trainer.fit(model=model, datamodule=datamodule)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/jay/mambaforge/envs/geometric-learning/lib/python3.8/site-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
Processing...
Done!
Processing...
Done!
/Users/jay/mambaforge/envs/geometric-learning/lib/python3.8/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:652: Checkpoint directory lightning_logs/gat/version_0/checkpoints exists and is not empty.

  | Name          | Type              | Params
----------------------------------------------------
0 | model         | GAT               | 3.6 K 
1 | loss_function | BCEWithLogitsLoss | 0     
----------------------------------------------------
3.6 K     Trainable params
0         Non-trainable params
3.6 K     Total params
0.014     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/jay/mambaforge/envs/geometric-learning/lib/python3.8/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:436: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.
/Users/jay/mambaforge/envs/geometric-learning/lib/python3.8/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:436: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x11fd4dd30>
Traceback (most recent call last):
  File "/Users/jay/mambaforge/envs/geometric-learning/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/Users/jay/mambaforge/envs/geometric-learning/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1442, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/Users/jay/mambaforge/envs/geometric-learning/lib/python3.8/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "/Users/jay/mambaforge/envs/geometric-learning/lib/python3.8/multiprocessing/popen_fork.py", line 44, in wait
    if not wait([self.sentinel], timeout):
  File "/Users/jay/mambaforge/envs/geometric-learning/lib/python3.8/multiprocessing/connection.py", line 931, in wait
    ready = selector.select(timeout)
  File "/Users/jay/mambaforge/envs/geometric-le

Validation: |          | 0/? [00:00<?, ?it/s]

We can also add custom callbacks to log and track things that we are interested in, such as the precision-recall curve or the ROC-AUC curve. Here's a function to calculate such curves and return the image

In [6]:
def get_metrics_and_curves(metric_type, y_pred, y_true, threshold=0.5):
    """
    Calculate metrics and curves for a given metric type
    ROC: Receiver Operating Characteristic curve, metric = Area under the curve
    PR: Precision-Recall curve, metric = Area under the curve (Average precision)
    CM: Confusion Matrix, metric = F1 score

    Parameters
    ----------
    metric_type : str
        One of "ROC", "PR"
    y_pred : torch.Tensor
        Predicted labels
    y_true : torch.Tensor
        True labels

    Returns
    -------
    metric_value : float
        Value of the metric
    metric_disp : matplotlib.figure.Figure
        Figure of the curve/matrix
    """
    y_true = y_true.cpu().detach().numpy()
    y_pred = y_pred.cpu().detach().numpy()
    if metric_type == "ROC":
        # Receiver Operating Characteristic Curve
        fpr, tpr, _ = metrics.roc_curve(y_true, y_pred, pos_label=1)
        roc_auc = metrics.auc(fpr, tpr)
        roc_disp = metrics.RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc).plot()
        return roc_auc, roc_disp.figure_
    elif metric_type == "PR":
        # Precision-Recall Curve
        precision, recall, _ = metrics.precision_recall_curve(y_true, y_pred, pos_label=1)
        pr_auc = metrics.auc(recall, precision)
        pr_disp = metrics.PrecisionRecallDisplay(precision=precision, recall=recall, average_precision=pr_auc).plot()
        return pr_auc, pr_disp.figure_

To add these to the logger, we need to save the predictions and targets for each batch during training and validation:

In [7]:
class GATModule(L.LightningModule):
    """
    LightningModule wrapping a GAT model.
    """
    def __init__(self):
        super().__init__()
        self.model = graph_nn.GAT(in_channels=20,
                         hidden_channels=32,
                         num_layers=2,
                         heads=2,
                         out_channels=1,
                         dropout=0.01,
                         jk="last", v2=True)
        self.loss_function = nn.BCEWithLogitsLoss()
        self.train_step_outputs = []
        self.validation_step_outputs = []

    def forward(self, node_attributes, edge_index):
        return self.model(node_attributes, edge_index)

    def training_step(self, batch, batch_idx):
        out = self(batch.amino_acid_one_hot.float(), batch.edge_index)
        loss = self.loss_function(out, batch.y.view(-1, 1))
        self.log('train_loss', loss, on_step=True, on_epoch=True, sync_dist=True,
                 batch_size=batch.batch_size)
        self.train_step_outputs.append((out.detach().cpu(), batch.y)) # SAVE OUTPUTS
        return loss
    
    def validation_step(self, batch, batch_idx):
        out = self(batch.amino_acid_one_hot.float(), batch.edge_index)
        loss = self.loss_function(out, batch.y.view(-1, 1))
        self.log('val_loss', loss, on_step=True, on_epoch=True, sync_dist=True,
                 batch_size=batch.batch_size)
        self.validation_step_outputs.append((out.detach().cpu(), batch.y)) # SAVE OUTPUTS
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(params=self.model.parameters(), lr=0.001, weight_decay=0.0001)

Then we use the in-built function sof the Callback class to get the images at the end of each train/validation epoch and send them to the logger:

In [15]:
class LogMetrics(callbacks.Callback):
    """
    Log metrics and curves for validation and training

    Scalars: ROC/val_AUC, ROC/train_AUC, PR/val_AUC, PR/train_AUC, CM/val_F1, CM/train_F1
    Images: ROC/val, ROC/train, PR/val, PR/train, CM/val, CM/train
    """
    def on_validation_epoch_end(self, trainer, pl_module):
        outputs = torch.cat([x[0] for x in pl_module.validation_step_outputs], dim=0)
        labels = torch.cat([x[1] for x in pl_module.validation_step_outputs], dim=0)
        for metric in ["ROC", "PR"]:
            metric_auc, metric_disp = get_metrics_and_curves(metric, outputs, labels)
            pl_module.log(f"{metric}/val_AUC", metric_auc)
            trainer.logger.experiment.add_figure(f"{metric}/val", metric_disp, global_step=trainer.global_step)

    def on_train_epoch_end(self, trainer, pl_module):
        outputs = torch.cat([x[0] for x in pl_module.train_step_outputs], dim=0)
        labels = torch.cat([x[1] for x in pl_module.train_step_outputs], dim=0)
        for metric in ["ROC", "PR"]:
            metric_auc, metric_disp = get_metrics_and_curves(metric, outputs, labels)
            pl_module.log(f"{metric}/train_AUC", metric_auc)
            trainer.logger.experiment.add_figure(f"{metric}/train", metric_disp, global_step=trainer.global_step)

In [16]:
logger = TensorBoardLogger("lightning_logs", name="gat")
model = GATModule()
datamodule = dataloader.ProteinGraphDataModule("./test_data", "dataset.txt")
trainer = L.Trainer(
    max_epochs=20,
    logger=logger,
    log_every_n_steps=1,

    callbacks=[callbacks.EarlyStopping(monitor="val_loss", patience=2),
               callbacks.ModelCheckpoint(monitor="val_loss"),
               LogMetrics()],
    accelerator="cpu",
)
trainer.fit(model=model, datamodule=datamodule)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Processing...
Done!
Processing...
Done!

  | Name          | Type              | Params
----------------------------------------------------
0 | model         | GAT               | 3.6 K 
1 | loss_function | BCEWithLogitsLoss | 0     
----------------------------------------------------
3.6 K     Trainable params
0         Non-trainable params
3.6 K     Total params
0.014     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/jay/mambaforge/envs/geometric-learning/lib/python3.8/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:436: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

## Bonus:
- Check out wandb
- Learning rate schedulers