In [1]:
# data processing
from ts_vae.data_processors.grambow_processor import ReactionDataset

# my GAEs
from ts_vae.gaes.n_gae import Node_AE, train_node_ae, test_node_ae
from ts_vae.gaes.ne_gae import NodeEdge_AE, train_ne_ae, test_ne_ae
from ts_vae.gaes.nec_gae import NodeEdgeCoord_AE, train_nec_ae, test_nec_ae, main

# torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import ConcatDataset

# torch geometric
from torch_geometric.data import DataLoader
from torch_geometric.utils import to_dense_adj

# other
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score

In [2]:
# remove processed files
import os
import glob

files = glob.glob(r'data/processed/*')
for f in files:
    os.remove(f)

In [2]:
### New Data Processing

rxns = ReactionDataset(r'data')

num_rxns = len(rxns)
train_ratio = 0.8
num_train = int(np.floor(train_ratio * num_rxns))

batch_size = 1
to_follow = ['edge_index_r', 'edge_index_ts', 'edge_index_p', 'edge_attr_r', 'edge_attr_ts', 'edge_attr_p'
             'pos_r', 'pos_ts', 'pos_p', 'x_r', 'x_ts', 'x_p']


train_loader = DataLoader(rxns[: num_train], batch_size = batch_size, follow_batch = to_follow)
test_loader = DataLoader(rxns[num_train: ], batch_size = batch_size, follow_batch = to_follow)

In [9]:
### NEW! NodeEdgeCoord AE
max_num_atoms = max([rxn.num_atoms.item() for rxn in train_loader.dataset])
assert all(rxn.num_atom_fs.item() == train_loader.dataset[0].num_atom_fs.item() for rxn in train_loader.dataset)
num_atom_fs = train_loader.dataset[0].num_atom_fs.item()
assert all(rxn.num_bond_fs.item() == train_loader.dataset[0].num_bond_fs.item() for rxn in train_loader.dataset)
num_bond_fs = train_loader.dataset[0].num_bond_fs.item()
h_nf = 5
emb_nf = 2

# model and opt
nec_ae = NodeEdgeCoord_AE(in_node_nf = num_atom_fs, in_edge_nf = num_bond_fs, h_nf = h_nf, out_nf = h_nf, emb_nf = emb_nf)
nec_opt = torch.optim.Adam(nec_ae.parameters(), lr = 1e-3)

# train and test
#train_loss, train_res = train_nec_ae(nec_ae, nec_opt, train_loaders['r'])
#test_loss, test_res = test_nec_ae(nec_ae, test_loaders['r'])

In [10]:
COORD_LOSS_SCALER = 10 # coord_loss around 0.015, adj_loss around 0.7, 0.7/0.015 ~= 47
from torch_geometric.utils.sparse import dense_to_sparse

