In [None]:
# !pip install --use-pep517 "graphein[extras]" lightning torch torch-geometric tensorboard nbformat "jsonargparse[signatures]" ipywidgets tabulate

There are a variety of free and paid resources available for interactively tracking training performance of deep learning models. Here we'll use [TensorBoard](https://github.com/tensorflow/tensorboard) which is free and open-source. Another popular option with a free tier is [Weights & Biases](https://wandb.ai/site), which has some additional features and integrations. 



In [None]:
import lightning
import torch
from torch import nn
from torch_geometric import nn as graph_nn
from src import dataloader
from lightning.pytorch import callbacks
from lightning.pytorch.loggers import TensorBoardLogger
from sklearn import metrics

We can actually visualise TensorBoard within the notebook with some cell magic to tell Jupyter to run the server in the background and display it in the notebook, by looking for files in the `lightning_logs` directory. As we run training, log files will be written to this directory by default and the TensorBoard display will update automatically (refreshes every 30 seconds).

In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/ --bind_all

Let's set up the LightningModule again and this time also log the loss value on our validation set. Note that the `validation_step` function automatically runs with gradients disabled, meaning that weights are frozen and the model is in evaluation mode, so we don't need to worry about that. We'll also save the predicted values in an `outputs` variable.

In [None]:
class InterfaceModule(lightning.LightningModule):
    """
    LightningModule wrapping a GAT model.
    """
    def __init__(self):
        super().__init__()
        self.model = graph_nn.GAT(in_channels=20,
                         hidden_channels=32,
                         num_layers=2,
                         heads=2,
                         out_channels=1,
                         dropout=0.01,
                         jk="last", v2=True)
        self.loss_function = nn.BCEWithLogitsLoss()
        self.train_step_outputs = []
        self.validation_step_outputs = []

    def forward(self, node_attributes, edge_index):
        return self.model(node_attributes, edge_index)

    def training_step(self, batch, batch_idx):
        y_pred = self(batch.amino_acid_one_hot.float(), batch.edge_index)
        y_true = batch.interface_label.float().view(-1, 1)
        loss = self.loss_function(y_pred, y_true)
        self.log('train_loss', loss, on_step=True, on_epoch=True, sync_dist=True,
                 batch_size=batch.batch_size)
        self.train_step_outputs.append((y_pred.detach().cpu(), y_true.detach().cpu())) # SAVE OUTPUTS
        return loss
    
    def validation_step(self, batch, batch_idx):
        y_pred = self(batch.amino_acid_one_hot.float(), batch.edge_index)
        y_true = batch.interface_label.float().view(-1, 1)
        loss = self.loss_function(y_pred, y_true)
        self.log('val_loss', loss, on_step=True, on_epoch=True, sync_dist=True,
                 batch_size=batch.batch_size)
        self.validation_step_outputs.append((y_pred.detach().cpu(), y_true.detach().cpu())) # SAVE OUTPUTS
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(params=self.model.parameters(), lr=0.001, weight_decay=0.0001)

In [None]:
trainer = lightning.Trainer(log_every_n_steps=1, max_epochs=10, accelerator='cpu')

We add the TensorBoard logger to our Trainer and train the model as before

In [None]:
model = InterfaceModule()
datamodule = dataloader.ProteinGraphDataModule("./test_data", "dataset.txt")
logger = TensorBoardLogger("lightning_logs", name="gat")
trainer.fit(model=model, datamodule=datamodule)

Every new run you train will be added to the logger so you can keep track of the improvements made over time.

## Adding callbacks

PyTorch Lightning has a number of [built-in callbacks](https://lightning.ai/docs/pytorch/stable/extensions/callbacks.html) which are used to perform actions at various points during training. For example, the `ModelCheckpoint` callback saves the model after each epoch, and the `EarlyStopping` callback stops training if the validation loss has not improved for a certain number of epochs.

In [None]:
model = InterfaceModule()
datamodule = dataloader.ProteinGraphDataModule("./test_data", "dataset.txt")
trainer = lightning.Trainer(
    max_epochs=50,
    logger=logger,
    log_every_n_steps=1,
    callbacks=[callbacks.EarlyStopping(monitor="val_loss", patience=2), # stop training if validation loss does not improve for 2 epochs
               callbacks.ModelCheckpoint(monitor="val_loss", save_top_k=1)], # save the best model based on validation loss
    accelerator="cpu",

)
trainer.fit(model=model, datamodule=datamodule)

We can also add custom callbacks to log and track things that we are interested in, such as the precision-recall curve or the ROC-AUC curve. 

Here's a function to calculate such curves and return the images and AUC values:

In [None]:
def get_metrics_and_curves(metric_type, y_pred, y_true, threshold=0.5):
    """
    Calculate metrics and curves for a given metric type
    ROC: Receiver Operating Characteristic curve, metric = Area under the curve
    PR: Precision-Recall curve, metric = Area under the curve (Average precision)
    CM: Confusion Matrix, metric = F1 score

    Parameters
    ----------
    metric_type : str
        One of "ROC", "PR"
    y_pred : torch.Tensor
        Predicted labels
    y_true : torch.Tensor
        True labels

    Returns
    -------
    metric_value : float
        Value of the metric
    metric_disp : matplotlib.figure.Figure
        Figure of the curve/matrix
    """
    y_true = y_true.cpu().detach().numpy()
    y_pred = y_pred.cpu().detach().numpy()
    if metric_type == "ROC":
        # Receiver Operating Characteristic Curve
        fpr, tpr, _ = metrics.roc_curve(y_true, y_pred, pos_label=1)
        roc_auc = metrics.auc(fpr, tpr)
        roc_disp = metrics.RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc).plot()
        return roc_auc, roc_disp.figure_
    elif metric_type == "PR":
        # Precision-Recall Curve
        precision, recall, _ = metrics.precision_recall_curve(y_true, y_pred, pos_label=1)
        pr_auc = metrics.auc(recall, precision)
        pr_disp = metrics.PrecisionRecallDisplay(precision=precision, recall=recall, average_precision=pr_auc).plot()
        return pr_auc, pr_disp.figure_

