In [105]:
# Import packages
import torch
from torch.optim import Adam
from Model.SURGE import SURGE
from Reinforcement_Learning.mol_env import vectorized_mol_env
from Model.graph_embedding import batch_from_states
import numpy as np
from tqdm import trange
import datetime
import wandb
import copy
import rdkit.Chem as Chem
import rdkit.Chem.QED as QED # QED
import sys # SAS
import os
from rdkit.Chem import RDConfig
sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
import sascorer
from rdkit.Chem import Descriptors # MW
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [106]:
model = SURGE()
max_steps = 200
num_envs = 4
env = vectorized_mol_env(num_envs = num_envs, max_steps = max_steps) # Vectorized molecular environment
lr = 0.02
optimizer = Adam(lr = lr, params = model.parameters())

In [98]:
gamma = 0.99
eps = np.finfo(np.float32).eps.item() # Small constant to decrease numerical instability
num_episodes = 500

In [None]:
current_datetime = datetime.datetime.now()
current_time = current_datetime.strftime("%Y-%m-%d-%H-%M-%S")

wandb.init(
    project = 'RL_Drug_Generation',
    name= f'REINFORCE---{current_time}',
    config={
        'lr': lr,
        'architecture': str(model),
        'episodes': num_episodes,
        'gamma': gamma,
        'num_envs': num_envs,
    })

In [99]:
def evaluate(saved_states, saved_actions):
    """
    This methods evaluates the molecules generated during the episode and returns diversity, novelty, and validity.
    It also analyzes the molecule metrics of drug-likeness (QED), synthetic accessibility (SAS), and size.
    It only analyzes the finished molecules (i.e., after SURGE tells generation to terminate).
    """
    
    if 1 in saved_actions['t']:
        flattened_states = []
        flattened_t = []
        for i in range(saved_states.shape[1]):
            states_col = saved_states[:, i].flatten().tolist()
            t_col = saved_actions['t'][:, i].flatten().tolist()
            if i == 0:
                flattened_states = states_col
                flattened_t = t_col
            else:
                flattened_states += states_col
                flattened_t += t_col
        
        idx = [i for i, t in enumerate(flattened_t) if t == 1]
        finished_states = [flattened_states[idx[i]] for i in range(len(idx))]
    else:
        finished_states = [saved_states[i, saved_states.shape[1] - 1] for i in range(saved_states.shape[0])]
    
    
    # QED: From 0 to 1 where 1 is the most drug-like
    # SAS: From 1 to 10 where 1 is the easiest to synthesize
    metric_names = ['Average Size', 'Average MW', 'Average QED', 'Average SAS']
    num_mols = len(finished_states)
    metrics = dict()
    for name in metric_names:
        metrics[name] = 0
        
    for i in range(num_mols):
        finished_states[i].UpdatePropertyCache()
        if finished_states[i].GetNumAtoms() == 1:
            metrics['Average SAS'] = 0
        else:
            metrics['Average SAS'] += sascorer.calculateScore(finished_states[i])
        
        metrics['Average Size'] += finished_states[i].GetNumAtoms()
        metrics['Average MW'] += Descriptors.MolWt(finished_states[i])
        metrics['Average QED'] += QED.weights_mean(finished_states[i])
    
    for name in metric_names:
        metrics[name] /= num_mols
    
    return metrics

In [100]:
def calc_returns(saved_rewards, gamma):
    """
    Given multidimensional matrix of rewards, computes the returns.
    """
    prev_row = []
    for reward_idx in reversed(range(saved_rewards.shape[0])):
        cur_row = saved_rewards[reward_idx, :]
        if reward_idx == saved_rewards.shape[0] - 1:
            all_returns = cur_row
            prev_row = cur_row
        else:
            all_returns = torch.vstack((prev_row * gamma + cur_row, all_returns))
            prev_row *= gamma
            prev_row += cur_row
    return all_returns

