# 2.2 Adding a new model type to Anvil
<div style="text-align: center">
<img src="anvil_diagram.png" alt="Anvil diagram" width="500"/>  
</div>

### Background
When we work with model types at OpenADMET, we aim to implement them in our main training harness, Anvil, which we then use to mass produce models across datasets. This makes training and comparing models very easy, and selecting the highest performing model types more simple.  

Our anvil harness is implemented in the [openadmet-models repo](https://github.com/OpenADMET/openadmet-models) (possibly subject to change) along with all of the model wrappers and comparison code.  
Generally at OpenADMET, we don’t have a mandate to develop our own model architectures from scratch, rather to compare and tweak extant model architectures, making adjustments where it is well supported by evidence.  

#### What we mean by model "type"
Broadly, we mean an algorithm that takes input features and converts them to some kind of representation that can then be used to predict on new data.  

Model types can basically be thought of as having two components

- **Features**
- **Architecture / Algorithm**

| Type A - sklearn | Type B - neural network |
|----------|----------|
| For type A model types, you can select from a number of prebuilt features (e.g descriptors, molecular fingerprints, SMILES strings as text) and feed them into an array of pre-built machine learning algorithms (e.g XGBoost gradient boosted trees, SVM, Ridge Regression). Models of this kind often support a scikit-learn style interface. The first important step of implementing a model architecture is recognizing that each of these subsections maps to a class that is pulled from a registry (more detail later). | For type B model types, the features and the model architecture are far more linked. This is often the case for neural networks such as GNNs, GINs, etc, where you need to feed in a molecular graph to the top of the algorithm. Additionally for neural networks, technical choices matter more as a various model implementations require data to be pushed through the NN library in which the architecture is implemented (e.g PyTorch). |

#### Type A - sklearn
In our anvil infrastructure, this is captured in the recipe YAML file in the procedure section. Here, we combine some features (RDKit 2D descriptors) with a TabPFN classifier model.
```yaml
procedure:
  split:
...

  feat:
    type: DescriptorFeaturizer
    params:
      descr_type: "desc2d"

  model:
    type: TabPFNClassifierModel
    params:
      ignore_pretraining_limits: True
      device: cpu

  train:
    type: SKLearnBasicTrainer
```

#### Type B - neural network
Alternatively for a Type B  or NN model, e.g ChemProp (chemprop) we would use the below. This is a pytorch model, so we use a featurizer that computes graphs and stores them in a pytorch dataloader. We then have a wrapper around the NN architecture that can be trained using pytorch lightning.
```yaml
procedure:

  split:
...
  feat:
    type: ChemPropFeaturizer


  model:
    type: ChemPropSingleTaskRegressorModel

  train:
    type: LightningTrainer
    params:
      max_epochs: 1
      accelerator: "cpu"
      use_wandb: false
      wandb_project: "openadmet-testing"
```
### Requirements
- Familiarity with the Anvil workflow, see the [2.1_training_models_with_Anvil](./2.1_Training_models_with_Anvil/) demo.
- You should have the `openadmet-models` repo in an IDE or coding tool of your choice.
## 1. Overview
There's no real code demo in this notebook. Instead, we aim to walkthrough some guidelines on how you can research models and correctly implement into Anvil for easy use.

## 2. Research models
1. Find a new model architecture and learn more about it:
    1. Does it have a reference implementation? 
    2. Does it have solid community support? 
    3. How easy is it to use? 
    4. How well established?
2. Decide if your model is **Type A or Type B** that is trainable through sckit-learn or pytorch. This should be reasonably obvious from looking at the reference implementation. 
    1. What kind of input features does it expect? See the relevant sections below.

## 3. Implement features (if needed)
#### Type A models
Traditional ML and scikit learn compatible models generally work with most input feature sets, e.g Descriptors, fingerprints, precomputed vectors from pretraining etc. However you should check to see what kind of input features are recommended for your model and see if they are implemented in openadmet/models/features. If not, you should implement them!  

The API for an example featurizer is shown below:
```python
from openadmet.models.features.feature_base import featurizers, FeatureBase

@featurizers.register("MyFeaturizer")
class MyFeaturizer(FeatureBase):
    """
    Fingerprint featurizer for molecules, relies on molfeat backend
    """

    def featurize(self, smiles: Iterable[str]) -> tuple[np.ndarray, np.ndarray]:
        """
        Featurize a list of SMILES strings
        """
        .... # compute your features
        return feat, indices
```


#### Type B models
Pytorch compatible models are a little bit more complicated. This is because you need to map between the normal machine learning representation and a setup that can be trained easily using pytorch.  

For a neural network you need to transform your data into a Pytorch dataloader which can then be consumed by the Pytorch lightning model framework.  

A Featurizer for Chemprop is shown below. Note that somewhat similar to the Type A models, a dataloader and successful indices are returned along with additional scaler (can be none for no scaling) and raw Pytorch Dataset.
```python
@featurizers.register("ChemPropFeaturizer")
class ChemPropFeaturizer(DeepLearningFeaturizer):
    """
    ChemPropFeaturizer featurizer for molecules, relies on chemprop
    """

    normalize_targets: bool = True
    n_jobs: int = 4
    batch_size: int = 128
    shuffle: bool = False

    def _prepare(self):
        """
        Prepare the featurizer
        """ 
    
    def featurize(
        self, smiles: Iterable[str], y: Iterable[Any] = None
    ) -> tuple[
        DataLoader,
        np.ndarray,
        StandardScaler,
        Union[MoleculeDataset, ReactionDataset, MulticomponentDataset],
    ]:
        """
        Featurize a list of SMILES strings

        #TODO: we likely want to separate the scaling from the featurization
        """
        if y is not None:
            # if a pandas dataframe or series
            if isinstance(y, pd.DataFrame) or isinstance(y, pd.Series):
                y = y.to_numpy()
            y = y.reshape(-1, 1) if y.ndim == 1 else y

            dataset = MoleculeDataset(
                [MoleculeDatapoint.from_smi(smi, y_) for smi, y_ in zip(smiles, y)]
            )
            if self.normalize_targets:
                scaler = dataset.normalize_targets()
            else:
                scaler = None
        else:
            dataset = MoleculeDataset(
                [MoleculeDatapoint.from_smi(smi) for smi in smiles]
            )
            scaler = None

        dataloader = self.dataset_to_dataloader(
            dataset,
            num_workers=self.n_jobs,
            shuffle=self.shuffle,
            batch_size=self.batch_size,
        )

        # Need to also return an index of the original input for which the features were computed
        indices = np.arange(len(smiles))

        return dataloader, indices, scaler, dataset
```
A few points to note: 

- You need to compute your feature matrix from the input iterable (in this case SMILES, but could be anything)
- You should compute for the whole length array, returning a feature matrix
- You can use multiprocessing inside the featurizer; the indices return variable is a mask saying which featurization worked.
- You need to register your featurizer by adding the `register` decorator and importing the python file in `openadmet/models/registries.py`

## 4. Implement model wrapper
Implement your model wrapper in openadmet/models/architecture/my_model.py. If the model doesn’t have an external reference implementation, you can bake the logic into this class.  
The three key API touch points are:  
1. **`build`:** Construct the internal representation of the model such that it is ready fro training
2. **`predict`:** Predict using a trained model 
3. **`train`:** Train the model directly (often better to use a **Trainer**)
The API looks like the below.  
```python
from openadmet.models.architecture.model_base import PickleableModelBase, models
from my_model_reference_implementation import MyRefModel 

@models.register("MyModel")
class MyModel(PickleableModelBase):
    """
		MyModel
    """
	
    type: ClassVar[str]
    mod_class = MyRefModel
		option1: blah
		option2: blah


    @classmethod
    def from_params(cls, class_params: dict = {}, mod_params: dict = {}):
        """
        Create a model from parameters
        """

        instance = cls(**class_params, mod_params=mod_params)
        instance.build()
        return instance


    def train(self, X: np.ndarray, y: np.ndarray):
        """
        Train the model
        """
        self.build()
        self.estimator = self.estimator.fit(X, y)

    def build(self):
        """
        Prepare the model
        """
        if not self.estimator:
            self.estimator = self.mod_class(option1=self.option1, option2=self.option2)
        else:
            logger.warning("Model already exists, skipping build")

    def predict(self, X: np.ndarray, **kwargs) -> np.ndarray:
        """
        Predict using the model
        """
        if not self.estimator:
            raise ValueError("Model not trained")
        return np.expand_dims(self.estimator.predict(X), axis=1)
```

**Like with the featurizer, you have to register the classes and import in registries.py**

## 5. Testing implementation
You should first test things are working the way you expect by doing some anvil runs with the new setup you want.  Some of the recipes in [`openadmet/models/tests/test_data`](/openadmet/models/tests/test_data/) should serve as good examples. See the [Training a model with Anvil](/demos/2.1_Training_models_with_Anvil/) with anvil on how to do that easily. Then, write and test with an anvil recipe that chains together all of the classes you want to use. 

✨✨✨✨✨✨✨