To add these to the logger, we can use the in-built functions of the Callback class to get the images at the end of each train/validation epoch and send them to the logger. Don't forget to clear the `output` variables at the end of each epoch to initialize them for the next epoch.

In [None]:
class LogMetrics(callbacks.Callback):
    """
    Log metrics and curves for validation and training

    Scalars: ROC/val_AUC, ROC/train_AUC, PR/val_AUC, PR/train_AUC
    Images: ROC/val, ROC/train, PR/val, PR/train
    """
    def on_train_epoch_end(self, trainer, pl_module):
        outputs = torch.cat([x[0] for x in pl_module.train_step_outputs], dim=0)
        labels = torch.cat([x[1] for x in pl_module.train_step_outputs], dim=0)
        for metric in ["ROC", "PR"]:
            metric_auc, metric_disp = get_metrics_and_curves(metric, outputs, labels)
            pl_module.log(f"{metric}/train_AUC", metric_auc)
            trainer.logger.experiment.add_figure(f"{metric}/train", metric_disp, global_step=trainer.global_step)
        pl_module.train_step_outputs.clear()

    def on_validation_epoch_end(self, trainer, pl_module):
        outputs = torch.cat([x[0] for x in pl_module.validation_step_outputs], dim=0)
        labels = torch.cat([x[1] for x in pl_module.validation_step_outputs], dim=0)
        for metric in ["ROC", "PR"]:
            metric_auc, metric_disp = get_metrics_and_curves(metric, outputs, labels)
            pl_module.log(f"{metric}/val_AUC", metric_auc)
            trainer.logger.experiment.add_figure(f"{metric}/val", metric_disp, global_step=trainer.global_step)
        pl_module.validation_step_outputs.clear()

In [None]:
logger = TensorBoardLogger("lightning_logs", name="gat")
model = InterfaceModule()
datamodule = dataloader.ProteinGraphDataModule("./test_data", "dataset.txt")
trainer = lightning.Trainer(
    max_epochs=20,
    logger=logger,
    log_every_n_steps=1,
    callbacks=[callbacks.EarlyStopping(monitor="val_loss", patience=2),
               callbacks.ModelCheckpoint(monitor="val_loss"),
               LogMetrics()],
    accelerator="cpu",
)
trainer.fit(model=model, datamodule=datamodule)

## Moving things to scripts and config files

```sh
src/
    __init__.py
    dataloader.py
        load_graph
        ProteinDataset
        ProteinGraphDataModule
    models.py
        InterfaceModel
        InterfaceModule
    callbacks.py
        get_metrics_and_curves
        LogMetrics
train.py
config.yaml
```

In `config.yaml`:

```yaml
seed_everything: true
model:
  class_path: src.models.InterfaceModule
  init_args:
    in_channels: 20
    num_layers: 2
    hidden_channels: 32
    heads: 2
    out_channels: 1
    dropout: 0.01
    jk: last
    v2: true
data:
  class_path: src.dataloader.ProteinGraphDataModule
  init_args:
    root: ./
    columns:
      - chain_id
      - coords
      - edge_index
      - kind
      - node_id
      - residue_number
      - meiler
      - amino_acid_one_hot
      - interface_label
    batch_size: 32
    num_workers: 4
optimizer:
  class_path: torch.optim.Adam
  init_args:
    lr: 0.001
    weight_decay: 0.0001
trainer:
  logger:
    - class_path: lightning.pytorch.loggers.TensorBoardLogger
      init_args:
        save_dir: lightning_logs
        name: interface
        log_graph: true
  enable_checkpointing: true
  callbacks:
    - class_path: lightning.pytorch.callbacks.EarlyStopping
      init_args:
        patience: 5
        monitor: val_loss
        mode: min
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      init_args:
        save_top_k: 3
        monitor: val_loss
        mode: min
        filename: "{epoch:02d}-{val_loss:.2f}"
    - class_path: src.callbacks.LogMetrics
  enable_progress_bar: true
  max_epochs: -1
  log_every_n_steps: 1
  accelerator: cpu
  strategy: auto
  precision: 32
```

And in `train.py`:

```python
from lightning.pytorch.cli import LightningCLI
import torch
import warnings
warnings.filterwarnings('ignore')

def main():
    """
    Run with python main.py fit -c config.yaml
    Or in an sbatch script with srun python main.py fit -c config.yaml
    """
    torch.set_float32_matmul_precision('medium')
    cli = LightningCLI(save_config_kwargs={"overwrite": True})

if __name__ == '__main__':
    main()
```