In [1]:
%load_ext autoreload

In [2]:
%autoreload 2

In [3]:
import numpy as np
import pandas as pd
import networkx as nx
import pickle
import math
import random
from tqdm import tqdm

import rdkit
import rdkit.Chem.AllChem
from rdkit import Chem
from rdkit.Chem import rdMolHash                                                        

from molecule_builder import build_molecules

In [14]:
class MolNode(rdkit.Chem.Mol):
    
    max_depth = 10
    constant = math.sqrt(2)
    
    def __init__(self, *args, graph=None, terminal=False, **kwargs):
        super(MolNode, self).__init__(*args, **kwargs)
        self.terminal = terminal if self.GetNumAtoms() < MolNode.max_depth else True
        self.G = graph
    
    def __hash__(self):
        return hash(Chem.MolToSmiles(self) + ('[t]' if self.terminal else ''))

    def __eq__(self, other):
        return self.__hash__() == other.__hash__()
    
    def __repr__(self):
        return '<{}{}>'.format(Chem.MolToSmiles(self), '[t]' if self.terminal else '')
    
    def children(self):
        if self.terminal:
            raise RuntimeError("Attemping to get children of terminal node")
        
        for mol in build_molecules(self, stereoisomers=False):
            yield MolNode(mol, graph=self.G)
        
        # Add this node as a terminal state (stop option)
        yield MolNode(self, graph=self.G, terminal=True)
    
    
    def ucb1(self, parent):
        if parent.visits == 0:
            raise RuntimeError("Child {} of parent {} with zero visits".format(self, parent))
        if self.visits == 0:
            return math.inf
        return self.value + self.constant * math.sqrt(2 * math.log(parent.visits) / self.visits)
    
    def update(self, reward):
        node = self.G.nodes[self]
        if 'visits' in node:
            node['visits'] += 1
        else:
            node['visits'] = 1
            
        if 'total_value' in node:
            node['total_value'] += reward
        else:
            node['total_value'] = reward
        
    @property
    def visits(self):
        node = self.G.nodes[self]
        try:
            return node['visits']
        except KeyError:
            return 0
        
    @property
    def value(self):
        node = self.G.nodes[self]
        
        try:
            total_value = node['total_value']
        except KeyError:
            total_value = 0
        
        return total_value / self.visits if self.visits > 0 else 0

In [15]:
start = MolNode(Chem.MolFromSmiles('C'))
list(start.children())

