In [1]:
%reload_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os

from peptdeep.pretrained_models import ModelManager

from alphabase.spectral_library.base import SpecLibBase
from alphadia import data, planning
from alphadia.workflow import manager, peptidecentric
from alphadia.tunning import settings, FinetuneManager


In [2]:
import torch
torch.set_num_threads(10)

In [4]:
transfer_lib = SpecLibBase()
transfer_lib.load_hdf('alphaDia/second_pass_score/second_pass_score/output/speclib.transfer.hdf', load_mod_seq=True)

In [3]:
transfer_lib = SpecLibBase()
transfer_lib.load_hdf('alphaDia/d0_search/output/speclib.transfer.hdf', load_mod_seq=True)

## RT Fine-tuning


In [None]:

tune_mgr = FinetuneManager(
    device="mps",
    settings=settings)
tune_mgr.nce = 25
tune_mgr.instrument = 'Lumos'
transfer_lib.precursor_df = tune_mgr.predict_rt(transfer_lib.precursor_df)
plt.scatter(transfer_lib.precursor_df['rt_norm'], transfer_lib.precursor_df['rt_norm_pred'], s=1, alpha=0.1)
plt.xlabel('RT observed')
plt.ylabel('RT predicted')

In [None]:
stats = tune_mgr.finetune_rt(transfer_lib.precursor_df)

transfer_lib.precursor_df = tune_mgr.predict_rt(transfer_lib.precursor_df)

plt.scatter(transfer_lib.precursor_df['rt_norm'], transfer_lib.precursor_df['rt_norm_pred'], s=0.1, alpha=0.1)
plt.xlabel('RT observed')
plt.ylabel('RT predicted')


In [None]:

num_plots = len(stats.columns) - 3
fig_col = 2
fig_row = np.ceil(num_plots / fig_col).astype(int) + 1
fig, ax = plt.subplots(fig_row, fig_col, figsize=(15, 5 * fig_row))

x_axis = stats["epoch"]
# Train and test loss
ax[0, 0].plot(x_axis, stats["train_loss"], label="train")
ax[0, 0].plot(x_axis, stats["test_loss"], label="test")
ax[0, 0].set_title("Loss")
ax[0, 0].set_xlabel("Epoch")
ax[0, 0].set_ylabel("Loss")
ax[0, 0].legend()


# Rest of the columns
columns_to_plot = stats.columns.drop(["epoch", "train_loss", "test_loss"])
for i, column_name in enumerate(columns_to_plot):
    row = (i + 1) // fig_col
    col = (i + 1) % fig_col
    ax[row, col].plot(x_axis, stats[column_name])
    ax[row, col].set_title(column_name)
    ax[row, col].set_xlabel("Epoch")
    ax[row, col].set_ylabel(column_name)


plt.tight_layout()
plt.show()

## Charge Fine-tuning

In [None]:

# Testing the charge finetuning on the transfer library
stats = tune_mgr.finetune_charge(psm_df=transfer_lib.precursor_df)

In [None]:

num_plots = len(stats.columns) - 3
fig_col = 2
fig_row = np.ceil(num_plots / fig_col).astype(int) + 1
fig, ax = plt.subplots(fig_row, fig_col, figsize=(15, 5 * fig_row))

x_axis = stats["epoch"]
# Train and test loss
ax[0, 0].plot(x_axis, stats["train_loss"], label="train")
ax[0, 0].plot(x_axis, stats["test_loss"], label="test")
ax[0, 0].set_title("Loss")
ax[0, 0].set_xlabel("Epoch")
ax[0, 0].set_ylabel("Loss")
ax[0, 0].legend()

# Rest of the columns
columns_to_plot = stats.columns.drop(["epoch", "train_loss", "test_loss"])
for i, column_name in enumerate(columns_to_plot):
    row = (i + 1) // fig_col
    col = (i + 1) % fig_col
    ax[row, col].plot(x_axis, stats[column_name])
    ax[row, col].set_title(column_name)
    ax[row, col].set_xlabel("Epoch")
    ax[row, col].set_ylabel(column_name)


plt.tight_layout()
plt.show()

## MS2 Fine-tuning

In [None]:

# Testing the ms2 finetuning on the transfer library
stats = tune_mgr.finetune_ms2(psm_df=transfer_lib.precursor_df,matched_intensity_df=transfer_lib.fragment_intensity_df)

In [None]:

num_plots = len(stats.columns) - 3
fig_col = 2
fig_row = np.ceil(num_plots / fig_col).astype(int) + 1
fig, ax = plt.subplots(fig_row, fig_col, figsize=(15, 5 * fig_row))

x_axis = stats["epoch"]
# Train and test loss
ax[0, 0].plot(x_axis, stats["train_loss"], label="train")
ax[0, 0].plot(x_axis, stats["test_loss"], label="test")
ax[0, 0].set_title("Loss")
ax[0, 0].set_xlabel("Epoch")
ax[0, 0].set_ylabel("Loss")
ax[0, 0].legend()


# Rest of the columns
columns_to_plot = stats.columns.drop(["epoch", "train_loss", "test_loss"])
for i, column_name in enumerate(columns_to_plot):
    row = (i + 1) // fig_col
    col = (i + 1) % fig_col
    ax[row, col].plot(x_axis, stats[column_name])
    ax[row, col].set_title(column_name)
    ax[row, col].set_xlabel("Epoch")
    ax[row, col].set_ylabel(column_name)


plt.tight_layout()
plt.show()