### Experiment Tracking with W&B

- config: store hp and metadata for each run
- wandb.init
- wandb.watch: log model gradients and params over time (helps detect bugs e.g. weird grad behaviour)
- wandb.log: log stuff we care about
- wandb.save: save online

use with block in context manager syntax

In [None]:
import wandb
wandb.login()

In [None]:
config = dict(
    epochs = 50,
    val_ratio = 0,
    test_ratio = 0.2
)

In [None]:
def make(base_path, val_ratio, test_ratio, encode_data_name, decode_data_name, latent_dim):
    # TODO: make edges to device here on when called on
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # dataset to encode
    encode_dataset = ReactionDataset(base_path, geo_file = encode_data_name, dataset_type= 'individual')
    encode_data = encode_dataset.data
    encode_data.train_mask = encode_data.val_mask = encode_data.test_mask = encode_data.y = None
    encode_data = train_test_split_edges(data = encode_data, val_ratio = val_ratio, test_ratio = test_ratio)
    encode_x = encode_data.x.to(device)
    encode_train_pos_edge_index = encode_data.train_pos_edge_index.to(device)

    # dataset to decode
    decode_dataset = ReactionDataset(base_path, geo_file = decode_data_name, dataset_type= 'individual')
    decode_data = decode_dataset.data
    decode_data.train_mask = decode_data.val_mask = decode_data.test_mask = decode_data.y = None
    decode_data = train_test_split_edges(data = decode_data, val_ratio = val_ratio, test_ratio = test_ratio)
    decode_x = decode_data.x.to(device)
    decode_train_pos_edge_index = decode_data.train_pos_edge_index.to(device)

    # model creation
    gae = GAE(MolEncoder(encode_data.num_node_features, latent_dim))
    opt = torch.optim.Adam(gae.parameters(), lr = 0.01)

    return gae, opt, encode_data, decode_data

In [None]:
def model_pipeline(hps):

    # start wandb
    with wandb.init(project="test", config=hps):
        
        # access hps through wandb.config so logging matches execution
        config = wandb.config

        # model data
        
        val_ratio = 0
        test_ratio = 0.2
        
        # make model, data, opt problem
        ts_r_gae, ts_r_opt, r_data, ts_data = make(r'data/', 0, 0.2, 'train_r', 'train_ts', 2)

### Testing GAEs

In [1]:
# data processing
from ts_vae.data_processors.new_pyg_processor import ReactionDataset

# Node AE
from ts_vae.gaes.n_gae import Node_AE, train_node_ae, test_node_ae
# NodeEdge AE
from ts_vae.gaes.ne_gae import NodeEdge_AE, train_ne_ae, test_ne_ae
# NodeEdgeCoord AE
from ts_vae.gaes.nec_gae import NodeEdgeCoord_AE, train_nec_ae, test_nec_ae

# torch
import torch
import torch.nn as nn
import torch.nn.functional as F

# torch geometric
from torch_geometric.data import DataLoader
# should double check this func works okay
from torch_geometric.utils import to_dense_adj

# other
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score

In [2]:
# remove processed files
import os
import glob

files = glob.glob(r'data/processed/*')
for f in files:
    os.remove(f)

## New try

In [2]:
rxns = ReactionDataset(r'data')
reactants = rxns.data.r
transition_states = rxns.data.ts
products = rxns.data.p

# train_loader = DataLoader(rxns[: num_train], batch_size = 2, follow_batch = ['r', 'p'])
# test_loader = DataLoader(rxns[num_train:], batch_size = 2, follow_batch = ['r', 'p'])

num_rxns = len(rxns)
train_ratio = 0.8
num_train = int(np.floor(train_ratio * num_rxns))

batch_size = 2

# need to be able to recover original reactants after encoding
# note: no padding, since PyG automatically factors this in
train_loaders = {'r':  DataLoader(reactants[: num_train], batch_size), 
                 'ts': DataLoader(transition_states[: num_train], batch_size), 
                 'p':  DataLoader(products[: num_train], batch_size)}

test_loaders =  {'r':  DataLoader(reactants[num_train: ], batch_size), 
                 'ts': DataLoader(transition_states[num_train: ], batch_size), 
                 'p':  DataLoader(products[num_train: ], batch_size)}

In [3]:
### Node AE
max_num_nodes = max([r.z.size(0) for r in train_loaders['r'].dataset])
assert([r.x.size(1) for r in train_loaders['r'].dataset] == [train_loaders['r'].dataset[0].x.size(1)] * len(train_loaders['r'].dataset))
num_node_fs = train_loaders['r'].dataset[0].x.size(1)
num_edge_fs = train_loaders['r'].dataset[0].edge_attr.size(1)
h_nf = 5
emb_nf = 2

# in_node_nf + in_edge_nf >= h_nf >= out_nf > emb_nf 
node_ae = Node_AE(in_node_nf = num_node_fs, in_edge_nf = num_edge_fs, h_nf = h_nf, out_nf = h_nf, emb_nf = emb_nf)
node_opt = torch.optim.Adam(node_ae.parameters(), lr = 1e-3)

