In [None]:
#default_exp model.model_interface

# Model Interface

## Description

This notebook mainly defines the basic interface that is used to interact with the deep learning models. Its 'public' functions are intended to stay untouched over the project, while the specific workings of the interface can be changed (i.e. programming polymorphism concept). For example, models can always be loaded with the `load()` function and details of the loading can be changed by inheriting the interface and changing the functions that `load()` calls. More details are given below.


## Imports

In [None]:
#export
import os
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm

from zipfile import ZipFile
from typing import IO, Tuple, List, Union
from alphabase.yaml_utils import save_yaml
from alphabase.peptide.precursor import is_precursor_sorted

from peptdeep.settings import model_const
from peptdeep.utils import logging

from peptdeep.model.building_block import *

## Utility functions

In [None]:
#export
from torch.optim.lr_scheduler import LambdaLR

# copied from huggingface
def get_cosine_schedule_with_warmup(
    optimizer, num_warmup_steps, 
    num_training_steps, num_cycles=0.5, 
    last_epoch=-1
):
    """ Create a schedule with a learning rate that decreases following the
    values of the cosine function between 0 and `pi * cycles` after a warmup
    period during which it increases linearly between 0 and 1.
    """

    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / max(1, num_warmup_steps)
        progress = float(
            current_step - num_warmup_steps
        ) / max(1, num_training_steps - num_warmup_steps)
        return max(0.0, 0.5 * (
            1.0 + np.cos(np.pi * num_cycles * 2.0 * progress)
        ))

    return LambdaLR(optimizer, lr_lambda, last_epoch)

## Interface Class
The `ModelInterface` below is intended to provide a standardized way to handle deep learning models. It does not contain the pyTorch-based models themselves, but provides methods to `load()`, `save()`, `build()`, `train()` and `predict()` new models. These methods are intended to stay unchanged. To adapt the interface to a new usecase, we inherit the interface in a new class and re-implement the relevant methods `_get_features_from_batch_df()`, `_get_targets_from_batch_df()`, `_prepare_predict_data_df()`.
The interface will adapt the training and prediction procedures. The implementation below will automatically empty the GPU cache at the end of `train()` and `predict()` to save GPU memory.

For example, if we would like to design a new model for peptides with different purposes, for example RT prediction, we need to:

- Design the pytorch model (`class RTPrediction(torch.nn.Module):...`)
- Design the sub-class inherited from PeptideModelInterfaceBase (`class RTPredictionModel(PeptideModelInterfaceBase):...`)
- Re-implement `def _get_features_from_batch_df(self, batch_df): ... return torch.LongTensor(aa_indices), torch.Tensor(mod_features)` to tell the base class how to get the input features from the dataframe.
- Re-implement `def _get_targets_from_batch_df(self, batch_df): return torch.Tensor(batch_df['rt'].values)` to tell the base class how to get the target values for training from the dataframe.
- Re-implement `def _prepare_predict_data_df(self, precursor_df): self._predict_column_in_df = 'rt_pred'...` to initialize the column which will store the predicted values.
- [Optional] Re-implement `def _set_loss_function(self): self.loss_func=...` to define the loss function. Defaults to L1Loss()
- At last, execute the model in a python script or a notebook:
```
model = RTPredictionModel()
model.build(model_class=RTPrediction)
df = ... # the training data
model.train(df)
pred_df = model.predict(df)
```

Check out `peptdeep.model.rt.AlphaRTModel` for details. And `peptdeep.model.ccs.AlphaCCSModel` is also similar. MS2 prediction model is more complicated as the output value for a peptide is not a scalar, see `peptdeep.model.ms2.pDeepModel`.

