In [1]:
from sklearn.model_selection import KFold,LeaveOneOut,LeavePOut,ShuffleSplit 
from feat_selection import *
from model_new import *
from catboost import CatBoostRegressor, Pool
import shap
from matplotlib import pyplot as plt
from sklearn.metrics import classification_report, mean_squared_error
from sklearn.metrics import roc_auc_score
from sklearn.metrics import average_precision_score
from sklearn.metrics import f1_score
from keras.callbacks import EarlyStopping, ModelCheckpoint
import numpy as np

In [2]:
available_feat_type_list = {'NCI_60':['met','mut','cop','exp']}
available_cancer_specific_cell_list = {'NCI_60':{'TNBC':['MDA-MB-231','MDA-MB-435','BT-549','HS 578T']}}

SYNERGY_THRES = 10
BATCH_SIZE = 256
N_EPOCHS = 200
PATIENCE = 30

method = 'CatBOOST_2022'
config = method_config_dict[method]
print(config)
synergy_data = input_synergy_data(config['synergy_data'])#import data

cell_data_dicts=input_cellline_data(config['cell_data'])

drug_data_dict=input_drug_data()

drug_list=synergy_data['drug1'].unique().tolist()# Extract a list of all drug names,102 drugs

cell_list=synergy_data['cell'].unique().tolist()#Extract a list of all cell names, 60 cell lines

drug_list=filter_drug(drug_list,config['drug_feat_filter'])#Filter drug features,68 drugs

cell_list=filter_cell(cell_list,config['cell_list'],available_cancer_specific_cell_list[config['cell_data']])#Filter cells,60 cell lines

synergy_data=synergy_data[(synergy_data['drug1'].isin(drug_list)) & synergy_data['drug2'].isin(drug_list) & synergy_data['cell'].isin(cell_list)]

#Data containing common drugs and cells were selected, resulting 59 cell lines
cell_feats,selected_cells=filter_cell_features(cell_data_dicts,cell_list,config['cell_feats'],config['cell_feat_filter'],config['cell_integrate'])

#Filter cell feeature
synergy_data = synergy_data[synergy_data['cell'].isin(selected_cells)]

#Select data that contains common cells

print("\n")
print("number of drugs:", len(drug_list))
print("number of cells:", len(selected_cells))
print("number of data:", synergy_data.shape)
print("\n")

if config['cell_integrate'] == True:#Construct the cell line matrix
    X_cell=np.zeros((synergy_data.shape[0],cell_feats.shape[0]))#Constructing matrix
    for i in tqdm(range(synergy_data.shape[0])):
        row=synergy_data.iloc[i]
        X_cell[i,:]=cell_feats[row['cell']].values#Assign values to the matrix according to the cells

else:
    X_cell = {}
    for feat_type in config['cell_feats']:#feature types of cell_feats
        print(feat_type, cell_feats[feat_type].shape[0])#cell_feats is a dict
        temp_cell = np.zeros((synergy_data.shape[0], cell_feats[feat_type].shape[0]))
        for i in tqdm(range(synergy_data.shape[0])):
            row = synergy_data.iloc[i]
            temp_cell[i, :] = cell_feats[feat_type][row['cell']].values
        X_cell[feat_type] = temp_cell

if config['cell_integrate'] == True:
    print("cell features: ", X_cell.shape)
else:
    print("cell features:", list(X_cell.keys()))

print("\ngenerating drug feats...")
drug_matrix_dict={}
for feat_type in config['drug_feats']:#Constructing drug matrix
    if feat_type != 'monetherapy':
        dim=drug_data_dict[feat_type].shape[0]
        temp_X_1=np.zeros((synergy_data.shape[0],dim))
        temp_X_2=np.zeros((synergy_data.shape[0],dim))
        for i in tqdm(range(synergy_data.shape[0])):
            row=synergy_data.iloc[i]
            temp_X_1[i,:]=drug_data_dict[feat_type][int(row['drug1'])]
            temp_X_2[i,:]=drug_data_dict[feat_type][int(row['drug2'])]
    else:
        dim=3
        temp_X_1=np.zeros((synergy_data.shape[0],dim))
        temp_X_2=np.zeros((synergy_data.shape[0],dim))
        for i in tqdm(range(synergy_data.shape[0])):
            row=synergy_data.iloc[i]
            temp_X_1[i,:]=drug_data_dict[feat_type].loc[row['cell'],int(row['drug1'])]
            temp_X_2[i,:]=drug_data_dict[feat_type].loc[row['cell'],int(row['drug2'])]
    drug_matrix_dict[feat_type + '_1']=temp_X_1
    drug_matrix_dict[feat_type + '_2']=temp_X_2

X_drug_temp = {}#The next question is whether the drug feature matrix should be independent
if config['drug_indep']==False:
    for feat_type in config['drug_feats']:
        if feat_type != 'monetherapy':
            temp_X=drug_matrix_dict[feat_type+'_1']+drug_matrix_dict[feat_type+'_2']#matrix addition
            X_drug_temp[feat_type]=temp_X
        else:#Monotherapy, so separate
            X_drug_temp[feat_type+'_1']=drug_matrix_dict[feat_type + '_1']
            X_drug_temp[feat_type+'_2']=drug_matrix_dict[feat_type + '_2']
else:
    X_drug_temp = drug_matrix_dict

if config['drug_integrate'] == False:#Whether drug features need to be combined
    X_drug=X_drug_temp
else:
    X_drug=np.concatenate(list(X_drug_temp.values()),axis=1)

if config['drug_integrate'] == True:
    print("drug features: ", X_drug.shape)
