In [2]:
import sys
sys.path.append('..')

import torch
import pickle
import argparse
import yaml
import numpy as np


from torch.utils.data.dataloader import DataLoader
from reaction_predictors.graph_model.models_yield import RGCNNTrClassifier
from utils.graph_utils import get_bonds, get_nodes
from utils.torch_dataset import Dataset, graph_collate
from reaction_predictors.graph_model.model_utils import train_epoch, evaluate, test, evaluate1
from collections import namedtuple


def prune_dataset_by_length(dataset, max_len):
    new_dataset = {}
    for idx in dataset:
        r_mask = dataset[idx]['reactants']['mask']
        r_mask = r_mask[r_mask > 0]
        if len(dataset[idx]['target_main_product']) <= max_len and len(np.unique(r_mask)) == len(r_mask):
            new_dataset[idx] = dataset[idx]
    return new_dataset

def delete_idx(dataset):
    new_dataset = {}
    for idx in dataset:
        if dataset[idx]['mined yield'] <= 1:
            new_dataset[idx] = dataset[idx]
    return new_dataset


def convert(dictionary):
    for key, value in dictionary.items():
        if isinstance(value, dict):
            dictionary[key] = convert(value)
    return namedtuple('GenericDict', dictionary.keys())(**dictionary)


def main(config, device):
    model_cfg = convert(config["model"])
    data_cfg = convert(config["dataset"])
    train_cfg = convert(config["train"])
    paths = convert(config["paths"])

    meta = pickle.load(open(paths.dataset_path + 'meta.pkl', 'rb'))

    node2label = get_nodes(meta['node'], n_molecule_level=data_cfg.n_molecule_level,
                           n_reaction_level=data_cfg.n_reaction_level)
    bond2label = get_bonds(meta['type'], n_molecule_level=data_cfg.n_molecule_level,
                           n_reaction_level=data_cfg.n_reaction_level,
                           self_bond=data_cfg.self_bond)
    if data_cfg.same_bond:
        bond2label = {i: 0 if i in meta['type'] else bond2label[i] for i in bond2label}
    num_rels = len(bond2label)
    pad_length = data_cfg.max_num_atoms + 15 * data_cfg.n_molecule_level + \
                 data_cfg.n_molecule_level * data_cfg.n_reaction_level
    num_nodes = pad_length

    model = RGCNNTrClassifier([len(node2label)] + data_cfg.feature_sizes,
                              num_nodes,
                              train_cfg.batch_size,
                              [model_cfg.n_hidden] + [model_cfg.feature_embed_size] * len(data_cfg.feature_sizes),
                              num_rels,
                              model_cfg.num_conv_layers,
                              model_cfg.num_trans_layers,
                              model_cfg.num_fcn_layers,
                              model_cfg.num_attention_heads,
                              model_cfg.num_model_heads,
                              )
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=train_cfg.lr, betas=(0.9, 0.98), eps=1e-9)
    exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=train_cfg.exp_step_size, gamma=0.1)

    train_dataset = pickle.load(open(paths.dataset_path + 'train.pkl', 'rb'))
    test_dataset = pickle.load(open(paths.dataset_path + 'test.pkl', 'rb'))
    valid_dataset = pickle.load(open(paths.dataset_path + 'valid.pkl', 'rb'))

    train_dataset = prune_dataset_by_length(train_dataset, data_cfg.max_num_atoms)
    test_dataset = prune_dataset_by_length(test_dataset, data_cfg.max_num_atoms)
    valid_dataset = prune_dataset_by_length(valid_dataset, data_cfg.max_num_atoms)
    
    train_dataset = delete_idx(train_dataset)
    test_dataset = delete_idx(test_dataset)
    valid_dataset = delete_idx(valid_dataset)

    tr_dataset = Dataset(train_dataset, device=device, pad_length=pad_length,
                         bond2label=bond2label, node2label=node2label, feature_idxs=data_cfg.feature_idxs,
                         target_main_product=data_cfg.target_main_product, target_center=data_cfg.target_center,
                         n_molecule_level=data_cfg.n_molecule_level, n_reaction_level=data_cfg.n_reaction_level)
    ts_dataset = Dataset(test_dataset, device=device, pad_length=pad_length,
                         bond2label=bond2label, node2label=node2label, feature_idxs=data_cfg.feature_idxs,
                         target_main_product=data_cfg.target_main_product, target_center=data_cfg.target_center,
                         n_molecule_level=data_cfg.n_molecule_level, n_reaction_level=data_cfg.n_reaction_level)
    vl_dataset = Dataset(valid_dataset, device=device, pad_length=pad_length,
                         bond2label=bond2label, node2label=node2label, feature_idxs=data_cfg.feature_idxs,
                         target_main_product=data_cfg.target_main_product, target_center=data_cfg.target_center,
                         n_molecule_level=data_cfg.n_molecule_level, n_reaction_level=data_cfg.n_reaction_level)

    train_loader = DataLoader(tr_dataset, train_cfg.batch_size, drop_last=True, collate_fn=graph_collate)
    test_loader = DataLoader(ts_dataset, train_cfg.batch_size, drop_last=True, collate_fn=graph_collate)
    valid_loader = DataLoader(vl_dataset, train_cfg.batch_size, drop_last=True, collate_fn=graph_collate)

    valid_scores = []
    test_scores = []
    valid_scores_1 = []
    test_scores_1 = []
    losses = []
    print('Training is starting')
    for epoch in range(30):
        losses.append(train_epoch(model, train_loader, optimizer, exp_lr_scheduler))
        loss = np.array(losses[-1])
        score_v = evaluate(model, valid_loader)
        score_t = evaluate(model, test_loader)
        score_v_1 = evaluate1(model, valid_loader)
        score_t_1 = evaluate1(model, test_loader)
        valid_scores.append(score_v)
        test_scores.append(score_t)
        valid_scores_1.append(score_v)
        test_scores_1.append(score_t)
        print(f'Epoch number - {epoch},loss = {loss.mean()}, R2_v = {score_v}, R2_t = {score_t}, L1_v = {score_v_1}, L1_t = {score_t_1}')
    results = test(model, train_loader)
    return losses, results, valid_scores, test_scores, valid_scores_1, test_scores_1

