In [11]:
#default_exp pretrained_models

In [12]:
#hide
import peptdeep.pretrained_models
__file__ = peptdeep.pretrained_models.__file__

`peptdeep.pretrained_models` handles the pretrained models, including downloading, installing, and loading the models.

## 1. Downloading and installing the models
For continuous model deployment, we uploaded several pretrained models (compressed as a ZIP file) onto a net disk. peptdeep will automatically download the ZIP file as `sandbox/installed_models/pretrained_models.zip` when importing peptdeep.pretrained_models. The models will be downloaded only once, if we would like to update them to the latest models, we can call `download_models(overwrite=True)`

In [13]:
#export
import os
import io
import pandas as pd
from zipfile import ZipFile
from tarfile import TarFile
from typing import Tuple
import torch
import urllib
import socket
import logging
import shutil
from pickle import UnpicklingError

from peptdeep.settings import global_settings

sandbox_dir = os.path.join(
    os.path.dirname(
        os.path.abspath(__file__)
    ),
    "installed_models"
)

if not os.path.exists(sandbox_dir):
    os.makedirs(sandbox_dir)

model_name = global_settings['local_model_zip_name']
model_url = global_settings['model_url']
url_zip_name = global_settings['model_url_zip_name']

model_zip = os.path.join(
    sandbox_dir, model_name
)

def is_model_zip(downloaded_zip):
    with ZipFile(downloaded_zip) as zip:
        return any(x=='regular/ms2.pth' for x in zip.namelist())

def download_models(
    url:str=model_url, overwrite=True
):
    """[summary]

    Args:
        url (str, optional): remote or local path. 
          Defaults to peptdeep.pretrained_models.model_url.
        overwrite (bool, optional): overwirte old model files. 
          Defaults to True.

    Raises:
        FileNotFoundError: If remote url is not accessible.
    """
    if not os.path.isfile(url):
        downloaded_zip = os.path.join(
            sandbox_dir,f'{url_zip_name}.zip'
        )
        if os.path.exists(model_zip):
            if overwrite:
                os.remove(model_zip)
            else:
                return
        
        logging.info(f'Downloading {model_name} ...')
        try:
            import ssl
            context = ssl._create_unverified_context()
            requests = urllib.request.urlopen(url, context=context, timeout=10)
            with open(downloaded_zip, 'wb') as f:
                f.write(requests.read())
        except (
            socket.timeout, 
            urllib.error.URLError, 
            urllib.error.HTTPError
        ) as e:
            raise FileNotFoundError(
                'Downloading model failed! Please download the '
                f'zip or tar file by yourself from "{url}",'
                ' and use \n'
                f'"peptdeep --install-model /path/to/{url_zip_name}.tar (or .zip)"\n'
                ' to install the models'
            )
    else:
        downloaded_zip = url
    install_models(downloaded_zip, overwrite=overwrite)
    os.remove(downloaded_zip)

def install_models(downloaded_zip:str, overwrite=True):
    """ Install the model zip file. Note that if the `downloaded_zip` 
    is downloaded using download_models(), it is a zip file; if it is 
    downloaded using a browser, it will be a tar file.

    Args:
        downloaded_zip (str): path to the downloaded file
        overwrite (bool, optional): Overwrite the existing model. 
          Defaults to True.
    """
    if os.path.exists(model_zip):
        if overwrite:
            os.remove(model_zip)
        else:
            return
    if is_model_zip(downloaded_zip):
        shutil.copy(
            downloaded_zip, model_zip
        )
        return
    _zip = ZipFile(downloaded_zip)
    try:
        _zip.extract(
            f'{url_zip_name}/{model_name}', 
            sandbox_dir
        )
        shutil.move(
            os.path.join(sandbox_dir, f'{url_zip_name}/{model_name}'),
            os.path.join(sandbox_dir, model_name)
        )
        os.rmdir(os.path.join(sandbox_dir, url_zip_name))
    except KeyError:
        tar = TarFile(downloaded_zip)
        with open(os.path.join(sandbox_dir, model_name), 'wb') as f:
            f.write(tar.extractfile(
                f'{url_zip_name}/{model_name}'
            ).read())
        tar.close()
    _zip.close()
    logging.info(f'Installed {model_name}')

In [14]:
#export
if not os.path.exists(model_zip):
    download_models()

In [15]:
#hide
assert is_model_zip(model_zip)

## 2. Loading the models
peptdeep provides a convenient APIs to load models from ZIP files. 

`load_models()` will load the regular models for unmodified peptides, `load_phos_models()` will load the phospho models. Note that CCS/mobility prediction models are the same for regular and phospho models because this model was trained on both regular and phospho peptides.

