# Predicting Peptide Collisional Crossection from timsTOF data

In this vignette, we build Transformer model to predict the measured collisional cross section (CCS) of a peptide from its sequence and charge state, using the same training and test data as [Meier et al](https://www.nature.com/articles/s41467-021-21352-8).
To accomplish this task, we'll create a Transformer encoder for peptide sequences and charge states, and add a feed forward neural network to predict CCS from a single output token.

**Before proceeding with this notebook, make sure to switch a GPU runtime on Google Colab.** To do this, click `Runtime` -> `Change runtime type`, and select `GPU` under `Hardware accelerator`. A new `GPU type` box should appear below. While any will work, we used the `T4` GPU to run this notebook previously.

## Setup

The follow code installs the additional dependencies we'll need: Depthcharge, PyTorch Lightning, and Tensorboard. 
It also downloads the data that we'll be using, directly from the code repository from Meier et al, [here](https://github.com/theislab/DeepCollisionalCrossSection).
In the end, we are left with our data in the working directory, `combined_sm.csv`

In [None]:
%%capture
%%bash
pip install lightning tensorboard depthcharge-ms
wget -nc https://github.com/theislab/DeepCollisionalCrossSection/raw/master/data/combined_sm.csv.tar.gz
tar -xzvf combined_sm.csv.tar.gz

## Import the libraries we'll need
To work with our data, we'll need a handful of standard data science tools (NumPy, Pandas, etc.).
For model building, we'll use PyTorch Lightning to wrap our model from Depthcharge, making it easy to train.

From Depthcharge, we'll use the following classes:
- `PeptideDataset` - This is a PyTorch Dataset that is designed to hold peptide sequences, their charge states, and additional metadata.
- `FeedForward` - This is a utility PyTorch Module for quickly creating feed forward neural networks.
- `PeptideTokenizer` - This class defines the amino acid alphabet, including modifications, that are valid tokens for our model. 
  It also tells Depthcharge how to convert a peptide sequence into tokens and back. 
  First-class support for ProForma formatted peptide sequences is included out-of-the-box.
- `PeptideTransformerEncoder` - This is a PyTorch Module that embeds the peptide and its residues using a Transformer architecture.

After importing these libraries, the following code also sets a plotting theme and a random seed for reproducibility.

In [None]:
import os

import lightning.pytorch as pl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from sklearn.preprocessing import StandardScaler

from depthcharge.data import PeptideDataset
from depthcharge.feedforward import FeedForward
from depthcharge.tokenizers import PeptideTokenizer
from depthcharge.transformers import PeptideTransformerEncoder

# Set our plotting theme:
sns.set_style("ticks")

# Set random seeds
pl.seed_everything(42, workers=True)

## Parse the data
With our library loaded, we can now parse the CSV file from Meier et al. 
The peptide sequences are provided in a MaxQuant format, which we convert to be ProForma compliant.

We then try and split the data in to training, validation, and test splits, matching the test data to that described in the paper;
The paper states that the ProteomeTools subset was used as a test set, which are denoted by using the `PT` column. 

In [None]:
data = (
    pd.read_csv("combined_sm.csv", index_col=0)
    .sample(frac=1)
    .reset_index()
    .rename(columns={"Modified sequence": "Seq"})
)

# Convert sequences to ProForma compliant:
data["Seq"] = (
    data["Seq"]
    .str.replace("_(ac)", "[Acetyl]-", regex=False)
    .str.replace("M(ox)", "M[Oxidation]", regex=False)
    .str.replace("_", "", regex=False)
)

# Verify we've accounted for all modifica†ions:
assert not data["Seq"].str.contains("(", regex=False).sum()

# Split the data:
test_df = data.loc[data["PT"], :]
data_df = data.loc[~data["PT"], :]

n_train = int(0.9 * len(data_df))
train_df = data_df.iloc[:n_train, :].copy()
validation_df = data_df.iloc[n_train:, :].copy()

# Print the number in each set: 
print("Test peptides:                 ", len(test_df["Seq"].unique()))
print("Training + Validation peptides:", len(data_df["Seq"].unique()))

Awesome! The unique peptides precisely mach the numbers described by Meier et al.

## Create a tokenizer
Now that we know the peptides that we want to consider, we need to create a tokenizer that accounts for all of the amino acids and modifications that may be present. Fortunately, the `PeptideTokenizer` class has a `from_proforma()` method that allows us to extract the amino acids and modifications from a collection of peptides.

In [None]:
tokenizer = PeptideTokenizer.from_proforma(
    sequences=validation_df["Seq"], 
    replace_isoleucine_with_leucine=False, 
    reverse=False,
)

pd.DataFrame(tokenizer.residues.items(), columns=["Token", "Mass"])

It looks like our tokenizer has captured all of the residues we expect.

## Preparing Datasets
Now we need to prepare our PyTorch `Dataset`s and their corresponding PyTorch `DataLoader`s. 
Here, we use Depthcharge's `PeptideDataset` class, which handles transforming the peptide strings into PyTorch tensors for modeling. 
Because this dataset is fairly small from a memory perspective, we go ahead and load it onto the GPU as well, to increase our training speed.
Finally, the `loader()` method simplifies the creation of a PyTorch `DataLoader` for each dataset. 

We also transform the measured CCS using standard scaling, making it an easier value for the model to learn.

In [None]:
scaler = StandardScaler()
train_dataset = PeptideDataset(
    tokenizer,
    train_df["Seq"].to_numpy(), 
    torch.tensor(train_df["Charge"].to_numpy()),
    torch.tensor(scaler.fit_transform(train_df[["CCS"]]).flatten()),
)


validation_dataset = PeptideDataset(
    tokenizer,
    validation_df["Seq"].to_numpy(),
    torch.tensor(validation_df["Charge"].to_numpy()),
    torch.tensor(scaler.transform(validation_df[["CCS"]]).flatten()),
)

test_dataset = PeptideDataset(
    tokenizer,
    test_df["Seq"].to_numpy(),
    torch.tensor(test_df["Charge"].to_numpy()),
)

# This data is small so they can all live on the GPU:
for dset in (train_dataset, validation_dataset, test_dataset):
    tensors = []
    for data in dset.tensors:
        tensors.append(data.to("cuda"))

    dset.tensors = tuple(tensors)

# Data loaders:
train_loader = train_dataset.loader(batch_size=128, shuffle=True)
validation_loader = validation_dataset.loader(batch_size=1024)
test_loader = test_dataset.loader(batch_size=1024)

## Create a model

Time to create a deep learning model using PyTorch Lightning and Depthcharge! 
Our model consists of a `PeptideTransformerEncoder` module to embed peptides and a `FeedForward` module to predict CCS from the latent representation. 
With PyTorch Lightning, we need to specify the modules that comprise our model, define the optimizer(s) we will use to train it, and tell Lightning how to run the model when training, validating, and predicting.

For this task, we're trying to minimize the mean squared error (MSE) loss function:
$$ L = \frac{1}{n}\sum^{n}_{i=1}(Y_i - \hat{Y}_i)^2$$

Where $n$ is the number of peptides, $Y$ is the measured CCS, and $\hat{Y}_i$ is the predicted CCS.

In [None]:
class CCSPredictor(pl.LightningModule):
    """A Transformer model for CCS prediction"""
    def __init__(self, tokenizer, d_model, n_layers):
        """Initialize the CCSPredictor"""
        super().__init__()
        self.peptide_encoder = PeptideTransformerEncoder(
            n_tokens=tokenizer,
            d_model=d_model,
            n_layers=n_layers,
        )
        self.ccs_head = FeedForward(d_model, 1, 3)

    def step(self, batch, batch_idx):
        """A training/validation/inference step."""
        seqs, charges, ccs = batch
        try:
            embedded, _ = self.peptide_encoder(seqs, charges)
        except IndexError as err:
            print(batch)
            raise err

        pred = self.ccs_head(embedded[:, 0, :]).flatten()
        if ccs is not None:
            ccs = ccs.type_as(pred)
            loss = torch.nn.functional.mse_loss(pred, ccs)
        else:
            loss = None

        return pred, loss

    def training_step(self, batch, batch_idx):
        """A training step"""
        _, loss = self.step(batch, batch_idx)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        """A validation step"""
        _, loss = self.step(batch, batch_idx)
        self.log("validation_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def predict_step(self, batch, batch_idx):
        """An inference step"""
        pred, _ = self.step(batch, batch_idx)
        return pred

    def configure_optimizers(self):
        """Configure the optimizer for training."""
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        return optimizer

## Optional: Setup Tensorboard

Tensorboard is a tool used to track the training progress of deep learning models in real time. 
Running the cell below will start Tensorboard and tell it to listen to the logs that will be written by our model.

In [None]:
%reload_ext tensorboard
%tensorboard --logdir=lightning_logs/

## Train the model

With our model defined and our data loaders ready to go, its time to fit the model to the data.
The PyTorch Lightning `Trainer` will handle a lot of the training for us. 
We enable an early stopping criterium here, so that the trainer will stop once the MSE on our validation dataset stops improving. 
This model should take <2 hours to train.
If you've enabled Tensorboard in the previous cell, scroll back up while the model trains and you'll be able to watch its progress.

In [None]:
# Create a model:
model = CCSPredictor(tokenizer, d_model=64, n_layers=4)
comp_model = torch.compile(model)

early_stopping = EarlyStopping(monitor="validation_loss", patience=5)
trainer = pl.Trainer(callbacks=[early_stopping], max_epochs=50)
trainer.fit(
    model=model, 
    train_dataloaders=train_loader, 
    val_dataloaders=validation_loader,
)

## Predict on the Validation dataset

We now want to see how we've done, aside from just looking at the MSE. 
To get the predicted CCS for every peptide in our validation set, we use the `predict()` method for the trainer on our validation data loader.
We then create a minimal scatter plot of the observed value against the predicted value. 

In [None]:
pred = trainer.predict(model, validation_loader)
validation_df = validation_df.copy()
validation_df["pred"] = scaler.inverse_transform(
    torch.cat(pred).detach().cpu().numpy()[:, None]
).flatten()

plt.figure()
plt.scatter(validation_df["CCS"], validation_df["pred"], s=1)
plt.xlabel("Measured CCS")
plt.ylabel("Predicted CCS")
plt.show()

This looks pretty good to me. 
If we want to perform further tweaks and optimizations, we should turn back and do them now. 
If not, we're ready to get our predictions for the test set, after which we should cease trying to optimize our model.

## Predict on the Test dataset

Like with our validation data, we use the `predict()` method to get the predicted CCS for each of our test dataset peptides. 
We make a similar scatterplot and save the data a a Parquet file, which we used to make the visualizations on our poster. 

In [None]:
trainer = pl.Trainer()
pred = trainer.predict(model, test_loader)

test_df = test_df.copy()
test_df["pred"] = scaler.inverse_transform(
    torch.cat(pred).detach().cpu().numpy()[:, None]
).flatten()

test_df.to_parquet("predictions.parquet")

plt.figure()
plt.scatter(test_df["CCS"], test_df["pred"], s=1)
plt.xlabel("Measured CCS")
plt.ylabel("Predicted CCS")
plt.show()

Nice! This looks great.

If you want to fully reproduce our figures from the poster, you'll need to clone [our GitHub repo](https://github.com/wfondrie/2023_asms-depthcharge), follow the instructions in the README for setting up your environment, and execute [ccs-figures.ipynb](https://github.com/wfondrie/2023_asms-depthcharge/blob/main/notebooks/ccs-figures.ipynb) Jupyter notebook.