# train and test, add epochs after
train_loss, train_res = train_node_ae(node_ae, node_opt, train_loaders['r'])
test_loss, test_res = test_node_ae(node_ae, node_opt, test_loaders['r']) 

In [4]:
### NodeEdge AE
max_num_nodes = max([r.z.size(0) for r in train_loaders['r'].dataset])
assert([r.x.size(1) for r in train_loaders['r'].dataset] == [train_loaders['r'].dataset[0].x.size(1)] * len(train_loaders['r'].dataset))
num_node_fs = train_loaders['r'].dataset[0].x.size(1)
num_edge_fs = train_loaders['r'].dataset[0].edge_attr.size(1)
h_nf = 5
emb_nf = 2

# model and opt
ne_ae = NodeEdge_AE(in_node_nf = num_node_fs, in_edge_nf = num_edge_fs, h_nf = h_nf, out_nf = h_nf, emb_nf = emb_nf)
ne_opt = torch.optim.Adam(ne_ae.parameters(), lr = 1e-3)

# train and test
train_loss, train_res = train_ne_ae(ne_ae, ne_opt, train_loaders['r'])
test_loss, test_res = test_ne_ae(ne_ae, test_loaders['r']) 

In [3]:
### NodeEdgeCoord AE
max_num_nodes = max([r.z.size(0) for r in train_loaders['r'].dataset])
assert([r.x.size(1) for r in train_loaders['r'].dataset] == [train_loaders['r'].dataset[0].x.size(1)] * len(train_loaders['r'].dataset))
num_node_fs = train_loaders['r'].dataset[0].x.size(1)
num_edge_fs = train_loaders['r'].dataset[0].edge_attr.size(1)
h_nf = 5
emb_nf = 2

# model and opt
nec_ae = NodeEdgeCoord_AE(in_node_nf = num_node_fs, in_edge_nf = num_edge_fs, h_nf = h_nf, out_nf = h_nf, emb_nf = emb_nf)
nec_opt = torch.optim.Adam(nec_ae.parameters(), lr = 1e-3)

# train and test
train_loss, train_res = train_nec_ae(nec_ae, nec_opt, train_loaders['r'])
test_loss, test_res = test_nec_ae(nec_ae, test_loaders['r'])

In [4]:
train_loss, test_loss

(3.447507619857788, 2.907812794049581)

In [158]:
# DONE! R->R: node ae, train 
# DONE! R->R: node + edge ae, train
# DONE! R->R: node + edge + coords, train
# Coords: training loop, more data (~1000), R->TS (ground truth as ['ts']) on NEC
#   - Compare N, NE, NEC on more data (~1000), more training epochs
#       - Need to figure out input num fs for train and test when on bigger sets
# Tues slides
#   - Work out which figures I need for tomorrow: D_init distr
#   - Presentation outline

# After meeting:
#   - MIT coords
#   - NE -> coords for comparison
#   - P->TS, (R,P)->TS
#   - All the data
#   - Gradients for diff NNs within func?
#   - Coordinates to interatomic
#   - RMSD from github

# Long after:
#   - DGL for improvements: k-hop graph func + khop adj util func

Batch(batch=[15], edge_attr=[30, 4], edge_index=[2, 30], idx=[1], pos=[15, 3], ptr=[2], x=[15, 11], z=[15])

In [19]:
epochs = 10
test_interval = 5

final_res = {'epochs': [], 'train_loss_arr': [], 'train_res_arr': [], 
             'test_loss_arr': [], 'test_res_arr': [], 'best_test': 1e10, 'best_epoch': 0}
train_loss_arr = []
train_res_arr = []
test_loss_arr = []
test_res_arr = []

# r_ae.reset_parameters()

for epoch in range(1, epochs + 1):
    
    train_loss, train_res = train_nec_ae(nec_ae, nec_opt, train_loaders['r'])
    final_res['train_loss_arr'].append(train_loss)
    final_res['train_res_arr'].append(train_res)
    print(f"===== Training epoch {epoch:03d} complete with loss: {train_loss:.4f} ====")
    
    if epoch % test_interval == 0:
    
        test_loss, test_res = test_nec_ae(nec_ae, test_loaders['r'])
        final_res['test_loss_arr'].append(test_loss)
        final_res['test_res_arr'].append(test_res)
        print(f'===== Testing epoch: {epoch:03d}, Loss: {test_loss:.4f} ===== \n')
        
        if test_loss < final_res['best_test']:
            final_res['best_test'] = test_loss
            final_res['best_epoch'] = epoch

===== Training epoch 001 complete with loss: 3.1545 ====
===== Training epoch 002 complete with loss: 3.1154 ====
===== Training epoch 003 complete with loss: 3.1296 ====
===== Training epoch 004 complete with loss: 3.1444 ====
===== Training epoch 005 complete with loss: 3.0177 ====
===== Testing epoch: 005, Loss: 2.9221 ===== 

===== Training epoch 006 complete with loss: 3.1225 ====
===== Training epoch 007 complete with loss: 3.1225 ====
===== Training epoch 008 complete with loss: 3.2075 ====
===== Training epoch 009 complete with loss: 3.1491 ====
===== Training epoch 010 complete with loss: 3.1824 ====
===== Testing epoch: 010, Loss: 2.9582 ===== 