In [4]:
with open('../scripts/graph_models/MT_EGBF_sm_vis.yml', 'r') as ymlfile:
    config = yaml.load(ymlfile, Loader=yaml.FullLoader)
losses_EGBF, results_EGBF, valid_scores_EGBF, test_scores_EGBF, valid_scores_1_EGBF, test_scores_1_EGBF = main(config, 'cuda:0')

Training is starting
Epoch number - 0,loss = 0.06547720770218543, R2_v = 0.05137474757114513, R2_t = 0.06771218704750148, L1_v = 0.20119811594486237, L1_t = 0.2018478810787201
Epoch number - 1,loss = 0.05742754489834503, R2_v = 0.08087361579288177, R2_t = 0.1011871283976884, L1_v = 0.1983325183391571, L1_t = 0.19860967993736267
Epoch number - 2,loss = 0.05643356917001375, R2_v = 0.0702149787043369, R2_t = 0.0893805819998944, L1_v = 0.20050658285617828, L1_t = 0.20118995010852814
Epoch number - 3,loss = 0.055816380367939865, R2_v = 0.08529059743532252, R2_t = 0.10624824546332312, L1_v = 0.1975090056657791, L1_t = 0.19793958961963654
Epoch number - 4,loss = 0.05537014166289504, R2_v = 0.07068170839172472, R2_t = 0.08980011149384393, L1_v = 0.20096848905086517, L1_t = 0.20160731673240662
Epoch number - 5,loss = 0.055113989864065564, R2_v = 0.09818690835468802, R2_t = 0.11592582243740968, L1_v = 0.196182519197464, L1_t = 0.19678440690040588
Epoch number - 6,loss = 0.05481929879629911, R2_v

In [3]:
with open('../scripts/graph_models/MP_EGTB.yml', 'r') as ymlfile:
    config = yaml.load(ymlfile, Loader=yaml.FullLoader)
losses_EGB, results_EGB, valid_scores_EGB, test_scores_EGB, valid_scores_1_EGB, test_scores_1_EGB = main(config, 'cuda:0')

Training is starting
Epoch number - 0,loss = 0.0646040803699055, R2_v = -0.005370480736561767, R2_t = 0.0008591722856584605, L1_v = 0.21344107389450073, L1_t = 0.2159481942653656
Epoch number - 1,loss = 0.058099549851100045, R2_v = 0.051341839976877135, R2_t = 0.060947271408841264, L1_v = 0.20463180541992188, L1_t = 0.20630759000778198
Epoch number - 2,loss = 0.05709024431540611, R2_v = 0.08038353453190217, R2_t = 0.09223722854631544, L1_v = 0.1985352486371994, L1_t = 0.20002521574497223
Epoch number - 3,loss = 0.056710945104042176, R2_v = -0.019758833183930014, R2_t = -0.008596810790498788, L1_v = 0.21417921781539917, L1_t = 0.21631261706352234
Epoch number - 4,loss = 0.05628831335838805, R2_v = 0.022733606740373502, R2_t = 0.0353757528999431, L1_v = 0.20815624296665192, L1_t = 0.21006542444229126
Epoch number - 5,loss = 0.05601808732788795, R2_v = 0.020380564285174096, R2_t = 0.03320473564958237, L1_v = 0.2089471071958542, L1_t = 0.2106465846300125
Epoch number - 6,loss = 0.055777958

In [30]:
import sys
sys.path.append('..')

import torch
import pickle
import argparse
import yaml
import numpy as np


from torch.utils.data.dataloader import DataLoader
from reaction_predictors.graph_model.models_yield import RGCNNTrClassifier
from utils.graph_utils import get_bonds, get_nodes
from utils.torch_dataset import Dataset, graph_collate
from reaction_predictors.graph_model.model_utils import train_epoch, evaluate, test, evaluate1
from collections import namedtuple


