In [1]:
#default_exp pretrained_models

In [2]:
#hide
__file__ = '../alphadeep/pretrained_models.py'

In [3]:
#export
import os
import io
import wget
import pandas as pd
from zipfile import ZipFile
from typing import Tuple

sandbox_dir = os.path.join(
    os.path.dirname(
        os.path.abspath(__file__)
    ),
    'sandbox'
)

if not os.path.exists(sandbox_dir):
    os.makedirs(sandbox_dir)

model_zip = os.path.join(
    sandbox_dir,
    'released_models/alphadeep_models.zip'
)

def download_models(
    url='https://datashare.biochem.mpg.de/s/ABnQuD2KkXfIGF3/download',
    overwrite=True
):
    downloaded_zip = os.path.join(
        sandbox_dir,'released_models.zip'
    )
    if os.path.exists(model_zip):
        if overwrite:
            os.remove(model_zip)
        else:
            return
    
    print('[Start] Downloading alphadeep_models.zip ...')
    wget.download(url, downloaded_zip)
    _zip = ZipFile(downloaded_zip)
    _zip.extract('released_models/alphadeep_models.zip', sandbox_dir)
    _zip.close()
    os.remove(downloaded_zip)
    print('[Done] Downloading alphadeep_models.zip')

if not os.path.exists(model_zip):
    download_models()

In [4]:
#export
from alphadeep.model.ms2 import (
    pDeepModel, normalize_training_intensities
)
from alphadeep.model.rt import AlphaRTModel, uniform_sampling
from alphadeep.model.ccs import AlphaCCSModel

def load_phos_models(mask_phos_modloss=False):
    ms2_model = pDeepModel(mask_modloss=mask_phos_modloss)
    ms2_model.load(model_zip, model_path_in_zip='phospho/ms2_phos.pth')
    rt_model = AlphaRTModel()
    rt_model.load(model_zip, model_path_in_zip='phospho/irt_phos.pth')
    ccs_model = AlphaCCSModel()
    ccs_model.load(model_zip, model_path_in_zip='regular/ccs.pth')
    return ms2_model, rt_model, ccs_model

def load_models():
    ms2_model = pDeepModel()
    ms2_model.load(model_zip, model_path_in_zip='regular/ms2.pth')
    rt_model = AlphaRTModel()
    rt_model.load(model_zip, model_path_in_zip='regular/rt.pth')
    ccs_model = AlphaCCSModel()
    ccs_model.load(model_zip, model_path_in_zip='regular/ccs.pth')
    return ms2_model, rt_model, ccs_model

In [5]:
#export
class AlphaDeepModels(object):
    def __init__(self, **kwargs):
        self.ms2_model:pDeepModel = None
        self.rt_model:AlphaRTModel = None
        self.ccs_model:AlphaCCSModel = None

        if 'grid_nce_search' in kwargs:
            self.grid_nce_search = kwargs['grid_nce_search']
        else:
            self.grid_nce_search = True

        self.n_ms2_tune = 3000
        self.epoch_ms2_tune = 5
        self.n_rt_ccs_tune = 1000
        self.epoch_rt_ccs_tune = 10

    def load_installed_models(self, phospho=False, mask_modloss=True):
        """ Load built-in MS2/CCS/RT models.
        Args:
            phospho (bool, optional): To load the installed MS2/RT/CCS models 
                or phos MS2/RT/CCS models. Defaults to False.
            mask_modloss (bool, optional): If modloss ions are masked to zeros
                in the ms2 model. `modloss` ions are mostly
                useful for phospho MS2 prediciton model. Defaults to True.
        """
        if phospho:
            (
                self.ms2_model, self.rt_model, self.ccs_model
            ) = load_phos_models(mask_modloss)
        else:
            (
                self.ms2_model, self.rt_model, self.ccs_model
            ) = load_phos_models(mask_modloss)

    def load_external_models(self, 
        ms2_model_file: Tuple[str, io.BytesIO], 
        rt_model_file: Tuple[str, io.BytesIO],
        ccs_model_file: Tuple[str, io.BytesIO],
        mask_modloss=True
    ):
        """Load external MS2/RT/CCS models 

        Args:
            ms2_model_file (Tuple[str, io.BytesIO]): ms2 model file or stream
            rt_model_file (Tuple[str, io.BytesIO]): rt model file or stream
            ccs_model_file (Tuple[str, io.BytesIO]): ccs model or stream
            mask_modloss (bool, optional): If modloss ions are masked to zeros
                in the ms2 model. Defaults to True.
        """
        self.ms2_model = pDeepModel(mask_modloss=mask_modloss)
        self.ms2_model.load(ms2_model_file)
        self.rt_model = AlphaRTModel()
        self.rt_model.load(rt_model_file)
        self.ccs_model = AlphaCCSModel()
        self.ccs_model.load(ccs_model_file)

    def fine_tune_rt_model(self,
        psm_df:pd.DataFrame
    ):
        if self.n_rt_ccs_tune > 0:
            tr_df = uniform_sampling(
                psm_df, target='rt_norm',
                n_train=self.n_rt_ccs_tune, 
                return_test_df=False
            )
            self.rt_model.train(tr_df, 
                epoch=self.epoch_rt_ccs_tune
            )

    def fine_tune_ccs_model(self,
        psm_df:pd.DataFrame,
    ):
        if self.n_ms2_tune > 0:
            tr_df = uniform_sampling(
                psm_df, target='ccs',
                n_train=self.n_rt_ccs_tune, 
                return_test_df=False
            )
            self.ccs_model.train(tr_df, 
                epoch=self.epoch_rt_ccs_tune
            )

    def fine_tune_ms2_model(self,
        psm_df: pd.DataFrame,
        matched_intensity_df: pd.DataFrame
    ):
        tr_df = psm_df.sample(self.n_ms2_tune).copy()
        tr_df, frag_df = normalize_training_intensities(
            tr_df, matched_intensity_df
        )
        tr_inten_df = pd.DataFrame()
        for frag_type in self.ms2_model.charged_frag_types:
            if frag_type in frag_df.columns:
                tr_inten_df[frag_type] = frag_df[frag_type]
            else:
                tr_inten_df[frag_type] = 0

        if (
            self.grid_nce_search 
            or 'nce' not in psm_df.columns
            or 'instrument' not in psm_df.columns
        ):
            nce, instrument = self.ms2_model.grid_nce_search(
                tr_df, tr_inten_df
            )
            tr_df['nce'] = nce
            tr_df['instrument'] = instrument
            psm_df['nce'] = nce
            psm_df['instrument'] = instrument

        self.ms2_model.train(tr_df, 
            fragment_inten_df=tr_inten_df,
            epoch=self.epoch_ms2_tune
        )


In [6]:
#hide
model = AlphaDeepModels()
model.load_installed_models()