In [1]:
import math
import h5py
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from keras.models import Sequential
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.layers import GlobalAveragePooling1D, Conv1D, GlobalMaxPooling1D, BatchNormalization, Dropout, Dense
from keras.layers import Dense
from keras.regularizers import L1L2
from tensorflow.keras.models import Model
from tensorflow.keras import layers, losses, regularizers
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import cross_val_score, KFold
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_curve, confusion_matrix, auc, confusion_matrix, roc_curve, auc
from sklearn.metrics import roc_auc_score, precision_recall_curve
from openpyxl import Workbook
import pandas as pd
import time
import os

In [None]:
def normalize(data):
    min = np.min(data, axis=0)
    max = np.max(data, axis=0)
    data = (data - min) / (max - min)
    return data

def create_model(filter_sz=64, kernel_sz1=3, kernel_sz2=3, dropout=0.4, input_shape=[None, 1]):
    model = Sequential()
    model.add(Conv1D(filters=filter_sz, kernel_size=kernel_sz1, strides=1, activation='relu', input_shape=input_shape, name='L1'))
    model.add(Dropout(dropout))
    model.add(Conv1D(filters=filter_sz, kernel_size=kernel_sz2, strides=1, activation='relu', name='L2'))
    model.add(Conv1D(filters=32, kernel_size=1, strides=1, activation='relu', name='L3'))
    model.add(GlobalAveragePooling1D())
    model.add(Dropout(dropout))
    model.add(Dense(64, activation='relu'))
    model.add(Dense(32, activation='relu'))
    model.add(Dense(1, activation='sigmoid', name='classification'))
    adam = tf.keras.optimizers.Adam(learning_rate=1.e-04)
    model.compile(loss='binary_crossentropy', optimizer=adam, metrics=['accuracy'])
    return model

def compute_metrics(y_true, y_pred):
    fpr, tpr, thresholds = roc_curve(y_true, y_pred)
    area = auc(fpr, tpr)
    gmeans = np.sqrt(tpr * (1 - fpr))
    ix = np.argmax(gmeans)
    tn, fp, fn, tp = confusion_matrix(y_true, (y_pred >= thresholds[ix])).ravel()
    N = tn + fp + fn + tp
    S = (tp + fn) / N
    P = (tp + fp) / N
    MCC = ((tp / N) - S * P) / np.sqrt(P * S * (1 - S) * (1 - P))
    accuracy = (tp + tn) / N
    sensitivity = tp / (tp + fn)
    specificity = tn / (tn + fp)
    return accuracy, area, MCC, sensitivity, specificity

def cross_val_cumulative_accuracy(X, y, folds=5):
    kf = StratifiedKFold(n_splits=folds, shuffle=False)
    y_true, y_pred = np.array([]), np.array([])
    _, n = X.shape
    filter_size = 64

    if n > 12:
        kernel_sz1 = 5
        kernel_sz2 = 7
    elif n >= 5 and n <= 12: 
        kernel_sz1 = n - 2
        kernel_sz2 = 3
    else:
        kernel_sz1 = 1
        kernel_sz2 = 1

    for train_index, test_index in kf.split(X, y):
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]
        
        X_train = X_train.reshape(X_train.shape[0], X_train.shape[1], 1)
        X_test = X_test.reshape(X_test.shape[0], X_test.shape[1], 1)
        
        model = create_model(filter_sz=filter_size, kernel_sz1=kernel_sz1, kernel_sz2=kernel_sz2, input_shape=[X_train.shape[1], 1])
        model.fit(X_train, y_train, validation_data=(X_test, y_test), epochs=2000, batch_size=16, shuffle=True, verbose=0, callbacks=[tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=20, restore_best_weights=True)])
        
        y_pred_fold = model.predict(X_test).flatten()
        y_true = np.append(y_true, y_test)
        y_pred = np.append(y_pred, y_pred_fold)
    
    return compute_metrics(y_true, y_pred)

