# Manual data processing if issues

In [11]:
from numpy.core.fromnumeric import product
from scipy.sparse import data
import torch
import torch.nn.functional as F
from torch_scatter import scatter
from torch_geometric.data import InMemoryDataset, DataLoader # , Data
from torch_geometric.data.data import Data
from rdkit import Chem
from rdkit.Chem.rdchem import HybridizationType
from rdkit.Chem.rdchem import BondType as BT
from tqdm import tqdm

def process_geometry_file(geometry_file, list = None):
    """ Code mostly lifted from QM9 dataset creation https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/datasets/qm9.html 
        Transforms molecules to their atom features and adjacency lists.
    """
    types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4}
    bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}
    limit = 100

    data_list = list if list else []
    full_path = r'data' + geometry_file
    geometries = Chem.SDMolSupplier(full_path, removeHs=False, sanitize=False)

    # get atom and edge features for each geometry
    for i, mol in enumerate(tqdm(geometries)):

        # temp soln cos of split edge memory issues
        if i == limit:
            break
        
        N = mol.GetNumAtoms()
        # get atom positions as matrix w shape [num_nodes, num_dimensions] = [num_atoms, 3]
        atom_data = geometries.GetItemText(i).split('\n')[4:4 + N] 
        atom_positions = [[float(x) for x in line.split()[:3]] for line in atom_data]
        atom_positions = torch.tensor(atom_positions, dtype=torch.float)
        # all the features
        type_idx = []
        atomic_number = []
        aromatic = []
        sp = []
        sp2 = []
        sp3 = []
        num_hs = []

        # atom/node features
        for atom in mol.GetAtoms():
            type_idx.append(types[atom.GetSymbol()])
            atomic_number.append(atom.GetAtomicNum())
            aromatic.append(1 if atom.GetIsAromatic() else 0)
            hybridisation = atom.GetHybridization()
            sp.append(1 if hybridisation == HybridizationType.SP else 0)
            sp2.append(1 if hybridisation == HybridizationType.SP2 else 0)
            sp3.append(1 if hybridisation == HybridizationType.SP3 else 0)
            # !!! should do the features that lucky does: whether bonded, 3d_rbf

        # bond/edge features
        row, col, edge_type = [], [], []
        for bond in mol.GetBonds(): 
            start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            row += [start, end]
            col += [end, start]
            # edge type for each bond type; *2 because both ways
            edge_type += 2 * [bonds[bond.GetBondType()]]
        # edge_index is graph connectivity in COO format with shape [2, num_edges]
        edge_index = torch.tensor([row, col], dtype=torch.long)
        edge_type = torch.tensor(edge_type, dtype=torch.long)
        # edge_attr is edge feature matrix with shape [num_edges, num_edge_features]
        edge_attr = F.one_hot(edge_type, num_classes=len(bonds)).to(torch.float) 

        # order edges based on combined ascending order
        perm = (edge_index[0] * N + edge_index[1]).argsort() # TODO
        edge_index = edge_index[:, perm]
        edge_type = edge_type[perm]
        edge_attr = edge_attr[perm]

        row, col = edge_index
        z = torch.tensor(atomic_number, dtype=torch.long)
        hs = (z == 1).to(torch.float) # hydrogens
        num_hs = scatter(hs[row], col, dim_size=N).tolist() # scatter helps with one-hot
        
        x1 = F.one_hot(torch.tensor(type_idx), num_classes=len(types))
        x2 = torch.tensor([atomic_number, aromatic, sp, sp2, sp3, num_hs], dtype=torch.float).t().contiguous()
        x = torch.cat([x1.to(torch.float), x2], dim=-1)

        data = Data(x=x, z=z, pos=atom_positions, edge_index=edge_index, edge_attr=edge_attr, idx=i)
        
        data_list.append(data)

    return data_list

In [12]:
# concat train r and test r
reactants = []
reactants = process_geometry_file('/raw/train_reactants.sdf', reactants)
reactants = process_geometry_file('/raw/test_reactants.sdf', reactants)

# concat train ts and test ts
ts = []
ts = process_geometry_file('/raw/train_ts.sdf', ts)
ts = process_geometry_file('/raw/test_ts.sdf', ts) 

# concat train p and test p
products = []
products = process_geometry_file('/raw/train_products.sdf', products)
products = process_geometry_file('/raw/test_products.sdf', products) 

assert len(reactants) == len(ts) == len(products)

print(type(reactants[0]), type(ts[0]), type(products[0]))

  1%|▏         | 100/6739 [00:00<00:23, 286.92it/s]
 12%|█▏        | 100/842 [00:00<00:02, 268.14it/s]
  1%|▏         | 100/6739 [00:00<00:16, 414.71it/s]
 12%|█▏        | 100/842 [00:00<00:01, 404.30it/s]
  1%|▏         | 100/6739 [00:00<00:12, 548.55it/s]
 12%|█▏        | 100/842 [00:00<00:02, 314.45it/s]

<class 'torch_geometric.data.data.Data'> <class 'torch_geometric.data.data.Data'> <class 'torch_geometric.data.data.Data'>





In [57]:
class ReactionTriple(Data):
    def __init__(self, r = None, ts = None, p = None):
        super(ReactionTriple, self).__init__()
        self.r = r
        self.ts = ts
        self.p = p

    def __inc__(self, key, value):
        if key == 'r':
            return self.r.edge_index.size(0)
        elif key == 'ts':
            return self.ts.edge_index.size(0)
        elif key == 'p':
            return self.p.edge_index.size(0)
        else:
            return super().__inc__(key, value)