In [107]:
# Training loop
best_reward_model = copy.deepcopy(model)
best_reward = 0
best_episode = 0
for episode in range(1, num_episodes + 1):

    # Reset environment after each episode
    states = env.reset()

    # Episode loggers
    keys = ['t', 'nmol', 'nfull', 'b']
    saved_actions = {'t': [], 'nmol': [], 'nfull': [], 'b': []}
    saved_log_probs = {'t': [], 'nmol': [], 'nfull': [], 'b': []}
    saved_rewards = []
    saved_states = np.array(states, dtype = object)

    # Episode computation
    pbar = trange(max_steps, unit="steps")
    for step in pbar:
        pbar.set_description(f"Episode {episode}")
        
        # Compute actions and log probabilities
        batch = batch_from_states(states)
        actions, log_probs = model.act(batch)

        # Record in episode loggers
        for key in keys:
            if step == 0:
                saved_actions[key] = actions[key]
                saved_log_probs[key] = log_probs[key]
            else:
                saved_actions[key] = np.vstack((saved_actions[key], actions[key]))
                saved_log_probs[key] = torch.vstack((saved_log_probs[key], log_probs[key]))

        # Take a step in environment
        states, rewards, valids, timestep = env.step(actions['t'], actions['nmol'], actions['nfull'], actions['b'])
        saved_states = np.vstack((saved_states, states))

        # Record rewards
        if step == 0:
            saved_rewards = torch.tensor(rewards)
        else:
            saved_rewards = torch.vstack((saved_rewards, torch.tensor(rewards)))

    # Loss calculation and gradient ascent
    cumulative_reward = torch.sum(saved_rewards) / num_envs
    
    metrics = evaluate(saved_states, saved_actions)
    
    # Returns
    all_returns = calc_returns(saved_rewards, gamma)
    returns = torch.zeros(num_envs)
    all_returns = (all_returns - all_returns.mean()) / (all_returns.std() + eps) # Normalize returns for better stability
    
    # Calculate loss
    all_loss = dict()
    cumulative_loss = 0
    for key in keys:
        # Find the average loss among vectorized environments for each SURGE component
        individual_loss = -1 * all_returns * saved_log_probs[key] / num_envs
        cumulative_loss += torch.sum(individual_loss)
        all_loss[key] = torch.sum(individual_loss)

    # Perform gradient ascent
    optimizer.zero_grad()
    cumulative_loss.backward()
    optimizer.step()
    
    metrics = evaluate(saved_states, saved_actions)
    metrics['Cumulative Reward'] = cumulative_reward.item()
    
    for key in metrics.keys():
        print(f"{key}: {metrics[key]}")
    
    metrics['Cumulative Loss'] = cumulative_loss.item()
    metrics['Nmol Loss'] = all_loss['nmol'].item()
    metrics['Nfull Loss'] = all_loss['nfull'].item()
    metrics['Bond Loss'] = all_loss['b'].item()
    metrics['Termination Loss'] = all_loss['t'].item()
    
    if cumulative_reward > best_reward:
        best_reward = cumulative_reward
        best_model = copy.deepcopy(model)
        best_episode = episode
    
    # wandb.log(metrics)

Episode 1: 100%|██████████| 200/200 [00:02<00:00, 95.59steps/s]


Average Size: 1.5012345679012347
Average MW: 29.4209679012344
Average QED: 0.3578981877099945
Average SAS: 0.0
Cumulative Reward: -416.8781498549757


Episode 2: 100%|██████████| 200/200 [00:02<00:00, 94.79steps/s]


Average Size: 1.8773946360153257
Average MW: 39.06759770114932
Average QED: 0.35473127564372364
Average SAS: 0.09885798123236081
Cumulative Reward: -201.08150524561486


Episode 3: 100%|██████████| 200/200 [00:02<00:00, 98.34steps/s]


Average Size: 2.5307692307692307
Average MW: 51.04302307692306
Average QED: 0.34516120807285156
Average SAS: 0.0
Cumulative Reward: -104.81753391900634


Episode 4: 100%|██████████| 200/200 [00:02<00:00, 97.20steps/s] 


Average Size: 4.109090909090909
Average MW: 78.81438181818181
Average QED: 0.32221684034497455
Average SAS: 0.31355702741897434
Cumulative Reward: -64.83104561805558


Episode 5: 100%|██████████| 200/200 [00:02<00:00, 96.87steps/s]


Average Size: 4.303030303030303
Average MW: 90.6930303030303
Average QED: 0.338750895284795
Average SAS: 1.1075870925456364
Cumulative Reward: -77.81102909334112


Episode 6: 100%|██████████| 200/200 [00:02<00:00, 86.97steps/s]