In [16]:
#export
from peptdeep.model.ms2 import (
    pDeepModel, normalize_training_intensities
)
from peptdeep.model.rt import AlphaRTModel
from peptdeep.model.ccs import AlphaCCSModel
from peptdeep.utils import uniform_sampling

from peptdeep.settings import global_settings
mgr_settings = global_settings['model_mgr']

def count_mods(psm_df)->pd.DataFrame:
    mods = psm_df[
        psm_df.mods.str.len()>0
    ].mods.apply(lambda x: x.split(';'))
    mod_dict = {}
    mod_dict['mutation'] = {}
    mod_dict['mutation']['spec_count'] = 0
    for one_mods in mods.values:
        for mod in set(one_mods):
            items = mod.split('->')
            if (
                len(items)==2 
                and len(items[0])==3 
                and len(items[1])==5
            ):
                mod_dict['mutation']['spec_count'] += 1
            elif mod not in mod_dict:
                mod_dict[mod] = {}
                mod_dict[mod]['spec_count'] = 1
            else:
                mod_dict[mod]['spec_count'] += 1
    return pd.DataFrame().from_dict(
            mod_dict, orient='index'
        ).reset_index(drop=False).rename(
            columns={'index':'mod'}
        ).sort_values(
            'spec_count',ascending=False
        ).reset_index(drop=True)

def psm_sampling_with_important_mods(
    psm_df, n_sample, 
    top_n_mods = 10,
    n_sample_each_mod = 0, 
    uniform_sampling_column = None,
    random_state=1337,
):
    psm_df_list = []
    if uniform_sampling_column is None:
        def _sample(psm_df, n):
            if n < len(psm_df):
                return psm_df.sample(
                    n, replace=False,
                    random_state=random_state
                ).copy()
            else:
                return psm_df.copy()
    else:
        def _sample(psm_df, n):
            return uniform_sampling(
                psm_df, target=uniform_sampling_column,
                n_train = n, random_state=random_state
            )

    psm_df_list.append(_sample(psm_df, n_sample))
    if n_sample_each_mod > 0:
        mod_df = count_mods(psm_df)
        mod_df = mod_df[mod_df['mod']!='mutation']

        if len(mod_df) > top_n_mods:
            mod_df = mod_df.iloc[:top_n_mods,:]
        for mod in mod_df['mod'].values:
            psm_df_list.append(
                _sample(
                    psm_df[psm_df.mods.str.contains(mod, regex=False)],
                    n_sample_each_mod,
                )
            )
    if len(psm_df_list) > 0:
        return pd.concat(psm_df_list, ignore_index=True)
    else:
        return pd.DataFrame()

def load_phos_models(mask_modloss=True):
    ms2_model = pDeepModel(mask_modloss=mask_modloss)
    ms2_model.load(model_zip, model_path_in_zip='phospho/ms2_phos.pth')
    rt_model = AlphaRTModel()
    rt_model.load(model_zip, model_path_in_zip='phospho/rt_phos.pth')
    ccs_model = AlphaCCSModel()
    ccs_model.load(model_zip, model_path_in_zip='regular/ccs.pth')
    return ms2_model, rt_model, ccs_model

def load_models(mask_modloss=True):
    ms2_model = pDeepModel(mask_modloss=mask_modloss)
    ms2_model.load(model_zip, model_path_in_zip='regular/ms2.pth')
    rt_model = AlphaRTModel()
    rt_model.load(model_zip, model_path_in_zip='regular/rt.pth')
    ccs_model = AlphaCCSModel()
    ccs_model.load(model_zip, model_path_in_zip='regular/ccs.pth')
    return ms2_model, rt_model, ccs_model

def load_models_by_model_type_in_zip(model_type_in_zip:str, mask_modloss=True):
    ms2_model = pDeepModel(mask_modloss=mask_modloss)
    ms2_model.load(model_zip, model_path_in_zip=f'{model_type_in_zip}/ms2.pth')
    rt_model = AlphaRTModel()
    rt_model.load(model_zip, model_path_in_zip=f'{model_type_in_zip}/rt.pth')
    ccs_model = AlphaCCSModel()
    ccs_model.load(model_zip, model_path_in_zip=f'{model_type_in_zip}/ccs.pth')
    return ms2_model, rt_model, ccs_model


## 3. Using `ModelManager`

For users, `ModelManager` class is the only thing we need to manage models (loading, transfer learning, etc). According to different arguments, `ModelManager::load_installed_models()` will call `load_models()` or `load_phos_models()`. For external models, `ModelManager::load_external_models()` will load them by file path or file stream. Here is an example:

