# Pretraining Models

TODO

## Pretraining the Graph-Based Transformer

TODO

In [1]:
import pandas as pd

N_PROC = 12
MAX_MOLS = 10000
USE_GPUS = (2,3,4)
N_EPOCHS = 20

In [2]:
smiles = pd.read_table('jupyter/data/chembl_30_ALL.smi', sep='\t', header=0, usecols=('Smiles',)).squeeze('columns').sample(MAX_MOLS)
smiles.shape

(10000,)

In [3]:
from utils import initLogger

initLogger('pretraining_graph.log')

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from drugex.data.processing import Standardization

standardizer = Standardization(n_proc=N_PROC)
smiles = standardizer.apply(smiles)
len(smiles)

Standardizing molecules (batch processing): 100%|██████████| 1/1 [00:02<00:00,  2.62s/it]


10000

In [5]:
from drugex.data.datasets import GraphFragDataSet
import os

graph_input_folder = "data/sets/graph/pretraining"
if not os.path.exists(graph_input_folder):
    os.makedirs(graph_input_folder)

train = GraphFragDataSet(f"{graph_input_folder}/chembl30_train.tsv", rewrite=True)
test = GraphFragDataSet(f"{graph_input_folder}/chembl30_test.tsv", rewrite=True)

In [6]:
from drugex.data.fragments import FragmentCorpusEncoder
from drugex.data.fragments import GraphFragmentEncoder, FragmentPairsSplitter
from drugex.molecules.converters.fragmenters import Fragmenter
from drugex.data.corpus.vocabulary import VocGraph

vocabulary = VocGraph(n_frags=4)
encoder = FragmentCorpusEncoder(
    fragmenter=Fragmenter(4, 4, 'brics'),
    encoder=GraphFragmentEncoder(
        vocabulary
    ),
    pairs_splitter=FragmentPairsSplitter(0.1, 1000, make_unique=False),
    n_proc=N_PROC,
    chunk_size=500
)

encoder.apply(smiles, encodingCollectors=[test, train])

Creating fragment-molecule pairs (batch processing): 100%|██████████| 1/1 [00:12<00:00, 12.35s/it]
Encoding fragment-molecule pairs. (batch processing): 100%|██████████| 1/1 [00:01<00:00,  1.82s/it]
Encoding fragment-molecule pairs. (batch processing): 100%|██████████| 3/3 [00:17<00:00,  5.82s/it]


In [7]:
from drugex.training.models.transform.gpt2graph import GraphModel
from drugex.data.corpus.vocabulary import VocGraph
from drugex.training.monitors import FileMonitor

vocabulary = VocGraph(n_frags=4)
model = GraphModel(voc_trg=vocabulary, use_gpus=USE_GPUS)
monitor = FileMonitor(f'data/models/pretrained/graph/chembl_sample_{MAX_MOLS}', verbose=True)

In [None]:
model.fit(train.asDataLoader(512), test.asDataLoader(512), monitor=monitor, epochs=N_EPOCHS)

  0%|          | 0/20 [00:00<?, ?it/s]

In [None]:
df_info = pd.read_table(monitor.outDF)
df_info

In [None]:
df_info[['loss_valid', 'mean_train_loss']].plot.line()

In [None]:
df_smiles = pd.read_table(monitor.outSmiles)
df_smiles

In [None]:
from utils import smilesToGrid

smilesToGrid(df_smiles[df_smiles['Epoch'] % 5 == 0]['SMILES'])

In [None]:
from drugex.training.models.transform.gpt2graph import GraphModel

pretrained = GraphModel(voc_trg=vocabulary, use_gpus=USE_GPUS)
pretrained.loadStatesFromFile(f'{monitor.path}.pkg')

In [None]:
inputs = [
    "c1ccncc1CC2CC2",
    "CC2CC2",
]

smiles, frags = pretrained.sampleFromSmiles(inputs, min_samples=100)
set(frags)

In [None]:
smilesToGrid(smiles)