In [1]:
%load_ext autoreload

In [2]:
%autoreload 2

In [3]:
import numpy as np
import pandas as pd
import networkx as nx
import pickle
import math
import random
from tqdm import tqdm

import rdkit
import rdkit.Chem.AllChem
from rdkit import Chem
from rdkit.Chem import rdMolHash                                                        

from molecule_builder import build_molecules

In [4]:
class MolNode(rdkit.Chem.Mol):
    
    max_depth = 10
    constant = math.sqrt(2)
    
    def __init__(self, *args, graph=None, terminal=False, **kwargs):
        super(MolNode, self).__init__(*args, **kwargs)
        self.terminal = terminal if self.GetNumAtoms() < MolNode.max_depth else True
        self.G = graph
    
    def __hash__(self):
        return hash(Chem.MolToSmiles(self) + ('[t]' if self.terminal else ''))

    def __eq__(self, other):
        return self.__hash__() == other.__hash__()
    
    def __repr__(self):
        return '<{}{}>'.format(Chem.MolToSmiles(self), '[t]' if self.terminal else '')
    
    def children(self):
        if self.terminal:
            raise RuntimeError("Attemping to get children of terminal node")
            
        return ([MolNode(mol, graph=self.G) for mol in build_molecules(self, stereoisomers=False)] + 
                [MolNode(self, graph=self.G, terminal=True)])
    
    def ucb1(self, parent):
        if parent.visits == 0:
            raise RuntimeError("Child {} of parent {} with zero visits".format(self, parent))
        if self.visits == 0:
            return math.inf
        return self.value + self.constant * math.sqrt(2 * math.log(parent.visits) / self.visits)
    
    def update(self, reward):
        node = self.G.nodes[self]
        if 'visits' in node:
            node['visits'] += 1
        else:
            node['visits'] = 1
            
        if 'total_value' in node:
            node['total_value'] += reward
        else:
            node['total_value'] = reward
        
    @property
    def visits(self):
        node = self.G.nodes[self]
        try:
            return node['visits']
        except KeyError:
            return 0
        
    @property
    def value(self):
        node = self.G.nodes[self]
        
        try:
            total_value = node['total_value']
        except KeyError:
            total_value = 0
        
        return total_value / self.visits if self.visits > 0 else 0

In [5]:
start = MolNode(Chem.MolFromSmiles('C'))
start.children()

[<CO>, <C=O>, <CC>, <C=C>, <C#C>, <C=N>, <C#N>, <CN>, <C[t]>]

In [6]:
from rdkit.Chem.Descriptors import qed
from functools import lru_cache

@lru_cache(maxsize=int(1E5))
def reward(mol):
    return qed(mol)

In [7]:
def build_nested(mol, num):   
    for mol in build_molecules(mol, stereoisomers=False):
        if num == 1:
            yield mol
            
        else:
            yield from build_nested(mol, num=num-1)
            
            
%timeit reward(next(build_nested(start, MolNode.max_depth)))

3.77 ms ± 122 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [8]:
def tree_policy(G, parent):

    yield parent
#    print(parent, flush=True, end=", ")

    if not parent.terminal:
        sorted_successors = sorted(G.successors(parent), key=lambda x: x.ucb1(parent), reverse=True)
        if sorted_successors:
            yield from tree_policy(G, sorted_successors[0])


def expand(G, node):
    G.add_edges_from(((node, child) for child in node.children()))
    #('Adding nodes: {}'.format(node.children()), flush=True, end=", ")

    
def rollout(node):
    
    if node.terminal:
        return reward(node)
    else:
        return rollout(random.choice(node.children()))


def mcts_step(G, start):
    
    # Perform the tree policy search
    history = list(tree_policy(G, start))
    leaf = history[-1]
    
    # Expand if the leaf has been visited previously
    if (leaf.visits > 0) and not leaf.terminal:
        expand(G, leaf)
    
    # Perform the rollout and see reward
    reward = rollout(leaf)
#    print('{:.3f}'.format(reward))

    # perform backprop
    for node in history:
        node.update(reward)    

In [9]:
MolNode.max_depth = 4
G = nx.DiGraph()
start = MolNode(Chem.MolFromSmiles('C'), graph=G)
G.add_node(start)

for i in tqdm(range(1000)):
    mcts_step(G, start)

# Read out graph data
df = pd.DataFrame.from_records(({**{'name': node[0], 'num_atoms': node[0].GetNumAtoms()}, **node[1]} for node in G.nodes(data=True)))
df = df.fillna(0)
df['avg_value'] = df.total_value / df.visits
df.sort_values('avg_value', ascending=False).head()

100%|██████████| 1000/1000 [00:04<00:00, 245.81it/s]


Unnamed: 0,name,num_atoms,visits,total_value,avg_value
112,<CCCO[t]>,4,3.0,1.391352,0.463784
108,<CCCN[t]>,4,3.0,1.389813,0.463271
133,<C=CNC[t]>,4,3.0,1.372774,0.457591
140,<CCNC[t]>,4,3.0,1.366574,0.455525
131,<C=CN=N[t]>,4,3.0,1.312525,0.437508


In [10]:
MolNode.max_depth = 10
G = nx.DiGraph()
start = MolNode(Chem.MolFromSmiles('C'), graph=G)
G.add_node(start)

for i in tqdm(range(10000)):
    mcts_step(G, start)

# Read out graph data
df = pd.DataFrame.from_records(({**{'name': node[0], 'num_atoms': node[0].GetNumAtoms()}, **node[1]} for node in G.nodes(data=True)))
df = df.fillna(0)
df['avg_value'] = df.total_value / df.visits
df.sort_values('avg_value', ascending=False).head()

100%|██████████| 10000/10000 [01:36<00:00, 103.91it/s]


Unnamed: 0,name,num_atoms,visits,total_value,avg_value
5570,<C=COC#CC>,6,1.0,0.61999,0.61999
6547,<N#CCN=NN>,6,1.0,0.597074,0.597074
6936,<O=CC=NCO>,6,1.0,0.587391,0.587391
7746,<N=NON=CN>,6,1.0,0.585683,0.585683
5087,<C=C(C)N=C=N>,6,1.0,0.58184,0.58184


In [11]:
df.sort_values('num_atoms', ascending=False).head()

Unnamed: 0,name,num_atoms,visits,total_value,avg_value
12065,<N#COOON=N>,7,0.0,0.0,
12331,<CCOOOC#N>,7,0.0,0.0,
12285,<N#COOOON>,7,0.0,0.0,
12286,<COOOOC#N>,7,0.0,0.0,
12287,<N#COOOOO>,7,0.0,0.0,


In [12]:
df.sort_values('num_atoms', ascending=False)

Unnamed: 0,name,num_atoms,visits,total_value,avg_value
12065,<N#COOON=N>,7,0.0,0.000000,
12331,<CCOOOC#N>,7,0.0,0.000000,
12285,<N#COOOON>,7,0.0,0.000000,
12286,<COOOOC#N>,7,0.0,0.000000,
12287,<N#COOOOO>,7,0.0,0.000000,
...,...,...,...,...,...
16,<C=O[t]>,2,220.0,79.337381,0.360624
41,<CC[t]>,2,234.0,87.231820,0.372786
1,<CN>,2,1061.0,354.496440,0.334115
9,<C[t]>,1,1427.0,513.413106,0.359785


In [13]:
reward.cache_info()

CacheInfo(hits=6225, misses=5586, maxsize=100000, currsize=5586)