In [None]:
%%capture
%%bash
pip install lightning vizta tensorboard git+https://github.com/wfondrie/depthcharge.git@asms 
for FILE in "train.hdf5 24635459" "valid.hdf5 24635442" "test.hdf5 24635438"
do
    set -- $FILE
    if [ ! -f $1 ]; then
        wget -nc https://figshare.com/ndownloader/files/$2
        mv $2 $1
    fi
done

In [None]:
import os

import einops
import h5py
import lightning.pytorch as pl
import matplotlib.pyplot as plt
import numba as nb
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import vizta
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from sklearn.preprocessing import StandardScaler

from depthcharge.data import PeptideDataset
from depthcharge.encoders import FloatEncoder
from depthcharge.feedforward import FeedForward
from depthcharge.tokenizers import PeptideTokenizer
from depthcharge.transformers import PeptideTransformerEncoder

# Set our plotting theme:
pal = vizta.mpl.set_theme(context="notebook", style="wfondrie")

# Set random seeds
pl.seed_everything(42, workers=True)

In [None]:
tokenizer = PeptideTokenizer.from_proforma(
    sequences="ACDEFGHIKLMNPQRSTVWYM[Oxidation]", 
    replace_isoleucine_with_leucine=False, 
    reverse=False,
)

pd.DataFrame(tokenizer.residues.items(), columns=["Token", "Mass"])

In [None]:
@nb.njit
def vecs2seqs(vecs, alphabet):
    """Convert Prosit vectors to peptide sequenes"""
    for idx, seq_idx in enumerate(vecs):
        yield "".join([alphabet[i - 1] for i in seq_idx if i])

@nb.njit
def flip_and_shift_b_ions(ions, n_ions):
    """Flip the b_ions and shift them one down.

    This let's them match the order of our Transformer model.
    """
    for idx, n_ion in enumerate(n_ions): 
        ions[idx, 1:n_ion+1, 1, :] = ions[idx, n_ion-1::-1, 1, :]

    return ions


class PrositDataset(PeptideDataset):
    """A class for the Prosit HDF5 files."""
    def __init__(self, tokenizer, hdf5_file, max_examples=1_000_000):
        """Initialize the Prosit Dataset"""
        alphabet = list(tokenizer.residues.keys())
        with h5py.File(hdf5_file) as data:
            n_rows = data["scan_number"].shape[0]

        if n_rows > max_examples:
            # The peptides are lexigraphically sorted, so we'll take a 
            # diverse subset with creative indexing.
            print(f"  -> Found {n_rows} peptides. Subsetting to ~{max_examples}...")
            step = int(np.floor(n_rows / max_examples))
        else:
            step = 1

        print("  -> Reading from HDF5 file....")
        with h5py.File(hdf5_file) as data:
            charge = np.argmax(data["precursor_charge_onehot"][::step], axis=1) + 1
            charge = torch.tensor(charge).to("cuda")
            nce = torch.tensor(data["collision_energy_aligned_normed"][::step]).to("cuda")
            seq = vecs2seqs(data["sequence_integer"][::step], np.array(alphabet))
            intensities = data["intensities_raw"][::step]
            n_rows = len(charge)

        print("  -> Preprocessing intensities...")
        # Transform the intensities for our Transformer.
        # Intensities are shape (L, I, Z) where:
        # L = The peptide length - 1, ordered from lowest mass to highest.
        # I = The ion series, (y, b)
        # Z = The charge state (increasing)
        intensities = intensities.reshape([n_rows, 29, 2, 3])
        n_ions = (intensities[:, :, 0, 0] >= 0).sum(axis=1)
        
        # Need an extra space because we want to shift b ions.
        intensities = np.pad(
            intensities,
            ((0, 0), (0, 1), (0, 0), (0, 0)), 
            "constant", 
            constant_values=-1,
        )

        intensities = flip_and_shift_b_ions(intensities, n_ions)
        intensities[:, 0, 1, :] = -1
        intensities = torch.tensor(intensities).to("cuda")

        print("  -> Tokenizing peptides...")
        super().__init__(tokenizer, seq, charge, nce, intensities)


