In [None]:
# install https://github.com/MarkusHaak/dlomix/ with pip
# OR uncomment to insert its path with sys:
#import os, sys
#sys.path.insert(0, os.path.abspath('../../dlomix/'))

In [None]:
# set global seeds for reproducibility
from dlomix.utils import set_global_seed
set_global_seed(42)

In [None]:
import os
import tensorflow as tf
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import dlomix
from dlomix import constants, data, eval, layers, models, pipelines, reports, utils
from dlomix.data import RetentionTimeDataset
from dlomix.models import PrositRetentionTimePredictor
from dlomix.eval import TimeDeltaMetric
from dlomix.reports import RetentionTimeReport

In [None]:
# alphabet using the same PTM identifiers as in the created datasets
ALPHABET_MOD = {
    "A": 1,
    "C": 2,
    "D": 3,
    "E": 4,
    "F": 5,
    "G": 6,
    "H": 7,
    "I": 8,
    "K": 9,
    "L": 10,
    "M": 11,
    "N": 12,
    "P": 13,
    "Q": 14,
    "R": 15,
    "S": 16,
    "T": 17,
    "V": 18,
    "W": 19,
    "Y": 20,
    "^": 21,
    "}": 22,
}

# Train Baseline models

Train individual PROSIT models for each cross-validation split, selecting the best model over a fixed number of epochs each. These models serve as baseline for performing uncertainty prediction.

In [None]:
BATCH_SIZE = 256
for ds,EPOCHS,tc in [('median',60,'median'), ('sel10',10,'indexed_retention_time')]:
    for cv in range(6):
        print("\n#####", ds, cv, '#####\n')
        if os.path.exists(f'output_{ds}/cv{cv}/history.pkl'):
            continue
        set_global_seed(42)
        
        TRAIN_DATAPATH = f'../data/PROSPECT_{ds}_training{cv}.csv'
        rtdata = RetentionTimeDataset(data_source=TRAIN_DATAPATH,
                                      seq_length=30, batch_size=BATCH_SIZE, val_ratio=0., test=False,
                                      sequence_col='modified_sequence_single_letter',
                                      target_col=tc)
        VALIDATION_DATAPATH = f'../data/PROSPECT_median_validation{cv}.csv'
        validation_rtdata = RetentionTimeDataset(data_source=VALIDATION_DATAPATH,
                                      seq_length=30, batch_size=BATCH_SIZE, val_ratio=1., test=False,
                                      sequence_col='modified_sequence_single_letter',
                                      target_col='median')
        TEST_DATAPATH = f'../data/PROSPECT_median_holdout_cv.csv'
        test_rtdata = RetentionTimeDataset(data_source=TEST_DATAPATH,
                                      seq_length=30, batch_size=BATCH_SIZE, test=True,
                                      sequence_col='modified_sequence_single_letter',
                                      target_col='median')
        test_targets = test_rtdata.get_split_targets(split="test")

        model = PrositRetentionTimePredictor(seq_length=30, vocab_dict=ALPHABET_MOD)

        # setup a learning rate schedule
        train_steps = BATCH_SIZE * len(rtdata.train_data) * EPOCHS
        lr_fn = tf.optimizers.schedules.PolynomialDecay(1e-3, train_steps, 1e-6, 2)
        opt = tf.optimizers.Adam(lr_fn)

            model.compile(optimizer=opt, 
                          loss='mse',
                          metrics=['mean_absolute_error', TimeDeltaMetric()])

        model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
            filepath=os.path.join(f"./output_{ds}/cv{cv}/", f'best'),
            save_weights_only=True,
            monitor='val_timedelta',
            mode='min',
            save_best_only=True)
        history = model.fit(rtdata.train_data,
                        validation_data=validation_rtdata.val_data,
                        callbacks=[model_checkpoint_callback],
                        epochs=EPOCHS,
                        verbose=2)
        report = RetentionTimeReport(output_path=f"./output_{ds}/cv{cv}", history=history)
        report.plot_keras_metric("loss")
        report.plot_keras_metric("mean_absolute_error")
        report.plot_keras_metric("timedelta")
        best_model = PrositRetentionTimePredictor(seq_length=30, vocab_dict=ALPHABET_MOD)
        best_model.load_weights(f"./output_{ds}/cv{cv}/best")
        predictions = best_model.predict(test_rtdata.test_data)
        predictions = predictions.ravel()
        with open(f'output_{ds}/cv{cv}/r2.txt', 'w') as f:
            print(report.calculate_r2(test_targets, predictions), file=f)
        report.plot_density(test_targets, predictions)
        report.plot_residuals(test_targets, predictions, xrange=(-30, 30))
        import pickle
        with open(f'output_{ds}/cv{cv}/history.pkl', 'wb') as f:
            pickle.dump(history, f)