In [1]:
import os

In [2]:
import numpy as np
import pandas as pd
from rdkit import Chem
from rdkit.Chem.rdchem import RWMol

In [3]:
smile_csv = pd.read_csv('../data/train_set.ReorgE.csv', index_col=0)

In [4]:
m1 = Chem.MolFromSmiles(smile_csv.iloc[0]['SMILES'])

In [5]:
m1

<rdkit.Chem.rdchem.Mol at 0x7f6231eabac0>

In [6]:
atoms_info = [ (atom.GetIdx(), atom.GetAtomicNum(), atom.GetSymbol()) for atom in m1.GetAtoms()]
bonds_info = [(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), bond.GetBondType(), bond.GetBondTypeAsDouble()) for bond in m1.GetBonds()]

In [7]:
atoms_info

[(0, 6, 'C'),
 (1, 6, 'C'),
 (2, 6, 'C'),
 (3, 6, 'C'),
 (4, 6, 'C'),
 (5, 6, 'C'),
 (6, 6, 'C'),
 (7, 7, 'N'),
 (8, 6, 'C'),
 (9, 8, 'O'),
 (10, 6, 'C'),
 (11, 6, 'C'),
 (12, 8, 'O'),
 (13, 6, 'C'),
 (14, 8, 'O'),
 (15, 6, 'C'),
 (16, 6, 'C'),
 (17, 6, 'C'),
 (18, 8, 'O'),
 (19, 6, 'C'),
 (20, 7, 'N'),
 (21, 6, 'C'),
 (22, 6, 'C'),
 (23, 6, 'C'),
 (24, 6, 'C'),
 (25, 6, 'C'),
 (26, 6, 'C'),
 (27, 7, 'N')]

In [8]:
bonds_info

