<a href="https://colab.research.google.com/github/casblaauw/BertOGlyc/blob/main/BertOGlyc_train_per_protein_NN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Initial model architecture based on [Elnaggar et al. (2020)](https://www.biorxiv.org/content/10.1101/2020.07.12.199554v3.full) and [Heinzinger et al. (2019)](https://bmcbioinformatics.biomedcentral.com/articles/10.1186/s12859-019-3220-8).  
Data loader structure inspired by [this pytorch documentation](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html), 
model architecture inspired by [this pytorch documentation](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html), training loop inspired by [this CNN tutorial](https://chriskhanhtran.github.io/posts/cnn-sentence-classification/), weights for loss function inspired by [this tutorial](https://towardsdatascience.com/handling-class-imbalanced-data-using-a-loss-specifically-made-for-it-6e58fd65ffab), based on [this paper](https://arxiv.org/abs/1901.05555). Hyperparameter tuning based on [this tutorial](https://pytorch.org/tutorials/beginner/hyperparameter_tuning_tutorial.html) and general [Ray docs](https://docs.ray.io/en/latest/tune/key-concepts.html).

<b>0. Import functions</b>

In [1]:
!pip install ray
!pip install -U hyperopt

Collecting ray
  Downloading ray-1.9.0-cp37-cp37m-manylinux2014_x86_64.whl (57.6 MB)
[K     |████████████████████████████████| 57.6 MB 1.2 MB/s 
Collecting redis>=3.5.0
  Downloading redis-4.0.2-py3-none-any.whl (119 kB)
[K     |████████████████████████████████| 119 kB 74.4 MB/s 
Collecting deprecated
  Downloading Deprecated-1.2.13-py2.py3-none-any.whl (9.6 kB)
Installing collected packages: deprecated, redis, ray
Successfully installed deprecated-1.2.13 ray-1.9.0 redis-4.0.2
Collecting hyperopt
  Downloading hyperopt-0.2.7-py2.py3-none-any.whl (1.6 MB)
[K     |████████████████████████████████| 1.6 MB 15.0 MB/s 
Collecting py4j
  Downloading py4j-0.10.9.3-py2.py3-none-any.whl (198 kB)
[K     |████████████████████████████████| 198 kB 71.7 MB/s 
Installing collected packages: py4j, hyperopt
  Attempting uninstall: hyperopt
    Found existing installation: hyperopt 0.1.2
    Uninstalling hyperopt-0.1.2:
      Successfully uninstalled hyperopt-0.1.2
Successfully installed hyperopt-0.2

In [2]:
import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from ray import tune
from ray.tune import JupyterNotebookReporter
from ray.tune.schedulers import ASHAScheduler
from ray.tune.suggest.hyperopt import HyperOptSearch
from functools import partial

import os
from google.colab import files, drive
import gc

In [3]:
drive.mount('/content/drive')

Mounted at /content/drive


<b>1. Read in and pre-split data </b>

The info data frame determines the contents of the dataset. Indexing the dataset (as done by the data loader) looks at the supplied info file, gets the gene/protein ID associated with that position, and retrieves that protein's embeddings from the zip file.   

Therefore, the data can be split by simply splitting the info data frame and building EmbeddingDatasets/DataLoaders with those. The actual construction of those happens within the training function (to split into random train/validation sets for each training run) and at post-training testing, but the info dataframes are already prepared here.

In [26]:
# Define the Dataset class for use with DataLoader, reading in files as needed
class EmbeddingDataset(Dataset):
    """Dataset of embeddings from ProtBert.
    Path is expected to be a path to an zip/npz file containing the .npy arrays for each gene.
    Then indexes into that zip file to find the f"embeddings_{gene_id} file."
    Info is expected to be a pandas dataframe containing a uniprot column corresponding to zip IDs and a label column with numeric labels."""
    def __init__(self, path, info):
        self.path = path
        self.info = info.reset_index(drop=True) 

    def __len__(self):
        return len(self.info)
    
    def __getitem__(self, idx):
        with np.load(self.path) as zip:
          embed = zip[self.info['uniprot'][idx]] 
          embed = embed.mean(axis = 0, keepdims = False) # Mean-pool embeddings into single 1024-length embedding for protein
        label = torch.tensor(self.info['label'][idx]) # I think turning into tensor is not even required? Automatic by the batch function of the dataloader too but it works so
        return embed, label

In [5]:
# Define the paths to the info and zip file
zip_path = '/content/drive/MyDrive/data/yeastdata_sequence_embeddings.npz'
info_path = '/content/drive/MyDrive/data/yeastdata_03122021.csv'

In [25]:
# test chunk
with np.load(zip_path) as zip:
  print(zip.files)
  print(zip[zip.files[2]].mean(axis=0, keepdims=True).shape)
print('hi')
info = pd.read_csv(info_path, index_col = 0)
print(info.head())
print(info['localisation'].value_counts())

['P38631', 'P06776', 'P06367', 'P07280', 'P35997', 'P0C0X0', 'P26786', 'P32904', 'O14455', 'P53319', 'A6ZRZ6', 'A6ZKU1', 'P15891', 'P53946', 'P21147', 'P47143', 'P40529', 'P35197', 'P00330', 'P43594', 'P48015', 'P24813', 'P38749', 'B3LI56', 'P21306', 'A6ZWD3', 'P32486', 'P33336', 'P23292', 'P32457', 'Q3E823', 'P04037', 'P00431', 'P28272', 'P33894', 'P40087', 'P40564', 'P20048', 'Q03063', 'P39940', 'A6ZM93', 'P02994', 'P25358', 'P00924', 'P39935', 'Q12522', 'P19097', 'P32599', 'B3LRA0', 'P14540', 'P32614', 'B3LHR7', 'P14922', 'P00360', 'P00358', 'P39726', 'P06738', 'Q06681', 'P39960', 'Q00246', 'A6ZSH6', 'P23585', 'P50276', 'P02293', 'Q12276', 'P53901', 'Q03281', 'Q03707', 'P25297', 'P27895', 'P42839', 'Q04958', 'P22133', 'P53599', 'P08018', 'P43638', 'P40356', 'Q03104', 'P21192', 'Q12019', 'P32839', 'P53220', 'Q12328', 'P23644', 'A6ZTE8', 'P39952', 'Q03667', 'B3LHE1', 'O14467', 'P42939', 'P47068', 'A6ZMG6', 'P23797', 'A6ZZH2', 'P22211', 'Q06616', 'P14907', 'A6ZTA1', 'P39928', 'P21375',

  interactivity=interactivity, compiler=compiler, result=result)


In [27]:
# Read in the info file
info = pd.read_csv(info_path, index_col = 0)
info = info.loc[:, ('uniprot', 'localisation')].drop_duplicates().dropna().reset_index(drop=True) # Need one protein per row, Drop duplicates to go from one site per row to one protein per row lazily
info['label'] = info['localisation'].map({'cytosolic': 0, 'extracellular': 1})

# Split data into test and training files
trainval_idx, test_idx = train_test_split(range(len(info)), test_size = 0.2)

info_trainval = info.iloc[trainval_idx, :]
info_test = info.iloc[test_idx, :]
print(info_test)

    uniprot   localisation  label
303  P38797  extracellular      1
469  Q04947  extracellular      1
734  O94399      cytosolic      0
83   P23644      cytosolic      0
358  P32319  extracellular      1
..      ...            ...    ...
593  P53304  extracellular      1
350  Q12025  extracellular      1
779  Q9UR09  extracellular      1
576  P43497  extracellular      1
663  P40092  extracellular      1

[162 rows x 3 columns]


  interactivity=interactivity, compiler=compiler, result=result)


<b>2. Define model and function for weights</b>

In [16]:
# class Net(nn.Module):
#     def __init__(self, kernel_size = 7):
#         super(Net, self).__init__()
#         self.conv1 = nn.Conv1d(in_channels = 1024, out_channels = 32, kernel_size = kernel_size, padding = kernel_size//2) 
#         self.dropout = nn.Dropout(p=0.25)
#         self.conv2 = nn.Conv1d(in_channels = 32, out_channels = 2, kernel_size = kernel_size, padding = kernel_size//2)

#     def forward(self, x):
#         # ---- Layer 1
#         # conv1 needs (batch_size, in_channels/features, length/seq_len), so (64, 1024, 4000) 
#         # and outputs (64, 32, 4000)
#         x = self.conv1(x)

#         # ---- Process first layer's output
#         x = self.dropout(x)
#         x = F.relu(x)

#         # ---- Layer 2
#         # conv2 takes (64, 32, 4000) and outputs (64, 2, 4000)
#         x = self.conv2(x)
        
#         return x

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.lin1 = nn.Linear(in_features = 1024, out_features = 128) 
        self.lin2 = nn.Linear(in_features = 128, out_features = 10)
        self.lin3 = nn.Linear(in_features = 10, out_features = 2)
        self.dropout = nn.Dropout(p=0.25)

    def forward(self, x):
        # ---- Layer 1
        # lin1 needs (batch_size, in_features), so (64, 1024) 
        # and outputs (64, 128)
        x = self.lin1(x)

        # ---- Process first layer's output
        x = self.dropout(x)
        x = F.relu(x)

        # ---- Layer 2
        # lin2 takes (64, 128) and outputs (64, 10)
        x = self.lin2(x)

        # ---- Process second layer's output
        x = self.dropout(x)
        x = F.relu(x)

        # ---- Layer 3
        # lin3 takes (64, 10) and outputs (64, 2)
        x = self.lin3(x)
        
        return x

In [24]:
# Loss weights can help prioritise properly predicting glycosites over the bulk unglycosylated sites
    
def weighted_weights(labels, balance_factor = 1):
  """Compute the weights for the loss function, weighted by proportion in the data.
  Args:
    labels: Pandas series of labels. Assumes binary 0/1 labeling.
    balance_factor: float. Hyperparameter for loss balancing. Default is 1.
      Increasing it above 1 will make 1-labels more important than their proportion in the data, and vice versa.
      If set to None, will not balance weights and return equal [1,1] weights.
  Returns:
    A set of weights to be supplied to the loss function.
  """
  # total_samples = labels.apply(len).sum()
  # total_sites = labels.apply(sum).sum() # Assumes 0-1 labeling
  # total_nonsites = total_samples - total_sites

  if balance_factor is None:
    return [1, 1]

  total_sites = labels.sum()
  total_nonsites = labels.len() - total_sites
  weights = [1, balance_factor * (total_nonsites / total_sites)]
  return weights

**3. Define training and tuning functions**

In [31]:
def train(config, device, zip_path, info, epochs=10, tuning = False, checkpoint_dir = None):
    """Train the NN model.
    Args: 
      config: a dictionary with hyperparameter values {'loss_balance_factor', 'lr'}. 
        If tuning = True, supports ray.tune search spaces.
      device: a pytorch device indicating whether the model should be loaded into cpu or gpu.
      zip_path: a path to an zip/npz file containing the .npy arrays for each gene.
      info: a pandas dataframe with gene names as keys, 'sequence' and 'label' keys as lists/iterables.
      epochs: optional. an integer value, indicating the number of epochs (training loops) the training should last.
      tuning: optional. a boolean indicating whether the model is ran in the context of ray.tune tuning.
        In that case, it won't print training results, but will instead pass them to ray.tune.
        Default = False.
      checkpoint_dir: optional. only used when tuning = true. used to retrieve the best model's model_state after tuning.
        Default = None.
    Returns:
      Doesn't return anything, but has modified the weights of the supplied model object. 
      """

    # Initialise model
    model = Net()
    model.to(device)

    # Load info and paths into dataset objects and create loaders
    train_idx, val_idx = train_test_split(range(len(info)), test_size = 0.2)
    loader_params = {'batch_size': 64, 'shuffle': True}

    info_train = info.iloc[train_idx, :]
    data_train = EmbeddingDataset(zip_path, info_train)
    loader_train = DataLoader(data_train, **loader_params)

    info_val = info.iloc[val_idx, :]
    data_val = EmbeddingDataset(zip_path, info_val)
    loader_val = DataLoader(data_val, **loader_params)

    # Gather weights
    # weights = weighted_weights(info_train['label'], balance_factor = config['loss_balance_factor'])
    weights = [1,1]

    # Initialise loss function and optimizer 
    loss_fn = nn.CrossEntropyLoss(weight = torch.FloatTensor(weights).to(device))
    optimizer = optim.Adam(model.parameters(), lr = config['lr'], amsgrad = True) 

    if tuning and checkpoint_dir:
        model_state, optimizer_state = torch.load(
            os.path.join(checkpoint_dir, "checkpoint"))
        model.load_state_dict(model_state)
        optimizer.load_state_dict(optimizer_state)

    # Start training loop
    if not tuning:
      print("Start training...\n")
      print(f"{'Epoch':^7} | {'Train Loss':^12} | {'Val Loss':^10} | {'Prec. @.5':^10} | {'Recall@.5':^10} | {'AUC':^10} | {'AP':^10}")
      print("-"*87)

    model = model.float()

    for epoch_i in range(epochs):
        # =======================================
        #               Training
        # =======================================

        # Track loss
        total_loss = 0

        # Put the model into the training mode
        model.train()

        for step, batch in enumerate(loader_train):
            # Load batch to GPU
            b_input_ids, b_labels = tuple(t.to(device) for t in batch)

            # Zero out any previously calculated gradients
            model.zero_grad()

            # Perform a forward pass. (output shape: (batch, n_classes, length))
            logits = model(b_input_ids.float())

            # Compute loss and accumulate the loss values
            loss = loss_fn(logits, b_labels)
            total_loss += loss.item()

            # Perform a backward pass to calculate gradients
            loss.backward()

            # Update parameters
            optimizer.step()

        # Calculate the average loss over the entire training data
        avg_train_loss = total_loss / len(loader_train)

        # =======================================
        #               Evaluation
        # =======================================
        # After the completion of each training epoch, measure the model's
        # performance on our validation set.
        val_loss, val_accuracy, val_precision, val_recall, val_auc, val_ap = evaluate(model, loader_val, loss_fn)

        # Print performance over the entire training data
        if tuning:
          with tune.checkpoint_dir(epoch_i) as checkpoint_dir:
            path = os.path.join(checkpoint_dir, "checkpoint")
            torch.save((model.state_dict(), optimizer.state_dict()), path)
          tune.report(loss = val_loss, precision = val_precision, recall = val_recall, auc = val_auc, ap = val_ap)
        else:
          print(f"{epoch_i + 1:^7} | {avg_train_loss:^12.6f} | {val_loss:^10.6f} | {val_precision:^10.2f} | {val_recall:^10.2f} | {val_auc:^10.4f} | {val_ap:^10.4f}")
            
        
        # # =======================================
        # #               Checkpoint
        # # =======================================

        # torch.save(model.state_dict(), f"model_{time}_{epoch_i}.pth")

        gc.collect()

    print(f"Training complete!")
    if not tuning:
      return model

In [29]:
def evaluate(model, val_dataloader, loss_fn = nn.CrossEntropyLoss):
    """Measure a model's performance on a validation set.
    Args:
      model: a model object to evaluate.
      val_dataloader: a dataloader with validation data.
      loss_fn: a loss function to calculate the validation loss with. 
        Usually passed on within train() to be the same loss function as used for training.
        Default = nn.CrossEntropyLoss, but should be overwritten to match training loss_fn.
    Returns:
      val_loss: the mean of the loss across batches.
      val_accuracy: the mean of the accuracy (correct predictions based on cutoff 0.5) across batches.
      val_precision: the fraction of correct positive predictions based on cutoff 0.5 (also known as positive predictive value)
      val_recall: the fraction of actual positives that were predicted by the model based on cutoff 0.5 (also known as sensitivity or true positive rate)
      val_auc: the area under the ROC curve, indicates our model's capability to distinguish between the two classes.
      val_ap: the average precision, aka the area under the precision-recall curve. indicates our model's capability to distinguish the positive values correctly.
    """
    # Put the model into the evaluation mode. The dropout layers are disabled
    # during the test time.
    model.eval()

    # Tracking variables
    val_loss = []
    true_labs_all = []
    probs_all = []
    preds_all = []

    # For each batch in our validation set...
    for batch in val_dataloader:
        # Load batch to GPU
        b_input, b_labels = tuple(t.to(device) for t in batch)

        # Compute scores (shape: (batch, n_classes, length))
        with torch.no_grad():
            scores = model(b_input.float())

        # Compute loss
        if loss_fn is not None:
          loss = loss_fn(scores, b_labels)
          val_loss.append(loss.item())

        # Get the probabilities and predictions
        true_labs = b_labels.cpu().numpy()
        probs = F.softmax(scores, dim=1).cpu().numpy()[:, 1] # keep only probabilities for label 1
        preds = torch.argmax(scores, dim=1).cpu().numpy()

        # Save to compute AUC and average precision (from precision-recall curve) later
        true_labs_all.append(true_labs.flatten())
        probs_all.append(probs.flatten())
        preds_all.append(preds.flatten())

    # Compute the performance statistics over the entire test set
    true_labs_all = np.hstack(true_labs_all)
    probs_all = np.hstack(probs_all)
    preds_all = np.hstack(preds_all)

    if loss_fn is not None:
      val_loss = np.mean(val_loss)
    val_accuracy = (preds_all == true_labs_all).mean() * 100
    val_precision = (preds_all[preds_all == 1] == true_labs_all[preds_all == 1]).mean()*100       # Also known as positive predictive value
    val_recall = (preds_all[true_labs_all == 1] == true_labs_all[true_labs_all == 1]).mean()*100  # Also known as sensitivity or true positive rate
    val_auc = roc_auc_score(y_true = true_labs_all, y_score = probs_all)                          # Area under ROC curve
    val_ap = average_precision_score(y_true = true_labs_all, y_score = probs_all)                 # Area under position-recall curve

    return val_loss, val_accuracy, val_precision, val_recall, val_auc, val_ap

In [14]:
def tune_model(config, device, num_samples):
  """A function to tune models to find the best hyperparameters using ray.tune.
  Args:
    config: a configuration dictionary with tune search space indicators.
    device: a pytorch device indicating whether the model should be loaded into cpu or gpu.
  Returns:
    A tune.ExperimentAnalysis object with information about the best trial. 
    Can be used in build_best_model() to reconstitute the model."""

  # Start training/tuning
  scheduler = ASHAScheduler(
      metric = "ap", # alternative: loss, min
      mode = "max",
      max_t = 20,
      grace_period = 3,
      reduction_factor = 2)
  reporter = JupyterNotebookReporter(
      overwrite = True,
      metric_columns = ["loss", "auc", "ap", "precision", "recall", "training_iteration"])
  search_alg = HyperOptSearch(
      metric = "ap",
      mode = "max")

  result = tune.run(
      partial(train, device = device, zip_path = zip_path, info = info_trainval, epochs = 20, tuning = True),
      resources_per_trial = {"gpu": 1},
      config = config,
      num_samples = num_samples,
      search_alg = search_alg,
      scheduler = scheduler,
      progress_reporter = reporter)

  return result


def build_best_model(best_trial, device):
  best_trained_model = Net()
  best_trained_model.to(device)

  best_checkpoint_dir = best_trial.checkpoint.value
  model_state, optimizer_state = torch.load(os.path.join(best_checkpoint_dir, "checkpoint"))
  best_trained_model.load_state_dict(model_state)
  return best_trained_model

<b>3 - option a. Tune the model</b>

This option will train the model `num_samples` times, trying different combinations of hyperparameters each time, and return the best one.  

`num_samples` has different interactions with random parameter selections (like `tune.choice()` or `tune.loguniform()`) and grid search. One unit of `num_samples` leads to only one sample from all random parameters, but one full grid search of all parameters (i.e. `num_workers` = 1 with `tune.grid_search(['A', 'B', 'C'])` makes three trials!). For an explanation on how exactly `num_samples` works, see [here](https://docs.ray.io/en/latest/tune/api_docs/search_space.html).  


For a non-tuning option, see 3b.

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Select hyperparameters
config = {
    "loss_balance_factor": tune.choice([0.75, 1, 1.25]),
    "lr": 0.001}

# Tune model
best_trial = tune_model(config, device, num_samples = 10) # Num_samples to # desired runs (like 10) if random choices, to 1 if grid search (will run entire grid once)

# Reconstitute best model
model = build_best_model(best_trial, device)

**3 - option b: Train without tuning**

Run a simple 20-epoch training sequence with set parameters. 

In [30]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
config = {
    "loss_balance_factor": None, 
    "lr": 0.001}
model = train(config, device, zip_path, info_trainval, epochs = 20, tuning = False)

Start training...

 Epoch  |  Train Loss  |  Val Loss  | Prec. @.5  | Recall@.5  |    AUC     |     AP    
---------------------------------------------------------------------------------------
   1    |   0.677379   |  0.662747  |   62.02    |   100.00   |   0.5950   |   0.7311  
   2    |   0.624037   |  0.580640  |   62.02    |   100.00   |   0.5948   |   0.7320  
   3    |   0.537027   |  0.595621  |   62.02    |   100.00   |   0.6139   |   0.7463  
   4    |   0.530222   |  0.866820  |   62.02    |   100.00   |   0.6338   |   0.7609  
   5    |   0.561994   |  0.842624  |   62.50    |   100.00   |   0.6557   |   0.7781  
   6    |   0.487188   |  0.761770  |   66.07    |   92.50    |   0.6733   |   0.7880  
   7    |   0.509214   |  0.539709  |   66.97    |   91.25    |   0.6792   |   0.7954  
   8    |   0.457603   |  0.452706  |   66.36    |   91.25    |   0.6864   |   0.8011  
   9    |   0.495160   |  0.434785  |   66.36    |   88.75    |   0.6874   |   0.8024  
  10    |   0

**4. Check performance on the test set**

In [None]:
# print(best_trial.)

# # Show best trial
# best_trial = result.get_best_trial("loss", "min", "last-5-avg")
print(f"Best trial config: {best_trial.config}")
print(f"Best trial final validation loss: {best_trial.last_result['loss']}")
print(f"Best trial final validation average precision: {best_trial.last_result['ap']}")
print(f"Best trial final validation auc: {best_trial.last_result['auc']}")

Best trial config: {'kernel_size': 13, 'beta': 0.9999538920913715, 'lr': 0.001}
Best trial final validation loss: 0.019430930105348427
Best trial final validation average precision: 0.19278001602192688
Best trial final validation auc: 0.9426880381350121


In [None]:
# Build the test set
data_test = EmbeddingDataset(zip_path, info_test)
loader_test = DataLoader(data_test, **{'batch_size': 16, 'shuffle': True})

# Run the best model on the test set
test_loss, test_accuracy, test_precision, test_recall, test_auc, test_ap = evaluate(model, loader_test, loss_fn=None)
print(f"Test AUC: {test_auc:.4f} | Test AP:  {test_ap:.4f} | Test precision (cutoff=0.5): {test_precision:.2f} | Test recall (cutoff=0.5): {test_recall:.2f}")

Test AUC: 0.9754 | Test AP:  0.2151 
 Test precision (cutoff=0.5): 17.53 | Test recall (cutoff=0.5): 57.56


<b>5. Save and export the model</b>

In [None]:
torch.save(model.state_dict(), "model_params.pth")
# torch.save(model, "model_full.pth")

In [None]:
!cp model_params.pth /content/drive/MyDrive/NetOGlyc/

In [None]:
drive.flush_and_unmount()