In [None]:
# Import packages
import torch
from torch.optim import Adam
from Model.SURGE import SURGE
from Reinforcement_Learning.mol_env import vectorized_mol_env
from Model.graph_embedding import batch_from_states
import numpy as np
from tqdm import trange
import datetime
import wandb
import copy
import rdkit.Chem.QED as QED
import sys
import os
from rdkit.Chem import RDConfig # SAS
sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
import sascorer

In [None]:
model = SURGE()
max_steps = 200
num_envs = 4
env = vectorized_mol_env(num_envs = num_envs, max_steps = max_steps) # Vectorized molecular environment
lr = 0.02
optimizer = Adam(lr = lr, params = model.parameters())

In [None]:
gamma = 0.99
eps = np.finfo(np.float32).eps.item() # Small constant to decrease numerical instability
num_episodes = 500

In [None]:
current_datetime = datetime.datetime.now()
current_time = current_datetime.strftime("%Y-%m-%d-%H-%M-%S")

wandb.init(
    project = 'RL_Drug_Generation',
    name= f'REINFORCE---{current_time}',
    config={
        'lr': lr,
        'architecture': str(model),
        'episodes': num_episodes,
        'gamma': gamma,
        'num_envs': num_envs
    })

In [87]:
def evaluate(saved_states, saved_actions):
    """
    This methods evaluates the molecules generated during the episode and returns diversity, novelty, and validity.
    It also analyzes the molecule metrics of drug-likeness (QED), synthetic accessibility (SAS), and size.
    It only analyzes the finished molecules (i.e., after SURGE tells generation to terminate).
    """
    
    flattened_states = []
    flattened_t = []
    for i in range(saved_states.shape[1]):
        states_col = saved_states[:, i].flatten().tolist()
        t_col = saved_actions['t'][:, i].flatten().tolist()
        if i == 0:
            flattened_states = states_col
            flattened_t = t_col
        else:
            flattened_states += states_col
            flattened_t += t_col
    
    idx = [i for i, t in enumerate(flattened_t) if t == 1]
    finished_states = [flattened_states[idx[i]] for i in range(len(idx))]
    
    
    # QED: From 0 to 1 where 1 is the most drug-like
    # SAS: From 1 to 10 where 1 is the easiest to synthesize
    metric_names = ['Average Size', 'Average QED', 'Average SAS']
    num_mols = len(finished_states)
    metrics = dict()
    for name in metric_names:
        metrics[name] = 0
        
    for i in range(num_mols):
        mol_size = finished_states[i].GetNumAtoms()
        if mol_size == 1:
            metrics['Average SAS'] = 0
        else:
            metrics['Average SAS'] += sascorer.calculateScore(finished_states[i])
        metrics['Average Size'] += mol_size
        metrics['Average QED'] += QED.weights_mean(finished_states[i])
    
    for name in metric_names:
        metrics[name] /= num_mols
    
    return metrics

In [88]:
def calc_returns(saved_rewards, gamma):
    """
    Given multidimensional matrix of rewards, computes the returns.
    """
    prev_row = []
    for reward_idx in reversed(range(saved_rewards.shape[0])):
        cur_row = saved_rewards[reward_idx, :]
        if reward_idx == saved_rewards.shape[0] - 1:
            all_returns = cur_row
            prev_row = cur_row
        else:
            all_returns = torch.vstack((prev_row * gamma + cur_row, all_returns))
            prev_row *= gamma
            prev_row += cur_row
    return all_returns

In [None]:
# Training loop
best_reward_model = copy.deepcopy(model)
best_reward = 0
best_episode = 0
for episode in range(1, num_episodes + 1):

    # Reset environment after each episode
    states = env.reset()

    # Episode loggers
    keys = ['t', 'nmol', 'nfull', 'b']
    saved_actions = {'t': [], 'nmol': [], 'nfull': [], 'b': []}
    saved_log_probs = {'t': [], 'nmol': [], 'nfull': [], 'b': []}
    saved_rewards = []
    saved_states = np.array(states, dtype = object)

    # Episode computation
    pbar = trange(max_steps, unit="steps")
    for step in pbar:
        pbar.set_description(f"Episode {episode}")
        
        # Compute actions and log probabilities
        batch = batch_from_states(states)
        actions, log_probs = model.act(batch)

        # Record in episode loggers
        for key in keys:
            if step == 0:
                saved_actions[key] = actions[key]
                saved_log_probs[key] = log_probs[key]
            else:
                saved_actions[key] = np.vstack((saved_actions[key], actions[key]))
                saved_log_probs[key] = torch.vstack((saved_log_probs[key], log_probs[key]))

        # Take a step in environment
        states, rewards, valids, timestep = env.step(actions['t'], actions['nmol'], actions['nfull'], actions['b'])
        saved_states = np.vstack((saved_states, states))

        # Record rewards
        if step == 0:
            saved_rewards = torch.tensor(rewards)
        else:
            saved_rewards = torch.vstack((saved_rewards, torch.tensor(rewards)))

    # Loss calculation and gradient ascent
    cumulative_reward = torch.sum(saved_rewards) / num_envs

    # Returns
    all_returns = calc_returns(saved_rewards, gamma)
    returns = torch.zeros(num_envs)
    all_returns = (all_returns - all_returns.mean()) / (all_returns.std() + eps) # Normalize returns for better stability
    
    # Calculate loss
    all_loss = dict()
    cumulative_loss = 0
    for key in keys:
        # Find the average loss among vectorized environments for each SURGE component
        individual_loss = -1 * all_returns * saved_log_probs[key] / num_envs
        cumulative_loss += torch.sum(individual_loss)
        all_loss[key] = torch.sum(individual_loss)

    # Perform gradient ascent
    optimizer.zero_grad()
    cumulative_loss.backward()
    optimizer.step()
    
    metrics = evaluate(saved_states, saved_actions)
    metrics['Cumulative Reward'] = cumulative_reward.item()
    metrics['Cumulative Loss'] = cumulative_loss.item()
    metrics['Nmol Loss'] = all_loss['nmol'].item()
    metrics['Nfull Loss'] = all_loss['nfull'].item()
    metrics['Bond Loss'] = all_loss['b'].item()
    metrics['Termination Loss'] = all_loss['t'].item()
    
    for key in metrics.keys():
        print(f"{key}: {metrics[key]}")
    
    if cumulative_reward > best_reward:
        best_reward = cumulative_reward
        best_model = copy.deepcopy(model)
        best_episode = episode
    
    wandb.log(metrics)

In [None]:
wandb.finish()