In [None]:
#export
class ModelInterface(object):
    """
    Provides standardized methods to interact
    with ml models. Inherit into new class and override
    the abstract (i.e. not implemented) methods.
    """
    def __init__(self,
        **kwargs
    ):
        self.model:torch.nn.Module = None
        self.optimizer = None
        self.model_params = None
        device_type = self._get_device_type_to_use(kwargs)
        self.set_device(device_type)


    def set_device(self, device_type = 'cuda'):
        """
        Sets the device (e.g. gpu (cuda), cpu) to be used in the model.
        """
        if not torch.cuda.is_available():
            device_type = 'cpu'
        self.device = torch.device(device_type)
        if self.model is not None:
            self.model.to(self.device)


    def build(self,
        model_class: torch.nn.Module,
        **kwargs
    ):
        """
        Builds the model by specifying the pyTorch module, 
        the parameters, the device and the loss function.
        """
        self.model = model_class(**kwargs)
        self.model_params = kwargs
        self.model.to(self.device)
        self._set_loss_function()

    def train_with_warmup(self,
        precursor_df: pd.DataFrame,
        *,
        batch_size=1024, 
        epoch=10, 
        warmup_epoch=5,
        lr=1e-4,
        verbose=False,
        verbose_each_epoch=False,
        **kwargs
    ):
        """
        Trains the model according to specifications. Includes a warumup 
        phase corresponds to a linear learning rate.
        """
        self._pre_training(precursor_df, lr, **kwargs)

        lr_scheduler = get_cosine_schedule_with_warmup(
            self.optimizer, warmup_epoch, epoch
        )

        for epoch in range(epoch):
            batch_cost = self._train_one_epoch(
                precursor_df, epoch,
                batch_size, verbose_each_epoch,
                **kwargs
            )
            lr_scheduler.step()
            if verbose: print(
                f'[Training] Epoch={epoch+1}, lr={lr_scheduler.get_last_lr()[0]}, loss={np.mean(batch_cost)}'
            )
        
        torch.cuda.empty_cache()

    def train(self,
        precursor_df: pd.DataFrame,
        *,
        batch_size=1024, 
        epoch=10, 
        lr=1e-4,
        verbose=False,
        verbose_each_epoch=False,
        **kwargs
    ):
        """
        Trains the model according to specifications.
        """
        self._pre_training(precursor_df, lr, **kwargs)

        for epoch in range(epoch):
            batch_cost = self._train_one_epoch(
                precursor_df, epoch,
                batch_size, verbose_each_epoch,
                **kwargs
            )
            if verbose: print(f'[Training] Epoch={epoch+1}, Mean Loss={np.mean(batch_cost)}')
        
        torch.cuda.empty_cache()

    def predict(self,
        precursor_df:pd.DataFrame,
        *,
        batch_size=1024,
        verbose=False,
        **kwargs
    )->pd.DataFrame:
        """
        The model predicts the properties based on the inputs it has been trained for.
        Returns the ouput as a pandas dataframe.
        """
        precursor_df = self._add_nAA_column_if_missing(precursor_df)
        self._check_predict_in_order(precursor_df)
        self._prepare_predict_data_df(precursor_df,**kwargs)
        self.model.eval()

        _grouped = precursor_df.groupby('nAA')
        if verbose:
            batch_tqdm = tqdm(_grouped)
        else:
            batch_tqdm = _grouped
        with torch.no_grad():
            for nAA, df_group in batch_tqdm:
                for i in range(0, len(df_group), batch_size):
                    batch_end = i+batch_size
                    
                    batch_df = df_group.iloc[i:batch_end,:]

                    features = self._get_features_from_batch_df(
                        batch_df, **kwargs
                    )

                    predicts = self._predict_one_batch(*features)

                    self._set_batch_predict_data(
                        batch_df, predicts, 
                        **kwargs
                    )

        torch.cuda.empty_cache()
        return self.predict_df

    def save(self, filename):
        """
        Save the model state, the constants used, the code defining the model and the model parameters.
        """
        # TODO save tf.keras.Model
        dir = os.path.dirname(filename)
        if not dir: dir = './'
        if not os.path.exists(dir): os.makedirs(dir)
        torch.save(self.model.state_dict(), filename)
        with open(filename+'.txt','w') as f: f.write(str(self.model))
        save_yaml(filename+'.model_const.yaml', model_const)
        self._save_codes(filename+'.model.py')
        save_yaml(filename+'.param.yaml', self.model_params)

    def load(
        self,
        model_file: Tuple[str, IO],
        model_path_in_zip: str = None,
        **kwargs
    ):
        """
        Load a model specified in a zip file, a text file or a file stream.
        """
        # TODO load tf.keras.Model
        if isinstance(model_file, str):
            # We may release all models (msms, rt, ccs, ...) in a single zip file
            if model_file.lower().endswith('.zip'):
                self._load_model_from_zipfile(model_file)
            else:
                self._load_model_from_textfile(model_file)
        else:
            self._load_model_from_filestream(model_file)

    

    def get_parameter_num(self):
        """
        Get total number of parameters in model.
        """
        return np.sum([p.numel() for p in self.model.parameters()])

    def build_from_py_codes(self,
        model_code_file:str,
        code_file_in_zip:str=None,
        **kwargs
    ):
        """
        Build the model based on a python file. Must contain a pyTorch 
        model implemented as 'class Model(...'
        """
        if model_code_file.lower().endswith('.zip'):
            with ZipFile(model_code_file, 'r') as model_zip:
                with model_zip.open(code_file_in_zip) as f:
                    codes = f.read()
        else:
            with open(model_code_file, 'r') as f:
                codes = f.read()
        codes = compile(
            codes, 
            filename='model_file_py',
            mode='exec'
        )
        exec(codes) #codes must contains torch model codes 'class Model(...'
        self.model = Model(**kwargs)
        self.model_params = kwargs
        self.model.to(self.device)
        self._set_loss_function()

    def _set_loss_function(self):
        self.loss_func = torch.nn.L1Loss()


    def _load_model_from_zipfile(self, model_file):
        with ZipFile(model_file) as model_zip:
            with model_zip.open(model_path_in_zip,'r') as pt_file:
                self._load_model_from_filestream(pt_file)

    def _load_model_from_textfile(self, model_file):
        with open(model_file,'rb') as pt_file:
            self._load_model_from_filestream(pt_file)

    def _load_model_from_filestream(self, stream):
        (
            missing_keys, unexpect_keys 
        ) = self.model.load_state_dict(torch.load(
            stream, map_location=self.device),
            strict=False
        )
        if len(missing_keys) > 0:
            logging.warn(f"nn parameters {missing_keys} are MISSING while loading models in {self.__class__}")
        if len(unexpect_keys) > 0:
            logging.warn(f"nn parameters {unexpect_keys} are UNEXPECTED while loading models in {self.__class__}")

    def _save_codes(self, save_as):
        import inspect
        code = '''import torch\nimport peptdeep.model.base as model_base\n'''
        class_code = inspect.getsource(self.model.__class__)
        code += 'class Model' + class_code[class_code.find('('):]
        with open(save_as, 'w') as f:
            f.write(code)

    def _train_one_epoch(self, 
        precursor_df, epoch, batch_size, verbose_each_epoch, 
        **kwargs
    ):
        """Training for an epoch"""
        batch_cost = []
        _grouped = list(precursor_df.sample(frac=1).groupby('nAA'))
        rnd_nAA = np.random.permutation(len(_grouped))
        if verbose_each_epoch:
            batch_tqdm = tqdm(rnd_nAA)
        else:
            batch_tqdm = rnd_nAA
        for i_group in batch_tqdm:
            nAA, df_group = _grouped[i_group]
            # df_group = df_group.reset_index(drop=True)
            for i in range(0, len(df_group), batch_size):
                batch_end = i+batch_size

                batch_df = df_group.iloc[i:batch_end,:]
                targets = self._get_targets_from_batch_df(
                    batch_df, **kwargs
                )
                features = self._get_features_from_batch_df(
                    batch_df, **kwargs
                )
                
                batch_cost.append(
                    self._train_one_batch(targets, *features)
                )
                
            if verbose_each_epoch:
                batch_tqdm.set_description(
                    f'Epoch={epoch+1}, nAA={nAA}, batch={len(batch_cost)}, loss={batch_cost[-1]:.4f}'
                )
        return batch_cost

    def _train_one_batch(
        self, 
        targets:torch.Tensor, 
        *features,
    ):
        """Training for a mini batch"""
        self.optimizer.zero_grad()
        predicts = self.model(*[fea.to(self.device) for fea in features])
        cost = self.loss_func(predicts, targets.to(self.device))
        cost.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        self.optimizer.step()
        return cost.item()

    def _predict_one_batch(self,
        *features
    ):
        """Predicting for a mini batch"""
        return self.model(
            *[fea.to(self.device) for fea in features]
        ).cpu().detach().numpy()

    def _get_targets_from_batch_df(self,
        batch_df:pd.DataFrame, **kwargs,
    )->torch.Tensor:
        """Tell the `train()` method how to get target values from the `batch_df`.
           All sub-classes must re-implement this method.

        Args:
            batch_df (pd.DataFrame): Dataframe of each mini batch.
            nAA (int, optional): Peptide length. Defaults to None.

        Raises:
            NotImplementedError: 'Must implement _get_targets_from_batch_df() method'

        Returns:
            torch.Tensor: Target value tensor
        """
        raise NotImplementedError(
            'Must implement _get_targets_from_batch_df() method'
        )

    def _get_features_from_batch_df(self,
        batch_df:pd.DataFrame, **kwargs,
    )->Tuple[torch.Tensor]:
        """Tell `train()` and `predict()` methods how to get feature tensors from the `batch_df`.
           All sub-classes must re-implement this method.

        Args:
            batch_df (pd.DataFrame): Dataframe of each mini batch.
            nAA (int, optional): Peptide length. Defaults to None.

        Raises:
            NotImplementedError: 'Must implement _get_features_from_batch_df() method'

        Returns:
            Tuple[torch.Tensor]: A feature tensor or multiple feature tensors.
        """
        raise NotImplementedError(
            'Must implement _get_features_from_batch_df() method'
        )

    def _prepare_predict_data_df(self,
        precursor_df:pd.DataFrame, 
        **kwargs
    ):
        """
        This method must define `self._predict_column_in_df` and create a `self.predict_df` dataframe.
        All sub-classes must re-implement this method.
        
        For example for RT prediction:
        >>> self._predict_column_in_df = 'rt_pred'
        >>> precursor_df[self._predict_column_in_df] = 0 (initialize the predict column in the df)
        >>> self.predict_df = precursor_df
        ...
        """
        raise NotImplementedError(
            'Must implement _prepare_predict_data_df() method'
        )

    def _prepare_train_data_df(self,
        precursor_df:pd.DataFrame, 
        **kwargs
    ):
        """Modifications to the training dataframe can be implemented here.

        Args:
            precursor_df (pd.DataFrame): Dataframe containing the training data.
        """
        pass

    def _set_batch_predict_data(self,
        batch_df:pd.DataFrame,
        predict_values:np.array,
        **kwargs
    ):
        """Set predicted values into `self.predict_df`.

        Args:
            batch_df (pd.DataFrame): Dataframe of mini batch when predicting
            predict_values (np.array): Predicted values
        """
        predict_values[predict_values<0] = 0.0
        if self._predict_in_order:
            self.predict_df.loc[:,self._predict_column_in_df].values[
                batch_df.index.values[0]:batch_df.index.values[-1]+1
            ] = predict_values
        else:
            self.predict_df.loc[
                batch_df.index,self._predict_column_in_df
            ] = predict_values

    def _init_optimizer(self, lr):
        """Set optimizer"""
        self.optimizer = torch.optim.Adam(
            self.model.parameters(), lr=lr
        )

    def set_lr(self, lr:float):
        """Set learning rate"""
        if self.optimizer is None:
            self._init_optimizer(lr)
        else:
            for g in self.optimizer.param_groups:
                g['lr'] = lr

    def _pre_training(self, precursor_df, lr, **kwargs):
        if 'nAA' not in precursor_df.columns:
            precursor_df['nAA'] = precursor_df.sequence.str.len()
        self._prepare_train_data_df(precursor_df, **kwargs)
        self.model.train()

        self.set_lr(lr)

    def _check_predict_in_order(self, precursor_df:pd.DataFrame):
        if is_precursor_sorted(precursor_df):
            self._predict_in_order = True
        else:
            self._predict_in_order = False

    def _get_device_type_to_use(self, kwargs):
        use_gpu = self._check_if_GPU_should_be_used(kwargs)
        if use_gpu:
            return 'cuda'
        else:
            return 'cpu'

    @staticmethod
    def _check_if_GPU_should_be_used(kwargs):
        if 'GPU' in kwargs:
            return kwargs['GPU']
        else:
            return True

    @staticmethod
    def _add_nAA_column_if_missing(precursor_df):
        """
        column containing the number of Amino Acids
        """
        if 'nAA' not in precursor_df.columns:
            precursor_df['nAA'] = precursor_df.sequence.str.len()
            precursor_df.sort_values('nAA', inplace=True)
            precursor_df.reset_index(drop=True,inplace=True)
        return precursor_df
    
    

#legacy
PeptideModelInterfaceBase = ModelInterface
ModelInterface.use_GPU = ModelInterface.set_GPU_state
ModelInterface._init_for_train = ModelInterface._set_loss_function