```
from zipfile import ZipFile

admodel = ModelManager()
ext_zip = 'external_models.zip' # model compressed in ZIP
rt_model_path = '/path/to/rt.pth' # model as file path
with ZipFile(ext_zip) as model_zip:
    with model_zip.open('regular/ms2.pth','r') as ms2_file:
        admodel.load_external_models(ms2_model_file=ms2_file, rt_model_file=rt_model_path)
```

Transfer learning for different models could also be done in `ModelManager` by using the given training dataframes.

In [17]:
#export
from alphabase.peptide.fragment import (
    create_fragment_mz_dataframe,
    get_charged_frag_types,
    concat_precursor_fragment_dataframes
)
from alphabase.peptide.precursor import (
    refine_precursor_df,
    update_precursor_mz
)
from alphabase.peptide.mobility import (
    mobility_to_ccs_for_df
)

from peptdeep.settings import global_settings

import torch.multiprocessing as mp
from typing import Dict
from peptdeep.utils import logging, process_bar

def clear_error_modloss_intensities(
    fragment_mz_df, fragment_intensity_df
):
    # clear error modloss intensities
    for col in fragment_mz_df.columns.values:
        if 'modloss' in col:
            fragment_intensity_df.loc[
                fragment_mz_df[col]==0,col
            ] = 0

