In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import pandas as pd
import networkx as nx
import pickle
import math
import random
from tqdm import tqdm

import rdkit
import rdkit.Chem.AllChem
from rdkit import Chem
from rdkit.Chem import rdMolHash                                                        

from molecule_builder import build_molecules

In [3]:
class MolNode(rdkit.Chem.Mol):
    
    max_depth = 10
    constant = math.sqrt(2)
    
    def __init__(self, *args, graph=None, terminal=False, **kwargs):
        super(MolNode, self).__init__(*args, **kwargs)
        self.terminal = terminal if self.GetNumAtoms() < MolNode.max_depth else True
        self.G = graph
    
    def __hash__(self):
        return hash(Chem.MolToSmiles(self) + ('[t]' if self.terminal else ''))

    def __eq__(self, other):
        return self.__hash__() == other.__hash__()
    
    def __repr__(self):
        return '<{}{}>'.format(Chem.MolToSmiles(self), '[t]' if self.terminal else '')
    
    def children(self):
        if self.terminal:
            raise RuntimeError("Attemping to get children of terminal node")
        
        for mol in build_molecules(self, stereoisomers=False):
            yield MolNode(mol, graph=self.G)
        
        # Add this node as a terminal state (stop option)
        yield MolNode(self, graph=self.G, terminal=True)
    
    
    def ucb1(self, parent):
        if parent.visits == 0:
            raise RuntimeError("Child {} of parent {} with zero visits".format(self, parent))
        if self.visits == 0:
            return math.inf
        return self.value + self.constant * math.sqrt(2 * math.log(parent.visits) / self.visits)
    
    def update(self, reward):
        node = self.G.nodes[self]
        if 'visits' in node:
            node['visits'] += 1
        else:
            node['visits'] = 1
            
        if 'total_value' in node:
            node['total_value'] += reward
        else:
            node['total_value'] = reward
        
    @property
    def visits(self):
        node = self.G.nodes[self]
        try:
            return node['visits']
        except KeyError:
            return 0
        
    @property
    def value(self):
        node = self.G.nodes[self]
        
        try:
            total_value = node['total_value']
        except KeyError:
            total_value = 0
        
        return total_value / self.visits if self.visits > 0 else 0

In [4]:
start = MolNode(Chem.MolFromSmiles('C'))
list(start.children())

