In [62]:
# Import packages
import torch
from torch.optim import Adam
from Model.SURGE import SURGE
from Reinforcement_Learning.mol_env import vectorized_mol_env
from Model.graph_embedding import batch_from_states
import numpy as np
from tqdm import trange
import datetime
import wandb
import copy
import rdkit.Chem.QED as QED
import sys
import os
from rdkit.Chem import RDConfig # SAS
sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
import sascorer

In [19]:
model = SURGE()
max_steps = 200
num_envs = 4
env = vectorized_mol_env(num_envs = num_envs, max_steps = max_steps) # Vectorized molecular environment
lr = 0.02
optimizer = Adam(lr = lr, params = model.parameters())

In [20]:
gamma = 0.99
eps = np.finfo(np.float32).eps.item() # Small constant to decrease numerical instability
num_episodes = 500

In [21]:
current_datetime = datetime.datetime.now()
current_time = current_datetime.strftime("%Y-%m-%d-%H-%M-%S")

wandb.init(
    project = 'RL_Drug_Generation',
    name= f'REINFORCE---{current_time}',
    config={
        'lr': lr,
        'architecture': str(model),
        'episodes': num_episodes,
        'gamma': gamma,
        'num_envs': num_envs
    })

[34m[1mwandb[0m: Currently logged in as: [33mmaxwelljchen[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [63]:
def evaluate(saved_states, saved_actions):
    """
    This methods evaluates the molecules generated during the episode and returns diversity, novelty, and validity.
    It also analyzes the molecule metrics of drug-likeness (QED), synthetic accessibility (SAS), and size.
    It only analyzes the finished molecules (i.e., after SURGE tells generation to terminate).
    """
    
    flattened_states = []
    flattened_t = []
    for i in range(saved_states.shape[1]):
        states_col = saved_states[:, i].flatten().tolist()
        t_col = saved_actions['t'][:, i].flatten().tolist()
        if i == 0:
            flattened_states = states_col
            flattened_t = t_col
        else:
            flattened_states += states_col
            flattened_t += t_col
    
    idx = [i for i, t in enumerate(flattened_t) if t == 1]
    finished_states = [flattened_states[idx[i]] for i in range(len(idx))]
    
    
    # QED: From 0 to 1 where 1 is the most drug-like
    # SAS: From 1 to 10 where 1 is the easiest to synthesize
    metric_names = ['size', 'QED', 'SAS']
    num_mols = len(finished_states)
    metrics = dict()
    for name in metric_names:
        metrics[name] = 0
        
    for i in range(num_mols):
        metrics['size'] += finished_states[i].GetNumAtoms()
        metrics['QED'] += QED.weights_mean(finished_states[i])
        metrics['SAS'] += sascorer.calculateScore(finished_states[i])
    
    for name in metric_names:
        metrics[name] /= num_mols
    
    return metrics

block_head = copy.deepcopy(saved_states)
block_bed = copy.deepcopy(saved_actions)
print(evaluate(block_head, block_bed))

[22:11:54] 

****
Pre-condition Violation
getNumImplicitHs() called without preceding call to calcImplicitValence()
Violation occurred on line 299 in file /Users/runner/work/rdkit-pypi/rdkit-pypi/build/temp.macosx-11.0-arm64-cpython-38/rdkit/Code/GraphMol/Atom.cpp
Failed Expression: d_implicitValence > -1
****



RuntimeError: Pre-condition Violation
	getNumImplicitHs() called without preceding call to calcImplicitValence()
	Violation occurred on line 299 in file Code/GraphMol/Atom.cpp
	Failed Expression: d_implicitValence > -1
	RDKIT: 2023.03.3
	BOOST: 1_78


In [29]:
# Training loop
best_reward_model = copy.deepcopy(model)
best_reward = 0
best_episode = 0
for episode in range(num_episodes):

    # Reset environment after each episode
    states = env.reset()

    # Episode loggers
    keys = ['t', 'nmol', 'nfull', 'b']
    saved_actions = {'t': [], 'nmol': [], 'nfull': [], 'b': []}
    saved_log_probs = {'t': [], 'nmol': [], 'nfull': [], 'b': []}
    saved_rewards = []
    saved_states = np.array(states, dtype = object)

    # Episode computation
    pbar = trange(max_steps, unit="steps")
    for step in pbar:
        pbar.set_description(f"Episode {episode}")
        
        # Compute actions and log probabilities
        batch = batch_from_states(states)
        actions, log_probs = model.act(batch)

        # Record in episode loggers
        for key in keys:
            if step == 0:
                saved_actions[key] = actions[key]
                saved_log_probs[key] = log_probs[key]
            else:
                saved_actions[key] = np.vstack((saved_actions[key], actions[key]))
                saved_log_probs[key] = torch.vstack((saved_log_probs[key], log_probs[key]))

        # Take a step in environment
        states, rewards, valids, timestep = env.step(actions['t'], actions['nmol'], actions['nfull'], actions['b'])
        saved_states = np.vstack((saved_states, states))

        # Record rewards
        if step == 0:
            saved_rewards = torch.tensor(rewards)
        else:
            saved_rewards = torch.vstack((saved_rewards, torch.tensor(rewards)))

    # 3. Loss Calculation & Gradient Ascent
    cumulative_reward = torch.sum(saved_rewards) / num_envs

    # Calculate returns
    all_returns = torch.tensor(num_envs)
    returns = torch.zeros(num_envs)
    for idx in reversed(range(max_steps)):
        returns = saved_rewards[idx, :] + gamma * returns
        if idx == max_steps - 1:
            all_returns = returns
        else:
            all_returns = torch.vstack((returns, all_returns))
    all_returns = (all_returns - all_returns.mean()) / (all_returns.std() + eps)

    # Calculate loss
    all_loss = dict()
    cumulative_loss = 0
    for key in keys:

        # Find the average loss among vectorized environments for each SURGE component
        individual_loss = -1 * all_returns * saved_log_probs[key] / num_envs
        cumulative_loss += torch.sum(individual_loss)
        all_loss[key] = torch.sum(individual_loss)

    # Perform gradient ascent
    optimizer.zero_grad()
    cumulative_loss.backward()
    optimizer.step()
    
    print(f"Cumulative Reward: {cumulative_reward}")
    print(f"Cumulative Loss: {cumulative_loss}")
    print()
    
    break
    
    if cumulative_reward > best_reward:
        best_reward = cumulative_reward
        best_model = copy.deepcopy(model)
        best_episode = episode

    wandb.log({"Cumulative Reward": cumulative_reward, "Cumulative Loss": cumulative_loss,
               "Termination Loss": all_loss['t'], "Nmol Loss": all_loss['nmol'],
               "Nfull Loss": all_loss['nfull'], "Bond Loss": all_loss['b']})

Episode 0: 100%|██████████| 200/200 [00:02<00:00, 98.55steps/s] 


Cumulative Reward: 54.903324127197266
Cumulative Loss: 2.005544662475586


In [23]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Bond Loss,▂▂▄▂▁█▂▂
Cumulative Loss,▇▇█▁▁█▇▆
Cumulative Reward,▆▇▆▁▂▅██
Nfull Loss,▇██▂▁▇█▇
Nmol Loss,▇██▁▃▅▇▆
Termination Loss,▃▂▁█▆▅▂▃

0,1
Bond Loss,-0.02426
Cumulative Loss,-3.38729
Cumulative Reward,50.15389
Nfull Loss,-0.16317
Nmol Loss,-4.67441
Termination Loss,1.47454
