In [1]:
import pandas as pd

df = pd.read_csv('data/A2AR_raw.txt', na_values=('NA', 'nan', 'NaN'), header=0, sep='\t', usecols=('CANONICAL_SMILES', 'PCHEMBL_VALUE'))
df.dropna(subset=['PCHEMBL_VALUE'], inplace=True)
df.head()

Unnamed: 0,CANONICAL_SMILES,PCHEMBL_VALUE
0,CCCCn1cc2c(nc(NC(=O)Nc3ccc(cc3)S(=O)(=O)O)n4nc...,6.45
2,NC1=Nc2c(cnn2CCN3CCC(CC3)N4CCOCC4)C5=NN(Cc6ccc...,7.55
4,COc1cncc(c1)c2cc(NC(=O)CN3CCOCC3)nc(n2)n4nc(C)...,8.4
5,Cc1ccc2c(NN)c(Cc3ccccc3)cnc2n1,6.14
6,Nc1nc2c(cnn2CCCc3ccc(O\C=C\c4ccccc4)cc3)c5nc(n...,6.51


In [2]:
from drugex.training.scorers.predictors import Predictor
from rdkit import Chem

X = Predictor.calc_physchem([Chem.MolFromSmiles(x) for x in df.CANONICAL_SMILES])
y = df.PCHEMBL_VALUE

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split

rf = RandomForestRegressor(random_state=42)
rf.fit(X, y)

In [4]:
from drugex.training.interfaces import Scorer

class ModelScorer(Scorer):
    
    def __init__(self, model):
        super().__init__()
        self.model = model
    
    def getScores(self, mols, frags=None):
        X = Predictor.calc_physchem([Chem.MolFromSmiles(x) for x in mols])
        return self.model.predict(X)
    
    def getKey(self):
        return f"ModelScorer{type(self.model)}"
    
scorer = ModelScorer(rf)
scorer(["CN1C=NC2=C1C(=O)N(C(=O)N2C)C"]) # caffeine

array([4.78069152])

In [5]:
from drugex.training.scorers.properties import Property
from drugex.training.scorers.modifiers import ClippedScore

logP = Property(
    "logP",
    modifier=ClippedScore(lower_x=6, upper_x=4)
)

mw = Property(
    "MW",
    modifier=ClippedScore(lower_x=1000, upper_x=500)
)

scorers = [
    scorer,
    logP,
    mw
]
thresholds = [
    0.99,
    0.5,
    0.5
]

In [6]:
from drugex.training.environment import DrugExEnvironment

environment = DrugExEnvironment(scorers, thresholds)

In [7]:
from drugex.training.models.explorer import GraphExplorer
from drugex.training.models.transform import GraphModel
from drugex.corpus.vocabulary import VocGraph
import torch

vocabulary = VocGraph() # maybe show how to load from file
finetuned = GraphModel(voc_trg=vocabulary)
finetuned.load_state_dict(torch.load('data/models/finetuned/A2AR_finetuned.pkg', map_location=torch.device('cuda')))
pretrained = GraphModel(voc_trg=vocabulary)
pretrained.load_state_dict(torch.load('data/models/pretrained/chembl27_graph.pkg', map_location=torch.device('cuda')))

explorer = GraphExplorer(pretrained, environment, mutate=finetuned)

In [8]:
from drugex.datasets.processing import GraphFragDataSet

train = GraphFragDataSet('train')
train.fromFile('data/inputs/graph/train.txt')
test = GraphFragDataSet('test')
test.fromFile('data/inputs/graph/test.txt')

In [9]:
import logging

logging.basicConfig(level=logging.INFO)

In [10]:
from drugex.training.trainers import Reinforcer
from drugex.training.monitors import FileMonitor

reinforcer = Reinforcer(explorer)
monitor = FileMonitor("data/models/reinforced/a2ar_RL")
reinforcer.fit(train.asDataLoader(batch_size=512), test.asDataLoader(batch_size=512), monitor=monitor, epochs=2)

INFO:root:
----------
ITERATION 0/ 1
----------
  0%|                                                                               | 0/2 [00:00<?, ?it/s]INFO:root:Forward pass. Batch 0/23.
INFO:root:Forward pass. Batch 1/23.
INFO:root:Forward pass. Batch 2/23.
INFO:root:Forward pass. Batch 3/23.
  0%|                                                                               | 0/2 [01:32<?, ?it/s]


KeyboardInterrupt: 