# Finetuning a Pretrained Model

The pretrained model usually captures general chemistry rules and as such is best trained on a diverse set of molecules (see [pretraining](./pretraining.ipynb)). However, in drug discovery we usually want to bias such a model to a more focused area of chemical space and that is what happens during finetuning. In this tutorial, we will use the finetuning data sets created [previously](./datasets.ipynb) to bias already existing models. Each model featured here is [available online for download](https://drive.google.com/file/d/1_j5tXXwDDD4wiX9DaPmp3cZM_icm1bUZ/view?usp=sharing). This tutorial assumes that these models are already placed in the `data/models/pretrained` folder.

## Finetuning a Graph-Based Model

The model we will use is pretrained on the CHMEBL database (version 27) with its states saved at `data/models/pretrained/chembl27_graph.pkg`. However, before we can load the states the model needs to be initialized for which we will need the vocabulary of the model, which is saved alongside it as `data/models/pretrained/chembl27_graph_voc.txt`:

In [1]:
from drugex.training.models.transform import GraphModel
from drugex.corpus.vocabulary import VocGraph

vocabulary = VocGraph.fromFile('data/models/pretrained/chembl27_graph_voc.txt')
pretrained = GraphModel(voc_trg=vocabulary)

  from .autonotebook import tqdm as notebook_tqdm


Now we can load the states of the model:

In [2]:
import torch

# https://pytorch.org/tutorials/beginner/saving_loading_models.html
pretrained.load_state_dict(torch.load('data/models/pretrained/chembl27_graph.pkg', map_location=torch.device('cuda')))

<All keys matched successfully>

In [3]:
pretrained.attachToDevices((0,))

In [7]:
from drugex.datasets.processing import GraphFragDataSet

finetuning_data_train = GraphFragDataSet("train")
finetuning_data_train.fromFile("data/inputs/graph/train.txt")

finetuning_data_test = GraphFragDataSet("test")
finetuning_data_test.fromFile("data/inputs/graph/test.txt")

In [8]:
from drugex.training.trainers import FineTuner
from drugex.training.monitors import FileMonitor  

finetuner = FineTuner(pretrained)
monitor = FileMonitor('data/models/finetuned/A2AR_finetuned')
finetuner.fit(
    finetuning_data_train.asDataLoader(128), 
    finetuning_data_test.asDataLoader(128), 
    epochs=3, # only 3 epochs to speed things up
    monitor=monitor
)

ft_model = monitor.getModel()
print("Finetuning done.")

100%|███████████████████████████████████████████████████████████████████████| 5/5 [06:01<00:00, 72.23s/it]


OrderedDict([('emb_word.weight',
              tensor([[-0.0205, -0.0666, -0.0142,  ...,  0.0925,  0.0784,  0.0773],
                      [ 0.0902,  0.0458, -0.0819,  ..., -0.0799,  0.0197,  0.0698],
                      [ 0.0079, -0.0163,  0.0373,  ..., -0.0541,  0.0064, -0.0043],
                      ...,
                      [ 0.0149, -0.0911,  0.1481,  ...,  0.0099,  0.0519, -0.0899],
                      [ 0.1012,  0.1131, -0.1165,  ...,  0.2805, -0.1792, -0.0867],
                      [ 0.1037, -0.0345,  0.0282,  ..., -0.0145, -0.0384, -0.1007]],
                     device='cuda:0')),
             ('emb_atom.weight',
              tensor([[-0.0183, -0.0838, -0.0696,  ...,  0.0588,  0.0752,  0.0655],
                      [-0.1270, -0.4553, -0.3298,  ...,  0.4359,  0.0338, -0.1721],
                      [-0.2488, -0.0036,  0.0822,  ...,  0.1169, -0.2072, -0.0414],
                      ...,
                      [-0.0074, -0.0176,  0.1349,  ..., -0.0855, -0.0814, -0.0909],