[(0, 1, rdkit.Chem.rdchem.BondType.SINGLE, 1.0),
 (1, 2, rdkit.Chem.rdchem.BondType.SINGLE, 1.0),
 (2, 3, rdkit.Chem.rdchem.BondType.SINGLE, 1.0),
 (3, 4, rdkit.Chem.rdchem.BondType.SINGLE, 1.0),
 (4, 5, rdkit.Chem.rdchem.BondType.SINGLE, 1.0),
 (5, 6, rdkit.Chem.rdchem.BondType.SINGLE, 1.0),
 (6, 7, rdkit.Chem.rdchem.BondType.SINGLE, 1.0),
 (7, 8, rdkit.Chem.rdchem.BondType.SINGLE, 1.0),
 (8, 9, rdkit.Chem.rdchem.BondType.DOUBLE, 2.0),
 (8, 10, rdkit.Chem.rdchem.BondType.SINGLE, 1.0),
 (10, 11, rdkit.Chem.rdchem.BondType.SINGLE, 1.0),
 (10, 12, rdkit.Chem.rdchem.BondType.SINGLE, 1.0),
 (12, 13, rdkit.Chem.rdchem.BondType.SINGLE, 1.0),
 (13, 14, rdkit.Chem.rdchem.BondType.DOUBLE, 2.0),
 (13, 15, rdkit.Chem.rdchem.BondType.SINGLE, 1.0),
 (15, 16, rdkit.Chem.rdchem.BondType.AROMATIC, 1.5),
 (16, 17, rdkit.Chem.rdchem.BondType.SINGLE, 1.0),
 (16, 18, rdkit.Chem.rdchem.BondType.AROMATIC, 1.5),
 (18, 19, rdkit.Chem.rdchem.BondType.AROMATIC, 1.5),
 (19, 20, rdkit.Chem.rdchem.BondType.SINGLE,

In [9]:
import torch

from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from torch_geometric.nn import DimeNet

In [10]:
Model = DimeNet

In [11]:
path = '../data/QM9'

In [12]:
dataset = QM9(path)

In [13]:
loader = DataLoader(dataset, batch_size=1)

In [14]:
x = next(iter(loader))

In [15]:
x

DataBatch(x=[5, 11], edge_index=[2, 8], edge_attr=[8, 4], y=[1, 19], pos=[5, 3], z=[5], name=[1], idx=[1], batch=[5], ptr=[2])

In [16]:
x.x

tensor([[0., 1., 0., 0., 0., 6., 0., 0., 0., 0., 4.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]])

In [17]:
x.edge_index

tensor([[0, 0, 0, 0, 1, 2, 3, 4],
        [1, 2, 3, 4, 0, 0, 0, 0]])

In [18]:
x.edge_attr

tensor([[1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.]])

In [19]:
x.y

tensor([[    0.0000,    13.2100,   -10.5499,     3.1865,    13.7363,    35.3641,
             1.2177, -1101.4878, -1101.4098, -1101.3840, -1102.0229,     6.4690,
           -17.1722,   -17.2868,   -17.3897,   -16.1519,   157.7118,   157.7100,
           157.7070]])

In [20]:
x.pos

tensor([[-1.2700e-02,  1.0858e+00,  8.0000e-03],
        [ 2.2000e-03, -6.0000e-03,  2.0000e-03],
        [ 1.0117e+00,  1.4638e+00,  3.0000e-04],
        [-5.4080e-01,  1.4475e+00, -8.7660e-01],
        [-5.2380e-01,  1.4379e+00,  9.0640e-01]])

In [21]:
x.z

tensor([6, 1, 1, 1, 1])

In [22]:
x.name

['gdb_1']

In [23]:
x.idx

tensor([0])

In [24]:
x.batch

tensor([0, 0, 0, 0, 0])

In [25]:
x.ptr

tensor([0, 5])

In [26]:
qm9 = pd.read_csv('../data/QM9/raw/gdb9.sdf.csv')

In [27]:
qm9

Unnamed: 0,mol_id,A,B,C,mu,alpha,homo,lumo,gap,r2,zpve,u0,u298,h298,g298,cv,u0_atom,u298_atom,h298_atom,g298_atom
0,gdb_1,157.71180,157.709970,157.706990,0.0000,13.21,-0.3877,0.1171,0.5048,35.3641,0.044749,-40.478930,-40.476062,-40.475117,-40.498597,6.469,-395.999595,-398.643290,-401.014647,-372.471772
1,gdb_2,293.60975,293.541110,191.393970,1.6256,9.46,-0.2570,0.0829,0.3399,26.1563,0.034358,-56.525887,-56.523026,-56.522082,-56.544961,6.316,-276.861363,-278.620271,-280.399259,-259.338802
2,gdb_3,799.58812,437.903860,282.945450,1.8511,6.31,-0.2928,0.0687,0.3615,19.0002,0.021375,-76.404702,-76.401867,-76.400922,-76.422349,6.002,-213.087624,-213.974294,-215.159658,-201.407171
3,gdb_4,0.00000,35.610036,35.610036,0.0000,16.28,-0.2845,0.0506,0.3351,59.5248,0.026841,-77.308427,-77.305527,-77.304583,-77.327429,8.574,-385.501997,-387.237686,-389.016047,-365.800724
4,gdb_5,0.00000,44.593883,44.593883,2.8937,12.99,-0.3604,0.0191,0.3796,48.7476,0.016601,-93.411888,-93.409370,-93.408425,-93.431246,6.278,-301.820534,-302.906752,-304.091489,-288.720028
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
133880,gdb_133881,3.59483,2.198990,1.904230,1.6637,69.37,-0.2254,0.0588,0.2842,760.7472,0.127406,-400.633868,-400.628599,-400.627654,-400.663098,23.658,-1603.983913,-1614.898804,-1623.788097,-1492.819438
133881,gdb_133882,3.65648,2.142370,1.904390,1.2976,69.52,-0.2393,0.0608,0.3002,762.6354,0.127495,-400.629713,-400.624444,-400.623500,-400.658942,23.697,-1601.376613,-1612.291504,-1621.181424,-1490.211511
133882,gdb_133883,3.67118,2.143140,1.895010,1.2480,73.60,-0.2233,0.0720,0.2953,780.3553,0.140458,-380.753918,-380.748619,-380.747675,-380.783148,23.972,-1667.045429,-1678.830048,-1688.312964,-1549.143391
133883,gdb_133884,3.52845,2.151310,1.865820,1.9576,77.40,-0.2122,0.0881,0.3003,803.1904,0.152222,-364.720374,-364.714974,-364.714030,-364.749650,24.796,-1794.600439,-1807.210860,-1817.286772,-1670.349892


In [28]:
sdf = Chem.SDMolSupplier('../data/QM9/raw/gdb9.sdf')

In [29]:
mol = sdf[0]

In [30]:
conf = mol.GetConformer()

In [31]:
pos = conf.GetPositions()

In [32]:
pos

array([[-0.0127,  1.0858,  0.008 ]])

In [33]:
pos = torch.tensor(pos, dtype=torch.float)

In [34]:
pos

tensor([[-0.0127,  1.0858,  0.0080]])

 # Mol

In [40]:
from rdkit import Chem, RDLogger
from rdkit.Chem.rdchem import BondType as BT
from rdkit.Chem.rdchem import HybridizationType

In [52]:
import torch.nn.functional as F

In [85]:
from torch_scatter import scatter
from torch_geometric.data import Data

In [35]:
m1 = Chem.MolFromMolFile('../data/mol_files/train_set/train_0_ex.mol')

In [36]:
N = m1.GetNumAtoms()

In [37]:
conf = mol.GetConformer()
pos = conf.GetPositions()
pos = torch.tensor(pos, dtype=torch.float)


In [41]:
types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4}
bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}