class ModelManager(object):
    def __init__(self, 
        mask_modloss:bool=mgr_settings['mask_modloss'],
        device:str='gpu',
    ):
        """ The manager class to access MS2/RT/CCS models.

        Args:
            mask_modloss (bool, optional): If modloss ions are masked to zeros
                in the ms2 model. `modloss` ions are mostly useful for phospho 
                MS2 prediciton model. 
                Defaults to :py:data:`global_settings`['model_mgr']['mask_modloss'].
            device (str, optional): Device for DL models, could be 'gpu' ('cuda') or 'cpu'.
                if device=='gpu' but no GPUs are detected, it will automatically switch to 'cpu'.
                Defaults to 'gpu'.
                
        Attributes:
            ms2_model (:py:class:`peptdeep.model.ms2.pDeepModel`): The MS2 (pDeep) 
                prediction.
            rt_model (:py:class:`peptdeep.model.rt.AlphaRTModel`): The RT prediction model.
            ccs_model (:py:class:`peptdeep.model.ccs.AlphaCCSModel`): The CCS prediciton model.
            psm_num_to_tune_ms2 (int): Number of PSMs to fine-tune the MS2 model. 
                Defaults to 5000.
            epoch_to_tune_ms2 (int): Number of epoches to fine-tune the MS2 model. 
                Defaults to global_settings['model_mgr']['fine_tune']['epoch_ms2'].
            psm_num_to_tune_rt_ccs (int): Number of PSMs to fine-tune RT/CCS model. 
                Defaults to 3000.
            epoch_to_tune_rt_ccs (int): Number of epoches to fine-tune RT/CCS model. 
                Defaults to global_settings['model_mgr']['fine_tune']['epoch_rt_ccs'].
            nce (float): Default NCE value for a precursor_df without the 'nce' column.
                Defaults to global_settings['model_mgr']['predict']['default_nce'].
            instrument (str): Default instrument type for a precursor_df without the 'instrument' column.
                Defaults to global_settings['model_mgr']['predict']['default_instrument'].
            use_grid_nce_search (bool): If self.ms2_model uses 
                :py:meth:`peptdeep.model.ms2.pDeepModel.grid_nce_search` to determine optimal
                NCE and instrument type. This will change `self.nce` and `self.instrument` values.
                Defaults to global_settings['model_mgr']['fine_tune']['grid_nce_search'].
        """
        self.ms2_model:pDeepModel = pDeepModel(mask_modloss=mask_modloss, device=device)
        self.rt_model:AlphaRTModel = AlphaRTModel(device=device)
        self.ccs_model:AlphaCCSModel = AlphaCCSModel(device=device)

        self.load_installed_models()
        self.load_external_models()

        self.use_grid_nce_search = mgr_settings[
            'fine_tune'
        ]['grid_nce_search']

        self.psm_num_to_tune_ms2 = 5000
        self.psm_num_per_mod_to_tune_ms2 = 0
        self.epoch_to_tune_ms2 = mgr_settings[
            'fine_tune'
        ]['epoch_ms2']
        self.batch_size_to_tune_ms2 = 512

        self.psm_num_to_tune_rt_ccs = 3000
        self.psm_num_per_mod_to_tune_rt_ccs = 0
        self.epoch_to_tune_rt_ccs = mgr_settings[
            'fine_tune'
        ]['epoch_rt_ccs']
        self.batch_size_to_tune_rt_ccs = 1024

        self.top_n_mods_to_tune = 10

        self.nce = mgr_settings[
            'predict'
        ]['default_nce']
        self.instrument = mgr_settings[
            'predict'
        ]['default_instrument']
        self.verbose = mgr_settings[
            'predict'
        ]['verbose']

    def set_default_nce_instrument(self, df):
        if 'nce' not in df.columns and 'instrument' not in df.columns:
            df['nce'] = self.nce
            df['instrument'] = self.instrument
        elif 'nce' not in df.columns:
            df['nce'] = self.nce
        elif 'instrument' not in df.columns:
            df['instrument'] = self.instrument

    def set_default_nce(self, df):
        self.set_default_nce_instrument(df)

    def load_installed_models(self, 
        model_type:str=mgr_settings['model_type']
    ):
        """ Load built-in MS2/CCS/RT models.
        Args:
            model_type (str, optional): To load the installed MS2/RT/CCS models 
                or phos MS2/RT/CCS models. It could be 'phospho', 'HLA', or 'regular'.
                Currently, HLA and regular share the same models.
                Defaults to `global_settings['model_mgr']['model_type']` ('regular').
        """
        if model_type.lower() in [
            'phospho','phos','phosphorylation'
        ]:
            self.ms2_model.load(
                model_zip,
                model_path_in_zip='phospho/ms2_phos.pth'
            )
            self.rt_model.load(
                model_zip, 
                model_path_in_zip='phospho/rt_phos.pth'
            )
            self.ccs_model.load(
                model_zip, 
                model_path_in_zip='regular/ccs.pth'
            )
        elif model_type.lower() in [
            'digly','glygly','ubiquitylation', 
            'ubiquitination','ubiquitinylation'
        ]:
            self.ms2_model.load(
                model_zip,
                model_path_in_zip='digly/ms2_digly.pth'
            )
            self.rt_model.load(
                model_zip, 
                model_path_in_zip='digly/rt_digly.pth'
            )
            self.ccs_model.load(
                model_zip, 
                model_path_in_zip='regular/ccs.pth'
            )
        elif model_type.lower() in ['regular','common']:
            self.ms2_model.load(
                model_zip, model_path_in_zip='regular/ms2.pth'
            )
            self.rt_model.load(
                model_zip, model_path_in_zip='regular/rt.pth'
            )
            self.ccs_model.load(
                model_zip, model_path_in_zip='regular/ccs.pth'
            )
        elif model_type.lower() in [
            'hla','unspecific','non-specific', 'nonspecific'
        ]:
            self.load_installed_models(model_type="regular")
        else:
            logging.warning(
                f"model_type='{model_type}' is not supported, use 'regular' instead."
            )
            self.load_installed_models(model_type="regular")

    def load_external_models(self,
        *,
        ms2_model_file: Tuple[str, io.BytesIO]=mgr_settings['external_ms2_model'],
        rt_model_file: Tuple[str, io.BytesIO]=mgr_settings['external_rt_model'],
        ccs_model_file: Tuple[str, io.BytesIO]=mgr_settings['external_ccs_model'],
    ):
        """Load external MS2/RT/CCS models.

        Args:
            ms2_model_file (Tuple[str, io.BytesIO], optional): ms2 model file or stream.
                Do nothing if the value is ''. Defaults to global_settings['model_mgr']['external_ms2_model'].
            rt_model_file (Tuple[str, io.BytesIO], optional): rt model file or stream.
                Do nothing if the value is ''. Defaults to global_settings['model_mgr']['external_rt_model'].
            ccs_model_file (Tuple[str, io.BytesIO], optional): ccs model or stream.
                Do nothing if the value is ''. Defaults to global_settings['model_mgr']['external_ccs_model'].
        """

        def _load_file(model, model_file):
            try:
                if isinstance(model_file, str):
                    if os.path.isfile(model_file):
                        model.load(model_file)
                    else:
                        return
                model.load(model_file)
            except UnpicklingError as e:
                logging.info(f"Cannot load {model_file} as {model.__class__} model, peptdeep will use the pretrained model instead.")

        _load_file(self.ms2_model, ms2_model_file)
        _load_file(self.rt_model, rt_model_file)
        _load_file(self.ccs_model, ccs_model_file)

    def fine_tune_rt_model(self,
        psm_df:pd.DataFrame,
    ):
        """ Fine-tune the RT model. The fine-tuning will be skipped 
            if `self.psm_num_to_tune_rt_ccs` is zero.

        Args:
            psm_df (pd.DataFrame): training psm_df which contains 'rt_norm' column.
        """
        if self.psm_num_to_tune_rt_ccs > 0:
            tr_df = psm_sampling_with_important_mods(
                psm_df, self.psm_num_to_tune_rt_ccs,
                self.top_n_mods_to_tune,
                self.psm_num_per_mod_to_tune_rt_ccs,
                uniform_sampling_column='rt_norm'
            )
            if len(tr_df) > 0:
                self.rt_model.train_with_warmup(tr_df, 
                    batch_size=self.batch_size_to_tune_rt_ccs,
                    epoch=self.epoch_to_tune_rt_ccs,
                    warmup_epoch=self.epoch_to_tune_rt_ccs//2,
                )

    def fine_tune_ccs_model(self,
        psm_df:pd.DataFrame,
    ):
        """ Fine-tune the CCS model. The fine-tuning will be skipped 
            if `self.psm_num_to_tune_rt_ccs` is zero.

        Args:
            psm_df (pd.DataFrame): training psm_df which contains 'ccs' column.
        """

        if 'mobility' not in psm_df.columns:
            return
        if 'ccs' not in psm_df.columns:
            psm_df['ccs'] = mobility_to_ccs_for_df(
                psm_df, 'mobility'
            )

        if self.psm_num_to_tune_rt_ccs > 0:
            tr_df = psm_sampling_with_important_mods(
                psm_df, self.psm_num_to_tune_rt_ccs,
                self.top_n_mods_to_tune,
                self.psm_num_per_mod_to_tune_rt_ccs,
                uniform_sampling_column='ccs'
            )
            if len(tr_df) > 0:
                self.ccs_model.train_with_warmup(tr_df, 
                    batch_size=self.batch_size_to_tune_rt_ccs,
                    epoch=self.epoch_to_tune_rt_ccs,
                    warmup_epoch=self.epoch_to_tune_rt_ccs//2,
                )

    def fine_tune_ms2_model(self,
        psm_df: pd.DataFrame,
        matched_intensity_df: pd.DataFrame,
    ):
        """Using matched_intensity_df to fine-tune the ms2 model. 
        1. It will sample `n=self.psm_num_to_tune_ms2` PSMs into training dataframe (`tr_df`) to for fine-tuning.
        2. This method will also consider some important PTMs (`n=self.top_n_mods_to_tune`) into `tr_df` for fine-tuning. 
        3. If `self.use_grid_nce_search==True`, this method will call `self.ms2_model.grid_nce_search` to find the best NCE and instrument.

        Args:
            psm_df (pd.DataFrame): PSM dataframe for fine-tuning.
            matched_intensity_df (pd.DataFrame): The matched fragment intensities for `psm_df`.
        """
        if self.psm_num_to_tune_ms2 > 0:
            tr_df = psm_sampling_with_important_mods(
                psm_df, self.psm_num_to_tune_ms2,
                self.top_n_mods_to_tune,
                self.psm_num_per_mod_to_tune_ms2
            )
            if len(tr_df) > 0:
                tr_df, frag_df = normalize_training_intensities(
                    tr_df, matched_intensity_df
                )
                tr_inten_df = pd.DataFrame()
                for frag_type in self.ms2_model.charged_frag_types:
                    if frag_type in frag_df.columns:
                        tr_inten_df[frag_type] = frag_df[frag_type]
                    else:
                        tr_inten_df[frag_type] = 0

                if self.use_grid_nce_search:
                    self.nce, self.instrument = self.ms2_model.grid_nce_search(
                        tr_df, tr_inten_df,
                        nce_first=mgr_settings['fine_tune'][
                            'grid_nce_first'
                        ],
                        nce_last=mgr_settings['fine_tune'][
                            'grid_nce_last'
                        ],
                        nce_step=mgr_settings['fine_tune'][
                            'grid_nce_step'
                        ],
                        search_instruments=mgr_settings['fine_tune'][
                            'grid_instrument'
                        ],
                    )
                    tr_df['nce'] = self.nce
                    tr_df['instrument'] = self.instrument
                else:
                    self.set_default_nce_instrument(tr_df)

                self.ms2_model.train_with_warmup(tr_df, 
                    fragment_intensity_df=tr_inten_df,
                    batch_size=self.batch_size_to_tune_ms2,
                    epoch=self.epoch_to_tune_ms2,
                    warmup_epoch=self.epoch_to_tune_ms2//2,
                )

    def predict_ms2(self, precursor_df:pd.DataFrame, 
        *, 
        batch_size:int=mgr_settings[
            'predict'
        ]['batch_size_ms2'],
        reference_frag_df:pd.DataFrame = None,
    )->pd.DataFrame:
        """Predict MS2 for the given precursor_df

        Args:
            precursor_df (pd.DataFrame): precursor dataframe for MS2 prediction.
            batch_size (int, optional): Batch size for prediction. 
              Defaults to mgr_settings[ 'predict' ]['batch_size_ms2'].
            reference_frag_df (pd.DataFrame, optional): 
              If precursor_df has 'frag_start_idx' pointing to reference_frag_df. 
              Defaults to None.

        Returns:
            pd.DataFrame: predicted fragment intensity dataframe. 
              If there are no such two columns in precursor_df, 
              it will insert 'frag_start_idx' and `frag_end_idx` in 
              precursor_df pointing to this predicted fragment dataframe.
        """
        self.set_default_nce_instrument(precursor_df)
        if self.verbose:
            logging.info('Predicting MS2 ...')
        return self.ms2_model.predict(precursor_df, 
            batch_size=batch_size,
            reference_frag_df=reference_frag_df,
            verbose=self.verbose
        )

    def predict_rt(self, precursor_df:pd.DataFrame,
        *, 
        batch_size:int=mgr_settings[
            'predict'
        ]['batch_size_rt_ccs']
    )->pd.DataFrame:
        """ Predict RT ('rt_pred') inplace into `precursor_df`.

        Args:
            precursor_df (pd.DataFrame): precursor_df for RT prediction
            batch_size (int, optional): Batch size for prediction. 
              Defaults to mgr_settings[ 'predict' ]['batch_size_rt_ccs']. 
              mgr_settings=peptdeep.settings.global_settings['model_mgr'].

        Returns:
            pd.DataFrame: df with 'rt_pred' and 'rt_norm_pred' columns.
        """
        if self.verbose:
            logging.info("Predicting RT ...")
        df = self.rt_model.predict(precursor_df, 
            batch_size=batch_size, verbose=self.verbose
        )
        df['rt_norm_pred'] = df.rt_pred
        return df

    def predict_mobility(self, precursor_df:pd.DataFrame,
        *, 
        batch_size:int=mgr_settings[
            'predict'
        ]['batch_size_rt_ccs']
    )->pd.DataFrame:
        """ Predict mobility ('ccs_pred' and `mobility_pred`) inplace into `precursor_df`.

        Args:
            precursor_df (pd.DataFrame): precursor_df for CCS/mobility prediction
            batch_size (int, optional): Batch size for prediction. 
              Defaults to mgr_settings[ 'predict' ]['batch_size_rt_ccs']. 
              mgr_settings=peptdeep.settings.global_settings['model_mgr'].

        Returns:
            pd.DataFrame: df with 'ccs_pred' and 'mobility_pred' columns.
        """
        if self.verbose:
            logging.info("Predicting mobility ...")
        precursor_df = self.ccs_model.predict(precursor_df,
            batch_size=batch_size, verbose=self.verbose
        )
        return self.ccs_model.ccs_to_mobility_pred(
            precursor_df
        )

    def _predict_all_for_mp(self, arg_dict):
        """Internal function, for multiprocessing"""
        return self.predict_all(
            multiprocessing=False, **arg_dict
        )

    def predict_all(self, precursor_df:pd.DataFrame,
        *, 
        predict_items:list = [
            'rt' ,'mobility' ,'ms2'
        ], 
        frag_types:list =  None,
        multiprocessing:bool = mgr_settings['predict']['multiprocessing'],
        process_num:int = global_settings['thread_num'],
        min_required_precursor_num_for_mp:int = 3000,
        mp_batch_size:int = 500000,
    )->Dict[str, pd.DataFrame]:
        """ predict all items defined by `predict_items`, 
        which may include rt, mobility, fragment_mz 
        and fragment_intensity.

        Args:
            precursor_df (pd.DataFrame): precursor dataframe contains 
              `sequence`, `mods`, `mod_sites`, `charge` ... columns. 
            predict_items (list, optional): items ('rt', 'mobility', 
              'ms2') to predict.
              Defaults to ['rt' ,'mobility' ,'ms2'].
            frag_types (list, optional): fragment types to predict. If it is None,
            it then depends on `self.ms2_model.charged_frag_types` and 
            `self.ms2_model.model._mask_modloss`.
              Defaults to None.
            multiprocessing (bool, optional): if use multiprocessing.
              Defaults to True.
            process_num (int, optional): Defaults to global_settings['thread_num']
            min_required_precursor_num_for_mp (int, optional): It will not use 
              multiprocessing when the number of precursors in precursor_df 
              is lower than this value. Defaults to 5000.
              
        Returns:
            Dict[str, pd.DataFrame]: {'precursor_df': precursor_df}
              if 'ms2' in predict_items, it also contains:
              {
                  'fragment_mz_df': fragment_mz_df,
                  'fragment_intensity_df': fragment_intensity_df
              }
        """
        def refine_df(df):
            if 'ms2' in predict_items:
                refine_precursor_df(df)
            else:
                refine_precursor_df(df, drop_frag_idx=False)

        if frag_types is None:
            if self.ms2_model.model._mask_modloss:
                frag_types = [
                    frag for frag in self.ms2_model.charged_frag_types
                    if 'modloss' not in frag
                ]
            else:
                frag_types = self.ms2_model.charged_frag_types

        if 'precursor_mz' not in precursor_df.columns:
            update_precursor_mz(precursor_df)

        if (
            torch.cuda.is_available() or not multiprocessing
            or len(precursor_df) < min_required_precursor_num_for_mp
        ):
            refine_df(precursor_df)
            if 'rt' in predict_items:
                self.predict_rt(precursor_df)
            if 'mobility' in predict_items:
                self.predict_mobility(precursor_df)
            if 'ms2' in predict_items:
                fragment_mz_df = create_fragment_mz_dataframe(
                    precursor_df, frag_types
                )

                precursor_df.drop(
                    columns=['frag_start_idx'], inplace=True
                )
                
                fragment_intensity_df = self.predict_ms2(
                    precursor_df
                )

                fragment_intensity_df.drop(
                    columns=[
                        col for col in fragment_intensity_df.columns
                        if col not in frag_types
                    ], inplace=True
                )

                clear_error_modloss_intensities(
                    fragment_mz_df, fragment_intensity_df
                )

                return {
                    'precursor_df': precursor_df, 
                    'fragment_mz_df': fragment_mz_df,
                    'fragment_intensity_df': fragment_intensity_df, 
                }
            else:
                return {'precursor_df': precursor_df}
        else:
            self.ms2_model.model.share_memory()
            self.rt_model.model.share_memory()
            self.ccs_model.model.share_memory()

            df_groupby = precursor_df.groupby('nAA')

            def get_batch_num_mp(df_groupby):
                batch_num = 0
                for group_len in df_groupby.size().values:
                    for i in range(0, group_len, mp_batch_size):
                        batch_num += 1
                return batch_num

            def mp_param_generator(df_groupby):
                for nAA, df in df_groupby:
                    for i in range(0, len(df), mp_batch_size):
                        yield {
                            'precursor_df': df.iloc[i:i+mp_batch_size,:],
                            'predict_items': predict_items,
                            'frag_types': frag_types,
                        }

            precursor_df_list = []
            if 'ms2' in predict_items:
                fragment_mz_df_list = []
                fragment_intensity_df_list = []
            else:
                fragment_mz_df_list = None

            if self.verbose:
                logging.info(
                    f'Predicting {",".join(predict_items)} ...'
                )
            verbose_bak = self.verbose
            self.verbose = False

            with mp.Pool(process_num) as p:
                for ret_dict in process_bar(
                    p.imap_unordered(
                        self._predict_all_for_mp, 
                        mp_param_generator(df_groupby)
                    ), 
                    get_batch_num_mp(df_groupby)
                ):
                    precursor_df_list.append(ret_dict['precursor_df'])
                    if fragment_mz_df_list is not None:
                        fragment_mz_df_list.append(
                            ret_dict['fragment_mz_df']
                        )
                        fragment_intensity_df_list.append(
                            ret_dict['fragment_intensity_df']
                        )
            self.verbose = verbose_bak

            if fragment_mz_df_list is not None:
                (
                    precursor_df, fragment_mz_df, fragment_intensity_df
                ) = concat_precursor_fragment_dataframes(
                    precursor_df_list,
                    fragment_mz_df_list,
                    fragment_intensity_df_list,
                )
                
                return {
                    'precursor_df': precursor_df, 
                    'fragment_mz_df': fragment_mz_df,
                    'fragment_intensity_df': fragment_intensity_df, 
                }
            else:
                precursor_df = pd.concat(precursor_df_list)
                precursor_df.reset_index(drop=True, inplace=True)
                
                return {'precursor_df': precursor_df} 


