### this script is trying to decode TMS conditions (sham, cTBS) based on functional connectivity patterns

* functional connectivity patterns are derived from AAL atlas
* FC is calculated between 170 AAL ROIs and 4 TMS-used masks
* classification models to try: logistic regression, SVM, neural nets


In [15]:
import numpy as np
from scipy.io import loadmat
import nibabel as nib
import pandas as pd
import os
from pathlib import Path
import matplotlib.pyplot as plt

In [16]:
from sklearn import datasets, svm
from sklearn.feature_selection import SelectPercentile, f_classif
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.inspection import DecisionBoundaryDisplay

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import GroupKFold

In [17]:
# Load the Excel file into a DataFrame

SubInfo = pd.read_excel('/Users/liuq13/NODEAP_scripts/ProcessedData/SubConds.xlsx')
Subs = SubInfo[SubInfo['Include'] == 1]['SubID']


In [18]:
SubInfo

Unnamed: 0,SubID,Gender,Birth.Year,StimLoc,StimOrder,Odors,StartOdor,Age,Sex,Include,btS1,btS2,btS3,btS1S2,btS2S3
0,NODEAP_06,Male,1992,Posterior,321,"Chocolate, Garlic",Sweet,30.83,M,1,1,1,1,13,15
1,NODEAP_07,Female,1991,Anterior,312,"Pineapple, Pizza",Savory,32.01,F,1,1,1,1,20,13
2,NODEAP_08,Male,1996,Posterior,132,"Chocolate, Garlic",Sweet,26.19,M,1,1,1,1,12,14
3,NODEAP_09,Female,2000,Posterior,321,"Gingerbread, Garlic",Savory,22.32,F,1,1,1,1,14,27
4,NODEAP_10,Male,1999,Posterior,321,"Chocolate, Garlic",Savory,23.62,M,1,1,1,1,13,12
5,NODEAP_12,Female,2002,Anterior,123,"Pineapple, Potato",Savory,20.89,F,1,1,1,1,19,27
6,NODEAP_13,Female,1990,Anterior,231,"Yellow, Garlic",Sweet,33.0,F,1,1,1,1,11,14
7,NODEAP_15,Female,1996,Posterior,231,"Gingerbread, Pizza",Sweet,26.62,F,1,1,1,1,13,21
8,NODEAP_16,Female,2000,Posterior,231,"Chocolate, Potato",Savory,22.37,F,1,1,1,1,13,34
9,NODEAP_17,Female,1996,Posterior,213,"Gingerbread, Garlic",Savory,26.39,F,1,1,1,1,13,27


In [19]:
SubInfo.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 48 entries, 0 to 47
Data columns (total 15 columns):
 #   Column      Non-Null Count  Dtype  
---  ------      --------------  -----  
 0   SubID       48 non-null     object 
 1   Gender      48 non-null     object 
 2   Birth.Year  48 non-null     int64  
 3   StimLoc     48 non-null     object 
 4   StimOrder   48 non-null     int64  
 5   Odors       48 non-null     object 
 6   StartOdor   48 non-null     object 
 7   Age         48 non-null     float64
 8   Sex         48 non-null     object 
 9   Include     48 non-null     int64  
 10  btS1        48 non-null     int64  
 11  btS2        48 non-null     int64  
 12  btS3        48 non-null     int64  
 13  btS1S2      48 non-null     int64  
 14  btS2S3      48 non-null     int64  
dtypes: float64(1), int64(8), object(6)
memory usage: 5.8+ KB


In [20]:
TMS_types = []

# Loop through each row and assign TMS types based on 'StimOrder'
for i, row in SubInfo.iterrows():
    tmp_order = row['StimOrder']
    
    # Assign TMS types based on the 'StimOrder' value
    if tmp_order == 123:
        TMS_types.append(['N', 'C', 'S', 'S', 'C', 'S', 'S'])
    elif tmp_order == 132:
        TMS_types.append(['N', 'C', 'S', 'S', 'S', 'S', 'C'])
    elif tmp_order == 213:
        TMS_types.append(['N', 'S', 'C', 'C', 'S', 'S', 'S'])
    elif tmp_order == 231:
        TMS_types.append(['N', 'S', 'C', 'S', 'S', 'C', 'S'])
    elif tmp_order == 312:
        TMS_types.append(['N', 'S', 'S', 'C', 'S', 'S', 'C'])
    elif tmp_order == 321:
        TMS_types.append(['N', 'S', 'S', 'S', 'C', 'C', 'S'])