def train(nec_ae, opt, loader):

    res = {'total_loss': 0, 'batch_counter': 0, 'coord_loss_arr': [], 'adj_loss_arr': [],
           'node_loss_arr': []}

    adj_save = {'adj_gt_arr': [], 'adj_pred_arr': []}

    for i, rxn_batch in enumerate(loader):

        nec_ae.train()
        opt.zero_grad()

        # init required variables
        r_node_feats, r_edge_index, r_edge_attr, r_coords = rxn_batch.x_r, rxn_batch.edge_index_r, rxn_batch.edge_attr_r, rxn_batch.pos_r
        ts_node_feats, ts_edge_index, ts_edge_attr, ts_coords = rxn_batch.x_ts, rxn_batch.edge_index_ts, rxn_batch.edge_attr_ts, rxn_batch.pos_ts
        batch_size = len(rxn_batch.idx)
        max_num_atoms = sum(rxn_batch.num_atoms).item() # add this in because sometimes we get hanging atoms if bonds broken

        # run model on reactant
        node_emb, edge_emb, recon_node_fs, recon_edge_fs, adj_pred, coord_out = nec_ae(r_node_feats, r_edge_index, r_edge_attr, r_coords)

        # ground truth values
        adj_gt = to_dense_adj(ts_edge_index, max_num_nodes = max_num_atoms).squeeze(dim = 0)
        assert adj_gt.shape == adj_pred.shape, f"Your adjacency matrices don't have the same shape! \n \
               GT shape: {adj_gt.shape}, Pred shape: {adj_pred.shape}, Batch size: {batch_size} \n \
               {ts_edge_index}, {ts_node_feats.shape}, \n \
                   num_atoms: {rxn_batch.num_atoms}"
        

        # losses and opt step
        try:
            adj_loss = F.binary_cross_entropy(adj_pred, adj_gt)
        except:
            print(f"Epoch: {i}")
            # print(f"{adj_pred}")
            continue
        
        total_loss = adj_loss

        #coord_loss = COORD_LOSS_SCALER * F.mse_loss(coord_out, ts_coords) # barely any change
        #coord_loss = torch.sqrt(F.mse_loss(coord_out, ts_coords))
        # total_loss = coord_loss

        #total_loss = adj_loss + coord_loss
        
        #node_loss = F.mse_loss(recon_node_fs, ts_node_feats) # goes from 0.2 -> 0.01 in 20 epochs i.e. works
        #total_loss = node_loss

        total_loss.backward()
        opt.step()

        # record batch results
        res['total_loss'] += total_loss.item()
        res['batch_counter'] += 1
        
        # res['coord_loss_arr'].append(coord_loss.item())
        res['adj_loss_arr'].append(adj_loss.item())
        # res['node_loss_arr'].append(node_loss.item())

        if i < 10:
            adj_save['adj_gt_arr'].append(adj_gt)
            adj_save['adj_pred_arr'].append(adj_pred)
    
    return res['total_loss'] / res['batch_counter'], res, adj_save

loss, res, adj_save = train(nec_ae, nec_opt, train_loader)

In [14]:
i = 4
adj_save['adj_gt_arr'][i]

tensor([[0., 1., 0., 0., 0., 0., 1., 1., 0., 0., 0.],
        [1., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 1., 1., 0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0., 1., 0., 0., 0., 1., 1.],
        [0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]])

In [15]:
adj_save['adj_pred_arr'][i]

