In [1]:
import torch, os
torch.manual_seed(0) 
import warnings;warnings.filterwarnings("ignore")
from HINT.dataloader import csv_three_feature_2_dataloader, generate_admet_dataloader_lst
from HINT.molecule_encode import MPNN, ADMET
from HINT.icdcode_encode import GRAM, build_icdcode2ancestor_dict
from HINT.protocol_encode import Protocol_Embedding
from HINT.model import HINTModel 
device = torch.device("cpu")  ## cuda:0
if not os.path.exists("figure"):
    os.makedirs("figure")
    
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import time

import xgboost as xgb
from catboost import CatBoostClassifier
import lightgbm as lgb
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, precision_recall_curve, precision_score, recall_score, auc

In [2]:
mpnn_model = MPNN(mpnn_hidden_size = 50, mpnn_depth=3, device = device)
icdcode2ancestor_dict = build_icdcode2ancestor_dict()
gram_model = GRAM(embedding_dim = 50, icdcode2ancestor = icdcode2ancestor_dict, device = device)
protocol_model = Protocol_Embedding(output_dim = 50, highway_num=3, device = device)

molecule_encoder = mpnn_model
disease_encoder = gram_model
protocol_encoder = protocol_model

In [14]:
def get_embed(dataloader):
    nctid_lst, label_lst, smiles_lst2, icdcode_lst3, criteria_lst = [], [], [], [], []
    for nctid, label, smiles, icdcode, criteria in dataloader:
        nctid_lst.extend(nctid)
        label_lst.extend([i.item() for i in label])
        smiles_lst2.extend(smiles)
        icdcode_lst3.extend(icdcode)
        criteria_lst.extend(criteria)
        
    molecule_embed = molecule_encoder.forward_smiles_lst_lst(smiles_lst2)
    icd_embed = disease_encoder.forward_code_lst3(icdcode_lst3)
    protocol_embed = protocol_encoder.forward(criteria_lst)
    print(molecule_embed.shape, icd_embed.shape, protocol_embed.shape)
    return molecule_embed, icd_embed, protocol_embed

def preprocess(file, loader):
    df = pd.read_csv(file)
    df.drop(['phase', 'why_stop'], axis=1, inplace=True)
    df.drop(['icdcodes', 'smiless', 'criteria'], axis=1, inplace=True)
    df.drop(['diseases', 'drugs'], axis=1, inplace=True) ## FE later
    
    molecule_embed, icd_embed, protocol_embed = get_embed(loader)

    molecule_df = pd.DataFrame(molecule_embed.detach().numpy(), columns=[f'molecule_feature_{i}' for i in range(len(molecule_embed[0]))])
    icd_df = pd.DataFrame(icd_embed.detach().numpy(), columns=[f'icd_feature_{i}' for i in range(len(icd_embed[0]))])
    protocol_df = pd.DataFrame(protocol_embed.detach().numpy(), columns=[f'protocol_feature_{i}' for i in range(len(protocol_embed[0]))])

    df = pd.concat([df, molecule_df, icd_df, protocol_df], axis=1)
    return df

def print_metrics(y_true, y_pred, label):
    print(f"{label} ROC AUC: {round(roc_auc_score(y_true, y_pred),3)}")
    print(f"{label} F1: {round(f1_score(y_true, y_pred),3)}")
    precision, recall, _ = precision_recall_curve(y_true, y_pred)
    print(f"{label} PR-AUC: {round(auc(recall, precision),3)}")
    print(f"{label} Precision: {round(precision_score(y_true, y_pred),3)}")
    print(f"{label} Recall: {round(recall_score(y_true, y_pred),3)}")
    print(f"{label} Accuracy: {round(accuracy_score(y_true, y_pred),3)}")
    print(f"{label} Predict 1 ratio: {round(sum(y_pred) / len(y_pred),3)}")
    print(f"{label} Label 1 ratio: {round(sum(y_true) / len(y_true),3)}")

In [28]:
class CustomEnsemble:
    def __init__(self, xgb_params, cat_params, lgbm_params):
        self.xgb_model = xgb.XGBClassifier(**xgb_params)
        self.cat_model = CatBoostClassifier(**cat_params)
        self.lgbm_model = lgb.LGBMClassifier(**lgbm_params, verbosity=-1)
        
    def fit(self, train_df, valid_df, test_df):
        # pre-process for xgb
        train_idx = range(len(train_df))
        valid_idx = range(len(train_df), len(train_df) + len(valid_df))
        test_idx = range(len(train_df) + len(valid_df), len(train_df) + len(valid_df) + len(test_df))
        combined_df = pd.DataFrame()
        combined_df = pd.concat([train_df, valid_df, test_df], axis=0)
        combined_df.reset_index(drop=True, inplace=True)
        combined_df = pd.get_dummies(combined_df, columns=['status'], drop_first=True)
        train_xgb = combined_df.loc[train_idx]
        valid_xgb = combined_df.loc[valid_idx]
        test_xgb = combined_df.loc[test_idx]
        X_train_xgb, y_train = train_xgb.drop(['nctid','label'], axis=1), train_xgb['label']  
        X_valid_xgb, y_valid = valid_xgb.drop(['nctid','label'], axis=1), valid_xgb['label']  
        X_test_xgb, y_test = test_xgb.drop(['nctid','label'], axis=1), test_xgb['label']  
        self.test_xgb = X_test_xgb
        
        self.xgb_model.fit(X_train_xgb, y_train)

        # pre-process for catboost, lgbm
        X_train_cl, y_train = train_df.drop(['nctid','label'], axis=1), train_df['label']  
        X_valid_cl, y_valid = valid_df.drop(['nctid','label'], axis=1), valid_df['label']  
        X_test_cl, y_test = test_df.drop(['nctid','label'], axis=1), test_df['label']
        X_train_cl['status'] = train_df['status'].astype('category')
        X_valid_cl['status'] = valid_df['status'].astype('category')
        X_test_cl['status'] = test_df['status'].astype('category')
        self.test_cl = X_test_cl
        self.y_true = y_test
        
        self.cat_model.fit(X_train_cl, y_train, cat_features=[0], verbose=0)
        self.lgbm_model.fit(X_train_cl, y_train)
        
    def predict(self):
        xgb_prob = self.xgb_model.predict_proba(self.test_xgb)[:, 1]
        cat_prob = self.cat_model.predict_proba(self.test_cl)[:, 1]
        lgbm_prob = self.lgbm_model.predict_proba(self.test_cl)[:, 1]
        final_prob = (xgb_prob + cat_prob + lgbm_prob) / 3
        threshold = 0.5
        final_pred = (final_prob >= threshold).astype(int)
        return final_pred
        
    def evaluate(self):
        y_true = self.y_true
        y_pred = self.predict()
        print_metrics(y_true, y_pred, 'Test')

