In [None]:
from hyperopt import fmin, hp, tpe, Trials
import numpy as np

import os
import sys

BASEDIR = os.path.dirname(os.getcwd())
sys.path.append(BASEDIR)

from copy import deepcopy
from functools import partial
from dgl.dataloading import GraphDataLoader
from src.model.attentivefp_gp import attentivefpGP
from src.utils.mol.attfp_graph import MoleculeDataset, collate_fn
from src.config.attentivefp_gp import attentivefpGPArgs
from src.pipeline.ensemble import training_ensemble_models
from src.utils.basic.logger import Writer, NpEncoder
from src.utils.model.metrics import accuracy, roc_auc_score, prc_auc, F1, MCC, expected_calibration_error, OverconfidentFalseRate, OverconfidentFalseNegatives, Brier
import torch
import pandas as pd
import json

# Initialize

In [None]:
torch.set_num_threads(4)

target_list = ['ames', 'BBB_Martins', 'Pgp_Broccatelli', 'CYP3A4_Veith', 'CYP2C9_Veith']
gpu_num     = 0

# Bayesian optimization searching space

SPACE = {
    'hidden_size'       : hp.quniform('hidden_size', low=300, high=600, q=100),
    'radius'            : hp.quniform('radius', low=2, high=6, q=1),
    'T'                 : hp.quniform('T', low=1, high=5, q=1),
    'p_dropout'         : hp.quniform('dropout', low=0.0, high=0.5, q=0.05),
    'init_lr'           : hp.loguniform('init_lr', low=np.log(1e-4), high=np.log(1e-2)),
    'ffn_num_layers'    : hp.quniform('ffn_num_layers', low=2, high=4, q=1),
    'n_inducing_points' : hp.quniform('n_inducing_points', low=100, high=300, q=50)
}

INT_KEYS = ['hidden_size', 'radius', 'T', 'ffn_num_layers', 'n_inducing_points']

# Training function