[<C#N>, <C=N>, <CN>, <CO>, <C=O>, <C#C>, <C=C>, <CC>, <C[t]>]

In [5]:
from rdkit.Chem.Descriptors import qed
from functools import lru_cache

@lru_cache(maxsize=int(1E5))
def reward(mol):
    return qed(mol)

In [6]:
def build_nested(mol, num):
    """ I dont use this in mcts; only to profile the reward func for larger molecules """
    for mol in build_molecules(mol, stereoisomers=False):
        if num == 1:
            yield mol
            
        else:
            yield from build_nested(mol, num=num-1)
            
            
%timeit reward(next(build_nested(start, MolNode.max_depth)))

3.86 ms ± 23.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [7]:
def tree_policy(G, parent):

    yield parent

    if not parent.terminal:
        sorted_successors = sorted(G.successors(parent), key=lambda x: x.ucb1(parent), reverse=True)
        if sorted_successors:
            yield from tree_policy(G, sorted_successors[0])


def expand(G, node):
    G.add_edges_from(((node, child) for child in node.children()))

    
def rollout(node):
    """ It's expensive to enumerate all possible children of a molecule;
    so node.children() is a generator that only builds one at a time. But 
    I need some way of simulating stopping early, so I just give a random
    chance of stopping at any given state along the way. """
    if node.terminal or (random.random() < rollout.early_stop_frac):
        return reward(node)
    else:
        return rollout(next(node.children()))

rollout.early_stop_frac = 0.1


def mcts_step(G, start):
    
    # Perform the tree policy search
    history = list(tree_policy(G, start))
    leaf = history[-1]
    
    # Expand if the leaf has been visited previously
    if (leaf.visits > 0) and not leaf.terminal:
        expand(G, leaf)
    
    # Perform the rollout and see reward
    reward = rollout(leaf)

    # perform backprop
    for node in history:
        node.update(reward)    

In [8]:
MolNode.max_depth = 4
G = nx.DiGraph()
start = MolNode(Chem.MolFromSmiles('C'), graph=G)
G.add_node(start)

for i in tqdm(range(1000)):
    mcts_step(G, start)

# Read out graph data
df = pd.DataFrame.from_records(({**{'name': node[0], 'num_atoms': node[0].GetNumAtoms()}, **node[1]} for node in G.nodes(data=True)))
df = df.fillna(0)
df['avg_value'] = df.total_value / df.visits
df.sort_values('avg_value', ascending=False).head()

100%|██████████| 1000/1000 [00:03<00:00, 313.82it/s]


Unnamed: 0,name,num_atoms,visits,total_value,avg_value
107,<CCCO[t]>,4,3.0,1.391352,0.463784
111,<CCCN[t]>,4,3.0,1.389813,0.463271
102,<C=CNC[t]>,4,3.0,1.372774,0.457591
157,<CCNC[t]>,4,3.0,1.366574,0.455525
101,<C=CN=N[t]>,4,3.0,1.312525,0.437508


In [9]:
MolNode.max_depth = 10
G = nx.DiGraph()
start = MolNode(Chem.MolFromSmiles('C'), graph=G)
G.add_node(start)

for i in tqdm(range(10000)):
    mcts_step(G, start)

# Read out graph data
df = pd.DataFrame.from_records(({**{'name': node[0], 'num_atoms': node[0].GetNumAtoms()}, **node[1]} for node in G.nodes(data=True)))
df = df.fillna(0)
df['avg_value'] = df.total_value / df.visits
df.sort_values('avg_value', ascending=False).head()

100%|██████████| 10000/10000 [01:07<00:00, 147.77it/s]


Unnamed: 0,name,num_atoms,visits,total_value,avg_value
9041,<C=C1N=NN1>,5,1.0,0.63714,0.63714
11722,<Cn1on1C>,5,1.0,0.619977,0.619977
7009,<CC=CN(C)C>,6,1.0,0.606135,0.606135
11034,<CCCCC=O>,6,1.0,0.595939,0.595939
6630,<CNC(C)N=N>,6,1.0,0.579466,0.579466


In [10]:
df.sort_values('num_atoms', ascending=False).head()

Unnamed: 0,name,num_atoms,visits,total_value,avg_value
12334,<N#CON=NC=N>,7,0.0,0.0,
12327,<C#CN=NOC#N>,7,0.0,0.0,
12333,<N#CN=NOC#N>,7,0.0,0.0,
12332,<N#CON=NCN>,7,0.0,0.0,
12331,<N#CON=NCO>,7,0.0,0.0,


In [11]:
df.sort_values('num_atoms', ascending=False)

Unnamed: 0,name,num_atoms,visits,total_value,avg_value
12334,<N#CON=NC=N>,7,0.0,0.000000,
12327,<C#CN=NOC#N>,7,0.0,0.000000,
12333,<N#CN=NOC#N>,7,0.0,0.000000,
12332,<N#CON=NCN>,7,0.0,0.000000,
12331,<N#CON=NCO>,7,0.0,0.000000,
...,...,...,...,...,...
1,<C=C>,2,1172.0,402.345379,0.343298
40,<CO[t]>,2,202.0,77.827450,0.385284
22,<C=O[t]>,2,219.0,78.976756,0.360624
9,<C[t]>,1,1423.0,511.973967,0.359785


In [12]:
reward.cache_info()

CacheInfo(hits=6120, misses=5691, maxsize=100000, currsize=5691)