In [1]:
%reload_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from alphabase.spectral_library.base import SpecLibBase
from alphadia.transferlearning.train import *


In [2]:
settings

{'batch_size': 1000,
 'max_lr': 0.0005,
 'train_ratio': 0.8,
 'test_interval': 1,
 'lr_patience': 3,
 'minimum_psms': 1200,
 'epochs': 51,
 'warmup_epochs': 5,
 'nce': 25,
 'instrument': 'Lumos'}

In [2]:
import torch
torch.set_num_threads(10)

In [3]:
transfer_lib = SpecLibBase()
transfer_lib.load_hdf('/Users/georgwallmann/Documents/data/alphadia_manuscript/2024_04_25_Dimethyl_GPF/transfer_learning_asms/speclib.transfer.hdf', load_mod_seq=True)

In [4]:
transfer_lib.precursor_df = transfer_lib.precursor_df[~transfer_lib.precursor_df['mods'].str.contains('Dimethyl@C')]

# Util function to plot the metrics

In [5]:
def plot_stats (stats_df: pd.DataFrame, loss_name: str, property: str, pre_train_dataset:str = 'all'):
    """
    Plot the the metrics of the fine-tuning process.

    Parameters
    ----------
    stats_df : pd.DataFrame
        The dataframe containing the metrics of the fine-tuning process.
    loss_name : str
        The name of the loss function used eg. 'l1_loss'.
    property : str
        The property being predicted eg. 'rt'.
    pre_train_dataset : str
        The name of the dataset used to test the model before fine-tuning. for example for rt, charge we use 'all', for ms2 we use 'validation'.
    """
    # converts stats to a dataframe 
    df = pd.DataFrame(stats_df)

    # Pivot the DataFrame

    pivot_df = df.pivot_table(index='epoch', columns=['dataset', 'property', 'metric_name'], values='value')
    num_plots = len(pivot_df['validation'][property].columns) - 1 + len(pivot_df['train'][property].columns) - 1
    fig_col = 2
    fig_row = np.ceil(num_plots / fig_col).astype(int) + 1
    fig, ax = plt.subplots(fig_row, fig_col, figsize=(15, 5 * fig_row))

    x_axis = pivot_df.index.values
    # Train and test loss
    ax[0, 0].plot(x_axis, pivot_df['train'][property][loss_name], label="Train")
    ax[0, 0].plot(x_axis, pivot_df['validation'][property][loss_name], label="Validation")
    ax[0, 0].scatter(-1, pivot_df[pre_train_dataset][property][loss_name].values[0], label="Before fine-tuning", color="red")
    ax[0, 0].set_title("Loss")
    ax[0, 0].set_xlabel("Epoch")
    ax[0, 0].set_ylabel("Loss")
    ax[0, 0].legend()

    # Plot the learning rate
    ax[0, 1].plot(x_axis, pivot_df['train'][property]["lr"])
    ax[0, 1].set_title("Learning rate")
    ax[0, 1].set_xlabel("Epoch")
    ax[0, 1].set_ylabel("Learning rate")


    # Rest of the columns
    columns_to_plot = pivot_df['validation'][property].columns.drop(loss_name)

    for i, column_name in enumerate(columns_to_plot):
        row = (i + 2) // fig_col
        col = (i + 2) % fig_col
        ax[row, col].plot(x_axis, pivot_df['validation'][property][column_name])
        ax[row, col].set_title(column_name + " (Validation)")
        ax[row, col].set_xlabel("Epoch")
        ax[row, col].set_ylabel(column_name)




    plt.tight_layout()
    plt.show()

## RT Fine-tuning


In [None]:

tune_mgr = FinetuneManager(
    device="gpu",
    settings=settings)
tune_mgr.nce = 25
tune_mgr.instrument = 'Lumos'
transfer_lib.precursor_df = tune_mgr.predict_rt(transfer_lib.precursor_df)
plt.scatter(transfer_lib.precursor_df['rt_norm'], transfer_lib.precursor_df['rt_norm_pred'], s=1, alpha=0.1)
plt.xlabel('RT observed')
plt.ylabel('RT predicted')

In [None]:
rt_stats = tune_mgr.finetune_rt(transfer_lib.precursor_df)

