# Specify, train and evaluate a PyTorch-based model using AbstractTorchFSMolModel

We provide framework code so that few-shot models can be evaluated according to our benchmarking procedure using the `AbstractTorchFSMolModel` base class.

In [None]:
2363-2048

In [None]:
# Setting up local details:
import os
import sys

# This should be the location of the checkout of the FS-Mol repository:
FS_MOL_CHECKOUT_PATH = os.path.join(os.environ['HOME'], "Projects", "FS-Mol")
FS_MOL_DATASET_PATH = os.path.join(os.environ['HOME'], "Datasets", "FS-Mol")

os.chdir(FS_MOL_CHECKOUT_PATH)
sys.path.insert(0, FS_MOL_CHECKOUT_PATH)

# Implementing a new `AbstractTorchFSMolModel`

As example, we use `fs_models/models/gnn_multitask.py`, our implementation of a GNN model implementation, but repeat function definitions integrating with the larger framework in simplified form in this notebook.

We need to implement five methods required by the `AbstractTorchFSMolModel` class:

* `forward(self, batch: BatchFeaturesType) -> BatchOutputType`:
  This has to implement the core model, as usual.
  In the GNNMultittask setting, it consumes a `FSMolMultitaskBatch` object, which is created by our batching pipeline, and extends the standard `FSMolBatch` by a mapping from each graph in the batch to a unique task ID.
   
  As is standard, our `GNNMultitaskModel` creates all needed sublayers in `__init__` and we now simply plug these together.
  Concretely, we simply chain four components:
  * Use `model.init_node_proj` to project the node features to the input for the GNN.
  * Use `model.gnn` to implement the graph-based message passing.
  * Use `model.readout` to compute a graph-level representation for all graphs in a minibatch.
  * Use `model.tail_mlp` to compute per-task predictions for each graph in the minibatch.

  The output of this method should be a `TorchFSMolModelOutput` object, which at least contains the predictions that can be used to calculate a differentiable loss when combined with graph labels from the batch.

* `get_model_state(self) -> ModelStateType` and `load_model_state(self, model_state: ModelStateType, load_task_specific_weights: bool, quiet: bool = False) -> None`:
  These two methods should be sufficient to get the current state of the model (e.g., parameters of the model) and reset to it.
  This is used to retrieve the best state found during the training loop and restoring it when evaluating on the test set.
  `load_task_specific_weights` is used to signal that only generic parameters of a model should be loaded, e.g., when preparing to fine-tune on a new task.

* `is_param_task_specific(self, param_name: str) -> bool`:
  This method is used to determine which parameters are task-specific to implement different learning rates for different parts of the model during fine-tuning.

* `build_from_model_file(...) -> AbstractTorchFSMolModel[BatchFeaturesType]`:
  This is the core factory method creating a fresh model from scratch.

## Training an `AbstractTorchFSMolModel` model

We provide a default training loop for implementations of `AbstractTorchFSMolModel`, implemented as `train_loop` in `fs_mol.models.abstract_torch_fsmol_model`. It is in particular useful for fine-tuning a pre-trained model, but can be used to train full models as well.

To use it, we need to assemble a number of components:
* An actual model, implementing `AbstractTorchFSMolModel`.
* A data pipeline providing appropriate minibatches to the model.
* A validation function which will evaluate the model on the validation tasks during training.

We will show these steps in detail below.

### Creating the `GNNMultitask` model

We first create the actual model - this is very much like any other PyTorch model, just requiring the interface discussed at the top of this notebook. We do not go into the details of the `GNNMultitask` model here.

In [None]:
import torch
from fs_mol.models.gnn_multitask import (
    GNNMultitaskConfig,
    GNNMultitaskModel,
    GNNConfig,
    create_model,
)
from fs_mol.data.fsmol_dataset import FSMolDataset, DataFold

fsmol_dataset = FSMolDataset.from_directory(FS_MOL_DATASET_PATH)

# Set up an output directory in which to save a model
out_dir = os.path.join(os.getcwd(), "test")
os.makedirs(out_dir, exist_ok=True)

# Set up the model configuration that specifies a GNNMultitaskModel, using mostly default parameters.
# Consult fs_mol/models/gnn_multitask.py for a full list.
model_config = GNNMultitaskConfig(
    num_tasks=fsmol_dataset.get_num_fold_tasks(DataFold.TRAIN),
    gnn_config=GNNConfig(
        type="PNA",
        hidden_dim=128,
    ),
)

# create an instance of a model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = create_model(model_config, device=device)

print(f"Num parameters {sum(p.numel() for p in model.parameters())}")
print(f"Model:\n{model}")

### Setting up Batching

To train the model, we first need to set up a data pipeline, which we can do using our existing batching infrastructure (see notebooks/dataset.ipynb for more details):

In [None]:
from fs_mol.data.multitask import MultitaskTaskSampleBatchIterable

