In [1]:
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split, cross_val_score, KFold
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score, classification_report
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import StratifiedKFold, cross_val_predict, GridSearchCV
import random
from tqdm import tqdm

In [2]:
# training dataset loading
Training_csv_file_path = './data/TrainingSet/mRNA_sublocation_TrainingSet_DNABERT_data.csv'
Training_data= pd.read_csv(Training_csv_file_path)
Training_data

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,759,760,761,762,763,764,765,766,767,label
0,-0.051352,0.110863,0.017913,-0.117001,-0.130711,-0.051733,-0.087305,-0.113558,0.015367,0.044985,...,-0.047920,-0.044799,0.009836,-0.027684,-0.060108,0.051503,0.074873,0.094912,0.142718,1
1,-0.028948,0.076648,0.053617,-0.151656,-0.125982,-0.033365,-0.052527,-0.106612,0.023717,0.048803,...,-0.063734,-0.007245,0.007870,-0.022348,-0.066921,0.064694,0.101658,0.116233,0.136837,1
2,-0.053768,0.076979,0.011430,-0.088812,-0.123918,-0.040801,-0.099212,-0.122977,0.022925,0.024804,...,-0.080484,-0.063357,0.037052,-0.023793,-0.052618,0.047666,0.088847,0.100192,0.156532,1
3,-0.013776,0.108522,0.030686,-0.113928,-0.116355,-0.049274,-0.067821,-0.104782,0.024167,0.016909,...,-0.044457,-0.043053,0.013077,-0.001783,-0.055099,0.052132,0.093488,0.079941,0.162534,1
4,-0.041918,0.088144,0.067097,-0.108955,-0.116393,-0.039382,-0.087630,-0.126116,0.024942,0.020732,...,-0.056244,-0.031527,0.027563,-0.006933,-0.035199,0.053570,0.106644,0.127546,0.129778,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4657,-0.045684,0.073154,0.031592,-0.088793,-0.112148,-0.057297,-0.112139,-0.115460,0.027914,0.033002,...,-0.071451,-0.064951,0.026554,-0.024152,-0.040399,0.054827,0.108673,0.107649,0.108352,0
4658,-0.051314,0.101583,0.049372,-0.130485,-0.146191,-0.078210,-0.049869,-0.120071,0.025753,0.044124,...,-0.069669,-0.118523,0.001044,-0.004282,-0.044469,0.064367,0.085585,0.090057,0.146606,0
4659,-0.072383,0.114879,0.043321,-0.113051,-0.119378,-0.046503,-0.109755,-0.090195,0.016729,0.033405,...,-0.073933,-0.074885,0.044387,-0.000744,-0.048722,-0.010145,0.098738,0.119175,0.117543,0
4660,-0.077676,0.075969,0.107718,-0.085082,-0.079427,-0.056554,-0.071221,-0.067641,0.003955,0.040569,...,-0.071051,-0.068419,-0.018095,0.009988,-0.081945,0.103257,-0.006743,0.098912,0.106491,0


In [3]:
# test dataset loading
Test_csv_file_path = './data/TestSet/mRNA_sublocation_TestSet_DNABERT_data.csv'
Test_data= pd.read_csv(Test_csv_file_path)
Test_data

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,759,760,761,762,763,764,765,766,767,label
0,-0.029158,0.100403,0.067271,-0.116924,-0.109726,-0.056563,-0.095235,-0.063101,0.026661,0.055791,...,-0.053218,-0.046008,0.024874,-0.005043,-0.057149,0.022523,0.123339,0.146821,0.165012,1
1,-0.051528,0.083082,0.069916,-0.125172,-0.106272,-0.032767,-0.076169,-0.062987,0.034608,0.097928,...,-0.060201,-0.053725,0.023453,0.007158,-0.104437,0.018186,0.066489,0.168612,0.168197,1
2,-0.042795,0.084680,0.087846,-0.136758,-0.113154,-0.070079,-0.075557,-0.065823,0.027558,0.086018,...,-0.086440,-0.074154,0.037891,0.010463,-0.048493,0.094847,0.098022,0.145162,0.164766,1
3,-0.011203,0.055997,0.090977,-0.172956,-0.144822,-0.070694,-0.123100,-0.074747,0.023865,0.037438,...,-0.072469,-0.075094,0.079060,-0.053585,-0.033657,-0.002991,0.147035,0.143220,0.139921,1
4,-0.023741,0.064950,0.127039,-0.152309,-0.082382,-0.070835,-0.145091,-0.057588,0.054729,0.018183,...,-0.060701,-0.105159,0.082175,-0.042828,-0.013750,-0.040331,0.205573,0.173816,0.187638,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
513,-0.038825,0.066094,0.105642,-0.127595,-0.096262,-0.075581,-0.077916,-0.082450,0.028816,0.078428,...,-0.083052,-0.081839,0.025406,0.034709,-0.057688,0.085145,0.059397,0.119146,0.114285,0
514,-0.048585,0.073876,0.038217,-0.139759,-0.106150,-0.056394,-0.057778,-0.105467,0.015478,0.053942,...,-0.050953,-0.018395,-0.001208,-0.011797,-0.064337,0.017005,0.087882,0.118566,0.137715,0
515,-0.039191,0.074916,0.047489,-0.125525,-0.114935,-0.063233,-0.059096,-0.098679,-0.006991,0.081666,...,-0.077407,-0.028342,0.029947,-0.024831,-0.072220,-0.001702,0.085262,0.091931,0.144309,0
516,-0.028589,0.071419,0.104370,-0.118703,-0.104961,-0.059899,-0.096033,-0.123387,0.024864,0.105596,...,-0.115380,-0.046133,0.049324,-0.001430,-0.059596,0.059717,0.050933,0.101675,0.132865,0