[<C#C>, <CC>, <C=C>, <CO>, <C=O>, <C=N>, <CN>, <C#N>, <C[t]>]

In [16]:
from rdkit.Chem.Descriptors import qed
from functools import lru_cache

@lru_cache(maxsize=int(1E5))
def reward(mol):
    return qed(mol)

In [17]:
def build_nested(mol, num):
    """ I dont use this in mcts; only to profile the reward func for larger molecules """
    for mol in build_molecules(mol, stereoisomers=False):
        if num == 1:
            yield mol
            
        else:
            yield from build_nested(mol, num=num-1)
            
            
%timeit reward(next(build_nested(start, MolNode.max_depth)))

3.76 ms ± 42.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [21]:
def tree_policy(G, parent):

    yield parent

    if not parent.terminal:
        sorted_successors = sorted(G.successors(parent), key=lambda x: x.ucb1(parent), reverse=True)
        if sorted_successors:
            yield from tree_policy(G, sorted_successors[0])


def expand(G, node):
    G.add_edges_from(((node, child) for child in node.children()))

    
def rollout(node):
    """ It's expensive to enumerate all possible children of a molecule;
    so node.children() is a generator that only builds one at a time. But 
    I need some way of simulating stopping early, so I just give a random
    chance of stopping at any given state along the way. """
    if node.terminal or (random.random() < rollout.early_stop_frac):
        return reward(node)
    else:
        return rollout(next(node.children()))

rollout.early_stop_frac = 0.1


def mcts_step(G, start):
    
    # Perform the tree policy search
    history = list(tree_policy(G, start))
    leaf = history[-1]
    
    # Expand if the leaf has been visited previously
    if (leaf.visits > 0) and not leaf.terminal:
        expand(G, leaf)
    
    # Perform the rollout and see reward
    reward = rollout(leaf)

    # perform backprop
    for node in history:
        node.update(reward)    

In [22]:
MolNode.max_depth = 4
G = nx.DiGraph()
start = MolNode(Chem.MolFromSmiles('C'), graph=G)
G.add_node(start)

for i in tqdm(range(1000)):
    mcts_step(G, start)

# Read out graph data
df = pd.DataFrame.from_records(({**{'name': node[0], 'num_atoms': node[0].GetNumAtoms()}, **node[1]} for node in G.nodes(data=True)))
df = df.fillna(0)
df['avg_value'] = df.total_value / df.visits
df.sort_values('avg_value', ascending=False).head()

100%|██████████| 1000/1000 [00:04<00:00, 232.43it/s]


Unnamed: 0,name,num_atoms,visits,total_value,avg_value
99,<CCCO[t]>,4,3.0,1.391352,0.463784
165,<CCCN[t]>,4,3.0,1.389813,0.463271
91,<C=CNC[t]>,4,3.0,1.372774,0.457591
238,<CCNC[t]>,4,3.0,1.366574,0.455525
92,<C=CN=N[t]>,4,3.0,1.312525,0.437508


In [23]:
MolNode.max_depth = 10
G = nx.DiGraph()
start = MolNode(Chem.MolFromSmiles('C'), graph=G)
G.add_node(start)

for i in tqdm(range(10000)):
    mcts_step(G, start)

# Read out graph data
df = pd.DataFrame.from_records(({**{'name': node[0], 'num_atoms': node[0].GetNumAtoms()}, **node[1]} for node in G.nodes(data=True)))
df = df.fillna(0)
df['avg_value'] = df.total_value / df.visits
df.sort_values('avg_value', ascending=False).head()

100%|██████████| 10000/10000 [01:08<00:00, 145.73it/s]


Unnamed: 0,name,num_atoms,visits,total_value,avg_value
3685,<C=C(C#N)OO>,6,1.0,0.646983,0.646983
3979,<CONC#CO>,6,1.0,0.617213,0.617213
4243,<CCCC(C)=N>,6,1.0,0.593545,0.593545
2692,<CC(N)=CN=N>,6,1.0,0.591516,0.591516
7445,<C=C(C)CN=N>,6,1.0,0.581348,0.581348


In [24]:
df.sort_values('num_atoms', ascending=False).head()

Unnamed: 0,name,num_atoms,visits,total_value,avg_value
12162,<C=COOOC#N>,7,0.0,0.0,
12157,<N#COOOCO>,7,0.0,0.0,
12159,<N#COOOCN>,7,0.0,0.0,
12160,<N#COOOC=N>,7,0.0,0.0,
12161,<N#COOOC#N>,7,0.0,0.0,


In [25]:
df.sort_values('num_atoms', ascending=False)

Unnamed: 0,name,num_atoms,visits,total_value,avg_value
12162,<C=COOOC#N>,7,0.0,0.000000,
12157,<N#COOOCO>,7,0.0,0.000000,
12159,<N#COOOCN>,7,0.0,0.000000,
12160,<N#COOOC=N>,7,0.0,0.000000,
12161,<N#COOOC#N>,7,0.0,0.000000,
...,...,...,...,...,...
41,<CN[t]>,2,201.0,77.316300,0.384658
45,<CC[t]>,2,234.0,87.231820,0.372786
55,<CO[t]>,2,204.0,78.598019,0.385284
9,<C[t]>,1,1422.0,511.614182,0.359785


In [26]:
reward.cache_info()

CacheInfo(hits=6129, misses=5682, maxsize=100000, currsize=5682)