source activate my-rdkit-env
#### 코드 수정 후
GENTRL 폴더에서 `python setup.py install` 수행

In [None]:
import gentrl
import os

import torch.distributed as dist
from torch.multiprocessing import Process

In [None]:
from moses.metrics import mol_passes_filters, QED, SA, logP
from moses.metrics.utils import get_n_rings, get_mol


def get_num_rings_6(mol):
    r = mol.GetRingInfo()
    return len([x for x in r.AtomRings() if len(x) > 6])


def penalized_logP(mol_or_smiles, masked=False, default=-5):
    mol = get_mol(mol_or_smiles)
    if mol is None:
        return default
    reward = logP(mol) - SA(mol) - get_num_rings_6(mol)
    if masked and not mol_passes_filters(mol):
        return default
    return reward

In [None]:
def run(rank, size, backend, flag='vae'):
    args = {}
    args['rank'] = rank
    args['size'] = size
    args['backend'] = backend
    args['batch_size'] = 2500
    args['num_epochs'] = 2
    args['verbose_step'] = 50
    args['lr'] = 1e-4
    args['hvd'] = True
    args['data_dir'] = '../examples/train_plogp_plogpm.csv'
    
    args['apex'] = False
    args['sync_bn'] = False
    args['opt_level'] = 'O1'
    args['keep_batchnorm_fp32'] = None
    args['loss_scale'] = None
        
    if flag == 'vae':
        print("Start Training VAE")
        gentrl.distributed_gentrl.train_as_vaelp(args)
        print("End Training VAE")
    elif flag == 'rl':
        args['reward_fn'] = penalized_logP
        args['num_iterations'] = 5000  ## 10000*200/1500 = 1350
        args['batch_size'] = 1500
        args['cond_lb'] = -2
        args['cond_rb'] = 0
        args['lr_lp'] = 1e-5
        args['lr_dec'] = 1e-6
        print("Start Training RL")
        gentrl.distributed_gentrl.train_as_rl(args)
        print("End Training RL")
    
def init_process(rank, size, flag, fn, backend='nccl'):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend, rank=rank, world_size=size)
    fn(rank, size, backend, flag)

In [None]:
# ! mkdir -p saved_gentrl_after_rl
# ! mkdir -p saved_gentrl

In [None]:
%%time
size = 4
flags = ['vae']

for flag in flags:
    processes = []
    for rank in range(size):
        p = Process(target=init_process, args=(rank, size, flag, run))
        p.start()
        processes.append(p)


    for p in processes:
        p.join()
    
    p.terminate()

In [None]:
%%time
size = 4
flags = ['rl']

for flag in flags:
    processes = []
    for rank in range(size):
        p = Process(target=init_process, args=(rank, size, flag, run))
        p.start()
        print(p)
        processes.append(p)

    for p in processes:
        p.join()
    p.terminate()