In [45]:
type_idx = []
atomic_number = []
aromatic = []
sp = []
sp2 = []
sp3 = []
num_hs = []
for atom in m1.GetAtoms():
    type_idx.append(types[atom.GetSymbol()])
    atomic_number.append(atom.GetAtomicNum())
    aromatic.append(1 if atom.GetIsAromatic() else 0)
    hybridization = atom.GetHybridization()
    sp.append(1 if hybridization == HybridizationType.SP else 0)
    sp2.append(1 if hybridization == HybridizationType.SP2 else 0)
    sp3.append(1 if hybridization == HybridizationType.SP3 else 0)

In [91]:
atomic_number

[6,
 6,
 6,
 6,
 6,
 6,
 6,
 7,
 6,
 8,
 6,
 6,
 8,
 6,
 8,
 6,
 6,
 6,
 8,
 6,
 7,
 6,
 6,
 6,
 6,
 6,
 6,
 7]

In [48]:
z = torch.tensor(atomic_number, dtype=torch.long)

In [57]:
row, col, edge_type = [], [], []
for bond in m1.GetBonds():
    start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
    row += [start, end]
    col += [end, start]
    edge_type += 2 * [bonds[bond.GetBondType()]]

In [58]:
edge_index = torch.tensor([row, col], dtype=torch.long)
edge_type = torch.tensor(edge_type, dtype=torch.long)
edge_attr = F.one_hot(edge_type,
                      num_classes=len(bonds)).to(torch.float)

In [62]:
perm = (edge_index[0] * N + edge_index[1]).argsort()
edge_index = edge_index[:, perm]
edge_type = edge_type[perm]
edge_attr = edge_attr[perm]

In [63]:
perm

tensor([ 0,  1,  2,  3,  4,  6,  5,  8,  9, 10, 11, 12, 13, 14,  7, 15, 16, 17,
        18, 20, 19, 21, 22, 24, 23, 25, 26, 27, 28, 30, 29, 31, 32, 34, 33, 36,
        38, 37, 39, 40, 41, 42, 44, 43, 46, 48, 47, 50, 51, 52, 53, 54, 49, 55,
        35, 45, 56, 57, 58, 59])