In [26]:
def ensemble_evaluate(phase):
    params_dict = {
        'phase_I': {
            'xgb': {'alpha': 10, 'learning_rate': 0.15, 'max_depth': 3, 'n_estimators': 100},
            'cat': {'iterations': 200, 'depth': 3, 'l2_leaf_reg': 1},
            'lgbm': {'learning_rate': 0.05, 'max_depth': 3, 'n_estimators': 100, 'num_leaves': 31},
        },
        'phase_II': {
            'xgb': {'alpha': 5, 'learning_rate': 0.05, 'max_depth': 3, 'n_estimators': 100},
            'cat': {'iterations': 100, 'depth': 6, 'l2_leaf_reg': 5},
            'lgbm': {'learning_rate': 0.05, 'max_depth': 5, 'n_estimators': 100, 'num_leaves': 31},
        },
        'phase_III': {
            'xgb': {'alpha': 5, 'learning_rate': 0.05, 'max_depth': 3, 'n_estimators': 100},
            'cat': {'iterations': 200, 'depth': 5, 'l2_leaf_reg': 1},
            'lgbm': {'learning_rate': 0.05, 'max_depth': 6, 'n_estimators': 100, 'num_leaves': 31},
        },
    }
    base_name = phase
    datafolder = "data" 
    train_file = os.path.join(datafolder, base_name + '_train.csv')
    valid_file = os.path.join(datafolder, base_name + '_valid.csv')
    test_file = os.path.join(datafolder, base_name + '_test.csv')
    
    train_loader = csv_three_feature_2_dataloader(train_file, shuffle=True, batch_size=32) 
    valid_loader = csv_three_feature_2_dataloader(valid_file, shuffle=False, batch_size=32) 
    test_loader = csv_three_feature_2_dataloader(test_file, shuffle=False, batch_size=32) 

    train_df = preprocess(train_file, train_loader)
    valid_df = preprocess(valid_file, valid_loader)
    test_df = preprocess(test_file, test_loader)
    print(train_df.shape, valid_df.shape, test_df.shape)

    phase_params = params_dict[phase]
    ensemble_model = CustomEnsemble(phase_params['xgb'], phase_params['cat'], phase_params['lgbm'])
    ensemble_model.fit(train_df, valid_df, test_df)
    ensemble_model.evaluate()    

In [29]:
ensemble_evaluate('phase_I')

torch.Size([1044, 50]) torch.Size([1044, 50]) torch.Size([1044, 50])
torch.Size([117, 50]) torch.Size([117, 50]) torch.Size([117, 50])
torch.Size([627, 50]) torch.Size([627, 50]) torch.Size([627, 50])
(1044, 153) (117, 153) (627, 153)
Test ROC AUC: 0.864
Test F1: 0.894
Test PR-AUC: 0.91
Test Precision: 0.838
Test Recall: 0.957
Test Accuracy: 0.874
Test Predict 1 ratio: 0.632
Test Label 1 ratio: 0.553


In [30]:
ensemble_evaluate('phase_II')

torch.Size([4005, 50]) torch.Size([4005, 50]) torch.Size([4005, 50])
torch.Size([446, 50]) torch.Size([446, 50]) torch.Size([446, 50])
torch.Size([1654, 50]) torch.Size([1654, 50]) torch.Size([1654, 50])
(4005, 153) (446, 153) (1654, 153)
Test ROC AUC: 0.783
Test F1: 0.845
Test PR-AUC: 0.868
Test Precision: 0.751
Test Recall: 0.965
Test Accuracy: 0.803
Test Predict 1 ratio: 0.713
Test Label 1 ratio: 0.555


In [31]:
ensemble_evaluate('phase_III')

torch.Size([3094, 50]) torch.Size([3094, 50]) torch.Size([3094, 50])
torch.Size([344, 50]) torch.Size([344, 50]) torch.Size([344, 50])
torch.Size([1146, 50]) torch.Size([1146, 50]) torch.Size([1146, 50])
(3094, 153) (344, 153) (1146, 153)
Test ROC AUC: 0.714
Test F1: 0.875
Test PR-AUC: 0.913
Test Precision: 0.851
Test Recall: 0.901
Test Accuracy: 0.807
Test Predict 1 ratio: 0.794
Test Label 1 ratio: 0.75