def prune_dataset_by_length(dataset, max_len):
    new_dataset = {}
    for idx in dataset:
        r_mask = dataset[idx]['reactants']['mask']
        r_mask = r_mask[r_mask > 0]
        if len(dataset[idx]['target_main_product']) <= max_len and len(np.unique(r_mask)) == len(r_mask):
            new_dataset[idx] = dataset[idx]
    return new_dataset

def delete_idx(dataset):
    new_dataset = {}
    for idx in dataset:
        if dataset[idx]['mined yield'] <= 1:
            new_dataset[idx] = dataset[idx]
    return new_dataset


def convert(dictionary):
    for key, value in dictionary.items():
        if isinstance(value, dict):
            dictionary[key] = convert(value)
    return namedtuple('GenericDict', dictionary.keys())(**dictionary)


def main(config, device):
    model_cfg = convert(config["model"])
    data_cfg = convert(config["dataset"])
    train_cfg = convert(config["train"])
    paths = convert(config["paths"])

    meta = pickle.load(open(paths.dataset_path + 'meta.pkl', 'rb'))

    node2label = get_nodes(meta['node'], n_molecule_level=data_cfg.n_molecule_level,
                           n_reaction_level=data_cfg.n_reaction_level)
    bond2label = get_bonds(meta['type'], n_molecule_level=data_cfg.n_molecule_level,
                           n_reaction_level=data_cfg.n_reaction_level,
                           self_bond=data_cfg.self_bond)
    if data_cfg.same_bond:
        bond2label = {i: 0 if i in meta['type'] else bond2label[i] for i in bond2label}
    num_rels = len(bond2label)
    pad_length = data_cfg.max_num_atoms + 15 * data_cfg.n_molecule_level + \
                 data_cfg.n_molecule_level * data_cfg.n_reaction_level
    num_nodes = pad_length

    model = RGCNNTrClassifier([len(node2label)] + data_cfg.feature_sizes,
                              num_nodes,
                              train_cfg.batch_size,
                              [model_cfg.n_hidden] + [model_cfg.feature_embed_size] * len(data_cfg.feature_sizes),
                              num_rels,
                              model_cfg.num_conv_layers,
                              model_cfg.num_trans_layers,
                              model_cfg.num_fcn_layers,
                              model_cfg.num_attention_heads,
                              model_cfg.num_model_heads,
                              )
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=train_cfg.lr, betas=(0.9, 0.98), eps=1e-9)
    exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=train_cfg.exp_step_size, gamma=0.1)

    train_dataset = pickle.load(open(paths.dataset_path + 'train.pkl', 'rb'))
    test_dataset = pickle.load(open(paths.dataset_path + 'test.pkl', 'rb'))
    valid_dataset = pickle.load(open(paths.dataset_path + 'valid.pkl', 'rb'))

    train_dataset = prune_dataset_by_length(train_dataset, data_cfg.max_num_atoms)
    test_dataset = prune_dataset_by_length(test_dataset, data_cfg.max_num_atoms)
    valid_dataset = prune_dataset_by_length(valid_dataset, data_cfg.max_num_atoms)
    
    train_dataset = delete_idx(train_dataset)
    test_dataset = delete_idx(test_dataset)
    valid_dataset = delete_idx(valid_dataset)

    tr_dataset = Dataset(train_dataset, device=device, pad_length=pad_length,
                         bond2label=bond2label, node2label=node2label, feature_idxs=data_cfg.feature_idxs,
                         target_main_product=data_cfg.target_main_product, target_center=data_cfg.target_center,
                         n_molecule_level=data_cfg.n_molecule_level, n_reaction_level=data_cfg.n_reaction_level)
    ts_dataset = Dataset(test_dataset, device=device, pad_length=pad_length,
                         bond2label=bond2label, node2label=node2label, feature_idxs=data_cfg.feature_idxs,
                         target_main_product=data_cfg.target_main_product, target_center=data_cfg.target_center,
                         n_molecule_level=data_cfg.n_molecule_level, n_reaction_level=data_cfg.n_reaction_level)
    vl_dataset = Dataset(valid_dataset, device=device, pad_length=pad_length,
                         bond2label=bond2label, node2label=node2label, feature_idxs=data_cfg.feature_idxs,
                         target_main_product=data_cfg.target_main_product, target_center=data_cfg.target_center,
                         n_molecule_level=data_cfg.n_molecule_level, n_reaction_level=data_cfg.n_reaction_level)

    tar_tr  = tr_dataset
    tar_ts = ts_dataset
    tar_vl = vl_dataset
    
    
    return tar_tr, tar_ts, tar_vl

In [31]:
with open('../scripts/graph_models/MP_EGTB.yml', 'r') as ymlfile:
    config = yaml.load(ymlfile, Loader=yaml.FullLoader)
tar_tr, tar_ts, tar_vl = main(config, 'cuda:0')

In [None]:
tar_tr[1][1]

In [None]:
tr = []
for i in range(len(tar_tr)):
    tr.append)