# Finetuning a Pretrained Model

The pretrained model usually captures general chemistry rules and as such is best trained on a diverse set of molecules (see [pretraining](./pretraining.ipynb)). However, in drug discovery we usually want to bias such a model to a more focused area of chemical space and that is what happens during finetuning. In this tutorial, we will use the finetuning data sets created [previously](./datasets.ipynb) to bias already existing models. The pretrained models used in this tutorial should be placed in the `jupyter/models/pretrained` folder.

## Finetuning a Graph-Based Model

The model we will use is pretrained on the CHMEBL database (version 27) with its states saved at `jupyter/models/pretrained/graph/chembl27/chembl27_graph.pkg`. However, before we can load its states the model needs to be initialized so we will need the vocabulary of the model first:

In [1]:
from drugex.training.models.transform import GraphModel
from drugex.data.corpus.vocabulary import VocGraph

vocabulary = VocGraph.fromFile('jupyter/models/pretrained/graph/chembl27/chembl27_graph_voc.txt')
pretrained = GraphModel(voc_trg=vocabulary)

  from .autonotebook import tqdm as notebook_tqdm


Now we can load the states of the model:

In [2]:
pretrained.loadStatesFromFile('jupyter/models/pretrained/graph/chembl27/chembl27_graph.pkg')

### Finetuning

This is not required for pretraining, but it does not hurt to test that the model is working as expected and can generate molecules. 

First, we need to load our finetuning data we created in the [previous tutorial](datasets.ipynb):

In [11]:
from drugex.data.datasets import GraphFragDataSet

# autoload=True requests that the data from the file is read upon initialization
finetuning_data_train = GraphFragDataSet("data/model_inputs/graph/ligand_train.tsv", autoload=True)
finetuning_data_test = GraphFragDataSet("data/model_inputs/graph/ligand_test.tsv", autoload=True)

We can verify the data has been loaded by converting it to pandas `DataFrame`:

In [12]:
finetuning_data_train.getDataFrame().shape

(1165, 400)

In [13]:
finetuning_data_test.getDataFrame().shape

(166, 400)

Finally, we can proceed to finetuning the model (note that we need to transform the `DataSet` to `DataLoader` first):

In [17]:
from drugex.training.trainers import FineTuner
from drugex.training.monitors import FileMonitor

class SimpleMonitor(FileMonitor):
    """
    Simplified implementation of the FileMonitor. We only need to write basic information for this tutorial.
    """
    
    def savePerformanceInfo(self, current_step=None, current_epoch=None, loss=None, *args, loss_valid=None, best=None, **kwargs):
        if current_epoch:
            self.out.write(f"Current training loss: {loss} \n")
            self.out.write(f"Current validation loss: {loss_valid} \n")
            self.out.write(f"\tBest validation loss: {best} \n")
            self.out.flush()

finetuner = FineTuner(pretrained, gpus=(0,)) # you can specify multiple GPUs to be used with models that support them.
monitor = SimpleMonitor('data/models/finetuned/ligand_finetuned')
finetuner.fit(
    finetuning_data_train.asDataLoader(128),
    finetuning_data_test.asDataLoader(128), 
    epochs=10, # only 10 epochs to speed things up
    monitor=monitor
)

print("Finetuning done.")

100%|█████████████████████████████████████████████████████████████████████| 10/10 [04:31<00:00, 27.16s/it]

Finetuning done.





If you want to follow the training progress, you can periodically check the `data/models/finetuned/ligand_finetuned.log` file that is created along with the states of the models by the monitor. The monitor also enables you to get this model's states directly:

In [19]:
ft_model = monitor.getModel()
ft_model['emb_word.weight']

tensor([[-0.0205, -0.0666, -0.0142,  ...,  0.0925,  0.0784,  0.0773],
        [ 0.0902,  0.0458, -0.0819,  ..., -0.0799,  0.0197,  0.0698],
        [ 0.0079, -0.0163,  0.0373,  ..., -0.0541,  0.0064, -0.0043],
        ...,
        [ 0.0148, -0.0911,  0.1478,  ...,  0.0099,  0.0517, -0.0901],
        [ 0.1012,  0.1127, -0.1173,  ...,  0.2809, -0.1795, -0.0868],
        [ 0.1036, -0.0344,  0.0282,  ..., -0.0143, -0.0385, -0.1008]],
       device='cuda:0')

So you can use the monitor to initialize another instance easily:

In [21]:
other_model = GraphModel(voc_trg=vocabulary)
other_model.load_state_dict(ft_model)

<All keys matched successfully>

For consistency, we will also save the used vocabulary with the model:

In [22]:
vocabulary.toFile('data/models/finetuned/ligand_finetuned.vocab')