In [1]:
# data processing
from ts_vae.data_processors.grambow_processor import ReactionDataset

# my GAEs
from ts_vae.gaes.n_gae import Node_AE, train_node_ae, test_node_ae
from ts_vae.gaes.ne_gae import NodeEdge_AE, train_ne_ae, test_ne_ae
from ts_vae.gaes.nec_gae import NodeEdgeCoord_AE, train_nec_ae, test_nec_ae, main

# torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import ConcatDataset

# torch geometric
from torch_geometric.data import DataLoader
from torch_geometric.utils import to_dense_adj

# other
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score

In [2]:
# remove processed files
import os
import glob

files = glob.glob(r'data/processed/*')
for f in files:
    os.remove(f)

In [4]:
### New Data Processing

rxns = ReactionDataset(r'data')

num_rxns = len(rxns)
train_ratio = 0.8
num_train = int(np.floor(train_ratio * num_rxns))

batch_size = 2
to_follow = ['edge_index_r', 'edge_index_ts', 'edge_index_p', 'edge_attr_r', 'edge_attr_ts', 'edge_attr_p'
             'pos_r', 'pos_ts', 'pos_p', 'x_r', 'x_ts', 'x_p']


train_loader = DataLoader(rxns[: num_train], batch_size = 10, follow_batch = to_follow)
test_loader = DataLoader(rxns[num_train: ], batch_size = 10, follow_batch = to_follow)

In [5]:
### NEW! NodeEdgeCoord AE
max_num_atoms = max([rxn.num_atoms.item() for rxn in train_loader.dataset])
assert all(rxn.num_atom_fs.item() == train_loader.dataset[0].num_atom_fs.item() for rxn in train_loader.dataset)
num_atom_fs = train_loader.dataset[0].num_atom_fs.item()
assert all(rxn.num_bond_fs.item() == train_loader.dataset[0].num_bond_fs.item() for rxn in train_loader.dataset)
num_bond_fs = train_loader.dataset[0].num_bond_fs.item()
h_nf = 5
emb_nf = 2

# model and opt
nec_ae = NodeEdgeCoord_AE(in_node_nf = num_atom_fs, in_edge_nf = num_bond_fs, h_nf = h_nf, out_nf = h_nf, emb_nf = emb_nf)
nec_opt = torch.optim.Adam(nec_ae.parameters(), lr = 1e-3)

# train and test
#train_loss, train_res = train_nec_ae(nec_ae, nec_opt, train_loaders['r'])
#test_loss, test_res = test_nec_ae(nec_ae, test_loaders['r'])

In [6]:
def train(nec_ae, opt, loader):

    res = {'total_loss': 0, 'num_rxns': 0, 'coord_loss_arr': []}

    for i, rxn_batch in enumerate(loader):

        nec_ae.train()
        opt.zero_grad()

        # init required variables
        r_node_feats, r_edge_index, r_edge_attr, r_coords = rxn_batch.x_r, rxn_batch.edge_index_r, rxn_batch.edge_attr_r, rxn_batch.pos_r
        ts_node_feats, ts_edge_index, ts_edge_attr, ts_coords = rxn_batch.x_ts, rxn_batch.edge_index_ts, rxn_batch.edge_attr_ts, rxn_batch.pos_ts
        batch_size = len(rxn_batch.idx)

        # run model on reactant
        node_emb, edge_emb, recon_node_fs, recon_edge_fs, adj_pred, coord_out = nec_ae(r_node_feats, r_edge_index, r_edge_attr, r_coords)

        # ground truth values
        adj_gt = to_dense_adj(ts_edge_index).squeeze(dim = 0)
        assert adj_gt.shape == adj_pred.shape, f"Your adjacency matrices don't have the same shape! \
                GT shape: {adj_gt.shape}, Pred shape: {adj_pred.shape}, Batch size: {batch_size}"
        
        # losses and opt step
        coord_loss = F.mse_loss(coord_out, ts_coords)
        total_loss = coord_loss
        total_loss.backward()
        opt.step()

        # record batch results
        res['total_loss'] += total_loss.item()
        res['num_rxns'] += batch_size
        res['coord_loss_arr'].append(coord_loss.item())
    
    return res['total_loss'] / res['num_rxns'], res

loss, res = train(nec_ae, nec_opt, train_loader)

In [7]:
### NodeEdgeCoord Model 

epochs = 50
test_interval = 1000

# r_ae.reset_parameters()
torch.set_printoptions(precision = 2)

final_res = {'train_loss_arr': [], 'train_res_arr': [], 'test_loss_arr': [], 'test_res_arr': [], 
             'best_test': 1e10, 'best_epoch': 0}

for epoch in range(1, epochs + 1):

    train_loss, train_res = train(nec_ae, nec_opt, train_loader)
    final_res['train_loss_arr'].append(train_loss)
    final_res['train_res_arr'].append(train_res)
    
    if epoch % 10 == 0:
        print(f"===== Training epoch {epoch:03d} complete with loss: {train_loss:.4f} ====")

    if epoch % test_interval == 0:
    
        test_loss, test_res = test(nec_ae, test_loader)
        final_res['test_loss_arr'].append(test_loss)
        final_res['test_res_arr'].append(test_res)
        print(f'===== Testing epoch: {epoch:03d}, Loss: {test_loss:.4f} ===== \n')
        
        if test_loss < final_res['best_test']:
            final_res['best_test'] = test_loss
            final_res['best_epoch'] = epoch

===== Training epoch 010 complete with loss: 0.0201 ====
===== Training epoch 020 complete with loss: 0.0200 ====
===== Training epoch 030 complete with loss: 0.0200 ====
===== Training epoch 040 complete with loss: 0.0200 ====
===== Training epoch 050 complete with loss: 0.0200 ====