transfer_lib.precursor_df = tune_mgr.predict_rt(transfer_lib.precursor_df)

plt.scatter(transfer_lib.precursor_df['rt_norm'], transfer_lib.precursor_df['rt_norm_pred'], s=0.1, alpha=0.1)
plt.xlabel('RT observed')
plt.ylabel('RT predicted')


In [None]:
plot_stats(rt_stats, 'l1_loss', 'rt')

## Charge Fine-tuning

In [None]:

# Testing the charge finetuning on the transfer library
charge_stats = tune_mgr.finetune_charge(psm_df=transfer_lib.precursor_df)

In [None]:
plot_stats(charge_stats, 'ce_loss', 'charge')

## MS2 Fine-tuning

In [11]:
# Uncomment the following line to only finetune the ms2 on high quality spectra
# transfer_lib.precursor_df = transfer_lib.precursor_df[transfer_lib.precursor_df['use_for_ms2']]


In [12]:
def calculate_similarity(precursor_df_a, precursor_df_b, intensity_df_a, intensity_df_b):

    _a_df = precursor_df_a[['precursor_idx', 'frag_start_idx', 'frag_stop_idx']].copy()
    _b_df = precursor_df_b[['precursor_idx', 'frag_start_idx', 'frag_stop_idx']].copy()

    _merged_df = pd.merge(_a_df, _b_df, on='precursor_idx', suffixes=('_a', '_b'))
    # keep only first precursor
    _merged_df = _merged_df.drop_duplicates(subset='precursor_idx', keep='first')
    similarity_list = []

    for i, (start_a, stop_a, start_b, stop_b) in enumerate(zip(_merged_df['frag_start_idx_a'], _merged_df['frag_stop_idx_a'], _merged_df['frag_start_idx_b'], _merged_df['frag_stop_idx_b'])):
        observed_intensity = intensity_df_a.iloc[start_a:stop_a, :4].values.flatten()
        predicted_intensity = intensity_df_b.iloc[start_b:stop_b, :4].values.flatten()

        similarity = np.dot(observed_intensity, predicted_intensity) / (np.linalg.norm(observed_intensity) * np.linalg.norm(predicted_intensity))
        similarity_list.append({'similarity': similarity, 'index': i, 'precursor_idx': _merged_df.iloc[i]['precursor_idx']})

    return pd.DataFrame(similarity_list)

In [None]:
res = tune_mgr.predict_all(transfer_lib.precursor_df.copy(), predict_items=['ms2'])

precursor_after_df = res['precursor_df']
fragment_mz_after_df = res['fragment_mz_df']
fragment_intensity_after_df = res['fragment_intensity_df']
similarity_after_df = calculate_similarity(precursor_after_df, transfer_lib.precursor_df, fragment_intensity_after_df, transfer_lib.fragment_intensity_df)
print(similarity_after_df['similarity'].median())
plt.scatter(similarity_after_df['index'], similarity_after_df['similarity'], s=0.1)
plt.xlabel('Index')
plt.ylabel('Similarity')
plt.title('Similarity between observed and predicted MS2 spectra before fine-tuning')

In [None]:

# Testing the ms2 finetuning on the transfer library
ms2_stats = tune_mgr.finetune_ms2(psm_df=transfer_lib.precursor_df.copy(), matched_intensity_df=transfer_lib.fragment_intensity_df.copy())

In [None]:
res = tune_mgr.predict_all(transfer_lib.precursor_df.copy(), predict_items=["ms2"])

precursor_after_df = res["precursor_df"]
fragment_mz_after_df = res["fragment_mz_df"]
fragment_intensity_after_df = res["fragment_intensity_df"]
similarity_after_df = calculate_similarity(
    precursor_after_df,
    transfer_lib.precursor_df,
    fragment_intensity_after_df,
    transfer_lib.fragment_intensity_df,
)
print(similarity_after_df["similarity"].median())
plt.scatter(similarity_after_df["index"], similarity_after_df["similarity"], s=0.1)
plt.xlabel("Index")
plt.ylabel("Similarity")
plt.title("Similarity between observed and predicted MS2 spectra after fine-tuning")

In [None]:
plot_stats(ms2_stats, 'l1_loss', 'ms2', pre_train_dataset='validation')