In [9]:
import numpy as np

import torch

import rdkit.Chem as Chem
from rdkit.Chem import Draw
from rdkit.Chem.Draw import IPythonConsole

from MolJuncTree import MolJuncTree


In [10]:
junc_tree = MolJuncTree('CC(C)[C@@H]1CC[C@@H](C)C[C@H]1OC(=O)[C@H]1O[C@H](n2cc(F)c(N)nc2=O)CS1')

In [11]:
smiles = 'CC(C)[C@@H]1CC[C@@H](C)C[C@H]1OC(=O)[C@H]1O[C@H](n2cc(F)c(N)nc2=O)CS1'

In [12]:
arr = np.array([ord(ch) for ch in smiles])

In [13]:
arr

array([ 67,  67,  40,  67,  41,  91,  67,  64,  64,  72,  93,  49,  67,
        67,  91,  67,  64,  64,  72,  93,  40,  67,  41,  67,  91,  67,
        64,  72,  93,  49,  79,  67,  40,  61,  79,  41,  91,  67,  64,
        72,  93,  49,  79,  91,  67,  64,  72,  93,  40, 110,  50,  99,
        99,  40,  70,  41,  99,  40,  78,  41, 110,  99,  50,  61,  79,
        41,  67,  83,  49])

In [15]:
x = torch.from_numpy(arr)

In [16]:
x.shape

torch.Size([69])

In [21]:
str = ""
for item in x:
    str += chr(item)       
print(str)

CC(C)[C@@H]1CC[C@@H](C)C[C@H]1OC(=O)[C@H]1O[C@H](n2cc(F)c(N)nc2=O)CS1


In [22]:
''.join(list(map(lambda x: chr(x), x)))

'CC(C)[C@@H]1CC[C@@H](C)C[C@H]1OC(=O)[C@H]1O[C@H](n2cc(F)c(N)nc2=O)CS1'

In [4]:
junc_tree_tensor = torch.tensor(junc_tree, dtype='MolJuncTree')

TypeError: tensor(): argument 'dtype' must be torch.dtype, not str

In [5]:
def mol_with_atom_index( mol ):
    atoms = mol.GetNumAtoms()
    for idx in range( atoms ):
        mol.GetAtomWithIdx( idx ).SetProp( 'molAtomMapNumber', str( mol.GetAtomWithIdx( idx ).GetIdx() ) )
    return mol

In [8]:
x = torch.randn(10,5)

In [9]:
x

tensor([[-0.2561,  1.1097, -0.9956,  1.2825,  1.7023],
        [-0.0237, -1.0353,  0.4413, -0.9329,  0.6193],
        [-1.6397, -0.2367, -1.0390,  1.2540, -0.7572],
        [ 0.1845,  1.2306, -0.5463, -0.6967, -0.2789],
        [ 0.5653,  1.5654, -0.8289,  1.5294,  0.6850],
        [ 0.5515, -0.4253, -0.8693,  0.7448, -1.8184],
        [ 1.4574,  0.0664, -0.0902, -0.2066, -1.1983],
        [ 0.4142, -2.0299, -0.4837, -1.6217,  1.6142],
        [-0.2006,  1.4563,  0.8201, -0.4498,  1.5089],
        [ 0.8570, -1.2105,  1.7694, -2.3028,  0.7238]])

In [12]:
y = x.narrow(0, 2, 5).sum()

In [14]:
y.sum(dim=0)

tensor([ 1.1190,  2.2005, -3.3737,  2.6249, -3.3678])

In [None]:
smiles_batch = ['CC(C)[C@@H]1CC[C@@H](C)C[C@H]1OC(=O)[C@H]1O[C@H](n2cc(F)c(N)nc2=O)CS1', 'CCCN(CC1CC1)C(=O)Nc1cc(NC(C)=O)ccc1Cl']

In [4]:
junc_tree_batch = [MolJuncTree(smiles) for smiles in smiles_batch]

In [8]:
(junc_tree_batch[0].mol).GetNumAtoms()

27

In [5]:
idx = 0
for node in junc_tree.nodes:
    node.idx = idx
    idx += 1

for node in junc_tree.nodes:
    print(node.idx, node.smiles)

0 CO
1 CO
2 CC
3 CN
4 CN
5 C=O
6 CC
7 CN
8 CC
9 C1=CC=NC=C1
10 C1C[NH2+]CCN1
11 C1=CC=CC=C1
12 C


In [7]:
from collections import deque

In [20]:
    def get_bottom_up_traversal_order(root):
        """
        This method, gets the bottom-up and top-down traversal order for tree message passing purposes.

        * node.idx is the id of the node across all nodes, of all junction trees, for all molecules of the dataset.

        Args:
        root: Root of junction tree of a molecule in the training dataset.

        Returns:
            traversal_order: List of lists of tuples. Each sublist of tuples corresponds to a depth of junction tree.
                            Each tuple corresponds to an edge along which message passing occurs.
        """

        # FIFO queue for BFS traversal
        fifo_queue = deque([root])

        # set to keep track of visited nodes
        visited = set([root.idx])

        # root node is at zeroth depth
        root.depth = 0

        # list to store appropriate traversal order
        bottom_up = []

        while len(fifo_queue) > 0:
            # pop node from front of the queue
            x = fifo_queue.popleft()

            # traverse the neighbors
            for y in x.neighbors:
                if y.idx not in visited:
                    fifo_queue.append(y)

                    visited.add(y.idx)

                    y.depth = x.depth + 1

                    if y.depth > len(bottom_up):
                        bottom_up.append([])
                        
                    bottom_up[y.depth - 1].append((y.idx, x.idx))

        # first we implement bottom-up traversal and then top-down traversal
        traversal_order = bottom_up[::-1]

        return traversal_order

In [21]:
traversal_order = get_bottom_up_traversal_order(junc_tree.nodes[0])

In [23]:
for depth in traversal_order:
    print(depth)

[(11, 8)]
[(8, 10)]
[(10, 7)]
[(7, 9)]
[(9, 6)]
[(5, 12), (6, 12)]
[(12, 4)]
[(4, 3)]
[(3, 2)]
[(2, 1)]
[(1, 0)]
