In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
import json


In [2]:
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.preprocessing import scale
from imblearn.over_sampling import BorderlineSMOTE
from imblearn.under_sampling import EditedNearestNeighbours


In [3]:
def res_to_csv(res, bacteria, mode, encoding, method, drug_name):
    tp = []
    for k in res.keys():
        if k != 'accuracy':
            tp.append(list(res[k].values()))
        else:
            tp.append([np.nan, np.nan, res[k], res['macro avg']['support']])
    tp = pd.DataFrame(tp, index=res.keys(), columns=res["S"].keys())
    tp.to_csv(f"results/{bacteria}/{encoding}/{bacteria}_{mode}_{encoding}_{method}_{drug_name}.csv", index=True, encoding='utf-8')


In [4]:
def MLPredModel(X, Y, bacteria, mode, encoding, method, drug_name, Normalization=True, seed=7, save_res='easy'):
    
    print(f"Training {method} model ...")
    MODEL = {
        'LR': LogisticRegression(solver='lbfgs', max_iter=1500),
        'RF': RandomForestClassifier(n_estimators=200, random_state=0),
        'SVM': SVC(kernel='linear', probability=True)
    }
    if encoding == 'FCGR':
        X = X.reshape(X.shape[0], -1)
    if Normalization:
        X = scale(X)
    x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=seed)
    # sample_solver= BorderlineSMOTE()
    # sample_solver=EditedNearestNeighbours()
    # x_train,y_train=sample_solver.fit_resample(x_train,y_train)
    model = MODEL[method]
    model.fit(x_train, y_train)
    # plot_feature_importance(model, bacteria, mode, encoding, method, drug_name, POS)
    preds = model.predict(x_test)
    print("Result for {}".format(drug_name))
    print(classification_report(y_test, preds, target_names=['S', 'R']))
    res=classification_report(y_test, preds, target_names=['S', 'R'],output_dict=True)
    if save_res=='easy':
        res=list(res['weighted avg'].values())[:3]+[res['accuracy']]
        return res
    else:
        res_to_csv(res, bacteria=bacteria, mode=mode,encoding=encoding, method=method, drug_name=drug_name)
        return -1


In [5]:
def plot_feature_importance(model, bacteria, mode, encoding, method, drug_name, POS, n=10):
    if method == 'RF':
        index = np.argsort(-model.feature_importances_)[:n].tolist()
        value = model.feature_importances_[index]
    if method == 'LR':
        index=np.argsort(-abs(model.coef_[0]))[:n].tolist()
        value = model.coef_[0][index]
        
    print(f"Most {n} important position: \n{POS[index]}")
    fig, ax = plt.subplots()
    ax.barh(range(n), value, align='center', color='c')
    ax.set_yticks(range(n))
    ax.set_yticklabels([str(i) for i in POS[index]])
    ax.invert_yaxis()
    plt.savefig(f'results/{bacteria}/{encoding}/{mode}_{method}_{drug_name}_TOP{n}_feature_importance.png')
    plt.show()
    plt.close()


In [6]:
Methods = ['LR','RF','SVM']
# Methods = ['LR']
# Encodings = ['FCGR',]
Encoding = 'Label_Encoding'
Drug_list = ['AMP', 'AMX', 'AMC', 'TZP', 'CXM', 'CET', 'TBM', 'TMP', 'CIP', 'CTX', 'CTZ', 'GEN']
Bacteria = 'E.coli'
Mode = 'ToN'
save_res='easy'
seed = 7



In [7]:
# set(X.flatten().tolist())

In [8]:
# for m in Methods:
#     res=[]
#     for d in Drug_list:
#         data = np.load(f'data/{Bacteria}/preprocessed/{Encoding}/{Bacteria}_{Mode}_{Encoding}_{d}.npz', allow_pickle=True)
#         X, Y, POS = data['X'].astype('float32'), data['Y'].astype('int32'), data['POS'].astype('int32')
#         print(X.shape, X.dtype, Y.shape, Y.dtype, POS.shape, POS.dtype)
#         res.append(MLPredModel(X, Y, bacteria=Bacteria, mode=Mode,encoding=Encoding, method=m, drug_name=d,Normalization=False,seed=seed))
#     if save_res=='easy':
#         res=pd.DataFrame(res,columns=['precision','recall','f1','accuaracy'],index=Drug_list)
#         res.to_csv(f'results/E.coli/raw_all_drug_{m}_result.csv')