tensor([[0.0000, 0.0289, 0.0208, 0.0040, 0.0141, 0.0037, 0.2207, 0.2782, 0.1380,
         0.2109, 0.2470],
        [0.0289, 0.0000, 0.0103, 0.0362, 0.0167, 0.0302, 0.1049, 0.1382, 0.0530,
         0.0927, 0.1257],
        [0.0208, 0.0103, 0.0000, 0.0377, 0.0324, 0.0333, 0.1746, 0.2075, 0.1081,
         0.1553, 0.2012],
        [0.0040, 0.0362, 0.0377, 0.0000, 0.0084, 0.0003, 0.2103, 0.2753, 0.1289,
         0.2063, 0.2340],
        [0.0141, 0.0167, 0.0324, 0.0084, 0.0000, 0.0058, 0.1366, 0.1926, 0.0721,
         0.1340, 0.1563],
        [0.0037, 0.0302, 0.0333, 0.0003, 0.0058, 0.0000, 0.1967, 0.2594, 0.1181,
         0.1923, 0.2200],
        [0.2207, 0.1049, 0.1746, 0.2103, 0.1366, 0.1967, 0.0000, 0.0090, 0.0104,
         0.0023, 0.0010],
        [0.2782, 0.1382, 0.2075, 0.2753, 0.1926, 0.2594, 0.0090, 0.0000, 0.0324,
         0.0056, 0.0093],
        [0.1380, 0.0530, 0.1081, 0.1289, 0.0721, 0.1181, 0.0104, 0.0324, 0.0000,
         0.0112, 0.0168],
        [0.2109, 0.0927, 0.1553, 0.20

In [17]:
### NodeEdgeCoord Model 

epochs = 10
# test_interval = 1000

# nec_ae.reset_parameters()
torch.set_printoptions(precision = 3)

final_res = {'train_loss_arr': [], 'train_res_arr': [], 'test_loss_arr': [], 'test_res_arr': [], 
             'best_test': 1e10, 'best_epoch': 0}

for epoch in range(1, epochs + 1):

    train_loss, train_res, adj = train(nec_ae, nec_opt, train_loader)
    final_res['train_loss_arr'].append(train_loss)
    final_res['train_res_arr'].append(train_res)
    
    if epoch < 10:
        print(f"===== Training epoch {epoch:03d} complete with loss: {train_loss:.4f} ====")
    if epoch % 10 == 0:
        print(f"===== Training epoch {epoch:03d} complete with loss: {train_loss:.4f} ====")

===== Training epoch 001 complete with loss: 0.3339 ====
===== Training epoch 002 complete with loss: 0.3328 ====
===== Training epoch 003 complete with loss: 0.3324 ====
===== Training epoch 004 complete with loss: 0.3310 ====
===== Training epoch 005 complete with loss: 0.3299 ====
===== Training epoch 006 complete with loss: 0.3292 ====
===== Training epoch 007 complete with loss: 0.3283 ====
===== Training epoch 008 complete with loss: 0.3278 ====
===== Training epoch 009 complete with loss: 0.3274 ====
===== Training epoch 010 complete with loss: 0.3276 ====


In [28]:
final_res['train_res_arr'][0]

{'total_loss': 11.281865656375885,
 'batch_counter': 16,
 'coord_loss_arr': [],
 'adj_loss_arr': [0.7181799411773682,
  0.7143734693527222,
  0.709192156791687,
  0.7101082801818848,
  0.7094240188598633,
  0.7085379362106323,
  0.7043125033378601,
  0.7037222385406494,
  0.70466548204422,
  0.702354371547699,
  0.6996329426765442,
  0.7027063965797424,
  0.6990594863891602,
  0.6990302205085754,
  0.6985227465629578,
  0.6980434656143188],
 'node_loss_arr': []}

In [52]:
final_res['train_res_arr'][0]

{'total_loss': 13.69211345911026,
 'num_rxns': 320,
 'coord_loss_arr': [0.4517003893852234,
  0.3646640181541443,
  0.379293292760849,
  0.4266248345375061,
  0.3577047884464264,
  0.4440578818321228,
  0.5548912882804871,
  0.45330995321273804,
  0.45632728934288025,
  0.3787263035774231,
  0.44680121541023254,
  0.36229124665260315,
  0.30292972922325134,
  0.42477449774742126,
  0.4116053283214569,
  0.4654482305049896,
  0.42196935415267944,
  0.3555295169353485,
  0.4371601343154907,
  0.35873374342918396,
  0.39976420998573303,
  0.4350343942642212,
  0.36323943734169006,
  0.4794536828994751,
  0.36955004930496216,
  0.3972865343093872,
  0.4190240800380707,
  0.4411216676235199,
  0.4474222660064697,
  0.4987640678882599,
  0.6237435936927795,
  0.5631664395332336],
 'adj_loss_arr': [],
 'node_loss_arr': []}

In [55]:
final_res['train_res_arr'][19]

{'total_loss': 13.650706440210342,
 'num_rxns': 320,
 'coord_loss_arr': [0.44959649443626404,
  0.36572614312171936,
  0.37784436345100403,
  0.4274539649486542,
  0.35596367716789246,
  0.44366729259490967,
  0.5529890656471252,
  0.44835323095321655,
  0.4550268352031708,
  0.3751155436038971,
  0.4419046640396118,
  0.36208486557006836,
  0.30449697375297546,
  0.42465561628341675,
  0.4112507998943329,
  0.4648565948009491,
  0.4198797643184662,
  0.3543452322483063,
  0.4375874698162079,
  0.35692909359931946,
  0.3987211585044861,
  0.4339248239994049,
  0.363501638174057,
  0.47662532329559326,
  0.3682370185852051,
  0.3952687680721283,
  0.4194740355014801,
  0.442590594291687,
  0.44637081027030945,
  0.49775010347366333,
  0.6186178922653198,
  0.5598965883255005],
 'adj_loss_arr': [],
 'node_loss_arr': []}