In [18]:
#hide
assert os.path.isfile(model_zip)
with ZipFile(model_zip) as _zip:
    with _zip.open('regular/ms2.pth'):
        pass
    with _zip.open('regular/rt.pth'):
        pass
    with _zip.open('regular/ccs.pth'):
        pass
    with _zip.open('phospho/ms2_phos.pth'):
        pass
    with _zip.open('phospho/rt_phos.pth'):
        pass

In [19]:
#hide
from io import StringIO

matched_df = pd.read_csv(
    StringIO(',b_z1,b_z2,y_z1,y_z2,b_modloss_z1,b_modloss_z2,y_modloss_z1,y_modloss_z2\r\n'
        '0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0\r\n1,0.13171915994341352,0.0,0.0,0.0,0.0,0.0,0.0,0.0\r\n'
        '2,0.09560456716002332,0.0,0.0,0.0,0.0,0.0,0.0,0.0\r\n'
        '3,0.032392355556351476,0.0,0.0,0.0,0.0,0.0,0.0,0.0\r\n'
        '4,0.06267661211925589,0.0,0.0,0.0,0.0,0.0,0.0,0.0\r\n'
        '5,0.10733421416437268,0.0,0.0,0.0,0.0,0.0,0.0,0.0\r\n'
        '6,0.07955175724673087,0.0,0.0,0.0,0.0,0.0,0.0,0.0\r\n'
        '7,0.08283861204882843,0.0,0.03294760940125559,0.0,0.0,0.0,0.0,0.0\r\n'
        '8,0.0914959582993716,0.0,0.09471333271745186,0.0,0.0,0.0,0.0,0.0\r\n'
        '9,0.10283525167783934,0.0,0.29624251030302834,0.0,0.0,0.0,0.0,0.0\r\n'
        '10,0.02220051360812495,0.0272619351931404,0.8077539764174795,0.0,0.0,0.0,0.0,0.0\r\n'
        '11,0.0,0.02411148245999131,0.851474013001872,0.0,0.0,0.0,0.0,0.0\r\n'
        '12,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0\r\n13,0.0,0.0,0.22244818653184315,0.0,0.0,0.0,0.0,0.0\r\n'
        '14,0.0,0.0,0.21824010319946407,0.0,0.0,0.0,0.0,0.0\r\n'
        '15,0.0,0.0,0.16690493688692923,0.0,0.0,0.0,0.0,0.0\r\n'),
    index_col=0
)

