# Training a HACTNet model

In this notebook, we will train the HACTNet graph neural network (GNN) model on input cell and tissue graphs using the new `pathml.graph` API.

To run the notebook and train the model, you will have to first download the BRACS ROI set from the [BRACS dataset](https://www.bracs.icar.cnr.it/download/). To do so, you will have to sign up and create an account. Next, you will have to construct the cell and tissue graphs using the tutorial in `examples/construct_graphs.ipynb`. Use the output directory specified there as the input to the functions in this tutorial. 

In [1]:
import os
from glob import glob
import argparse
from PIL import Image
import numpy as np
from tqdm import tqdm
import torch 
import h5py
import warnings
import math
from skimage.measure import regionprops, label
import networkx as nx
import traceback
from glob import glob
import torch
import torch.nn as nn
from torch_geometric.data import Batch
from torch_geometric.data import Data
from torch.utils.data import Dataset
from torch_geometric.loader import DataLoader
from torch.optim.lr_scheduler import StepLR
from sklearn.metrics import f1_score

from pathml.datasets import EntityDataset
from pathml.ml.utils import get_degree_histogram, get_class_weights
from pathml.ml import HACTNet

## Model Training

Here we define the main training loop for loading the constructed graphs, initializing and training the model. 

In [None]:
def train_hactnet(root_dir, load_histogram=True, histogram_dir=None, calc_class_weights=True):

    # Read the train, validation and test dataset into the pathml.datasets.EntityDataset class 
    train_dataset = EntityDataset(os.path.join(root_dir, 'cell_graphs/train/'),
                                  os.path.join(root_dir, 'tissue_graphs/train/'),
                                  os.path.join(root_dir, 'assignment_matrices/train/'))
    val_dataset = EntityDataset(os.path.join(root_dir, 'cell_graphs/val/'),
                                os.path.join(root_dir, 'tissue_graphs/val/'),
                                os.path.join(root_dir, 'assignment_matrices/val/'))
    test_dataset = EntityDataset(os.path.join(root_dir, 'cell_graphs/test/'),
                                 os.path.join(root_dir, 'tissue_graphs/test/'),
                                 os.path.join(root_dir, 'assignment_matrices/test/'))

    # Print the lengths of each dataset split
    print(f"Length of training dataset: {len(train_dataset)}")
    print(f"Length of validation dataset: {len(val_dataset)}")
    print(f"Length of test dataset: {len(test_dataset)}")

    # Define the torch_geometric.DataLoader object for each dataset split with a batch size of 4
    train_batch = DataLoader(train_dataset, batch_size=4, shuffle=False, follow_batch =['x_cell', 'x_tissue'], drop_last=True)
    val_batch = DataLoader(val_dataset, batch_size=4, shuffle=True, follow_batch =['x_cell', 'x_tissue'], drop_last=True)
    test_batch = DataLoader(test_dataset, batch_size=4, shuffle=True, follow_batch =['x_cell', 'x_tissue'], drop_last=True)

    # The GNN layer we use in this model, PNAConv, requires the computation of a node degree histogram of the 
    # train dataset. We only need to compute it once. If it is precomputed already, set the load_histogram=True.
    # Else, the degree histogram is calculated. 
    if load_histogram:
        histogram_dir = "./"
        cell_deg = torch.load(os.path.join(histogram_dir, 'cell_degree_norm.pt'))
        tissue_deg = torch.load(os.path.join(histogram_dir, 'tissue_degree_norm.pt'))
    else:
        train_batch_hist = DataLoader(train_dataset, batch_size=20, shuffle=True, follow_batch =['x_cell', 'x_tissue'])
        print('Calculating degree histogram for cell graph')
        cell_deg = get_degree_histogram(train_batch_hist, 'edge_index_cell', 'x_cell')
        print('Calculating degree histogram for tissue graph')
        tissue_deg = get_degree_histogram(train_batch_hist, 'edge_index_tissue', 'x_tissue')
        torch.save(cell_deg, 'cell_degree_norm.pt')
        torch.save(tissue_deg, 'tissue_degree_norm.pt')

    # Since the BRACS dataset has unbalanced data, it is important to calculate the class weights in the training set
    # and provide that as an argument to our loss function. 
    if calc_class_weights:
        train_w = get_class_weights(train_batch)
        torch.save(torch.tensor(train_w), 'loss_weights_norm.pt')

    # Here we define the keyword arguments for the PNAConv layer in the model for both cell and tissue processing 
    # layers. 
    kwargs_pna_cell =  {'aggregators': ["mean", "max", "min", "std"],
                "scalers": ["identity", "amplification",  "attenuation"],
                "deg": cell_deg}
    kwargs_pna_tissue =  {'aggregators': ["mean", "max", "min", "std"],
                "scalers": ["identity", "amplification",  "attenuation"],
                "deg": tissue_deg}
    
    cell_params = {'layer':'PNAConv', 'in_channels':514, 'hidden_channels':64, 
                   'num_layers':3, 'out_channels':64, 'readout_op':'lstm', 
                   'readout_type':'mean', 'kwargs':kwargs_pna_cell}
    
    tissue_params =  {'layer':'PNAConv', 'in_channels':514, 'hidden_channels':64, 
                      'num_layers':3, 'out_channels':64, 'readout_op':'lstm', 
                      'readout_type':'mean', 'kwargs':kwargs_pna_tissue}
    
    classifier_params = {'in_channels':128, 'hidden_channels':128,
                         'out_channels':7, 'num_layers': 2}

    # Transfer the model to GPU
    device = torch.device("cuda")

    # Initialize the pathml.ml.HACTNet model
    model = HACTNet(cell_params, tissue_params, classifier_params)

    # Set up optimizer
    opt = torch.optim.Adam(model.parameters(), lr = 0.0005)

    # Learning rate scheduler to reduce LR by factor of 10 each 25 epochs
    scheduler = StepLR(opt, step_size=25, gamma=0.1)

    # Send the model to GPU
    model = model.to(device)

    # Define number of epochs 
    n_epochs = 60

    # Keep a track of best epoch and metric for saving only the best models
    best_epoch = 0
    best_metric = 0

    # Load the computed class weights if calc_class_weights = True
    if calc_class_weights:
        loss_weights = torch.load('loss_weights_norm.pt')

    # Define the loss function
    loss_fn = nn.CrossEntropyLoss(weight=loss_weights.float().to(device) if calc_class_weights else None)

    # Define the evaluate function to compute metrics for validation and test set to keep track of performance.
    # The metrics used are per-class and weighted F1 score. 
    def evaluate(data_loader):
        model.eval()
        y_true = []
        y_pred = []
        with torch.no_grad():
            for data in tqdm(data_loader):
                data = data.to(device)
                outputs = model(data)
                y_true.append(torch.argmax(outputs.detach().cpu().softmax(dim=1), dim=-1).numpy())
                y_pred.append(data.target.cpu().numpy())
            y_true = np.array(y_true).ravel()
            y_pred = np.array(y_pred).ravel()
            per_class = f1_score(y_true, y_pred, average=None)
            weighted = f1_score(y_true, y_pred, average='weighted')
        print(f'Per class F1: {per_class}')
        print(f'Weighted F1: {weighted}')
        return np.append(per_class, weighted)

    # Start the training loop
    for i in range(n_epochs):
        print(f'\n>>>>>>>>>>>>>>>>Epoch number {i}>>>>>>>>>>>>>>>>')
        minibatch_train_losses = []
    
        # Put model in training mode
        model.train()
        
        print('Training')
        
        for data in tqdm(train_batch):
            
            # Send the data to the GPU
            data = data.to(device)
    
            # Zero out gradient
            opt.zero_grad()
    
            # Forward pass
            outputs = model(data)
    
            # Compute loss
            loss = loss_fn(outputs, data.target)
    
            # Compute gradients
            loss.backward()
    
            # Step optimizer and scheduler
            opt.step() 

            # Track loss
            minibatch_train_losses.append(loss.detach().cpu().numpy())
    
        print(f'Loss: {np.array(minibatch_train_losses).ravel().mean()}')

        # Print performance metrics on validation set
        print('\nEvaluating on validation')
        val_metrics = evaluate(val_batch)

        # Save the model only if it is better than previous checkpoint in validation metrics
        if val_metrics[-1] > best_metric:
            print('Saving checkpoint')
            torch.save(model.state_dict(), "hact_net_norm.pt")
            best_metric = val_metrics[-1]

        # Print performance metrics on test set
        print('\nEvaluating on test')
        _ = evaluate(test_batch)
            
        # Step LR scheduler
        scheduler.step()

In [None]:
root_dir = '../../../../mnt/disks/data/varun/BRACS_RoI/latest_version/pathml_graph_data_norm/'
train_hactnet(root_dir, load_histogram=True, calc_class_weights=False)

## References

*  Pati, Pushpak, Guillaume Jaume, Antonio Foncubierta-Rodriguez, Florinda Feroce, Anna Maria Anniciello, Giosue Scognamiglio, Nadia Brancati et al. "Hierarchical graph representations in digital pathology." Medical image analysis 75 (2022): 102264.
*  Brancati, Nadia, Anna Maria Anniciello, Pushpak Pati, Daniel Riccio, Giosuè Scognamiglio, Guillaume Jaume, Giuseppe De Pietro et al. "Bracs: A dataset for breast carcinoma subtyping in h&e histology images." Database 2022 (2022): baac093.

## Session info

In [None]:
import IPython
print(IPython.sys_info())
print(f"torch version: {torch.__version__}")