In [None]:
import gentrl
import os

import torch.distributed as dist
from torch.multiprocessing import Pool, Process
 
import torch.multiprocessing as mp
# torch.cuda.set_device(3)

In [None]:
from moses.metrics import mol_passes_filters, QED, SA, logP
from moses.metrics.utils import get_n_rings, get_mol


def get_num_rings_6(mol):
    r = mol.GetRingInfo()
    return len([x for x in r.AtomRings() if len(x) > 6])


def penalized_logP(mol_or_smiles, masked=False, default=-5):
    mol = get_mol(mol_or_smiles)
    if mol is None:
        return default
    reward = logP(mol) - SA(mol) - get_num_rings_6(mol)
    if masked and not mol_passes_filters(mol):
        return default
    return reward

In [None]:
def run(rank, size, backend, flag='vae'):
    args = {}
    args['rank'] = rank
    args['size'] = size
    args['backend'] = backend
    args['batch_size'] = 2500
    args['num_epochs'] = 20
    args['verbose_step'] = 50
    args['lr'] = 1e-4
    args['data_dir'] = '../examples/train_plogp_plogpm.csv'
        
    if flag == 'vae':
        print("Start Training VAE")
        gentrl.distributed_gentrl.train_as_vaelp(args)
        print("End Training VAE")
    elif flag == 'rl':
        args['reward_fn'] = penalized_logP
        args['num_iterations'] = 5000  ## 10000*200/1500 = 1350
        args['batch_size'] = 1500
        args['cond_lb'] = -2
        args['cond_rb'] = 0
        args['lr_lp'] = 1e-5
        args['lr_dec'] = 1e-6
        print("Start Training RL")
        gentrl.distributed_gentrl.train_as_rl(args)
        print("End Training RL")
    
def init_process(rank, size, flag, fn, backend='gloo'):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend, rank=rank, world_size=size)
    fn(rank, size, backend, flag)

In [None]:
df = pd.read_csv('dataset_v1.csv')
df = df[df['SPLIT'] == 'train']
df['plogP'] = df['SMILES'].apply(penalized_logP)
df.to_csv('train_plogp_plogpm.csv', index=None)

In [None]:
enc = gentrl.RNNEncoder(latent_size=50)
dec = gentrl.DilConvDecoder(latent_input_size=50)
model = gentrl.GENTRL(enc, dec, 50 * [('c', 20)], [('c', 20)], beta=0.001)
model.cuda();

In [None]:
md = gentrl.MolecularDataset(sources=[
    {'path':'train_plogp_plogpm.csv',
     'smiles': 'SMILES',
     'prob': 1,
     'plogP' : 'plogP',
    }], 
    props=['plogP'])

from torch.utils.data import DataLoader
train_loader = DataLoader(md, batch_size=50, shuffle=True, num_workers=1, drop_last=True)

In [None]:
model.train_as_vaelp(train_loader, lr=1e-4)

In [None]:
! mkdir -p saved_gentrl

In [None]:
model.save('./saved_gentrl/')