else:
    print("drug features")
    print(list(X_drug.keys()))
    for key, value in X_drug.items():
        print(key, value.shape)

Y = (synergy_data['score']>SYNERGY_THRES).astype(int).values
#Y=synergy_data['score'].values
X=np.concatenate([X_cell,X_drug],axis=1)

{'synergy_data': 'NCI_ALMANAC', 'cell_data': 'NCI_60', 'cell_list': 'all', 'drug_feats': ['morgan_fingerprint', 'drug_target', 'monetherapy'], 'cell_feats': ['exp'], 'cell_feat_filter': 'cancer', 'drug_feat_filter': 'target', 'model_name': 'CatBoost', 'cell_integrate': True, 'drug_integrate': True, 'drug_indep': False}


number of drugs: 68
number of cells: 59
number of data: (130182, 4)




100%|███████████████████████████████████████████████████████████████████████| 130182/130182 [00:09<00:00, 13958.55it/s]


cell features:  (130182, 470)

generating drug feats...


100%|███████████████████████████████████████████████████████████████████████| 130182/130182 [00:10<00:00, 12318.87it/s]
100%|███████████████████████████████████████████████████████████████████████| 130182/130182 [00:12<00:00, 10788.39it/s]
100%|███████████████████████████████████████████████████████████████████████| 130182/130182 [00:12<00:00, 10353.31it/s]


drug features:  (130182, 1594)


In [4]:
i=1
train_idx = np.load(r'..\idx\train_idx_%s.npy' % (i))
test_idx = np.load(r'..\idx\test_idx_%s.npy' % (i))
X_train,Y_train=X[train_idx],Y[train_idx]
X_test,Y_test=X[test_idx],Y[test_idx]

In [5]:
params = {
    'iterations':600,
    'learning_rate':0.1,
    'depth':9,
    'loss_function':'RMSE'}
model = CatBoostRegressor(**params)
model.fit(X_train, Y_train,verbose=3)
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)

0:	learn: 0.1587915	total: 480ms	remaining: 4m 47s
3:	learn: 0.1534658	total: 1.59s	remaining: 3m 57s
6:	learn: 0.1494612	total: 2.44s	remaining: 3m 26s
9:	learn: 0.1466548	total: 3.22s	remaining: 3m 10s
12:	learn: 0.1446260	total: 4.01s	remaining: 3m 1s
15:	learn: 0.1432336	total: 4.75s	remaining: 2m 53s
18:	learn: 0.1421000	total: 5.54s	remaining: 2m 49s
21:	learn: 0.1410988	total: 6.36s	remaining: 2m 47s
24:	learn: 0.1401973	total: 7.12s	remaining: 2m 43s
27:	learn: 0.1395702	total: 7.9s	remaining: 2m 41s
30:	learn: 0.1388889	total: 8.72s	remaining: 2m 40s
33:	learn: 0.1384166	total: 9.59s	remaining: 2m 39s
36:	learn: 0.1378715	total: 10.3s	remaining: 2m 37s
39:	learn: 0.1373324	total: 11s	remaining: 2m 34s
42:	learn: 0.1369984	total: 11.8s	remaining: 2m 32s
45:	learn: 0.1366213	total: 12.5s	remaining: 2m 30s
48:	learn: 0.1363002	total: 13.2s	remaining: 2m 28s
51:	learn: 0.1360590	total: 14.1s	remaining: 2m 28s
54:	learn: 0.1356646	total: 15s	remaining: 2m 28s
57:	learn: 0.1352551	t

468:	learn: 0.1065877	total: 2m 9s	remaining: 36.1s
471:	learn: 0.1064036	total: 2m 9s	remaining: 35.2s
474:	learn: 0.1062363	total: 2m 10s	remaining: 34.4s
477:	learn: 0.1061707	total: 2m 11s	remaining: 33.5s
480:	learn: 0.1059869	total: 2m 12s	remaining: 32.7s
483:	learn: 0.1059005	total: 2m 12s	remaining: 31.8s
486:	learn: 0.1057931	total: 2m 13s	remaining: 31s
489:	learn: 0.1055777	total: 2m 14s	remaining: 30.1s
492:	learn: 0.1054729	total: 2m 14s	remaining: 29.3s
495:	learn: 0.1053443	total: 2m 15s	remaining: 28.4s
498:	learn: 0.1051778	total: 2m 16s	remaining: 27.6s
501:	learn: 0.1050258	total: 2m 17s	remaining: 26.8s
504:	learn: 0.1048694	total: 2m 18s	remaining: 26s
507:	learn: 0.1047258	total: 2m 18s	remaining: 25.1s
510:	learn: 0.1045237	total: 2m 19s	remaining: 24.3s
513:	learn: 0.1043453	total: 2m 20s	remaining: 23.5s
516:	learn: 0.1042965	total: 2m 21s	remaining: 22.6s
519:	learn: 0.1042296	total: 2m 21s	remaining: 21.8s
522:	learn: 0.1040613	total: 2m 22s	remaining: 21s
5

In [6]:
feature_importance = pd.DataFrame()

feature_importance['feature'] = range(X.shape[1])

feature_importance['importance'] = np.abs(shap_values).mean(0)

feature_importance.sort_values('importance', ascending=False)

Unnamed: 0,feature,importance
2063,2063,0.007569
994,994,0.004904
2060,2060,0.004247
661,661,0.003374
2059,2059,0.002786
...,...,...
1256,1256,0.000000
1255,1255,0.000000
1254,1254,0.000000
1253,1253,0.000000