# note we need this dictionary to connect task numbers with names
train_task_name_to_id = {
    name: i for i, name in enumerate(fsmol_dataset.get_task_names(data_fold=DataFold.TRAIN))
}

train_data = MultitaskTaskSampleBatchIterable(
    fsmol_dataset,
    data_fold=DataFold.TRAIN,
    task_name_to_id=train_task_name_to_id,
    max_num_graphs=256,
)
print(f"Have {len(train_task_name_to_id)} training tasks.")

### Defining a Validation Function

Our generic `train_loop` takes a function to evaluate the current model to determine when to stop training (using early stopping) and to identify the best checkpoint.

While there are many potential validation functions, a particularly obvious one is to evaluate how well the model performs on the validation tasks when fine-tuned on them, for which we provide a `eval_model_by_finetuning_on_task` function. Internally, that function will kick off another `train_loop` on a single task sample and report the results of the fine-tuned model on the test of that task sample.

In [None]:
import tempfile
import numpy as np

from fs_mol.data import FSMolTaskSample
from fs_mol.data.multitask import get_multitask_inference_batcher
from fs_mol.models.abstract_torch_fsmol_model import save_model, eval_model_by_finetuning_on_task
from fs_mol.utils.metrics import BinaryEvalMetrics
from fs_mol.utils.test_utils import eval_model

def validate_by_finetuning_on_tasks(
    model: GNNMultitaskModel,
    seed: int = 0,
) -> float:
    with tempfile.TemporaryDirectory() as tempdir:
        # First, store the current state of the model, so that we can just load it back in
        # repeatedly as starting point during finetuning:
        current_model_path = os.path.join(tempdir, "cur_model.pt")
        save_model(current_model_path, model)

        # Move model off GPU to make space for validation model:
        model_device = model.device
        model = model.to(torch.device("cpu"))

        def test_model_fn(
            task_sample: FSMolTaskSample, temp_out_folder: str, seed: int
        ) -> BinaryEvalMetrics:
            return eval_model_by_finetuning_on_task(
                current_model_path,
                model_cls=GNNMultitaskModel,
                task_sample=task_sample,
                batcher=get_multitask_inference_batcher(max_num_graphs=256),
                learning_rate=0.00005,
                task_specific_learning_rate=0.0001,
                quiet=True,
                device=model_device,
            )

        task_to_results = eval_model(
            test_model_fn=test_model_fn,
            dataset=fsmol_dataset,
            train_set_sample_sizes=[16],
            num_samples=1,
            valid_size_or_ratio=0.2,
            test_size_or_ratio=128,
            fold=DataFold.VALIDATION,
            seed=seed,
        )

        model = model.to(model_device)

        # Compute mean of average precisions per task:
        return np.mean(
            [
                np.mean([task_result.avg_precision for task_result in task_results])
                for task_results in task_to_results.values()
            ]
        )

### Running Training

We are now ready to run training combining all of our components:

In [None]:
from fs_mol.models.abstract_torch_fsmol_model import (
    train_loop,
    create_optimizer,
)

validate_by_finetuning_on_tasks(model)

# create a specific optimizer with learning rate for training
optimizer, lr_scheduler = create_optimizer(model, lr=0.00005, task_specific_lr=0.0001)

# run a training loop on the model for a single epoch - this will take ~15 minutes on a P100, or very long a CPU-only machine.
best_valid_metric, best_model_state = train_loop(
    model=model,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler,
    train_data=train_data,
    valid_fn=validate_by_finetuning_on_tasks,
    max_num_epochs=1,
)

# restore best model parameters:
model.load_model_state(best_model_state, load_task_specific_weights=True)

### Evaluating the Model

Finally, we can use the generic evaluation infrastructure (see notebooks/evaluation.ipynb) to evaluate the trained model:

In [None]:
import tempfile
from fs_mol.models.abstract_torch_fsmol_model import save_model

with tempfile.TemporaryDirectory() as temp_dir:
    model_weights_file = os.path.join(temp_dir, "model.pt")
    save_model(model_weights_file, model)

    def test_model_fn(
        task_sample: FSMolTaskSample, temp_out_folder: str, seed: int
    ) -> BinaryEvalMetrics:
        return eval_model_by_finetuning_on_task(
            model_weights_file,
            model_cls=GNNMultitaskModel,
            task_sample=task_sample,
            batcher=get_multitask_inference_batcher(max_num_graphs=256),
            learning_rate=0.00005,
            task_specific_learning_rate=0.0001,
            seed=seed,
            quiet=True,
            device=device,
        )

    eval_results = eval_model(
        test_model_fn=test_model_fn,
        dataset=fsmol_dataset,
        # Require a validation set so that fine-tuning can work
        valid_size_or_ratio=0.2,
        # Restrict number of samples to one per task:
        train_set_sample_sizes=[16],
        num_samples=1,
    )

eval_results