In [67]:
row, col = edge_index
hs = (z == 1).to(torch.float)
num_hs = scatter(hs[row], col, dim_size=N).tolist()

In [68]:
row

tensor([ 0,  1,  1,  2,  2,  2,  3,  3,  4,  4,  5,  5,  6,  6,  7,  7,  7,  8,
         8,  8,  9, 10, 10, 10, 11, 12, 12, 13, 13, 13, 14, 15, 15, 15, 16, 16,
        16, 17, 18, 18, 19, 19, 19, 20, 20, 20, 21, 21, 22, 22, 23, 23, 24, 24,
        25, 25, 25, 26, 26, 27])

In [69]:
col

tensor([ 1,  0,  2,  1,  3,  7,  2,  4,  3,  5,  4,  6,  5,  7,  2,  6,  8,  7,
         9, 10,  8,  8, 11, 12, 10, 10, 13, 12, 14, 15, 13, 13, 16, 25, 15, 17,
        18, 16, 16, 19, 18, 20, 25, 19, 21, 24, 20, 22, 21, 23, 22, 24, 20, 23,
        15, 19, 26, 25, 27, 26])

In [70]:
hs

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.])

In [71]:
num_hs

[0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0]

In [72]:
x1 = F.one_hot(torch.tensor(type_idx), num_classes=len(types))
x2 = torch.tensor([atomic_number, aromatic, sp, sp2, sp3, num_hs],
                  dtype=torch.float).t().contiguous()
x = torch.cat([x1.to(torch.float), x2], dim=-1)

In [73]:
x1

tensor([[0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 0, 1, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 0, 0, 1, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 0, 0, 1, 0],
        [0, 1, 0, 0, 0],
        [0, 0, 0, 1, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 0, 0, 1, 0],
        [0, 1, 0, 0, 0],
        [0, 0, 1, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 0, 1, 0, 0]])

In [74]:
x2

tensor([[6., 0., 0., 0., 1., 0.],
        [6., 0., 0., 0., 1., 0.],
        [6., 0., 0., 0., 1., 0.],
        [6., 0., 0., 0., 1., 0.],
        [6., 0., 0., 0., 1., 0.],
        [6., 0., 0., 0., 1., 0.],
        [6., 0., 0., 0., 1., 0.],
        [7., 0., 0., 1., 0., 0.],
        [6., 0., 0., 1., 0., 0.],
        [8., 0., 0., 1., 0., 0.],
        [6., 0., 0., 0., 1., 0.],
        [6., 0., 0., 0., 1., 0.],
        [8., 0., 0., 1., 0., 0.],
        [6., 0., 0., 1., 0., 0.],
        [8., 0., 0., 1., 0., 0.],
        [6., 1., 0., 1., 0., 0.],
        [6., 1., 0., 1., 0., 0.],
        [6., 0., 0., 0., 1., 0.],
        [8., 1., 0., 1., 0., 0.],
        [6., 1., 0., 1., 0., 0.],
        [7., 1., 0., 1., 0., 0.],
        [6., 1., 0., 1., 0., 0.],
        [6., 1., 0., 1., 0., 0.],
        [6., 1., 0., 1., 0., 0.],
        [6., 1., 0., 1., 0., 0.],
        [6., 1., 0., 1., 0., 0.],
        [6., 0., 1., 0., 0., 0.],
        [7., 0., 1., 0., 0., 0.]])

In [75]:
x

