In [70]:
import torch
import torch.nn as nn
import torch.nn.functional as functional
import sklearn.metrics as metrics
from rdkit import Chem
from rdkit.Chem import rdchem as utils

import numpy as np
from numpy import exp
from numpy.random import normal


In [72]:
base_folder = 'data/'

train_r_file = base_folder + 'train_reactants.sdf'
train_ts_file = base_folder + 'train_ts.sdf'
train_p_file = base_folder + 'train_products.sdf'

test_r_file = base_folder + 'test_reactants.sdf'
test_ts_file = base_folder + 'test_ts.sdf'
test_p_file = base_folder + 'test_products.sdf'

train_r = Chem.ForwardSDMolSupplier(train_r_file, removeHs=False, sanitize=False)
train_r = [x for x in train_r]
train_ts = Chem.ForwardSDMolSupplier(train_ts_file, removeHs=False, sanitize=False)
train_ts = [x for x in train_ts]
train_p = Chem.ForwardSDMolSupplier(train_p_file, removeHs=False, sanitize=False)
train_p = [x for x in train_p]

test_r = Chem.ForwardSDMolSupplier(test_r_file, removeHs=False, sanitize=False)
test_r = [x for x in test_r]
test_ts = Chem.ForwardSDMolSupplier(test_ts_file, removeHs=False, sanitize=False)
test_ts = [x for x in test_ts]
test_p = Chem.ForwardSDMolSupplier(test_p_file, removeHs=False, sanitize=False)
test_p = [x for x in test_p]


In [None]:
from rdkit.Chem.rdchem import HybridizationType
from rdkit.Chem.rdchem import BondType as BT
from torch_scatter import scatter
from torch_geometric.data import Data

types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4}
bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}

geometries = Chem.SDMolSupplier('data/raw/test_reactants.sdf', removeHs=False, sanitize=False)

# Reaction centers refer to the pairs of atoms that lose/form a bond in the reactions.

In [None]:
data_list = []
for i, mol in enumerate(geometries):
    N = mol.GetNumAtoms()
    # get atom positions and other data of each atom
    atom_data = train_r.GetItemText(i).split('\n')[4:4 + N]
    # just get atom positions (x,y,z)
    print(i)
    atom_positions = [[float(x) for x in line.split()[:3]] for line in atom_data]
    atom_positions = torch.tensor(atom_positions, dtype=torch.float) # node position matrix with shape [num_nodes, num_dimensions]=torch.Size([num_atoms, 3])

    type_idx = []
    atomic_number = []
    aromatic = []
    sp = []
    sp2 = []
    sp3 = []
    num_hs = []

    for atom in mol.GetAtoms():
        
        # remember not to use these features in any special MP
        type_idx.append(types[atom.GetSymbol()]) # self.types
        atomic_number.append(atom.GetAtomicNum())
        aromatic.append(1 if atom.GetIsAromatic() else 0)
        
        hybridisation = atom.GetHybridization()
        sp.append(1 if hybridisation == HybridizationType.SP else 0)
        sp2.append(1 if hybridisation == HybridizationType.SP2 else 0)
        sp3.append(1 if hybridisation == HybridizationType.SP3 else 0)

    row, col, edge_type = [], [], []
    for bond in mol.GetBonds(): 
        # get start and end atoms of bond
        start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        row += [start, end]
        col += [end, start]
        # edge type for each bond type; *2 because for each atom I assume
        edge_type += 2 * [bonds[bond.GetBondType()]] # self.bonds
        # print(edge_type)

    # [0]: [start_1, end_1, start_2, end_2, ...]; [1]: [end_1, start_2, end_2, ...]
    edge_index = torch.tensor([row, col], dtype=torch.long) # in PTG, edge_index is graph connectivity in COO format with shape [2, num_edges] and type torch.long; note that num_edges = bonds*2 because undirected edges count both ways; maybe i'll want to change this later?
    edge_type = torch.tensor(edge_type, dtype=torch.long)
    # one hot the edge types so we have len(types of bond) classes
    edge_attr = F.one_hot(edge_type, num_classes=len(bonds)).to(torch.float) # in PTG, edge_attr is edge feature matrix with shape [num_edges, num_edge_features]; self.bonds

    # order edges based on combined ascending order
    perm = (edge_index[0] * N + edge_index[1]).argsort()
    edge_index = edge_index[:, perm]
    edge_type = edge_type[perm]
    edge_attr = edge_attr[perm]

    row, col = edge_index
    z = torch.tensor(atomic_number, dtype=torch.long)
    hs = (z == 1).to(torch.float) # hydrogens
    # scatter reduces values from src tensor into out at indices specified in index
    # https://abderhasan.medium.com/pytorchs-scatter-function-a-visual-explanation-351d25c05c73
    # helps with one-hot encoding
    num_hs = scatter(hs[row], col, dim_size=N).tolist() # length=N
    
    x1 = F.one_hot(torch.tensor(type_idx), num_classes=len(types)) # self.types
    x2 = torch.tensor([atomic_number, aromatic, sp, sp2, sp3, num_hs], dtype=torch.float).t().contiguous()
    x = torch.cat([x1.to(torch.float), x2], dim=-1)

    # name is lucky's name for these
    name = mol.GetProp('_Name')
    # no direct y since decode to TS
    data = Data(x=x, z=z, pos=atom_positions, edge_index=edge_index, edge_attr=edge_attr,
    name=name, idx=i)

    data_list = []
    data_list.append(data)

# torch.save(self.collate(data_list), self.processed_paths[0]) # need to deal with selfs here

In [71]:
# figuring out padding
max(mol.GetNumAtoms() for mol in train_r) # = 21, same for ts, p
# min(mol.GetNumAtoms() for mol in train_r) # = 4
# need to get more

# train_r_small = train_r[0:100]
# train_ts_small = train_ts[0:100]
# train_p_small = train_p[0:100]

print(torch.__version__)
# do AE
# then get to grips with PTG



21

In [None]:
# notes
num_latent_params = 2 * latent_space_dim
network = nn.Sequential(nn.Linear(data_dim, 300), ..., nn.Linear(400, num_latent_params))
encoder(network)