model_mgr = ModelManager(mask_modloss=True)
model_mgr.verbose=False
def pred_one(seq, mods, mod_sites, charge):
    df = pd.DataFrame()
    df["sequence"] = [seq]
    df["mods"] = [mods]
    df["mod_sites"] = [mod_sites]
    df["charge"] = charge
    df["nce"] = 35
    df["instrument"] = "Lumos"
    predict_dict = model_mgr.predict_all(
        df, predict_items=['mobility','rt','ms2'],
        multiprocessing=False
    )
    return predict_dict['fragment_intensity_df']

pred_df = pred_one('ANEKTESSSAQQVAVSR', '', '', 3)

def get_pcc(matched_df, pred_df):
    matched_df = matched_df[pred_df.columns.values]
    return torch.nn.functional.cosine_similarity(
        torch.tensor((pred_df.values   -pred_df.values.mean()).reshape(-1)), 
        torch.tensor((matched_df.values-matched_df.values.mean()).reshape(-1)), 
        dim=0
    )
assert get_pcc(matched_df, pred_df) > 0.95

In [20]:
#hide
model_mgr = ModelManager(mask_modloss=False)
model_mgr.load_installed_models('phos')
model_mgr.verbose=False
def pred_one(seq, mods, mod_sites, charge):
    df = pd.DataFrame()
    df["sequence"] = [seq]
    df["mods"] = [mods]
    df["mod_sites"] = [mod_sites]
    df["charge"] = charge
    df["nce"] = 30
    df["instrument"] = "Lumos"
    predict_dict = model_mgr.predict_all(
        df, predict_items=['mobility','rt','ms2'],
        multiprocessing=False
    )
    return predict_dict['fragment_intensity_df']

pred_df = pred_one('ANEKTESSSAQQVAVSR', 'Phospho@S', '9',2)
assert (pred_df.y_modloss_z1.values>0.5).any()
pred_df = pred_one('ANEKTESSTAQQVAVSR', 'Phospho@T', '9',2)
assert (pred_df.y_modloss_z1.values>0.5).any()
pred_df = pred_one('ANEKTESSSAQQVAVSR', 'Phospho@S', '16',2)
assert (pred_df.y_modloss_z1.values>0.5).any()
pred_df = pred_one('ANEKTESSYAQQVAVSR', 'Phospho@Y', '9',2)
assert (pred_df.y_modloss_z1.values<=0).all()