In [14]:
from sklearn.model_selection import KFold,LeaveOneOut,LeavePOut,ShuffleSplit 
from feat_selection import *
from model_new import *
from data_config import *
from data_process import *
from catboost import CatBoostRegressor, Pool
import shap
from matplotlib import pyplot as plt
from sklearn.metrics import classification_report, mean_squared_error
from sklearn.metrics import roc_auc_score
from sklearn.metrics import average_precision_score
from sklearn.metrics import f1_score
from keras.callbacks import EarlyStopping, ModelCheckpoint
import numpy as np
import h5py

In [15]:
sys.argv=['CatBOOST_2022',1]

####     method=['CatBOOST_2022','deepsynergy_preuer_2018','XGBOOST_janizek_2018']

####      i=[1,2,3,4,5]

In [16]:
available_feat_type_list = {'NCI_60':['met','mut','cop','exp']}
available_cancer_specific_cell_list = {'NCI_60':{'TNBC':['MDA-MB-231','MDA-MB-435','BT-549','HS 578T']}}

SYNERGY_THRES = 10
BATCH_SIZE = 256
N_EPOCHS = 200
PATIENCE = 30
  
    
    
def prepare_data():
    synergy_data = input_synergy_data(config['synergy_data'])  # 导入数据
    cell_data_dicts = input_cellline_data(config['cell_data'])
    drug_data_dict = input_drug_data()

    drug_list = synergy_data['drug1'].unique().tolist()  # 提取所有药物名称的列表
    cell_list = synergy_data['cell'].unique().tolist()  # 提取所有细胞名称的列表

    drug_list = filter_drug(drug_list, config['drug_feat_filter'])  # 过滤药物特征
    cell_list = filter_cell(cell_list, config['cell_list'],
                            available_cancer_specific_cell_list[config['cell_data']])  # 过滤细胞
    synergy_data = synergy_data[
        (synergy_data['drug1'].isin(drug_list)) & synergy_data['drug2'].isin(drug_list) & synergy_data['cell'].isin(
            cell_list)]

    # 选取包含共同的药物和细胞的数据
    cell_feats, selected_cells = filter_cell_features(cell_data_dicts, cell_list, config['cell_feats'],
                                                      config['cell_feat_filter'], config['cell_integrate'])
    # 过滤细胞特征
    synergy_data = synergy_data[synergy_data['cell'].isin(selected_cells)]

    # 选取包含共同细胞的数据

    print("\n")
    print("number of drugs:", len(drug_list))
    print("number of cells:", len(selected_cells))
    print("number of data:", synergy_data.shape)
    print("\n")

    if config['cell_integrate'] == True:  # 构建细胞矩阵
        X_cell = np.zeros((synergy_data.shape[0], cell_feats.shape[0]))  # 构建矩阵
        for i in tqdm(range(synergy_data.shape[0])):
            row = synergy_data.iloc[i]
            X_cell[i, :] = cell_feats[row['cell']].values  # 根据细胞，给矩阵赋值
    else:
        X_cell = {}
        for feat_type in config['cell_feats']:  # cell_feats有几个特征类型
            print(feat_type, cell_feats[feat_type].shape[0])  # 此时cell_feats是一个字典
            temp_cell = np.zeros((synergy_data.shape[0], cell_feats[feat_type].shape[0]))
            for i in tqdm(range(synergy_data.shape[0])):
                row = synergy_data.iloc[i]
                temp_cell[i, :] = cell_feats[feat_type][row['cell']].values
            X_cell[feat_type] = temp_cell

    if config['cell_integrate'] == True:
        print("cell features: ", X_cell.shape)
    else:
        print("cell features:", list(X_cell.keys()))

    print("\ngenerating drug feats...")
    drug_matrix_dict = {}
    for feat_type in config['drug_feats']:  # 构建药物矩阵
        if feat_type != 'monetherapy':  # 单药治疗
            dim = drug_data_dict[feat_type].shape[0]
            temp_X_1 = np.zeros((synergy_data.shape[0], dim))
            temp_X_2 = np.zeros((synergy_data.shape[0], dim))
            for i in tqdm(range(synergy_data.shape[0])):
                row = synergy_data.iloc[i]
                temp_X_1[i, :] = drug_data_dict[feat_type][int(row['drug1'])]
                temp_X_2[i, :] = drug_data_dict[feat_type][int(row['drug2'])]
        else:
            dim = 3
            temp_X_1 = np.zeros((synergy_data.shape[0], dim))
            temp_X_2 = np.zeros((synergy_data.shape[0], dim))
            for i in tqdm(range(synergy_data.shape[0])):
                row = synergy_data.iloc[i]
                temp_X_1[i, :] = drug_data_dict[feat_type].loc[row['cell'], int(row['drug1'])]
                temp_X_2[i, :] = drug_data_dict[feat_type].loc[row['cell'], int(row['drug2'])]
        drug_matrix_dict[feat_type + '_1'] = temp_X_1
        drug_matrix_dict[feat_type + '_2'] = temp_X_2

    X_drug_temp = {}  # 接下来是药物特征矩阵是否独立
    if config['drug_indep'] == False:
        for feat_type in config['drug_feats']:
            if feat_type != 'monetherapy':
                temp_X = drug_matrix_dict[feat_type + '_1'] + drug_matrix_dict[feat_type + '_2']  # 矩阵相加
                X_drug_temp[feat_type] = temp_X
            else:  # 单药治疗，所以要分开
                X_drug_temp[feat_type + '_1'] = drug_matrix_dict[feat_type + '_1']
                X_drug_temp[feat_type + '_2'] = drug_matrix_dict[feat_type + '_2']
    else:
        X_drug_temp = drug_matrix_dict

    if config['drug_integrate'] == False:  # 药物特征是否需要合并
        X_drug = X_drug_temp
    else:
        X_drug = np.concatenate(list(X_drug_temp.values()), axis=1)

    if config['drug_integrate'] == True:
        print("drug features: ", X_drug.shape)
    else:
        print("drug features")
        print(list(X_drug.keys()))
        for key, value in X_drug.items():
            print(key, value.shape)
    
    Y = (synergy_data['score'] > SYNERGY_THRES).astype(int).values

    return X_cell, X_drug, Y


