In [None]:
%%capture
%%bash
pip install lightning vizta tensorboard git+https://github.com/wfondrie/depthcharge.git@asms 
wget -nc https://github.com/theislab/DeepCollisionalCrossSection/raw/master/data/combined_sm.csv.tar.gz
tar -xzvf combined_sm.csv.tar.gz

In [None]:
import os

import lightning.pytorch as pl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from sklearn.preprocessing import StandardScaler

from depthcharge.data import PeptideDataset
from depthcharge.feedforward import FeedForward
from depthcharge.tokenizers import PeptideTokenizer
from depthcharge.transformers import PeptideTransformerEncoder

# Set our plotting theme:
sns.set_style("ticks")

# Set random seeds
pl.seed_everything(42, workers=True)

In [None]:
data = (
    pd.read_csv("combined_sm.csv", index_col=0)
    .sample(frac=1)
    .reset_index()
    .rename(columns={"Modified sequence": "Seq"})
)

# Convert sequences to ProForma compliant:
data["Seq"] = (
    data["Seq"]
    .str.replace("_(ac)", "[Acetyl]-", regex=False)
    .str.replace("M(ox)", "M[Oxidation]", regex=False)
    .str.replace("_", "", regex=False)
)

# Verify we've accounted for all modifica†ions:
assert not data["Seq"].str.contains("(", regex=False).sum()

# Split the data:
test_df = data.loc[data["PT"], :]
data_df = data.loc[~data["PT"], :]

n_train = int(0.9 * len(data_df))
train_df = data_df.iloc[:n_train, :].copy()
validation_df = data_df.iloc[n_train:, :].copy()

# Print the number in each set: 
print("Test peptides:                 ", len(test_df["Seq"].unique()))
print("Training + Validation peptides:", len(data_df["Seq"].unique()))

In [None]:
tokenizer = PeptideTokenizer.from_proforma(
    sequences=validation_df["Seq"], 
    replace_isoleucine_with_leucine=False, 
    reverse=False,
)

pd.DataFrame(tokenizer.residues.items(), columns=["Token", "Mass"])

In [None]:
scaler = StandardScaler()
train_dataset = PeptideDataset(
    tokenizer,
    train_df["Seq"].to_numpy(), 
    torch.tensor(train_df["Charge"].to_numpy()),
    torch.tensor(scaler.fit_transform(train_df[["CCS"]]).flatten()),
)


validation_dataset = PeptideDataset(
    tokenizer,
    validation_df["Seq"].to_numpy(),
    torch.tensor(validation_df["Charge"].to_numpy()),
    torch.tensor(scaler.transform(validation_df[["CCS"]]).flatten()),
)

test_dataset = PeptideDataset(
    tokenizer,
    test_df["Seq"].to_numpy(),
    torch.tensor(test_df["Charge"].to_numpy()),
)

# This data is small so they can all live on the GPU:
for dset in (train_dataset, validation_dataset, test_dataset):
    tensors = []
    for data in dset.tensors:
        tensors.append(data.to("cuda"))

    dset.tensors = tuple(tensors)

# Data loaders:
train_loader = train_dataset.loader(batch_size=128, shuffle=True)
validation_loader = validation_dataset.loader(batch_size=1024)
test_loader = test_dataset.loader(batch_size=1024)

In [None]:
class CCSPredictor(pl.LightningModule):
    """A Transformer model for CCS prediction"""
    def __init__(self, tokenizer, d_model, n_layers):
        """Initialize the CCSPredictor"""
        super().__init__()
        self.peptide_encoder = PeptideTransformerEncoder(
            n_tokens=tokenizer,
            d_model=d_model,
            n_layers=n_layers,
        )
        self.ccs_head = FeedForward(d_model, 1, 3)

    def step(self, batch, batch_idx):
        """A training/validation/inference step."""
        seqs, charges, ccs = batch
        try:
            embedded, _ = self.peptide_encoder(seqs, charges)
        except IndexError as err:
            print(batch)
            raise err

        pred = self.ccs_head(embedded[:, 0, :]).flatten()
        if ccs is not None:
            ccs = ccs.type_as(pred)
            loss = torch.nn.functional.mse_loss(pred, ccs)
        else:
            loss = None

        return pred, loss

    def training_step(self, batch, batch_idx):
        """A training step"""
        _, loss = self.step(batch, batch_idx)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        """A validation step"""
        _, loss = self.step(batch, batch_idx)
        self.log("validation_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def predict_step(self, batch, batch_idx):
        """An inference step"""
        pred, _ = self.step(batch, batch_idx)
        return pred

    def configure_optimizers(self):
        """Configure the optimizer for training."""
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        return optimizer

In [None]:
%reload_ext tensorboard
%tensorboard --logdir=lightning_logs/

In [None]:
# Create a model:
model = CCSPredictor(tokenizer, d_model=64, n_layers=4)
comp_model = torch.compile(model)

early_stopping = EarlyStopping(monitor="validation_loss", patience=5)
trainer = pl.Trainer(callbacks=[early_stopping], max_epochs=50)
trainer.fit(
    model=model, 
    train_dataloaders=train_loader, 
    val_dataloaders=validation_loader,
)

In [None]:
pred = trainer.predict(model, validation_loader)
validation_df = validation_df.copy()
validation_df["pred"] = scaler.inverse_transform(
    torch.cat(pred).detach().cpu().numpy()[:, None]
).flatten()

plt.figure()
plt.scatter(validation_df["CCS"], validation_df["pred"], s=1)
plt.xlabel("Measured CCS")
plt.ylabel("Predicted CCS")
plt.show()

In [None]:
err = (validation_df["CCS"] - validation_df["pred"])
rel_err = err / validation_df["CCS"]

fig, axs = plt.subplots(1, 2, figsize=(9, 4))

ax = axs[0]
sns.histplot(rel_err, ax=ax)
ax.set_xlabel("Relative Error")
ax.set_ylabel("Number of Peptides")

ax = axs[1]
sns.histplot(err, ax=ax)
ax.set_xlabel("Absolute Error")
ax.set_ylabel("Number of Peptides")

plt.tight_layout()
plt.show()

In [None]:
trainer = pl.Trainer()
pred = trainer.predict(model, test_loader)

test_df = test_df.copy()
test_df["pred"] = scaler.inverse_transform(
    torch.cat(pred).detach().cpu().numpy()[:, None]
).flatten()

test_df.to_parquet("predictions.parquet")

plt.figure()
plt.scatter(test_df["CCS"], test_df["pred"], s=1)
plt.xlabel("Measured CCS")
plt.ylabel("Predicted CCS")
plt.show()

In [None]:
err = (test_df["CCS"] - test_df["pred"])
rel_err = err / test_df["CCS"]

fig, axs = plt.subplots(1, 2, figsize=(9, 4))

ax = axs[0]
sns.histplot(rel_err, ax=ax)
ax.set_xlabel("Relative Error")
ax.set_ylabel("Number of Peptides")

ax = axs[1]
sns.histplot(err, ax=ax)
ax.set_xlabel("Absolute Error")
ax.set_ylabel("Number of Peptides")

plt.tight_layout()
plt.show()