In [21]:
TMS_types

[['N', 'S', 'S', 'S', 'C', 'C', 'S'],
 ['N', 'S', 'S', 'C', 'S', 'S', 'C'],
 ['N', 'C', 'S', 'S', 'S', 'S', 'C'],
 ['N', 'S', 'S', 'S', 'C', 'C', 'S'],
 ['N', 'S', 'S', 'S', 'C', 'C', 'S'],
 ['N', 'C', 'S', 'S', 'C', 'S', 'S'],
 ['N', 'S', 'C', 'S', 'S', 'C', 'S'],
 ['N', 'S', 'C', 'S', 'S', 'C', 'S'],
 ['N', 'S', 'C', 'S', 'S', 'C', 'S'],
 ['N', 'S', 'C', 'C', 'S', 'S', 'S'],
 ['N', 'S', 'S', 'S', 'C', 'C', 'S'],
 ['N', 'S', 'C', 'C', 'S', 'S', 'S'],
 ['N', 'S', 'C', 'C', 'S', 'S', 'S'],
 ['N', 'C', 'S', 'S', 'S', 'S', 'C'],
 ['N', 'C', 'S', 'S', 'C', 'S', 'S'],
 ['N', 'C', 'S', 'S', 'C', 'S', 'S'],
 ['N', 'S', 'S', 'S', 'C', 'C', 'S'],
 ['N', 'S', 'C', 'C', 'S', 'S', 'S'],
 ['N', 'C', 'S', 'S', 'C', 'S', 'S'],
 ['N', 'C', 'S', 'S', 'C', 'S', 'S'],
 ['N', 'S', 'S', 'C', 'S', 'S', 'C'],
 ['N', 'S', 'C', 'C', 'S', 'S', 'S'],
 ['N', 'C', 'S', 'S', 'C', 'S', 'S'],
 ['N', 'S', 'S', 'S', 'C', 'C', 'S'],
 ['N', 'S', 'C', 'S', 'S', 'C', 'S'],
 ['N', 'S', 'S', 'C', 'S', 'S', 'C'],
 ['N', 'S', 

In [22]:
base_nifti_folder = '/Volumes/X9Pro/NODEAP/FuncConn_AAL_ROIs_PAID'
sessions = ['D0', 'S1D1', 'S1D2', 'S2D1', 'S2D2', 'S3D1', 'S3D2']
nifti_paths_by_tms_type = {'N': [], 'C': [], 'S': []}

In [25]:
all_corr_data = []
all_tms_type = []
all_subject_id = []

for i, subject_id in enumerate(Subs):
    tms_types = TMS_types[i]  
    for j, session in enumerate(sessions):
        mat_name = 'conn_matrix_' + session + '.mat'
        mat_file = os.path.join(base_nifti_folder, subject_id, mat_name)
        if os.path.exists(mat_file):
            matdat = loadmat(mat_file)
            dat_corr = matdat['correlation_matrix']  
            dat_corr_use = dat_corr[:,1] # use corr data from mask 1
            tms_type = tms_types[j]  
            all_tms_type.append(tms_type)
            all_corr_data.append(dat_corr_use)
            all_subject_id.append(i)
        else:
            print(f"matfile not found: {mat_file}")

matfile not found: /Volumes/X9Pro/NODEAP/FuncConn_AAL_ROIs_PAID/NODEAP_30/conn_matrix_S3D2.mat
matfile not found: /Volumes/X9Pro/NODEAP/FuncConn_AAL_ROIs_PAID/NODEAP_44/conn_matrix_S1D1.mat
matfile not found: /Volumes/X9Pro/NODEAP/FuncConn_AAL_ROIs_PAID/NODEAP_83/conn_matrix_S3D1.mat
matfile not found: /Volumes/X9Pro/NODEAP/FuncConn_AAL_ROIs_PAID/NODEAP_87/conn_matrix_D0.mat
matfile not found: /Volumes/X9Pro/NODEAP/FuncConn_AAL_ROIs_PAID/NODEAP_88/conn_matrix_D0.mat


In [26]:
# Concatenate all vectors as columns in a single matrix
concatenated_matrix = np.vstack(all_corr_data)
concatenated_matrix = concatenated_matrix[:, ~np.isnan(concatenated_matrix).any(axis=0)]
print(concatenated_matrix.shape)

(331, 163)


In [27]:
char_to_num = {'C': 1, 'N': 2, 'S': 0}
all_tms_type = [char_to_num[char] for char in all_tms_type]
all_tms_type = np.array(all_tms_type)
all_tms_type.shape

(331,)

In [28]:
all_subject_id = np.array(all_subject_id)
all_subject_id.shape

(331,)

In [29]:
mask = (all_tms_type == 0) | (all_tms_type == 1)  # only using C and S data
concatenated_matrix = concatenated_matrix[mask]
all_tms_type = all_tms_type[mask]
all_subject_id = all_subject_id[mask]

In [30]:
# training-test split
X_train, X_test, y_train, y_test = train_test_split(concatenated_matrix, all_tms_type, test_size=0.2, random_state=4)
print("X_train.shape", X_train.shape, "y_train.shape", y_train.shape)
print("X_test.shape", X_test.shape, "y_test.shape", y_test.shape)

X_train.shape (228, 163) y_train.shape (228,)
X_test.shape (57, 163) y_test.shape (57,)


In [31]:
# Initialize logistic regression with L2 regularization (default)
model = LogisticRegression(penalty='l2', 
                           C=0.001,  # smaller means stronger regularization
                           solver='lbfgs', 
                           verbose=0,
                           max_iter=1000)

In [32]:
# Initialize your GroupKFold
group_kfold = GroupKFold(n_splits=len(set(all_subject_id)))  # Set the number of splits to the number of unique persons

In [33]:
for train_index, test_index in group_kfold.split(concatenated_matrix, all_tms_type, groups=all_subject_id):

    X_train, X_test = concatenated_matrix[train_index], concatenated_matrix[test_index]
    y_train, y_test = all_tms_type[train_index], all_tms_type[test_index]

    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)
    print(f"Test accuracy for one subject: {accuracy}")
    print(classification_report(y_test, y_pred))

Test accuracy for one subject: 0.6666666666666666
              precision    recall  f1-score   support

           0       0.67      1.00      0.80         4
           1       0.00      0.00      0.00         2

    accuracy                           0.67         6
   macro avg       0.33      0.50      0.40         6
weighted avg       0.44      0.67      0.53         6

Test accuracy for one subject: 0.6666666666666666
              precision    recall  f1-score   support

           0       0.67      1.00      0.80         4
           1       0.00      0.00      0.00         2

    accuracy                           0.67         6
   macro avg       0.33      0.50      0.40         6
weighted avg       0.44      0.67      0.53         6

Test accuracy for one subject: 0.6666666666666666
              precision    recall  f1-score   support

           0       0.67      1.00      0.80         4
           1       0.00      0.00      0.00         2

    accuracy                    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize

In [34]:
model = SVC()
for train_index, test_index in group_kfold.split(concatenated_matrix, all_tms_type, groups=all_subject_id):
    X_train, X_test = concatenated_matrix[train_index], concatenated_matrix[test_index]
    y_train, y_test = all_tms_type[train_index], all_tms_type[test_index]
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)
    print(f"Test accuracy for one subject: {accuracy}")

Test accuracy for one subject: 0.6666666666666666
Test accuracy for one subject: 0.6666666666666666
Test accuracy for one subject: 0.6666666666666666
Test accuracy for one subject: 0.6666666666666666
Test accuracy for one subject: 0.6666666666666666
Test accuracy for one subject: 0.6666666666666666
Test accuracy for one subject: 0.6666666666666666
Test accuracy for one subject: 0.6666666666666666
Test accuracy for one subject: 0.6666666666666666
Test accuracy for one subject: 0.6666666666666666
Test accuracy for one subject: 0.6666666666666666
Test accuracy for one subject: 0.6666666666666666
Test accuracy for one subject: 0.6666666666666666
Test accuracy for one subject: 0.6666666666666666
Test accuracy for one subject: 0.6666666666666666
Test accuracy for one subject: 0.6666666666666666
Test accuracy for one subject: 0.6666666666666666
Test accuracy for one subject: 0.6666666666666666
Test accuracy for one subject: 0.6666666666666666
Test accuracy for one subject: 0.6666666666666666
