In [169]:
import pandas as pd
from sklearn.preprocessing import StandardScaler
import csv
from sklearn.model_selection import train_test_split, cross_val_score, KFold
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score, classification_report
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_curve, auc
from sklearn.model_selection import StratifiedKFold, ParameterGrid
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import StratifiedKFold, cross_val_predict, GridSearchCV
from itertools import product
from sklearn.model_selection import RepeatedStratifiedKFold
import random
from tqdm import tqdm
import warnings
warnings.filterwarnings(action='ignore')

In [170]:
# training dataset loading
Training_csv_file_path = './data/TrainingSet/mRNA_sublocation_TrainingSet_NC-BERTdata.csv'
Training_data= pd.read_csv(Training_csv_file_path)
# separate sequence features and labels
X_Train = Training_data.drop(columns=['label']).values
y_Train = Training_data['label'].values
# the training data and test data are standardized
sc = StandardScaler()
sc.fit(X_Train)
X_Train = sc.transform(X_Train)

In [171]:
from sklearn.decomposition import PCA
import numpy as np
# Initialize PCA and set the number of principal components
pca = PCA(n_components=72)
# Fit PCA on the training data and transform the training data to its principal components
X_Train = pca.fit_transform(X_Train)


In [172]:
# Define hyperparameter ranges
param_grid = {
    'C':[0.1,0.5],
    'max_iter':[20,30,40]
}

# Create all combinations of hyperparameters
grid = list(ParameterGrid(param_grid))

# Initialize the result list
results = []

# 100 times 5-fold cross-validation
rskf = RepeatedStratifiedKFold(n_splits=5, n_repeats=100, random_state=43)
rskf.get_n_splits(X_Train, y_Train)

# Open the CSV file for writing
with open('./result/LR-100times-5-fold cv.csv', mode='w', newline='') as file:
    writer = csv.writer(file)
    # Write the header
    writer.writerow(['C', 'max_iter', 'accuracy', 'precision', 'recall', 'f1'])
    
    # Perform a grid search
    for params in tqdm(grid, desc="Hyperparameter search"):
        acc_scores = []
        prec_scores = []
        rec_scores = []
        f1_scores = []
        
        for i, (train_index, test_index) in enumerate(rskf.split(X_Train, y_Train)):
            X_train, X_test = X_Train[train_index], X_Train[test_index]
            y_train, y_test = y_Train[train_index], y_Train[test_index]
            clf = LogisticRegression(C=params['C'],max_iter=params['max_iter'],fit_intercept=False)
            clf.fit(X_train, y_train)
            y_pred = clf.predict(X_test)
            acc_scores.append(accuracy_score(y_test, y_pred))
            prec_scores.append(precision_score(y_test, y_pred))
            rec_scores.append(recall_score(y_test, y_pred))
            f1_scores.append(f1_score(y_test, y_pred))
        
        # Save scores for this hyperparameter combination
        for acc, prec, rec, f1 in zip(acc_scores, prec_scores, rec_scores, f1_scores):
            writer.writerow([params['C'], params['max_iter'], acc, prec, rec, f1])

        acc_mean, acc_std = np.mean(acc_scores), np.std(acc_scores)
        prec_mean, prec_std = np.mean(prec_scores), np.std(prec_scores)
        rec_mean, rec_std = np.mean(rec_scores), np.std(rec_scores)
        f1_mean, f1_std = np.mean(f1_scores), np.std(f1_scores)
        print(f" params: {params}")
        print(f" acc_mean: {acc_mean}, acc_std: {acc_std}")
        print(f" prec_mean: {prec_mean}, prec_std: {prec_std}")
        print(f" rec_mean: {rec_mean}, rec_std: {rec_std}")
        print(f" f1_mean: {f1_mean}, f1_std: {f1_std}")

Hyperparameter search:  17%|██████████▏                                                  | 1/6 [00:12<01:01, 12.21s/it]

 params: {'C': 0.1, 'max_iter': 20}
 acc_mean: 0.6722951667287673, acc_std: 0.012969961031479117
 prec_mean: 0.6486376920566854, prec_std: 0.018885182142049258
 rec_mean: 0.5388819315210119, rec_std: 0.02280325758668309
 f1_mean: 0.5884417695788493, f1_std: 0.018098985087038316


Hyperparameter search:  33%|████████████████████▎                                        | 2/6 [00:24<00:48, 12.18s/it]

 params: {'C': 0.1, 'max_iter': 30}
 acc_mean: 0.6722951667287673, acc_std: 0.012969961031479117
 prec_mean: 0.6486376920566854, prec_std: 0.018885182142049258
 rec_mean: 0.5388819315210119, rec_std: 0.02280325758668309
 f1_mean: 0.5884417695788493, f1_std: 0.018098985087038316


Hyperparameter search:  50%|██████████████████████████████▌                              | 3/6 [00:37<00:38, 12.73s/it]

 params: {'C': 0.1, 'max_iter': 40}
 acc_mean: 0.6722951667287673, acc_std: 0.012969961031479117
 prec_mean: 0.6486376920566854, prec_std: 0.018885182142049258
 rec_mean: 0.5388819315210119, rec_std: 0.02280325758668309
 f1_mean: 0.5884417695788493, f1_std: 0.018098985087038316


Hyperparameter search:  67%|████████████████████████████████████████▋                    | 4/6 [00:51<00:26, 13.01s/it]

 params: {'C': 0.5, 'max_iter': 20}
 acc_mean: 0.6722930208060206, acc_std: 0.01298025546880926
 prec_mean: 0.6486149890757782, prec_std: 0.01888484203465516
 rec_mean: 0.5389312534209085, rec_std: 0.02284683473603032
 f1_mean: 0.5884614422351592, f1_std: 0.018126956421275677


Hyperparameter search:  83%|██████████████████████████████████████████████████▊          | 5/6 [01:04<00:13, 13.11s/it]

 params: {'C': 0.5, 'max_iter': 30}
 acc_mean: 0.6722930208060206, acc_std: 0.01298025546880926
 prec_mean: 0.6486149890757782, prec_std: 0.01888484203465516
 rec_mean: 0.5389312534209085, rec_std: 0.02284683473603032
 f1_mean: 0.5884614422351592, f1_std: 0.018126956421275677


Hyperparameter search: 100%|█████████████████████████████████████████████████████████████| 6/6 [01:17<00:00, 12.99s/it]

 params: {'C': 0.5, 'max_iter': 40}
 acc_mean: 0.6722930208060206, acc_std: 0.01298025546880926
 prec_mean: 0.6486149890757782, prec_std: 0.01888484203465516
 rec_mean: 0.5389312534209085, rec_std: 0.02284683473603032
 f1_mean: 0.5884614422351592, f1_std: 0.018126956421275677



