<a href="https://colab.research.google.com/github/CasBlaauw/BertOGlyc/blob/main/ProtBert_NetOGlyc_classification_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Initial model architecture based on [Elnaggar et al. (2020)](https://www.biorxiv.org/content/10.1101/2020.07.12.199554v3.full) and [Heinzinger et al. (2019)](https://bmcbioinformatics.biomedcentral.com/articles/10.1186/s12859-019-3220-8).  
Data loader structure inspired by [this pytorch documentation](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html), 
model architecture inspired by [this pytorch documentation](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html), training loop inspired by [this CNN tutorial](https://chriskhanhtran.github.io/posts/cnn-sentence-classification/), weights for loss function inspired by [this tutorial](https://towardsdatascience.com/handling-class-imbalanced-data-using-a-loss-specifically-made-for-it-6e58fd65ffab), based on [this paper](https://arxiv.org/abs/1901.05555).

<b>0. Import functions</b>

In [7]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from datetime import datetime
import os
from google.colab import files, drive

In [26]:
drive.mount('/content/drive')

Mounted at /content/drive


<b>1. Read in and split data </b>

In [9]:
# Define the Dataset class for use with DataLoader, reading in files as needed

class EmbeddingDataset(Dataset):
    """Dataset of embeddings from ProtBert.
    Path is expected to be a path to an zip/npz file containing the .npy arrays for each gene.
    Then indexes into those with f"embed/embeddings_{gene_id}.txt"
    Info is expected to be a pandas dataframe with gene names as keys, 'sequence' and 'label' keys as lists/iterables."""
    def __init__(self, path, info):
        self.path = path
        self.info = info.reset_index() 

    def __len__(self):
        return len(self.info)
    
    def __getitem__(self, idx):
        with np.load(self.path) as zip:
          embed = zip[f"embed/embeddings_{self.info['gene'][idx]}.txt"] # Colab has a weird zip structure + I missed the .txt
          embed = embed.T # Need to return transposed because conv1d expects channels, then length
        label = torch.tensor(self.info['label'][idx])
        return embed, label

In [10]:
# Define the paths to the info and zip file
zip_path = '/content/drive/MyDrive/NetOGlyc/embeddings_npy.zip'
info_path = '/content/drive/MyDrive/NetOGlyc/embeddings_info.txt'

In [11]:
# Read in the info file
info = pd.read_csv(info_path, sep = '\t')
info['sequence'] = info['sequence'].apply(list)
info['label'] = info['label'].apply(lambda x: list(map(int, list(x))))

# Split data into test and training files
info_train_idx, info_test_idx = train_test_split(range(len(info)), test_size = 0.2)
info_train = info.iloc[info_train_idx, :]
info_test = info.iloc[info_test_idx, :]
print(info_test)

       gene  ...                                              label
102  P08572  ...  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
186  P55058  ...  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
338  Q8TBP5  ...  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
201  Q02818  ...  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
227  Q14696  ...  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
..      ...  ...                                                ...
244  Q495W5  ...  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
310  Q86VZ4  ...  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
267  Q6L9W6  ...  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
449  Q9UQ53  ...  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
226  Q14667  ...  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...

[94 rows x 4 columns]


In [12]:
# Load info and paths into dataset objects and create loades
params = {'batch_size': 64, 'shuffle': True}
data_train = EmbeddingDataset(zip_path, info_train)
loader_train = DataLoader(data_train, **params)
data_test = EmbeddingDataset(zip_path, info_test)
loader_test = DataLoader(data_test, **params)

<b>2. Construct the model</b>

In [13]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv1d(in_channels = 1024, out_channels = 32, kernel_size = 7, padding = 3) 
        self.dropout = nn.Dropout(p=0.25)
        self.conv2 = nn.Conv1d(in_channels = 32, out_channels = 2, kernel_size = 7, padding = 3)

    def forward(self, x):
        # ---- Layer 1
        # conv1 needs (batch_size, in_channels/features, length/seq_len), so (64, 1024, 4000) 
        # and outputs (64, 32, 4000)
        x = self.conv1(x)

        # ---- Process first layer's output
        x = self.dropout(x)
        x = F.relu(x)

        # ---- Layer 2
        # conv2 takes (64, 32, 4000) and outputs (64, 2, 4000)
        x = self.conv2(x)
        
        return x

model = Net()

In [14]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
model.to(device)

cuda:0


Net(
  (conv1): Conv1d(1024, 32, kernel_size=(7,), stride=(1,), padding=(3,))
  (dropout): Dropout(p=0.25, inplace=False)
  (conv2): Conv1d(32, 2, kernel_size=(7,), stride=(1,), padding=(3,))
)

In [None]:
# Loss weights are one of the most important params due to our unbalanced data
# Priority here is predicting sites, not predicting non-sites, so we want a very high beta to distinguish between them still.

def cb_weights(samples_per_cls, no_of_classes, beta):
    """Compute the weights for Class Balanced Loss.
    Class Balanced Loss: ((1-beta)/(1-beta^n))*Loss(labels, logits)
    where Loss is a loss function, here cross entropy loss.
    Args:
      samples_per_cls: A python list of size [no_of_classes].
      no_of_classes: int. total number of classes.
      beta: float. Hyperparameter for Class balanced loss.
    Returns:
      cb_weights: A float tensor of weights to be supplied to the loss function
    """
    effective_num = 1.0 - np.power(beta, samples_per_cls)
    weights = (1.0 - beta) / np.array(effective_num)
    weights = weights / np.sum(weights) * no_of_classes

    return weights

In [16]:
total_samples = info['label'].apply(len).sum()
total_sites = info['label'].apply(sum).sum() # Assumes 0-1 labeling
samples_dist = [total_samples - total_sites, total_sites]
weights = cb_weights(samples_dist, len(samples_dist), beta = 0.99999)
loss_fn = nn.CrossEntropyLoss(weight = torch.FloatTensor(weights).to(device))
optimizer = optim.Adam(model.parameters(), lr = 0.001, amsgrad = True)

<b>3. Train the model</b>

In [17]:
def train(model, optimizer, train_dataloader, val_dataloader=None, epochs=10, checkpoint_path = './'):
    """Train the CNN model.
    Args: 
      model: the pytorch model object.
      optimiser: the loss function, like nn.CrossEntropyLoss(). 
      train_dataloader: a pytorch DataLoader object that returns the training data, with shape (batch_size, features, length).
      val_dataloader: optional. a pytorch DataLoader object like train_dataloader, used to print performance on the test set during training.
      epochs: optional. an integer value, indicating the number of epochs (training loops) the training should last.
      checkpoint_path: optional. The location to save the model parameter files at every epoch, which are useful to get model values before overfitting sets in.
          Note that they currently don't include the optimizer state, so they cannot be used as full checkpoints to continue training later.
    Returns:
      Doesn't return anything, but has modified the weights of the supplied model object. 
      Also writes weights per epoch to {checkpoint_path}/model_{date}_{time}_{epoch_i}.pth.
      """
    
    # Tracking best validation accuracy
    best_accuracy = 0
    best_site_accuracy = 0

    # Start training loop
    print("Start training...\n")
    print(f"{'Epoch':^7} | {'Train Loss':^12} | {'Val Loss':^10} | {'Val Acc':^9} | {'Val Site Acc':^14} | {'True Pos':^10} | {'True Neg':^10}")
    print("-"*90)

    # Save current time 

    time = datetime.now().strftime("%m-%d_%H-%M")

    model = model.float()

    for epoch_i in range(epochs):
        # =======================================
        #               Training
        # =======================================

        # Tracking time and loss
        total_loss = 0

        # Put the model into the training mode
        model.train()

        for step, batch in enumerate(train_dataloader):
            # Load batch to GPU
            b_input_ids, b_labels = tuple(t.to(device) for t in batch)

            # Zero out any previously calculated gradients
            model.zero_grad()

            # Perform a forward pass. (output shape: (batch, n_classes, length))
            logits = model(b_input_ids.float())

            # Compute loss and accumulate the loss values
            loss = loss_fn(logits, b_labels)
            total_loss += loss.item()

            # Perform a backward pass to calculate gradients
            loss.backward()

            # Update parameters
            optimizer.step()

        # Calculate the average loss over the entire training data
        avg_train_loss = total_loss / len(train_dataloader)

        # =======================================
        #               Evaluation
        # =======================================
        if val_dataloader is not None:
            # After the completion of each training epoch, measure the model's
            # performance on our validation set.
            val_loss, val_accuracy, val_site_accuracy, val_tp, val_tn = evaluate(model, val_dataloader)

            # Print performance over the entire training data
            print(f"{epoch_i + 1:^7} | {avg_train_loss:^12.6f} | {val_loss:^10.6f} | {val_accuracy:^9.2f} | {val_site_accuracy:^14.2f} | {val_tp:^10.2f} | {val_tn:^10.2f}")
        
        # =======================================
        #               Checkpoint
        # =======================================

        torch.save(model.state_dict(), os.path.join(checkpoint_path, f"model_{time}_{epoch_i}.pth"))

    print(f"Training complete!")

def evaluate(model, val_dataloader):
    """After the completion of each training epoch, measure the model's
    performance on our validation set.
    Returns:
      val_loss: the mean of the loss across batches.
      val_accuracy: the mean of the accuracy (correct predictions) across batches.
      val_site_accuracy: the mean of the site accuracy ('1' labels correctly detected) across batches.
      val_tp: the mean of the true positive percentage ('1' preditions actually '1' labels) across batches.
      val_tn: the mean of the true negative percentage ('0' preditions actually '0' labels) across batches.
    """
    # Put the model into the evaluation mode. The dropout layers are disabled
    # during the test time.
    model.eval()

    # Tracking variables
    val_accuracy = []
    val_site_accuracy = []
    val_loss = []
    val_tp = []
    val_tn = []

    # For each batch in our validation set...
    for batch in val_dataloader:
        # Load batch to GPU
        b_input, b_labels = tuple(t.to(device) for t in batch)

        # Compute scores (shape: (batch, n_classes, length))
        with torch.no_grad():
            scores = model(b_input.float())

        # Compute loss
        loss = loss_fn(scores, b_labels)
        val_loss.append(loss.item())

        # Get the predictions
        preds = torch.argmax(scores, dim=1)

        # Calculate the accuracy rate
        accuracy = (preds == b_labels).cpu().numpy().mean() * 100
        site_accuracy = (preds[b_labels == 1] == b_labels[b_labels == 1]).cpu().numpy().mean()*100
        true_pos = (preds[preds == 1] == b_labels[preds == 1]).cpu().numpy().mean()*100
        true_neg = (preds[preds == 0] == b_labels[preds == 0]).cpu().numpy().mean()*100
        val_accuracy.append(accuracy)
        val_site_accuracy.append(site_accuracy)
        val_tp.append(true_pos)
        val_tn.append(true_neg)

    # Compute the average accuracy and loss over the validation set.
    val_loss = np.mean(val_loss)
    val_accuracy = np.mean(val_accuracy)
    val_site_accuracy = np.mean(val_site_accuracy)
    val_tp = np.mean(val_tp)
    val_tn = np.mean(val_tn)

    return val_loss, val_accuracy, val_site_accuracy, val_tp, val_tn

In [18]:
!mkdir checkpoint_dir

In [19]:
train(model, optimizer, loader_train, loader_test, epochs = 20, checkpoint_path = 'checkpoint_dir')

Start training...

 Epoch  |  Train Loss  |  Val Loss  |  Val Acc  |  Val Site Acc  |  True Pos  |  True Neg 
------------------------------------------------------------------------------------------
   1    |   0.631109   |  0.590762  |   99.82   |     10.15      |    9.63    |   99.91   
   2    |   0.574027   |  0.562861  |   98.58   |     75.95      |    5.60    |   99.98   
   3    |   0.543790   |  0.541366  |   99.30   |     68.79      |    9.67    |   99.97   
   4    |   0.513176   |  0.497464  |   99.11   |     74.56      |    7.66    |   99.98   
   5    |   0.481214   |  0.464758  |   99.11   |     76.13      |    8.14    |   99.98   
   6    |   0.446107   |  0.422800  |   99.14   |     82.48      |    9.23    |   99.98   
   7    |   0.406868   |  0.382631  |   99.19   |     81.82      |    9.55    |   99.98   
   8    |   0.361281   |  0.337564  |   99.18   |     82.93      |    8.99    |   99.98   
   9    |   0.312608   |  0.289396  |   99.17   |     84.12      |    9

<b>4. Save and export the model</b>

In [20]:
torch.save(model.state_dict(), "model_params.pth")
torch.save(model, "model_full.pth")

In [25]:
!zip  model_checkpoints.zip checkpoint_dir/model_06-14*.*

  adding: checkpoint_dir/model_06-14_12-05_0.pth (deflated 7%)
  adding: checkpoint_dir/model_06-14_12-05_10.pth (deflated 8%)
  adding: checkpoint_dir/model_06-14_12-05_11.pth (deflated 7%)
  adding: checkpoint_dir/model_06-14_12-05_12.pth (deflated 7%)
  adding: checkpoint_dir/model_06-14_12-05_13.pth (deflated 7%)
  adding: checkpoint_dir/model_06-14_12-05_14.pth (deflated 7%)
  adding: checkpoint_dir/model_06-14_12-05_15.pth (deflated 7%)
  adding: checkpoint_dir/model_06-14_12-05_16.pth (deflated 7%)
  adding: checkpoint_dir/model_06-14_12-05_17.pth (deflated 7%)
  adding: checkpoint_dir/model_06-14_12-05_18.pth (deflated 7%)
  adding: checkpoint_dir/model_06-14_12-05_19.pth (deflated 7%)
  adding: checkpoint_dir/model_06-14_12-05_1.pth (deflated 7%)
  adding: checkpoint_dir/model_06-14_12-05_2.pth (deflated 8%)
  adding: checkpoint_dir/model_06-14_12-05_3.pth (deflated 8%)
  adding: checkpoint_dir/model_06-14_12-05_4.pth (deflated 8%)
  adding: checkpoint_dir/model_06-14_12-05_5.

In [27]:
!cp model_params.pth /content/drive/MyDrive/NetOGlyc/
!cp model_full.pth /content/drive/MyDrive/NetOGlyc/
!cp model_checkpoints.zip /content/drive/MyDrive/NetOGlyc/

In [28]:
drive.flush_and_unmount()