In [None]:

import os
import sklearn
import gentrl
import torch
import pandas as pd
from tqdm import tqdm

# moses and rdkit
from moses.metrics import mol_passes_filters, QED, SA, logP
from moses.metrics.utils import get_n_rings, get_mol
from rdkit import Chem
from rdkit.Chem import Descriptors
from rdkit import RDLogger
from torch.utils.data import DataLoader
RDLogger.DisableLog('rdApp.*')

import warnings
warnings.filterwarnings("ignore",category=UserWarning)

os.environ['CUDA_VISIBLE_DEVICES']='0,1,2,3'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(0)

In [None]:
MODEL_PATH = "ddr1_models"
BATCH_SIZE = 1000
NUM_EPOCHS = 100
LR = 1e-3

DATA = [{'path':'ddr1_datasets/ZINC_IHAD_~100k_clean.csv',
         'smiles': 'smiles',
         'prob': 0.175,
         'label' : 'label',
    },
    {'path':'ddr1_datasets/train_moses_all.csv',
         'smiles': 'SMILES',
         'prob': 0.175,
         'label' : 'label',
    },
    {'path':'ddr1_datasets/ddr1_inhibitors.csv',
     'smiles': 'smi',
     'prob': 0.35,
     'label' : 'label',
    },{'path':'ddr1_datasets/common_inhibitors.csv',
     'smiles': 'smi',
     'prob': 0.15,
     'label' : 'label',
    },
    {'path':'ddr1_datasets/none_kinase_target_compounds.csv',
     'smiles': 'smi',
     'prob': 0.15,
     'label' : 'label',
    }
    ]

In [None]:
def init_model():
    enc = gentrl.RNNEncoder(latent_size=50,hidden_size=128)
    dec = gentrl.DilConvDecoder(latent_input_size=50)
    model = gentrl.GENTRL(enc, dec, 50 * [('c', 10)], [('c', 10)], tt_int=30,beta=0.001)
    model = model.to(device)
    return model

In [None]:
model_data = gentrl.MolecularDataset(sources=DATA, props=['label'])
train_loader = DataLoader(model_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)

In [None]:
os.mkdir(MODEL_PATH)

model = init_model()
global_stats,local_stats, stats_dictionary = model.increase_vaelp_validity(train_loader, lr=LR,num_epochs=NUM_EPOCHS,file_path=MODEL_PATH,dec_ratio=0)

pd.DataFrame(stats_dictionary).to_csv(MODEL_PATH +"/losses.csv",index=None)