<a href="https://colab.research.google.com/github/casblaauw/BertOGlyc/blob/main/ProtBert_NetOGlyc_classification_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Initial model architecture based on [Elnaggar et al. (2020)](https://www.biorxiv.org/content/10.1101/2020.07.12.199554v3.full) and [Heinzinger et al. (2019)](https://bmcbioinformatics.biomedcentral.com/articles/10.1186/s12859-019-3220-8).  
Data loader structure inspired by [this pytorch documentation](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html), 
model architecture inspired by [this pytorch documentation](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html), training loop inspired by [this CNN tutorial](https://chriskhanhtran.github.io/posts/cnn-sentence-classification/), weights for loss function inspired by [this tutorial](https://towardsdatascience.com/handling-class-imbalanced-data-using-a-loss-specifically-made-for-it-6e58fd65ffab), based on [this paper](https://arxiv.org/abs/1901.05555). Hyperparameter tuning based on [this tutorial](https://pytorch.org/tutorials/beginner/hyperparameter_tuning_tutorial.html) and general [Ray docs](https://docs.ray.io/en/latest/tune/key-concepts.html).

<b>0. Import functions</b>

In [1]:
!pip install ray
!pip install -U hyperopt

Collecting ray
  Downloading ray-1.6.0-cp37-cp37m-manylinux2014_x86_64.whl (49.6 MB)
[K     |████████████████████████████████| 49.6 MB 6.4 kB/s 
Collecting redis>=3.5.0
  Downloading redis-3.5.3-py2.py3-none-any.whl (72 kB)
[K     |████████████████████████████████| 72 kB 576 kB/s 
Installing collected packages: redis, ray
Successfully installed ray-1.6.0 redis-3.5.3
Collecting hyperopt
  Downloading hyperopt-0.2.5-py2.py3-none-any.whl (965 kB)
[K     |████████████████████████████████| 965 kB 13.3 MB/s 
Installing collected packages: hyperopt
  Attempting uninstall: hyperopt
    Found existing installation: hyperopt 0.1.2
    Uninstalling hyperopt-0.1.2:
      Successfully uninstalled hyperopt-0.1.2
Successfully installed hyperopt-0.2.5
Collecting tensorboardX
  Downloading tensorboardX-2.4-py2.py3-none-any.whl (124 kB)
[K     |████████████████████████████████| 124 kB 12.3 MB/s 
Installing collected packages: tensorboardX
Successfully installed tensorboardX-2.4


In [2]:
import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, average_precision_score

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from ray import tune
from ray.tune import JupyterNotebookReporter
from ray.tune.schedulers import ASHAScheduler
from ray.tune.suggest.hyperopt import HyperOptSearch
from functools import partial

import os
from google.colab import files, drive
import gc

In [30]:
drive.mount('/content/drive')

Mounted at /content/drive


<b>1. Read in and pre-split data </b>

The info data frame determines the contents of the dataset. Indexing the dataset (as done by the data loader) looks at the supplied info file, gets the gene/protein ID associated with that position, and retrieves that protein's embeddings from the zip file.   

Therefore, the data can be split by simply splitting the info data frame and building EmbeddingDatasets/DataLoaders with those. The actual construction of those happens within the training function (to split into random train/validation sets for each training run) and at post-training testing, but the info dataframes are already prepared here.

In [4]:
# Define the Dataset class for use with DataLoader, reading in files as needed
class EmbeddingDataset(Dataset):
    """Dataset of embeddings from ProtBert.
    Path is expected to be a path to an zip/npz file containing the .npy arrays for each gene.
    Then indexes into that zip file to find the f"embeddings_{gene_id} file."
    Info is expected to be a pandas dataframe with gene names as keys, 'sequence' and 'label' keys as lists/iterables."""
    def __init__(self, path, info):
        self.path = path
        self.info = info.reset_index() 

    def __len__(self):
        return len(self.info)
    
    def __getitem__(self, idx):
        with np.load(self.path) as zip:
          embed = zip[f"embeddings_{self.info['gene'][idx]}"] 
          embed = embed.T # Need to return transposed because conv1d expects channels, then length
        label = torch.tensor(self.info['label'][idx])
        return embed, label

In [5]:
# Define the paths to the info and zip file
zip_path = '/content/drive/MyDrive/NetOGlyc/embeddings_npy.zip'
info_path = '/content/drive/MyDrive/NetOGlyc/embeddings_info.txt'

In [6]:
# Read in the info file
info = pd.read_csv(info_path, sep = '\t')
info['sequence'] = info['sequence'].apply(list)
info['label'] = info['label'].apply(lambda x: list(map(int, list(x))))

# Split data into test and training files
trainval_idx, test_idx = train_test_split(range(len(info)), test_size = 0.2)

info_trainval = info.iloc[trainval_idx, :]
info_test = info.iloc[test_idx, :]
print(info_test)

           gene  ...                                              label
90   P05362_pos  ...  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
230  Q14767_pos  ...  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
845  Q9UKU7_neg  ...  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
663  Q04721_neg  ...  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
488  O43278_neg  ...  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
..          ...  ...                                                ...
305  Q86TE4_pos  ...  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
263  Q6E0U4_pos  ...  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
560  P10619_neg  ...  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
200  Q01974_pos  ...  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
251  Q5JRA6_pos  ...  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...

[173 rows x 4 columns]


<b>2. Define model and function for weights</b>

In [7]:
class Net(nn.Module):
    def __init__(self, kernel_size = 7):
        super(Net, self).__init__()
        self.conv1 = nn.Conv1d(in_channels = 1024, out_channels = 32, kernel_size = kernel_size, padding = kernel_size//2) 
        self.dropout = nn.Dropout(p=0.25)
        self.conv2 = nn.Conv1d(in_channels = 32, out_channels = 2, kernel_size = kernel_size, padding = kernel_size//2)

    def forward(self, x):
        # ---- Layer 1
        # conv1 needs (batch_size, in_channels/features, length/seq_len), so (64, 1024, 4000) 
        # and outputs (64, 32, 4000)
        x = self.conv1(x)

        # ---- Process first layer's output
        x = self.dropout(x)
        x = F.relu(x)

        # ---- Layer 2
        # conv2 takes (64, 32, 4000) and outputs (64, 2, 4000)
        x = self.conv2(x)
        
        return x

In [8]:
# Loss weights can help prioritise properly predicting glycosites over the bulk unglycosylated sites
# Priority here is predicting sites, not predicting non-sites, so we want a very high beta to distinguish between them still.

def cb_weights(labels, beta = None):
    """Compute the weights for Class Balanced Loss.
    Class Balanced Loss: ((1-beta)/(1-beta^n))*Loss(labels, logits)
    where Loss is a loss function, here cross entropy loss.
    Args:
      labels: Pandas series of labels. Assumes binary 0/1 labeling.
      beta: float. Hyperparameter for Class balanced loss or list of pre-set weights. If beta = None, returns standard (1, 30) weights.
    Returns:
      A set of weights to be supplied to the loss function.
    """
    
    if beta is None:
      return (1, 30)
    elif type(beta) == float: 
      total_samples = labels.apply(len).sum()
      total_sites = labels.apply(sum).sum() # Assumes 0-1 labeling
      samples_dist = [total_samples - total_sites, total_sites]

      effective_num = 1.0 - np.power(beta, samples_dist)
      weights = (1.0 - beta) / np.array(effective_num)
      weights = weights / np.sum(weights) * len(samples_dist)
      return weights
    else:
      try:
        len(beta) == 2:
        return beta
      except:
        print('Beta must be a float, a 2-length iterable, or None.')
        raise ValueError

**3. Define training and tuning functions**

In [9]:
def train(config, device, zip_path, info, epochs=10, tuning = False, checkpoint_dir = None):
    """Train the CNN model.
    Args: 
      config: a dictionary with hyperparameter values {'kernel_size', 'beta', 'lr'}. 
        If tuning = True, supports ray.tune search spaces.
      device: a pytorch device indicating whether the model should be loaded into cpu or gpu.
      zip_path: a path to an zip/npz file containing the .npy arrays for each gene.
      info: a pandas dataframe with gene names as keys, 'sequence' and 'label' keys as lists/iterables.
      epochs: optional. an integer value, indicating the number of epochs (training loops) the training should last.
      tuning: optional. a boolean indicating whether the model is ran in the context of ray.tune tuning.
        In that case, it won't print training results, but will instead pass them to ray.tune.
        Default = False.
      checkpoint_dir: optional. only used when tuning = true. used to retrieve the best model's model_state after tuning.
        Default = None.
    Returns:
      Doesn't return anything, but has modified the weights of the supplied model object. 
      """

    # Initialise model
    model = Net(kernel_size = config['kernel_size'])
    model.to(device)

    # Load info and paths into dataset objects and create loaders
    train_idx, val_idx = train_test_split(range(len(info)), test_size = 0.2)
    loader_params = {'batch_size': 64, 'shuffle': True}

    info_train = info.iloc[train_idx, :]
    data_train = EmbeddingDataset(zip_path, info_train)
    loader_train = DataLoader(data_train, **loader_params)

    info_val = info.iloc[val_idx, :]
    data_val = EmbeddingDataset(zip_path, info_val)
    loader_val = DataLoader(data_val, **loader_params)

    # Gather weights
    weights = cb_weights(info_train['label'], beta = config['beta'])

    # Initialise loss function and optimizer 
    loss_fn = nn.CrossEntropyLoss(weight = torch.FloatTensor(weights).to(device))
    optimizer = optim.Adam(model.parameters(), lr = config['lr'], amsgrad = True) 

    if tuning and checkpoint_dir:
        model_state, optimizer_state = torch.load(
            os.path.join(checkpoint_dir, "checkpoint"))
        model.load_state_dict(model_state)
        optimizer.load_state_dict(optimizer_state)

    # Start training loop
    if not tuning:
      print("Start training...\n")
      print(f"{'Epoch':^7} | {'Train Loss':^12} | {'Val Loss':^10} | {'Prec. @.5':^10} | {'Recall@.5':^10} | {'AUC':^10} | {'AP':^10}")
      print("-"*87)

    model = model.float()

    for epoch_i in range(epochs):
        # =======================================
        #               Training
        # =======================================

        # Tracking time and loss
        total_loss = 0

        # Put the model into the training model
        model.train()

        for step, batch in enumerate(loader_train):
            # Load batch to GPU
            b_input_ids, b_labels = tuple(t.to(device) for t in batch)

            # Zero out any previously calculated gradients
            model.zero_grad()

            # Perform a forward pass. (output shape: (batch, n_classes, length))
            logits = model(b_input_ids.float())

            # Compute loss and accumulate the loss values
            loss = loss_fn(logits, b_labels)
            total_loss += loss.item()

            # Perform a backward pass to calculate gradients
            loss.backward()

            # Update parameters
            optimizer.step()

        # Calculate the average loss over the entire training data
        avg_train_loss = total_loss / len(loader_train)

        # =======================================
        #               Evaluation
        # =======================================
        # After the completion of each training epoch, measure the model's
        # performance on our validation set.
        val_loss, val_accuracy, val_precision, val_recall, val_auc, val_ap = evaluate(model, loader_val, loss_fn)

        # Print performance over the entire training data
        if tuning:
          with tune.checkpoint_dir(epoch_i) as checkpoint_dir:
            path = os.path.join(checkpoint_dir, "checkpoint")
            torch.save((model.state_dict(), optimizer.state_dict()), path)
          tune.report(loss = val_loss, precision = val_precision, recall = val_recall, auc = val_auc, ap = val_ap)
        else:
          print(f"{epoch_i + 1:^7} | {avg_train_loss:^12.6f} | {val_loss:^10.6f} | {val_precision:^12.2f} | {val_recall:^12.2f} | {val_auc:^10.4f} | {val_ap:^10.4f}")
            
        
        # # =======================================
        # #               Checkpoint
        # # =======================================

        # torch.save(model.state_dict(), f"model_{time}_{epoch_i}.pth")

        gc.collect()

    print(f"Training complete!")

In [28]:
def evaluate(model, val_dataloader, loss_fn = nn.CrossEntropyLoss):
    """Measure a model's performance on a validation set.
    Args:
      model: a model object to evaluate.
      val_dataloader: a dataloader with validation data.
      loss_fn: a loss function to calculate the validation loss with. 
        Usually passed on within train() to be the same loss function as used for training.
        Default = nn.CrossEntropyLoss, but should be overwritten to match training loss_fn.
    Returns:
      val_loss: the mean of the loss across batches.
      val_accuracy: the mean of the accuracy (correct predictions based on cutoff 0.5) across batches.
      val_precision: the fraction of correct positive predictions based on cutoff 0.5 (also known as positive predictive value)
      val_recall: the fraction of actual positives that were predicted by the model based on cutoff 0.5 (also known as sensitivity or true positive rate)
      val_auc: the area under the ROC curve, indicates our model's capability to distinguish between the two classes.
      val_ap: the average precision, aka the area under the precision-recall curve. indicates our model's capability to distinguish the positive values correctly.
    """
    # Put the model into the evaluation mode. The dropout layers are disabled
    # during the test time.
    model.eval()

    # Tracking variables
    val_loss = []
    true_labs_all = []
    probs_all = []
    preds_all = []

    # For each batch in our validation set...
    for batch in val_dataloader:
        # Load batch to GPU
        b_input, b_labels = tuple(t.to(device) for t in batch)

        # Compute scores (shape: (batch, n_classes, length))
        with torch.no_grad():
            scores = model(b_input.float())

        # Compute loss
        if loss_fn is not None:
          loss = loss_fn(scores, b_labels)
          val_loss.append(loss.item())

        # Get the probabilities and predictions
        true_labs = b_labels.cpu().numpy()
        probs = F.softmax(scores, dim=1).cpu().numpy()[:, 1, :] # keep only probabilities for label 1
        preds = torch.argmax(scores, dim=1).cpu().numpy()

        # Save to compute AUC and average precision (from precision-recall curve) later
        true_labs_all.append(true_labs.flatten())
        probs_all.append(probs.flatten())
        preds_all.append(preds.flatten())

    # Compute the performance statistics over the entire test set
    true_labs_all = np.hstack(true_labs_all)
    probs_all = np.hstack(probs_all)
    preds_all = np.hstack(preds_all)

    if loss_fn is not None:
      val_loss = np.mean(val_loss)
    val_accuracy = (preds_all == true_labs_all).mean() * 100
    val_precision = (preds_all[preds_all == 1] == true_labs_all[preds_all == 1]).mean()*100       # Also known as positive predictive value
    val_recall = (preds_all[true_labs_all == 1] == true_labs_all[true_labs_all == 1]).mean()*100  # Also known as sensitivity or true positive rate
    val_auc = roc_auc_score(y_true = true_labs_all, y_score = probs_all)                          # Area under ROC curve
    val_ap = average_precision_score(y_true = true_labs_all, y_score = probs_all)                 # Area under position-recall curve

    return val_loss, val_accuracy, val_precision, val_recall, val_auc, val_ap

In [17]:
def tune_model(config, device, num_samples):
  """A function to tune models to find the best hyperparameters using ray.tune.
  Args:
    config: a configuration dictionary with tune search space indicators.
    device: a pytorch device indicating whether the model should be loaded into cpu or gpu.
  Returns:
    A tune.ExperimentAnalysis object with information about the best trial. 
    Can be used in build_best_model() to reconstitute the model."""

  # Start training/tuning
  scheduler = ASHAScheduler(
      metric = "ap", # alternative: loss, min
      mode = "max",
      max_t = 20,
      grace_period = 3,
      reduction_factor = 2)
  reporter = JupyterNotebookReporter(
      overwrite = True,
      metric_columns = ["loss", "auc", "ap", "precision", "recall", "training_iteration"])
  search_alg = HyperOptSearch(
      metric = "ap",
      mode = "max")

  result = tune.run(
      partial(train, device = device, zip_path = zip_path, info = info_trainval, epochs = 20, tuning = True),
      resources_per_trial = {"gpu": 1},
      config = config,
      num_samples = num_samples,
      search_alg = search_alg,
      scheduler = scheduler,
      progress_reporter = reporter)

  return result


def build_best_model(best_trial, device):
  best_trained_model = Net(kernel_size = best_trial.config["kernel_size"])
  best_trained_model.to(device)

  best_checkpoint_dir = best_trial.checkpoint.value
  model_state, optimizer_state = torch.load(os.path.join(best_checkpoint_dir, "checkpoint"))
  best_trained_model.load_state_dict(model_state)
  return best_trained_model

<b>3 - option a. Tune the model</b>

This option will train the model `num_samples` times, trying different combinations of hyperparameters each time, and return the best one.  

`num_samples` has different interactions with random parameter selections (like `tune.choice()` or `tune.loguniform()`) and grid search. One unit of `num_samples` leads to only one sample from all random parameters, but one full grid search of all parameters (i.e. `num_workers` = 1 with `tune.grid_search(['A', 'B', 'C'])` makes three trials!). For an explanation on how exactly `num_samples` works, see [here](https://docs.ray.io/en/latest/tune/api_docs/search_space.html).  


For a non-tuning option, see 3b.

In [13]:
# cb_weights(info['label'], 0.99999)

array([0.03512864, 1.96487136])

In [18]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Select hyperparameters
config = {
    "kernel_size": tune.choice([5, 7, 9, 13]),
    # "kernel_size": tune.qrandint(5, 13, 2),
    # "beta": tune.choice([0.9999, 0.99999, (1, 30)]),
    "beta": tune.loguniform(0.9999, 0.99999),
    "lr": 0.001}

# Tune model
best_trial = tune_model(config, device, num_samples = 10) # Num_samples to # desired runs (like 10) if random choices, to 1 if grid search (will run entire grid once)

# Reconstitute best model
model = build_best_model(best_trial, device)

# Show tensorboard output
# !tensorboard --logdir=~/ray_results/my_experiment

Trial name,status,loc,beta,kernel_size,lr,loss,auc,ap,precision,recall,training_iteration
DEFAULT_3f9cc9ee,TERMINATED,,0.999937,5,0.001,0.0339966,0.926077,0.133444,,0.0,20
DEFAULT_40536f1e,TERMINATED,,0.999943,5,0.001,0.0655836,0.903175,0.0705934,,0.0,12
DEFAULT_1e759c2e,TERMINATED,,0.999955,5,0.001,0.0947441,0.921972,0.0669419,,0.0,12
DEFAULT_1cfe895a,TERMINATED,,0.999968,13,0.001,0.023049,0.970192,0.204706,18.1924,55.0877,20
DEFAULT_6542918a,TERMINATED,,0.999959,9,0.001,0.0540958,0.927753,0.0818549,,0.0,12
DEFAULT_0b863e38,TERMINATED,,0.999988,13,0.001,0.0825802,0.94304,0.191116,14.5583,64.375,20
DEFAULT_a407e676,TERMINATED,,0.999934,13,0.001,0.492432,0.159278,0.000369411,,0.0,3
DEFAULT_51bea394,TERMINATED,,0.99991,5,0.001,0.528582,0.150778,0.000337385,,0.0,3
DEFAULT_5bec40f0,TERMINATED,,0.999954,13,0.001,0.0194309,0.942688,0.19278,16.9394,57.8947,20
DEFAULT_63d39f38,TERMINATED,,0.999987,13,0.001,0.238874,0.826304,0.0543146,5.71319,64.2241,6


2021-09-11 21:05:22,064	INFO tune.py:561 -- Total run time: 20070.12 seconds (20069.31 seconds for the tuning loop).


Best trial config: {'kernel_size': 13, 'beta': 0.9999538920913715, 'lr': 0.001}
Best trial final validation loss: 0.019430930105348427
Best trial final validation average precision: 0.19278001602192688
Best trial final validation auc: 0.9426880381350121
2021-09-11 21:05:43.289042: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-09-11 21:05:43.893115: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-09-11 21:05:43.894052: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero

NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=f



**3 - option b: Train without tuning**

Run a simple 20-epoch training sequence with set parameters. 

In [15]:
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# config = {
#     "kernel_size": 7,
#     "beta": (1, 30), # Pass a 2-length list as beta to pre-set weights, pass a 1-length float between 0.99 and 0.99999 to get class-balanced weights
#     "lr": 0.001}
# train(config, device, zip_path, info_trainval, epochs = 20, tuning = False)

KeyboardInterrupt: ignored

**4. Check performance on the test set**

In [27]:
# print(best_trial.)

# # Show best trial
# best_trial = result.get_best_trial("loss", "min", "last-5-avg")
print(f"Best trial config: {best_trial.config}")
print(f"Best trial final validation loss: {best_trial.last_result['loss']}")
print(f"Best trial final validation average precision: {best_trial.last_result['ap']}")
print(f"Best trial final validation auc: {best_trial.last_result['auc']}")

Best trial config: {'kernel_size': 13, 'beta': 0.9999538920913715, 'lr': 0.001}
Best trial final validation loss: 0.019430930105348427
Best trial final validation average precision: 0.19278001602192688
Best trial final validation auc: 0.9426880381350121


In [32]:
# Build the test set
data_test = EmbeddingDataset(zip_path, info_test)
loader_test = DataLoader(data_test, **{'batch_size': 16, 'shuffle': True})

# Run the best model on the test set
test_loss, test_accuracy, test_precision, test_recall, test_auc, test_ap = evaluate(model, loader_test, loss_fn=None)
print(f"Test AUC: {test_auc:.4f} | Test AP:  {test_ap:.4f} | Test precision (cutoff=0.5): {test_precision:.2f} | Test recall (cutoff=0.5): {test_recall:.2f}")

Test AUC: 0.9754 | Test AP:  0.2151 
 Test precision (cutoff=0.5): 17.53 | Test recall (cutoff=0.5): 57.56


<b>5. Save and export the model</b>

In [21]:
torch.save(model.state_dict(), "model_params.pth")
# torch.save(model, "model_full.pth")

In [22]:
!cp model_params.pth /content/drive/MyDrive/NetOGlyc/

In [23]:
drive.flush_and_unmount()