In [None]:
for target_name in target_list:
    
    SAVEDIR = os.path.join(BASEDIR, "results", "TDC", target_name, 'AttFpGP')
    DATADIR = os.path.join(BASEDIR, "data", target_name)
    n = 0
    logger = Writer(os.path.join(SAVEDIR, "history.log"))
    def func(hyperparams):
        logger(" ")
        logger(" ")

        global n
        n = n+1
        logger(f"ROUND {n}")

        BASESAVEDIR = os.path.join(SAVEDIR, f"ROUND_{n}")
        logger_writer = Writer(os.path.join(BASESAVEDIR, "config.txt"))
        for k in hyperparams:
            logger_writer(f"{k}\t{hyperparams[k]}")

        for key in INT_KEYS:
            hyperparams[key] = int(hyperparams[key])

        config = attentivefpGPArgs().parse_args([], known_only=True)
        hyper_args = deepcopy(config)

        for key, value in hyperparams.items():
            setattr(hyper_args, key, value)

        setattr(hyper_args, "dataset_type", "classification")
        setattr(hyper_args, 'metric', "roc-auc")
        setattr(hyper_args, "extra_metrics", ["MCC", "prc-auc", "accuracy", "F1"])
        setattr(hyper_args, "ffn_hidden_size", hyper_args.hidden_size)
        setattr(hyper_args, "early_stopping_num", 30)
        setattr(hyper_args, "gpu", gpu_num)
        setattr(hyper_args, "log_frequency", 100)
        setattr(hyper_args, "batch_size", 128)
        setattr(hyper_args, "at_least_epoch", 0)
        print(hyper_args)

        for i in [1, 2, 3, 4, 5]:

            train_dataset = MoleculeDataset(os.path.join(DATADIR, f"{target_name}_train_{i}.csv"))
            valid_dataset = MoleculeDataset(os.path.join(DATADIR, f"{target_name}_valid_{i}.csv"))
            test_dataset  = MoleculeDataset(os.path.join(DATADIR, f"{target_name}_test.csv"))

            train_targets = []
            for _, t in train_dataset:
                train_targets.append(t)

            N = torch.tensor([len(train_targets) - np.sum(train_targets).astype(np.int64),
                            np.sum(train_targets).astype(np.int64)], dtype=torch.float64)

            setattr(hyper_args, "N", N)
            
            train_dataloader = GraphDataLoader(dataset=train_dataset, collate_fn=collate_fn, batch_size=256, drop_last=False, shuffle=True)
            train_dataloader.smiles = [[s] for s in train_dataset.smiles_list]
            valid_dataloader = GraphDataLoader(dataset=valid_dataset, collate_fn=collate_fn, batch_size=256, drop_last=False, shuffle=False)
            valid_dataloader.smiles = [[s] for s in valid_dataset.smiles_list]
            test_dataloader  = GraphDataLoader(dataset=test_dataset,  collate_fn=collate_fn, batch_size=256, drop_last=False, shuffle=False)
            test_dataloader.smiles  = [[s] for s in test_dataset.smiles_list]

            training_ensemble_models(os.path.join(BASESAVEDIR, f"fold_{i}"),
                                    attentivefpGP,
                                    hyper_args,
                                    train_dataloader,
                                    valid_dataloader=valid_dataloader,
                                    test_dataloader=test_dataloader,
                                    ensemble_num=1)
            
            test_prediction = []
            
            valid_ROC = []
            valid_PRC = []
            valid_ACC = []
            valid_MCC = []
            valid_F1  = []
            valid_ECE = []
            valid_OFR = []
            valid_OFN = []
            valid_Brier = []
            
            test_ROC  = []
            test_PRC  = []
            test_ACC  = []
            test_MCC  = []
            test_F1   = []
            test_ECE  = []
            test_OFR = []
            test_OFN = []
            test_Brier = []

        for i in [1, 2, 3, 4, 5]:
        
            temp_dir = os.path.join(BASESAVEDIR, f"fold_{i}", "model_0")
            
            temp_valid_prediction = pd.read_csv(os.path.join(temp_dir, "valid_prediction.csv"))["property_pred"].to_numpy()
            temp_valid_label = pd.read_csv(os.path.join(temp_dir, "valid_prediction.csv"))["property_label"].to_numpy()
            
            temp_test_prediction = pd.read_csv(os.path.join(temp_dir, "test_prediction.csv"))["property_pred"].to_numpy()
            test_label = pd.read_csv(os.path.join(temp_dir, "test_prediction.csv"))["property_label"].to_numpy()
            test_prediction.append(temp_test_prediction) # for ensemble calculation
            
            valid_ROC.append(pd.read_csv(os.path.join(temp_dir, "valid_prediction_performance.csv"))["roc-auc"].iloc[0])
            valid_PRC.append(pd.read_csv(os.path.join(temp_dir, "valid_prediction_performance.csv"))["prc-auc"].iloc[0])
            valid_ACC.append(pd.read_csv(os.path.join(temp_dir, "valid_prediction_performance.csv"))["accuracy"].iloc[0])
            valid_MCC.append(pd.read_csv(os.path.join(temp_dir, "valid_prediction_performance.csv"))["MCC"].iloc[0])
            valid_F1.append(pd.read_csv(os.path.join(temp_dir, "valid_prediction_performance.csv"))["F1"].iloc[0])
            valid_ECE.append(expected_calibration_error(temp_valid_label, temp_valid_prediction, bins=10))
            valid_OFR.append(OverconfidentFalseRate(temp_valid_prediction, temp_valid_label))
            valid_OFN.append(OverconfidentFalseNegatives(temp_valid_prediction, temp_valid_label))
            valid_Brier.append(Brier(temp_valid_label,temp_valid_prediction))
            
            test_ROC.append(pd.read_csv(os.path.join(temp_dir, "test_prediction_performance.csv"))["roc-auc"].iloc[0])
            test_PRC.append(pd.read_csv(os.path.join(temp_dir, "test_prediction_performance.csv"))["prc-auc"].iloc[0])
            test_ACC.append(pd.read_csv(os.path.join(temp_dir, "test_prediction_performance.csv"))["accuracy"].iloc[0])
            test_MCC.append(pd.read_csv(os.path.join(temp_dir, "test_prediction_performance.csv"))["MCC"].iloc[0])
            test_F1.append(pd.read_csv(os.path.join(temp_dir, "test_prediction_performance.csv"))["F1"].iloc[0])
            test_ECE.append(expected_calibration_error(test_label, temp_test_prediction, bins=10))
            test_OFR.append(OverconfidentFalseRate(temp_test_prediction, test_label))
            test_OFN.append(OverconfidentFalseNegatives(temp_test_prediction, test_label))
            test_Brier.append(Brier(test_label,temp_test_prediction))
            

        logger(f'ROUND {n} Valid ROC-AUC {np.mean(valid_ROC)} +/- {np.std(valid_ROC)}')
        logger(f'ROUND {n} Valid PRC-AUC {np.mean(valid_PRC)} +/- {np.std(valid_PRC)}')
        logger(f'ROUND {n} Valid ACC     {np.mean(valid_ACC)} +/- {np.std(valid_ACC)}')
        logger(f'ROUND {n} Valid MCC     {np.mean(valid_MCC)} +/- {np.std(valid_MCC)}')
        logger(f'ROUND {n} Valid F1      {np.mean(valid_F1)} +/- {np.std(valid_F1)}')
        logger(f'ROUND {n} Valid ECE     {np.mean(valid_ECE)} +/- {np.std(valid_ECE)}')
        logger(f'ROUND {n} Valid OFR     {np.mean(valid_OFR)} +/- {np.std(valid_OFR)}')
        logger(f'ROUND {n} Valid OFN     {np.mean(valid_OFN)} +/- {np.std(valid_OFN)}')
        logger(f'ROUND {n} Valid Brier   {np.mean(valid_Brier)} +/- {np.std(valid_Brier)}')
        logger(' ')
        logger(f'ROUND {n} Test ROC-AUC {np.mean(test_ROC)} +/- {np.std(test_ROC)}')
        logger(f'ROUND {n} Test PRC-AUC {np.mean(test_PRC)} +/- {np.std(test_PRC)}')
        logger(f'ROUND {n} Test ACC     {np.mean(test_ACC)} +/- {np.std(test_ACC)}')
        logger(f'ROUND {n} Test MCC     {np.mean(test_MCC)} +/- {np.std(test_MCC)}')
        logger(f'ROUND {n} Test F1      {np.mean(test_F1)} +/- {np.std(test_F1)}')
        logger(f'ROUND {n} Test ECE     {np.mean(test_ECE)} +/- {np.std(test_ECE)}')
        logger(f'ROUND {n} Test OFR     {np.mean(test_OFR)} +/- {np.std(test_OFR)}')
        logger(f'ROUND {n} Test OFN     {np.mean(test_OFN)} +/- {np.std(test_OFN)}')
        logger(f'ROUND {n} Test Brier   {np.mean(test_Brier)} +/- {np.std(test_Brier)}')
        logger(' ')
        logger(f'ROUND {n} Ensemble Test ROC-AUC {roc_auc_score(test_label, np.mean(test_prediction, axis=0))}')
        logger(f'ROUND {n} Ensemble Test PRC-AUC {prc_auc(test_label, np.mean(test_prediction, axis=0))}')
        logger(f'ROUND {n} Ensemble Test ACC {accuracy(test_label, np.mean(test_prediction, axis=0))}')
        logger(f'ROUND {n} Ensemble Test MCC {MCC(test_label, np.mean(test_prediction, axis=0))}')
        logger(f'ROUND {n} Ensemble Test F1 {F1(test_label, np.mean(test_prediction, axis=0))}')
        logger(f'ROUND {n} Ensemble Test ECE {expected_calibration_error(test_label, np.mean(test_prediction, axis=0), bins=10)}')
        logger(f'ROUND {n} Ensemble Test OFR {OverconfidentFalseRate(np.mean(test_prediction, axis=0), test_label)}')
        logger(f'ROUND {n} Ensemble Test OFN {OverconfidentFalseNegatives(np.mean(test_prediction, axis=0), test_label)}')
        logger(f'ROUND {n} Ensemble Test Brier {Brier(test_label, np.mean(test_prediction, axis=0))}')
        logger(' ')

        return -np.mean(valid_ROC)



    algo = partial(tpe.suggest, n_startup_jobs=1)
    best = fmin(func, SPACE, algo=algo, max_evals=1)  # best is a dictionary
    json_str = json.dumps(best, cls=NpEncoder)  # using json to turn a dictionary to a str 
    logger(json_str)