def idx(i):
    train_idx = np.load(r'..\idx\train_idx_%s.npy' % (i))
    test_idx = np.load(r'..\idx\test_idx_%s.npy' % (i))
    return train_idx, test_idx


def training(X_cell, X_drug, Y, train_idx, test_idx):
    X = np.concatenate([X_cell, X_drug], axis=1)
    X_train, Y_train = X[train_idx], Y[train_idx]
    X_test, Y_test = X[test_idx], Y[test_idx]
    
    model = get_model(config['model_name'])

    if config['model_name'] ==  'NN':
        model.compile(loss='binary_crossentropy',
                          optimizer='adam',
                          metrics=['accuracy'])
        callbacks = [EarlyStopping(monitor='val_loss', patience=PATIENCE),
                         ModelCheckpoint(filepath='best_model_%s.h5' % config['model_name'], monitor='val_loss',
                                         save_best_only=True)]
        _ = model.fit(X_train, Y_train,
                              batch_size=BATCH_SIZE,
                              epochs=N_EPOCHS,
                              verbose=1,
                              validation_split=0.1,
                              callbacks=callbacks)
       
    elif config['model_name'] == 'CatBoost':
        
        model.fit(X_train, Y_train, verbose=3)
    else:
        
        model.fit(X_train, Y_train)

    return model, X_test, Y_test


def evaluate(model, X_test, Y_test):
    if config['model_name'] == 'NN':
        pred = model.predict(X_test)
        pred=pred[:,0]
    elif config['model_name'] == 'CatBoost':
        pred = model.predict(X_test)
    else:
        pred = model.predict_proba(X_test)[:, 1]

    auc = roc_auc_score(Y_test, pred)
    ap = average_precision_score(Y_test, pred)
    RMSE = np.sqrt(mean_squared_error(Y_test, pred))
    RPearson = np.corrcoef(np.array(Y_test), np.array(pred))[0,1]
   

    val_results = {'AUC': auc, 'AUPR': ap, 'RMSE': RMSE,'RPearson':RPearson, 'Y_test': Y_test.tolist(), 'Y_pred': pred.tolist()}
    #val_results = {'RMSE': RMSE,'RPearson':RPearson, 'Y_test': Y_test.tolist(), 'Y_pred': pred.tolist()}

    return val_results


def main(method, i):
    X_cell, X_drug, Y = prepare_data()
    print("data loaded")

    train_idx, test_idx = idx(i)

    model, X_test, Y_test = training(X_cell, X_drug, Y, train_idx, test_idx)
    print("training finished")
    val_results = evaluate(model, X_test, Y_test)

    # save results
    rand_num = random.randint(1, 1000000)
    with open(r"..\results\%s_%s.json" % (method, str(rand_num)), "w") as f:
        json.dump(val_results, f)


