There are a variety of free and paid resources available for interactively tracking training performance of deep learning models. Here we'll use [TensorBoard](https://github.com/tensorflow/tensorboard) which is free and open-source. Another popular option with a free tier is [Weights & Biases](https://wandb.ai/site), which has some additional features and integrations. 

We can actually visualise TensorBoard within the notebook with some cell magic. This tells Jupyter to run the TensorBoard server in the background and display it in the notebook, by looking for files in the `lightning_logs` directory. As we run training, log files will be written to this directory by default and the TensorBoard display will update automatically (refreshes every 30 seconds).

In [1]:
%load_ext tensorboard

Let's set up the LightningModule again and this time also log the loss value on our validation set. Note that the `validation_step` function automatically runs with gradients disabled, meaning that weights are frozen and the model is in evaluation mode, so we don't need to worry about that.

In [2]:
import lightning as L
import torch
from torch import nn
from torch_geometric import nn as graph_nn
from src import dataloader

class GATModule(L.LightningModule):
    """
    LightningModule wrapping a GAT model.
    """
    def __init__(self):
        super().__init__()
        self.model = graph_nn.GAT(in_channels=20,
                         hidden_channels=32,
                         num_layers=2,
                         heads=2,
                         out_channels=1,
                         dropout=0.01,
                         jk="last", v2=True)
        self.loss_function = nn.BCEWithLogitsLoss()

    def forward(self, node_attributes, edge_index):
        return self.model(node_attributes, edge_index)

    def training_step(self, batch, batch_idx):
        out = self(batch.amino_acid_one_hot.float(), batch.edge_index)
        loss = self.loss_function(out, batch.y.view(-1, 1))
        self.log('train_loss', loss, on_step=True, on_epoch=True, sync_dist=True,
                 batch_size=batch.batch_size)
        return loss
    
    def validation_step(self, batch, batch_idx):
        out = self(batch.amino_acid_one_hot.float(), batch.edge_index)
        loss = self.loss_function(out, batch.y.view(-1, 1))
        self.log('val_loss', loss, on_step=True, on_epoch=True, sync_dist=True,
                 batch_size=batch.batch_size)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(params=self.model.parameters(), lr=0.001, weight_decay=0.0001)

In [3]:
model = GATModule()
datamodule = dataloader.ProteinGraphDataModule("./test_data", "dataset.txt")
trainer = L.Trainer(max_epochs=2)
trainer.fit(model=model, datamodule=datamodule)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Processing...
Done!
Processing...
Done!
/scicore/home/schwede/durair0000/mambaforge/envs/leuven/lib/python3.8/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:652: Checkpoint directory /scicore/home/schwede/durair0000/projects/leuven_course/lightning_logs/version_44909931/checkpoints exists and is not empty.

  | Name          | Type              | Params