print("Loading the training dataset...")
train_dataset = PrositDataset(tokenizer, "train.hdf5", 200_000)
print("Loading the validation dataset...")
validation_dataset = PrositDataset(tokenizer, "valid.hdf5", 100_000)
print("Loading the test dataset...")
test_dataset = PrositDataset(tokenizer, "test.hdf5", 100_000)

# The GPU memory on this instance is larger than the host, so
# we put data on the gpu to run fast.
for dset in (train_dataset, validation_dataset, test_dataset):
    tensors = []
    for data in dset.tensors:
        tensors.append(data.to("cuda"))

    dset.tensors = tuple(tensors)

train_loader = train_dataset.loader(
    batch_size=128, shuffle=True,
)
validation_loader = validation_dataset.loader(
    batch_size=1024, shuffle=False,
)

test_loader = test_dataset.loader(
    batch_size=1024, shuffle=False,
)

In [None]:
def masked_spectral_angle(true, pred):
    """This is an PyTorch adaptation of the Prosit implementation here:
    https://github.com/kusterlab/prosit/blob/dd16c47f8c3f4cfbd7ae84a1ca92a4d117e5e9ec/prosit/losses.py#L4-L16
    """
    true = true.flatten(start_dim=1)
    pred = pred.flatten(start_dim=1)
    epsilon = torch.finfo(torch.float32).eps
    pred_masked = ((true + 1) * pred) / (true + 1 + epsilon)
    true_masked = ((true + 1) * true) / (true + 1 + epsilon)
    pred_norm = F.normalize(true_masked, p=2, dim=-1)
    true_norm = F.normalize(pred_masked, p=2, dim=-1)
    product = torch.sum(pred_norm * true_norm, dim=1)
    arccos = torch.acos(product)
    return 2 * arccos / np.pi


class FragmentPredictor(pl.LightningModule):
    """A Transformer model for CCS prediction"""
    def __init__(self, tokenizer, d_model, n_layers):
        """Initialize the CCSPredictor"""
        super().__init__()
        self.peptide_encoder = PeptideTransformerEncoder(
            n_tokens=tokenizer,
            d_model=d_model,
            n_layers=n_layers,
            max_charge=10,
        )
        self.nce_encoder = FloatEncoder(d_model, max_wavelength=1)
        self.fragment_head = FeedForward(d_model, 6, 3)

    def step(self, batch, batch_idx):
        """A training/validation/inference step."""
        seqs, charges, nce, intensities = batch
        embedded, mask = self.peptide_encoder(seqs, charges)
        emb_nce = self.nce_encoder(nce[:, None])
        pred = self.fragment_head(embedded + emb_nce) 

        # Reshape for the Prosit data:
        pred = einops.rearrange(pred, "B L (I Z) -> B I Z L", I=2)
        pred = F.pad(pred, (0, 30 - pred.shape[-1]), "constant", 0)
        pred = einops.rearrange(pred, "B I Z L -> B L I Z")

        # Calculate the loss
        if intensities is not None:
            intensities = intensities.type_as(pred)
            loss = masked_spectral_angle(intensities, pred).mean()
        else:
            loss = None

        return pred, loss

    def training_step(self, batch, batch_idx):
        """A training step"""
        _, loss = self.step(batch, batch_idx)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        """A validation step"""
        _, loss = self.step(batch, batch_idx)
        self.log("validation_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def predict_step(self, batch, batch_idx):
        """An inference step"""
        pred, _ = self.step(batch, batch_idx)
        return pred

    def configure_optimizers(self):
        """Configure the optimizer for training."""
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        return optimizer

In [None]:
%reload_ext tensorboard
%tensorboard --logdir=lightning_logs/

In [None]:
# Create a model:
model = FragmentPredictor(tokenizer, d_model=64, n_layers=6)

early_stopping = EarlyStopping(monitor="validation_loss", patience=3)
trainer = pl.Trainer(
    #accelerator="cpu", 
    callbacks=[early_stopping],
    max_epochs=10, 
)


trainer.fit(
    model=model, 
    train_dataloaders=train_loader, 
    val_dataloaders=validation_loader,
)