In [1]:
# data processing
from ts_vae.data_processors.grambow_processor import ReactionDataset

# my GAEs
from ts_vae.gaes.n_gae import Node_AE, train_node_ae, test_node_ae
from ts_vae.gaes.ne_gae import NodeEdge_AE, train_ne_ae, test_ne_ae
from ts_vae.gaes.nec_gae import NodeEdgeCoord_AE, train_nec, test_nec, main_nec

# torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import ConcatDataset

# torch geometric
from torch_geometric.data import DataLoader
from torch_geometric.utils import to_dense_adj
from torch_geometric.utils.sparse import dense_to_sparse

# other
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score

In [2]:
# remove processed files
import os
import glob

files = glob.glob(r'data/processed/*')
for f in files:
    os.remove(f)

In [2]:
### New Data Processing

rxns = ReactionDataset(r'data')

num_rxns = len(rxns)
train_ratio = 0.8
num_train = int(np.floor(train_ratio * num_rxns))

batch_size = 10
to_follow = ['edge_index_r', 'edge_index_ts', 'edge_index_p', 'edge_attr_r', 'edge_attr_ts', 'edge_attr_p'
             'pos_r', 'pos_ts', 'pos_p', 'x_r', 'x_ts', 'x_p']


train_loader = DataLoader(rxns[: num_train], batch_size = batch_size, follow_batch = to_follow)
test_loader = DataLoader(rxns[num_train: ], batch_size = batch_size, follow_batch = to_follow)

In [4]:
### NEW! NodeEdgeCoord AE
max_num_atoms = max([rxn.num_atoms.item() for rxn in train_loader.dataset])
assert all(rxn.num_atom_fs.item() == train_loader.dataset[0].num_atom_fs.item() for rxn in train_loader.dataset)
num_atom_fs = train_loader.dataset[0].num_atom_fs.item()
assert all(rxn.num_bond_fs.item() == train_loader.dataset[0].num_bond_fs.item() for rxn in train_loader.dataset)
num_bond_fs = train_loader.dataset[0].num_bond_fs.item()
h_nf = 5
emb_nf = 2

# model and opt
nec_ae = NodeEdgeCoord_AE(in_node_nf = num_atom_fs, in_edge_nf = num_bond_fs, h_nf = h_nf, out_nf = h_nf, emb_nf = emb_nf)
nec_opt = torch.optim.Adam(nec_ae.parameters(), lr = 1e-3)

# train, test, main
train_loss, train_res = train_nec(nec_ae, nec_opt, train_loader)
test_loss, test_res = test_nec(nec_ae, test_loader)
final_res = main_nec(nec_ae, nec_opt, train_loader, test_loader)

===== Training epoch 001 complete with loss: 4.7408 ====
===== Training epoch 002 complete with loss: 4.4853 ====
===== Training epoch 003 complete with loss: 4.2326 ====
===== Training epoch 004 complete with loss: 3.9473 ====
===== Training epoch 005 complete with loss: 3.6093 ====
===== Testing epoch: 005, Loss: 3.3752 ===== 

===== Training epoch 006 complete with loss: 3.2334 ====
===== Training epoch 007 complete with loss: 2.8690 ====
===== Training epoch 008 complete with loss: 2.5583 ====
===== Training epoch 009 complete with loss: 2.3064 ====
===== Training epoch 010 complete with loss: 2.0526 ====
===== Testing epoch: 010, Loss: 1.8907 ===== 

===== Training epoch 011 complete with loss: 1.6853 ====
===== Training epoch 012 complete with loss: 1.3897 ====
===== Training epoch 013 complete with loss: 1.2721 ====
===== Training epoch 014 complete with loss: 1.2093 ====
===== Training epoch 015 complete with loss: 1.1642 ====
===== Testing epoch: 015, Loss: 1.1696 ===== 

====