if __name__ == "__main__":
    method = sys.argv[0]
    i = sys.argv[1]
    # this is a global variable
    config = method_config_dict[method]
    print(method, config, i)
    main(method,i)

CatBOOST_2022 {'synergy_data': 'NCI_ALMANAC', 'cell_data': 'NCI_60', 'cell_list': 'all', 'drug_feats': ['morgan_fingerprint', 'drug_target', 'monetherapy'], 'cell_feats': ['exp'], 'cell_feat_filter': 'cancer', 'drug_feat_filter': 'target', 'model_name': 'CatBoost', 'cell_integrate': True, 'drug_integrate': True, 'drug_indep': False} 1


number of drugs: 68
number of cells: 59
number of data: (130182, 4)




100%|███████████████████████████████████████████████████████████████████████| 130182/130182 [00:09<00:00, 13489.10it/s]


cell features:  (130182, 470)

generating drug feats...


100%|███████████████████████████████████████████████████████████████████████| 130182/130182 [00:11<00:00, 10974.33it/s]
100%|███████████████████████████████████████████████████████████████████████| 130182/130182 [00:11<00:00, 10972.39it/s]
100%|███████████████████████████████████████████████████████████████████████| 130182/130182 [00:12<00:00, 10511.72it/s]


drug features:  (130182, 1594)
data loaded
0:	learn: 0.1587915	total: 201ms	remaining: 2m
3:	learn: 0.1534658	total: 1.03s	remaining: 2m 33s
6:	learn: 0.1494612	total: 1.69s	remaining: 2m 23s
9:	learn: 0.1466548	total: 2.36s	remaining: 2m 19s
12:	learn: 0.1446260	total: 3.05s	remaining: 2m 17s
15:	learn: 0.1432336	total: 3.69s	remaining: 2m 14s
18:	learn: 0.1421000	total: 4.33s	remaining: 2m 12s
21:	learn: 0.1410988	total: 5s	remaining: 2m 11s
24:	learn: 0.1401973	total: 5.69s	remaining: 2m 10s
27:	learn: 0.1395702	total: 6.36s	remaining: 2m 10s
30:	learn: 0.1388889	total: 7.07s	remaining: 2m 9s
33:	learn: 0.1384166	total: 7.78s	remaining: 2m 9s
36:	learn: 0.1378715	total: 8.44s	remaining: 2m 8s
39:	learn: 0.1373324	total: 9.09s	remaining: 2m 7s
42:	learn: 0.1369984	total: 9.76s	remaining: 2m 6s
45:	learn: 0.1366213	total: 10.4s	remaining: 2m 5s
48:	learn: 0.1363002	total: 11.1s	remaining: 2m 4s
51:	learn: 0.1360590	total: 11.7s	remaining: 2m 3s
54:	learn: 0.1356646	total: 12.4s	remain

468:	learn: 0.1065877	total: 1m 45s	remaining: 29.4s
471:	learn: 0.1064036	total: 1m 45s	remaining: 28.7s
474:	learn: 0.1062363	total: 1m 46s	remaining: 28s
477:	learn: 0.1061707	total: 1m 47s	remaining: 27.3s
480:	learn: 0.1059869	total: 1m 47s	remaining: 26.7s
483:	learn: 0.1059005	total: 1m 48s	remaining: 26s
486:	learn: 0.1057931	total: 1m 48s	remaining: 25.3s
489:	learn: 0.1055777	total: 1m 49s	remaining: 24.6s
492:	learn: 0.1054729	total: 1m 50s	remaining: 23.9s
495:	learn: 0.1053443	total: 1m 50s	remaining: 23.3s
498:	learn: 0.1051778	total: 1m 51s	remaining: 22.6s
501:	learn: 0.1050258	total: 1m 52s	remaining: 21.9s
504:	learn: 0.1048694	total: 1m 52s	remaining: 21.2s
507:	learn: 0.1047258	total: 1m 53s	remaining: 20.6s
510:	learn: 0.1045237	total: 1m 54s	remaining: 19.9s
513:	learn: 0.1043453	total: 1m 55s	remaining: 19.2s
516:	learn: 0.1042965	total: 1m 55s	remaining: 18.6s
519:	learn: 0.1042296	total: 1m 56s	remaining: 17.9s
522:	learn: 0.1040613	total: 1m 56s	remaining: 17.