In [9]:
def get_POSweight(drug_name,sd):
    np.random.seed(sd)#CTX
    # np.random.seed(0)#CTZ
    # np.random.seed(7)

    data = np.load(f'data/{Bacteria}/preprocessed/{Encoding}/{Bacteria}_{Mode}_{Encoding}_{drug_name}.npz', allow_pickle=True)
    X, Y, POS = data['X'].astype('float32'), data['Y'].astype('int32'), data['POS'].astype('int32')
    shuffe_idx=list(range(len(X)))
    np.random.shuffle(shuffe_idx)
    X1=X[shuffe_idx]
    Y1=Y[shuffe_idx]
    model=RandomForestClassifier(n_estimators=200, random_state=0)
    model.fit(X1,Y1)
    index = np.argsort(-model.feature_importances_).tolist()
    print(f"TOP 10 POS: {POS[index][:10]}")
    with open(f'results/E.coli/Label_Encoding/{drug_name}_POS_weight.json', 'w', encoding='utf-8') as f:
        json.dump(model.feature_importances_.tolist(), f, ensure_ascii=False, indent=4)
    


    x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=seed)
    model=RandomForestClassifier(n_estimators=200, random_state=0)
    model.fit(x_train,y_train)
    preds = model.predict(x_test)
    print("raw Result for {}".format(drug_name))
    print(classification_report(y_test, preds, target_names=['S', 'R']))
    # res1=classification_report(y_test, preds, target_names=['S', 'R'],output_dict=True)
    X2=X[:,index[:50]]
    X2=(X2!=4)
    x_train, x_test, y_train, y_test = train_test_split(X2, Y, test_size=0.2, random_state=seed)
    model=RandomForestClassifier(n_estimators=200, random_state=0)
    model.fit(x_train,y_train)
    preds = model.predict(x_test)
    print("POS Result for {}".format(drug_name))
    print(classification_report(y_test, preds, target_names=['S', 'R']))
    # res2=classification_report(y_test, preds, target_names=['S', 'R'],output_dict=True)



    # X2=X[:,index[:50]]
    # X2=(X2!=4)
    # x_train, x_test, y_train, y_test = train_test_split(X2, Y, test_size=0.2, random_state=seed)
    # model=LogisticRegression(solver='lbfgs', max_iter=1500)
    # model.fit(x_train,y_train)
    # preds = model.predict(x_test)
    # print("Result for {}".format(drug_name))
    # print(classification_report(y_test, preds, target_names=['S', 'R']))
    # return res1['accuracy'],res2['accuracy']

In [17]:
get_POSweight('AMX',1)

TOP 10 POS: [ 905642 1191833 3544549 2210125 4107784 2423513  101145  619397 2133122
 3076007]
raw Result for AMX
              precision    recall  f1-score   support

           S       0.55      0.47      0.51        77
           R       0.73      0.79      0.76       141

    accuracy                           0.68       218
   macro avg       0.64      0.63      0.63       218
weighted avg       0.67      0.68      0.67       218

POS Result for AMX
              precision    recall  f1-score   support

           S       0.49      0.49      0.49        77
           R       0.72      0.72      0.72       141

    accuracy                           0.64       218
   macro avg       0.61      0.61      0.61       218
weighted avg       0.64      0.64      0.64       218



In [10]:
# res={'CTX':[],'CTZ':[]}
# for d in ['CTX','CTZ']:
#     for sd in range(100):
#         res1,res2=get_POSweight(d,sd)
#         if res2>res1:
#             print(d+f'  {res1}  {res2}')
#             res[d].append({'seed':sd,'res1':res1,'res2':res2})
#         if len(res[d])>3: break
# print(res)

TOP 10 POS: [3149226 2421630 3997576 4156280 3373874  180155  784449 1319685 4135148
 3053818]
TOP 10 POS: [3893315 2526966 2138483 2421630 3997576 3077200 4135148 3149226 1256926
 1319685]
TOP 10 POS: [ 180155 2322653 3997576 2421630 2560525 4135148 2715202 2339440 4203069
  763976]
TOP 10 POS: [3149226 2421630 4156318 3997576  180155 4135148 3068635 2715202 3278126
 4418266]
TOP 10 POS: [3997576 4135148 4445918   79614 1007391 3507357  783190 1192710  486768
  893052]
TOP 10 POS: [ 180155 3997576 2421630 2560525 4445549 3900673 4135148 3074152 1256926
 2322653]
TOP 10 POS: [ 189081 2421630  180155 2322653 4135148 1319685 4424011  713492  713864
 3077209]
TOP 10 POS: [2421630 3997576 3507357 3997381  180155 2322653 1256926 2393581  486768
 3149226]
TOP 10 POS: [ 180155  189081 1192104 2421630 4186773 3893315 4412860 3373874 4203077
 3997576]
TOP 10 POS: [ 180155 3997576 2421630 4135148 3961074 4114364 4203077  189081 1319685
 2431302]
TOP 10 POS: [3149226 3997576  180155 1263711 30686

In [1]:
# for d in Drug_list:
#     get_POSweight(d)

## 验证关键位置的突变与耐药性的相关性

In [16]:
# snp = pd.read_csv('/data/HWK/DeepGene/data/E.coli/preprocessed/FCGR/E.coli_ToN_FCGR_input.csv', sep=',', encoding='utf-8')
# snp

In [17]:
# snp=snp.set_index(data['POS'])
# snp

In [18]:
# plt.bar(range(len(data['Y'])),data['Y'])

In [19]:
# imp_pos=POS[index][:10]
# tmp=(snp.loc[imp_pos[0]]!='N')
# plt.bar(range(len(tmp)),tmp)

In [20]:
# np.corrcoef(data['Y'],tmp.to_numpy())