----------------------------------------------------
0 | model         | GAT               | 3.6 K 
1 | loss_function | BCEWithLogitsLoss | 0     
----------------------------------------------------
3.6 K     Trainable params
0         Non-trainable params
3.6 K     Total params
0.014     Total estimated model params size (MB)
SLURM auto-requeueing enabled. Setting signal handlers.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/scicore/home/schwede/durair0000/mambaforge/envs/leuven/lib/python3.8/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (7) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=2` reached.


In [2]:
%tensorboard --logdir lightning_logs/ --bind_all

## Adding callbacks

[Built-in callbacks](https://lightning.ai/docs/pytorch/stable/extensions/callbacks.html)

In [None]:
from pytorch_lightning import callbacks

In [None]:
trainer = L.Trainer(
    devices="auto",
    accelerator="auto",
    enable_progress_bar=True,
    max_epochs=5,
    logger=logger,
    log_every_n_steps=1,
    callbacks=[callbacks.EarlyStopping(monitor="val_loss", patience=5),
               callbacks.ModelCheckpoint(monitor="val_loss")],

)
trainer.fit(model=model, datamodule=dataloader)

Custom callbacks

In [None]:
from sklearn import metrics
import pandas as pnd
import seaborn as sns
import matplotlib.pyplot as plt

def get_metrics_and_curves(metric_type, y_pred, y_true, invert=False, threshold=0.5):
    """
    Calculate metrics and curves for a given metric type
    ROC: Receiver Operating Characteristic curve, metric = Area under the curve
    PR: Precision-Recall curve, metric = Area under the curve (Average precision)
    CM: Confusion Matrix, metric = F1 score

    Parameters
    ----------
    metric_type : str
        One of "ROC", "PR", "CM"
    y_pred : torch.Tensor
        Predicted labels
    y_true : torch.Tensor
        True labels
    invert : bool
        If True, do 1 - y_pred, use if y_pred is distance instead of probability

    Returns
    -------
    metric_value : float
        Value of the metric
    metric_disp : matplotlib.figure.Figure
        Figure of the curve/matrix
    """
    if invert:
        y_pred = 1 - y_pred
    y_true = y_true.cpu().detach().numpy()
    y_pred = y_pred.cpu().detach().numpy()
    if metric_type == "ROC":
        fpr, tpr, _ = metrics.roc_curve(y_true, y_pred, pos_label=1)
        roc_auc = metrics.auc(fpr, tpr)
        roc_disp = metrics.RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc).plot()
        return roc_auc, roc_disp.figure_
    elif metric_type == "PR":
        # Precision-Recall Curve
        precision, recall, _ = metrics.precision_recall_curve(y_true, y_pred, pos_label=1)
        pr_auc = metrics.auc(recall, precision)
        pr_disp = metrics.PrecisionRecallDisplay(precision=precision, recall=recall, average_precision=pr_auc).plot()
        return pr_auc, pr_disp.figure_
    elif metric_type == "CM":
        confusion_matrix = metrics.confusion_matrix(y_true, y_pred > threshold)
        df_cm = pnd.DataFrame(confusion_matrix)
        plt.figure(figsize = (10,7))
        cm_disp = sns.heatmap(df_cm, annot=True, cmap='Blues').get_figure()
        plt.close(cm_disp)
        f1 = metrics.f1_score(y_true, y_pred > threshold)
        return f1, cm_disp



class LogMetrics(callbacks.Callback):
    """
    Log metrics and curves for validation and training

    Scalars: ROC/val_AUC, ROC/train_AUC, PR/val_AUC, PR/train_AUC, CM/val_F1, CM/train_F1
    Images: ROC/val, ROC/train, PR/val, PR/train, CM/val, CM/train
    """
    def on_validation_epoch_end(self, trainer, pl_module):
        outputs = torch.cat([x['out'] for x in pl_module.validation_step_outputs], dim=0)
        labels = torch.cat([x['y'] for x in pl_module.validation_step_outputs], dim=0)
        for metric, value in zip(["ROC", "PR", "CM"], ["AUC", "AUC", "F1"]):
            metric_value, metric_disp = get_metrics_and_curves(metric, outputs, labels)
            pl_module.log(f"{metric}/val_{value}", metric_value)
            if trainer.current_epoch % 10 == 0:
                trainer.logger.experiment.add_figure(f"{metric}/val", metric_disp, global_step=trainer.global_step)

    def on_train_epoch_end(self, trainer, pl_module):
        outputs = torch.cat([x['out'] for x in pl_module.train_step_outputs], dim=0)
        labels = torch.cat([x['y'] for x in pl_module.train_step_outputs], dim=0)
        for metric, value in zip(["ROC", "PR", "CM"], ["AUC", "AUC", "F1"]):
            metric_value, metric_disp = get_metrics_and_curves(metric, outputs, labels)
            pl_module.log(f"{metric}/train_{value}", metric_value)
            if trainer.current_epoch % 10 == 0:
                trainer.logger.experiment.add_figure(f"{metric}/train", metric_disp, global_step=trainer.global_step)



In [None]:
trainer = L.Trainer(
    devices="auto",
    accelerator="auto",
    enable_progress_bar=True,
    max_epochs=5,
    logger=logger,
    log_every_n_steps=1,

    callbacks=[callbacks.EarlyStopping(monitor="val_loss", patience=5),
               callbacks.ModelCheckpoint(monitor="val_loss"),
               LogMetrics()],
)
trainer.fit(model=model, datamodule=dataloader)

## Bonus:
- Check out wandb
- Learning rate schedulers