In [4]:
# separate sequence features and labels
X_train = Training_data.drop(columns=['label']).values
y_train = Training_data['label'].values

X_test = Test_data.drop(columns=['label']).values
y_test = Test_data['label'].values

# the training data and test data are standardized
sc = StandardScaler()
sc.fit(X_train)
X_train = sc.transform(X_train)
X_test = sc.transform(X_test)

In [5]:
from sklearn.model_selection import train_test_split, cross_val_score, KFold
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score, classification_report
from sklearn.metrics import roc_curve, auc
results = []
best_acc = 0
# Hyperparameter tuning
for k in tqdm(range(1,101),desc="Hyperparameter Search"):
    
    val_accuracy_scores = []
    val_precision_scores = []
    val_recall_scores = []
    val_f1_scores = []
#     5-fold cross-validation    
    kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    for fold, (train_index, val_index) in enumerate(kf.split(X_train,y_train), 1):
        X_train_fold, X_val_fold = X_train[train_index], X_train[val_index]
        y_train_fold, y_val_fold = y_train[train_index], y_train[val_index]
        
        clf = KNeighborsClassifier(n_neighbors=k,weights='distance',algorithm='ball_tree')
        clf.fit(X_train_fold,y_train_fold)
        
        val_predictions = clf.predict(X_val_fold)
        val_accuracy = accuracy_score(y_val_fold, val_predictions)
        val_precision = precision_score(y_val_fold, val_predictions)
        val_recall = recall_score(y_val_fold, val_predictions)
        val_f1 = f1_score(y_val_fold, val_predictions)
#         Saving metrics
        val_accuracy_scores.append(val_accuracy)
        val_precision_scores.append(val_precision)
        val_recall_scores.append(val_recall)
        val_f1_scores.append(val_f1)
    #   The average value of each metric was calculated
    val_ACC = np.mean(val_accuracy_scores)
    val_Precision = np.mean(val_precision_scores)
    val_Recall = np.mean(val_recall_scores)
    val_F1 = np.mean(val_f1_scores)
# Independent testing    
    clf.fit(X_train,y_train)
    test_predictions = clf.predict(X_test)
    cm = confusion_matrix(y_test,test_predictions)
    TP = cm[1, 1]
    TN = cm[0, 0]
    FP = cm[0, 1]
    FN = cm[1, 0]
# Calculating test metrics
    test_ACC = accuracy_score(y_test, test_predictions)
    test_Precision = precision_score(y_test, test_predictions)
    test_Recall = recall_score(y_test, test_predictions)
    test_F1 = f1_score(y_test, test_predictions)
    mcc = (TP * TN - FP * FN) / ((TP + FP) * (TP + FN) * (TN + FP) * (TN + FN))**0.5
    fpr, tpr, thresholds = roc_curve(y_test, test_predictions, pos_label=1)
    roc_auc = auc(fpr, tpr)
    
    
    
        
    results.append({
        "k": k,
        "val_ACC": val_ACC,
        "val_Precision": val_Precision,
        "val_Recall": val_Recall,
        "val_F1":val_F1,
        "test_ACC":test_ACC,
        "test_Precision":test_Precision,
        "test_Recall":test_Recall,
        "test_F1":test_F1,
        "test_MCC":mcc,
        "test_roc_auc":roc_auc
    })
    

Hyperparameter Search: 100%|█████████████████████████████████████████████████████████| 100/100 [37:11<00:00, 22.31s/it]


In [6]:
# result ranking
sorted_results = sorted(results, key=lambda x: x["test_ACC"], reverse=True)

for i,result in enumerate(sorted_results):
    print("超参数: k :", result["k"])
    print("val_ACC:", result["val_ACC"])
    print("val_Precision:", result["val_Precision"])
    print("val_Recall:", result["val_Recall"])
    print("val_F1:", result["val_F1"])
    print("test_ACC:", result["test_ACC"])
    print("test_Precision:", result["test_Precision"])
    print("test_Recall:", result["test_Recall"])
    print("test_F1:", result["test_F1"])
    print("test_MCC:", result["test_MCC"])
    print("test_roc_auc:", result["test_roc_auc"])
    print("-" * 60)

超参数: k : 15
val_ACC: 0.6758888444217508
val_Precision: 0.6574187896367677
val_Recall: 0.5315404731496686
val_F1: 0.5875447876679762
test_ACC: 0.7277992277992278
test_Precision: 0.7403314917127072
test_Recall: 0.5877192982456141
test_F1: 0.6552567237163816
test_MCC: 0.44316470218761184
test_roc_auc: 0.7128251663641864
------------------------------------------------------------
超参数: k : 16
val_ACC: 0.6763187189784212
val_Precision: 0.6600945756561173
val_Recall: 0.5270960287052241
val_F1: 0.5858699950111849
test_ACC: 0.7258687258687259
test_Precision: 0.7443181818181818
test_Recall: 0.5745614035087719
test_F1: 0.6485148514851485
test_MCC: 0.4395558457360397
test_roc_auc: 0.7096944948578342
------------------------------------------------------------
超参数: k : 17
val_ACC: 0.6769615757927034
val_Precision: 0.6628247276130506
val_Recall: 0.524141580003649
val_F1: 0.585007444501729
test_ACC: 0.7258687258687259
test_Precision: 0.7388888888888889
test_Recall: 0.5833333333333334
test_F1: 0.6519