tensor([[0., 1., 0., 0., 0., 6., 0., 0., 0., 1., 0.],
        [0., 1., 0., 0., 0., 6., 0., 0., 0., 1., 0.],
        [0., 1., 0., 0., 0., 6., 0., 0., 0., 1., 0.],
        [0., 1., 0., 0., 0., 6., 0., 0., 0., 1., 0.],
        [0., 1., 0., 0., 0., 6., 0., 0., 0., 1., 0.],
        [0., 1., 0., 0., 0., 6., 0., 0., 0., 1., 0.],
        [0., 1., 0., 0., 0., 6., 0., 0., 0., 1., 0.],
        [0., 0., 1., 0., 0., 7., 0., 0., 1., 0., 0.],
        [0., 1., 0., 0., 0., 6., 0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0., 8., 0., 0., 1., 0., 0.],
        [0., 1., 0., 0., 0., 6., 0., 0., 0., 1., 0.],
        [0., 1., 0., 0., 0., 6., 0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0., 8., 0., 0., 1., 0., 0.],
        [0., 1., 0., 0., 0., 6., 0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0., 8., 0., 0., 1., 0., 0.],
        [0., 1., 0., 0., 0., 6., 1., 0., 1., 0., 0.],
        [0., 1., 0., 0., 0., 6., 1., 0., 1., 0., 0.],
        [0., 1., 0., 0., 0., 6., 0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0.,

In [77]:
# y = target[i].unsqueeze(0)
name = m1.GetProp('_Name')

In [78]:
name

''

In [87]:
data = Data(x=x, z=z, pos=pos, edge_index=edge_index,
                        edge_attr=edge_attr, name=name)

In [88]:
data

Data(x=[28, 11], edge_index=[2, 60], edge_attr=[60, 4], pos=[1, 3], z=[28], name='')

In [89]:
data.x

tensor([[0., 1., 0., 0., 0., 6., 0., 0., 0., 1., 0.],
        [0., 1., 0., 0., 0., 6., 0., 0., 0., 1., 0.],
        [0., 1., 0., 0., 0., 6., 0., 0., 0., 1., 0.],
        [0., 1., 0., 0., 0., 6., 0., 0., 0., 1., 0.],
        [0., 1., 0., 0., 0., 6., 0., 0., 0., 1., 0.],
        [0., 1., 0., 0., 0., 6., 0., 0., 0., 1., 0.],
        [0., 1., 0., 0., 0., 6., 0., 0., 0., 1., 0.],
        [0., 0., 1., 0., 0., 7., 0., 0., 1., 0., 0.],
        [0., 1., 0., 0., 0., 6., 0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0., 8., 0., 0., 1., 0., 0.],
        [0., 1., 0., 0., 0., 6., 0., 0., 0., 1., 0.],
        [0., 1., 0., 0., 0., 6., 0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0., 8., 0., 0., 1., 0., 0.],
        [0., 1., 0., 0., 0., 6., 0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0., 8., 0., 0., 1., 0., 0.],
        [0., 1., 0., 0., 0., 6., 1., 0., 1., 0., 0.],
        [0., 1., 0., 0., 0., 6., 1., 0., 1., 0., 0.],
        [0., 1., 0., 0., 0., 6., 0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0.,

In [90]:
data.edge_index

tensor([[ 0,  1,  1,  2,  2,  2,  3,  3,  4,  4,  5,  5,  6,  6,  7,  7,  7,  8,
          8,  8,  9, 10, 10, 10, 11, 12, 12, 13, 13, 13, 14, 15, 15, 15, 16, 16,
         16, 17, 18, 18, 19, 19, 19, 20, 20, 20, 21, 21, 22, 22, 23, 23, 24, 24,
         25, 25, 25, 26, 26, 27],
        [ 1,  0,  2,  1,  3,  7,  2,  4,  3,  5,  4,  6,  5,  7,  2,  6,  8,  7,
          9, 10,  8,  8, 11, 12, 10, 10, 13, 12, 14, 15, 13, 13, 16, 25, 15, 17,
         18, 16, 16, 19, 18, 20, 25, 19, 21, 24, 20, 22, 21, 23, 22, 24, 20, 23,
         15, 19, 26, 25, 27, 26]])