Average Size: 4.25
Average MW: 83.96845
Average QED: 0.33093793066466665
Average SAS: 4.519265414346751
Cumulative Reward: -84.99653052795418


Episode 7: 100%|██████████| 200/200 [00:02<00:00, 97.56steps/s] 


Average Size: 4.615384615384615
Average MW: 77.107
Average QED: 0.3306032465682463
Average SAS: 5.078762521712268
Cumulative Reward: -91.81118810718903


Episode 8: 100%|██████████| 200/200 [00:02<00:00, 98.93steps/s]


Average Size: 4.0625
Average MW: 76.69731249999998
Average QED: 0.35823040080229934
Average SAS: 4.940507059437719
Cumulative Reward: -91.706116395392


Episode 9: 100%|██████████| 200/200 [00:02<00:00, 98.30steps/s] 


Average Size: 3.619047619047619
Average MW: 63.9227619047619
Average QED: 0.33530403937405356
Average SAS: 2.123123045733815
Cumulative Reward: -84.96850548924999


Episode 10: 100%|██████████| 200/200 [00:01<00:00, 101.23steps/s]


Average Size: 4.555555555555555
Average MW: 77.69155555555554
Average QED: 0.3278292431335518
Average SAS: 5.052902429775141
Cumulative Reward: -93.45929168123837


Episode 11: 100%|██████████| 200/200 [00:01<00:00, 101.07steps/s]


Average Size: 4.066666666666666
Average MW: 80.62633333333335
Average QED: 0.3313980532414583
Average SAS: 5.136969167698374
Cumulative Reward: -92.58480128198492


Episode 12: 100%|██████████| 200/200 [00:02<00:00, 96.28steps/s]


Average Size: 4.9375
Average MW: 96.49931249999999
Average QED: 0.3538122486371012
Average SAS: 0.3601086519427103
Cumulative Reward: -82.10014477729442


Episode 13: 100%|██████████| 200/200 [00:02<00:00, 92.82steps/s]


Average Size: 9.2
Average MW: 218.53359999999998
Average QED: 0.3564327378875797
Average SAS: 5.758277306011797
Cumulative Reward: -84.05821883089297


Episode 14: 100%|██████████| 200/200 [00:02<00:00, 93.77steps/s]


Average Size: 5.411764705882353
Average MW: 130.6599411764706
Average QED: 0.32091731376633864
Average SAS: 3.304898383277883
Cumulative Reward: -81.67743651954675


Episode 15: 100%|██████████| 200/200 [00:02<00:00, 84.08steps/s]


Average Size: 7.857142857142857
Average MW: 150.51742857142858
Average QED: 0.3054636077185177
Average SAS: 7.164425593980437
Cumulative Reward: -85.46916970296645


Episode 16: 100%|██████████| 200/200 [00:02<00:00, 93.69steps/s]


Average Size: 5.666666666666667
Average MW: 106.93758333333335
Average QED: 0.3498054950195361
Average SAS: 0.8640669565587297
Cumulative Reward: -72.43128276744584


Episode 17: 100%|██████████| 200/200 [00:02<00:00, 94.94steps/s]


Average Size: 4.833333333333333
Average MW: 84.74536666666667
Average QED: 0.31076723390974303
Average SAS: 1.2509338789765354
Cumulative Reward: -72.72614519210968


Episode 18: 100%|██████████| 200/200 [00:02<00:00, 96.71steps/s]


Average Size: 4.67741935483871
Average MW: 77.51422580645162
Average QED: 0.3395989273161674
Average SAS: 1.0174124222847882
Cumulative Reward: -76.40215435275621


Episode 19:  17%|█▋        | 34/200 [00:00<00:01, 91.90steps/s]


KeyboardInterrupt: 

In [None]:
wandb.finish()

In [None]:
env = vectorized_mol_env()
best_model.eval()
states = env.reset()
print(best_model.act(batch_from_states(states)))
for i in range(200):
    actions = best_model.act(batch_from_states(states))
    states, rewards, valids, timestep = env.step(actions['t'], actions['nmol'], actions['nfull'], actions['b'])
    print(f'{i}:\t{rewards}\t{states}')
    if i % 200 == 0:
        states = env.reset()
    env.visualize()