In [58]:
rxns = []
for rxn_id in range(len(reactants)):
    rxn = ReactionTriple(reactants[rxn_id], ts[rxn_id], products[rxn_id])
    rxns.append(rxn)

# Normal data processing from files

In [4]:
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv, GAE
from torch_geometric.data import DataLoader
from ts_vae.data_processors.new_pyg_processor import ReactionDataset
from ts_vae.gae import GAE, MolEncoder, InnerProductDecoder
import numpy as np

In [5]:
rxns = ReactionDataset(r'data')

num_rxns = len(rxns)
train_ratio = 0.8
num_train = int(np.floor(train_ratio * num_rxns))

train_loader = DataLoader(rxns[: num_train], batch_size = 2, follow_batch = ['r', 'p'])
test_loader = DataLoader(rxns[num_train:], batch_size = 3, follow_batch = ['r', 'p'])

# batch = next(iter(train_loader))
# batch.p
# batch.r

In [9]:
for i, b in enumerate(test_loader):
    if i < 2:
        print(b.r)

[Data(edge_attr=[30, 4], edge_index=[2, 30], idx=48, pos=[15, 3], x=[15, 11], z=[15]), Data(edge_attr=[26, 4], edge_index=[2, 26], idx=49, pos=[14, 3], x=[14, 11], z=[14]), Data(edge_attr=[30, 4], edge_index=[2, 30], idx=50, pos=[14, 3], x=[14, 11], z=[14])]
[Data(edge_attr=[38, 4], edge_index=[2, 38], idx=51, pos=[17, 3], x=[17, 11], z=[17]), Data(edge_attr=[32, 4], edge_index=[2, 32], idx=52, pos=[15, 3], x=[15, 11], z=[15]), Data(edge_attr=[34, 4], edge_index=[2, 34], idx=53, pos=[17, 3], x=[17, 11], z=[17])]


In [None]:
# have to convert the training scheme: no more edge sampling and now use batches
# TODO: create data, opt, model; data and model to device

# simple R->R GAE, then build up

def train_gae(gae, opt, loader):
    # singular batch train loop

    model.train() # set flags

    batch_loss = 0

    # one iteration over different batches
    for i, rxn_batch in enumerate(loader):
        reactants = rxn_batch.r
        
        opt.zero_grad() # zero gradients

        # pad mols for batch/maybe just pad all with max_num_atoms

        # encode reactant batch and calculate loss
        z_r = gae.encode(reactants)
        loss = gae.recon_loss(z_r)
        
        # modify gradients
        loss.backward()
        opt.step()
        
        return float(loss)
        
    
    pass


# my train has automatic train-test edge split but this is what i was going for before [mmvae]
def train(epoch, agg):
    model.train()
    b_loss = 0
    for i, dataT in enumerate(train_loader):
        data = unpack_data(dataT, device=device)
        optimizer.zero_grad()
        loss = - objective(model, data, K=args.K)
        loss.backward()
        optimizer.step()
        b_loss += loss.item()
        if args.print_freq > 0 and i % args.print_freq == 0:
            print("iteration {:04d}: loss: {:6.3f}".format(i, loss.item() / args.batch_size))
    agg['train_loss'].append(b_loss / len(train_loader.dataset))
    print('====> Epoch: {:03d} Train loss: {:.4f}'.format(epoch, agg['train_loss'][-1]))

In [None]:
def train_gae(gae, opt, x, train_pos_edge_index):
    gae.train()
    opt.zero_grad()
    z = gae.encode(x, train_pos_edge_index)
    loss = gae.recon_loss(z, train_pos_edge_index)
    loss.backward()
    opt.step()
    return float(loss)

def test_gae(gae, x, train_pos_edge_index, test_pos_edge_index, test_neg_edge_index):
    gae.eval()
    with torch.no_grad():
        z = gae.encode(x, train_pos_edge_index)
    return gae.test(z, test_pos_edge_index, test_neg_edge_index)

r_ae.reset_parameters()

epochs = 100
for epoch in range(1, epochs + 1):
    loss = train_gae(r_ae, r_opt, r_x, r_data.train_pos_edge_index)
    auc, ap = test_gae(r_ae, r_x, r_data.train_pos_edge_index, r_data.test_pos_edge_index, r_data.test_neg_edge_index)
    if epoch % 10 == 0:
        print('Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}'.format(epoch, auc, ap))

In [None]:
# my train has automatic train-test edge split but this is what i was going for before [mmvae]
def train(epoch, agg):
    model.train()
    b_loss = 0
    for i, dataT in enumerate(train_loader):
        data = unpack_data(dataT, device=device)
        optimizer.zero_grad()
        loss = - objective(model, data, K=args.K)
        loss.backward()
        optimizer.step()
        b_loss += loss.item()
        if args.print_freq > 0 and i % args.print_freq == 0:
            print("iteration {:04d}: loss: {:6.3f}".format(i, loss.item() / args.batch_size))
    agg['train_loss'].append(b_loss / len(train_loader.dataset))
    print('====> Epoch: {:03d} Train loss: {:.4f}'.format(epoch, agg['train_loss'][-1]))


# Loop over epochs
for epoch in range(max_epochs):
    # Training
    for batch, labels in loader:
        # Transfer to GPU if available
        batch, labels = batch.to(device), labels.to(device)
        # Model computations
        [...]