def recursive_feature_elimination(X, y, min_features=1, connectivity_list=None, output_dir=None, st_num=None):
    n_features = X.shape[1]
    feature_indices = list(range(n_features))
    print(feature_indices)
    selected_features = []
    
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)
    
    metrics_list = []
    feature_sets = []
    
    # Create a new Excel workbook
    wb = Workbook()
    ws = wb.active
    ws.title = "Feature Sets"
    ws.append(["Set", "Accuracy", "AUC", "MCC", "Sensitivity", "Specificity"] + [f"Feature_{i}" for i in range(n_features)])
    
    set_counter = 1
    
    while len(feature_indices) >= min_features:
        accuracy, auc_score, mcc, sensitivity, specificity = cross_val_cumulative_accuracy(X_scaled[:, feature_indices], y)
        metrics_list.append((accuracy, auc_score, mcc, sensitivity, specificity))
        feature_sets.append(feature_indices.copy())
        
        print(f"\nCurrent number of features: {len(feature_indices)}")
        print(f"Accuracy: {accuracy:.4f}, AUC: {auc_score:.4f}, MCC: {mcc:.4f}, Sensitivity: {sensitivity:.4f}, Specificity: {specificity:.4f}")
        print("Current feature set:")
        for idx in feature_indices:
            if connectivity_list:
                print(f"  Index {idx}: {connectivity_list[idx]}")
            else:
                print(f"  Index {idx}")
        
        # Write to Excel
        excel_row = [set_counter, accuracy, auc_score, mcc, sensitivity, specificity] + [''] * n_features
        for idx in feature_indices:
            excel_row[idx + 6] = connectivity_list[idx] if connectivity_list else idx
        ws.append(excel_row)
        set_counter += 1
        
        if len(feature_indices) == min_features:
            break
        
        feature_importances = []
        for i in feature_indices:
            temp_features = [f for f in feature_indices if f != i] # Removes one features at at time and check the accuracy
            if temp_features:  # Only compute metrics if there's at least one feature left
                temp_metrics = cross_val_cumulative_accuracy(X_scaled[:, temp_features], y)
                feature_importances.append((i, temp_metrics[0]))  # Using accuracy for feature importance
        
        if feature_importances:
            worst_feature, _ = max(feature_importances, key=lambda x: x[1])
            feature_indices.remove(worst_feature)  # removes the feature that has least contribution/effect on the accuray
            selected_features.append(worst_feature)
            
            if connectivity_list:
                print(f"Removed feature {worst_feature}: {connectivity_list[worst_feature]}")
            else:
                print(f"Removed feature {worst_feature}")
        else:
            break  # Stop if we can't remove any more features
    
    # Find the point where accuracy is highest
    best_metrics_index = max(range(len(metrics_list)), key=lambda i: metrics_list[i][0])  # Using accuracy for best set
    best_feature_set = feature_sets[best_metrics_index]
    best_metrics = metrics_list[best_metrics_index]
    
    print("\nMetrics for each feature set:")
    for i, (metrics, feat_set) in enumerate(zip(metrics_list, feature_sets)):
        print(f"Features: {len(feat_set)}, Accuracy: {metrics[0]:.4f}, AUC: {metrics[1]:.4f}, MCC: {metrics[2]:.4f}, Sensitivity: {metrics[3]:.4f}, Specificity: {metrics[4]:.4f}")
        if i == best_metrics_index:
            print("^ Best performing set")
            print("Best feature set:")
            for idx in feat_set:
                if connectivity_list:
                    print(f"  Index {idx}: {connectivity_list[idx]}")
                else:
                    print(f"  Index {idx}")
    
    print("\n" + "="*50)
    print(f"Best Feature Set: {best_feature_set}")
    if connectivity_list:
        print("Best Feature Set Connectivities:")
        for idx in best_feature_set:
            print(f"  Index {idx}: {connectivity_list[idx]}")
    print(f"Best Metrics - Accuracy: {best_metrics[0]:.4f}, AUC: {best_metrics[1]:.4f}, MCC: {best_metrics[2]:.4f}, Sensitivity: {best_metrics[3]:.4f}, Specificity: {best_metrics[4]:.4f}")
    print("="*50)
    
    # Save Excel file
    if output_dir and st_num is not None:
        #excel_filename = os.path.join(output_dir, f'ST_{st_num}_feature_sets_metrics.xlsx')
        excel_filename = os.path.join(output_dir, f'P_{st_num}_feature_sets_metrics.xlsx')
        wb.save(excel_filename)
        print(f"Excel file saved: {excel_filename}")
    
    return best_feature_set, best_metrics

# Directory and file details
connectivity = ['LPFC-->RPFC', 'LPFC-->LPMC', 'LPFC-->RPMC', 'LPFC-->SMA', 'RPFC-->LPFC', 'RPFC-->LPMC', 'RPFC-->RPMC', 'RPFC-->SMA', 'LPMC-->LPFC', 'LPMC-->RPFC',
                'LPMC-->RPMC', 'LPMC-->SMA', 'RPMC-->LPFC', 'RPMC-->RPFC', 'RPMC-->LPMC', 'RPMC-->SMA', 'SMA-->LPFC', 'SMA-->RPFC', 'SMA-->LPMC', 'SMA-->RPMC']
Dayx = 'D1'
Dayy = 'Retention'
subtask = ['ST1','ST2','ST3','ST4','ST5','ST6','ST7','ST8','ST9','ST10','ST11','ST12','ST13']
# Loop over st_num
for st_num in range(0, 13):
    ###
    dir1 =   # directory of connectivities data set1
    dir2 =   # directory of connectivities data set2
    output_directory = # directory to save the outputs
    # HypothesisName = subtask[st_num]+Dayx+'vs'+Dayy+'.xlsx'
    # print('Working on :',HypothesisName)
    filename1 = Dayx+'_LC_ST'+subtask[st_num]+'.csv'
    filename2 = 'Reten'+subtask[st_num]+'.csv'
    data1 = pd.read_csv(os.path.join(dir1,filename1))
    data1 = data1.drop('DaySubjTrial', axis = 1).values
    data1[:,-1] = 0  # convert the class lable
    data2 = pd.read_csv(os.path.join(dir2,filename2)).values
    data2[:,-1] = 1  # convert the class lable
    data = np.concatenate((data1, data2))  # join two datasets
    data = np.delete(data,[0,6,12,18,24],1)  # delete self-causal columns from the datasets.

    m,n = data.shape
    X = normalize(data[:,0:n-1])
    print(X.shape)
    y = data[:,n-1]
    print(y.shape)
    ###
    # y = np.array([0 if y[i] == -1 else 1 for i in range(len(y))])
   
    # When calling the function, pass the connectivity list, output directory, and st_num
    best_features, best_metrics = recursive_feature_elimination(X, y, min_features=1, connectivity_list=connectivity, output_dir=output_directory, st_num=st_num+1)
    print("\nSelected features (in order of importance):", best_features)
    print(f"Optimal number of features: {len(best_features)}")
    print(f"Best metrics - Accuracy: {best_metrics[0]:.4f}, AUC: {best_metrics[1]:.4f}, MCC: {best_metrics[2]:.4f}, Sensitivity: {best_metrics[3]:.4f}, Specificity: {best_metrics[4]:.4f}")
