## Deep neural network for ASD classification using resting-state fMRI

This notebook evaluate a deep neural network for ASD diagnosis using functional time series data from brain regions of interest. The used resting-state fMRI data from the ABIDE dataset were preprocessed by the **Preprocessed Connectome Project (PCP)** using four pipelines, involving 1100 subjects from multiple international sites.

### Configure the loading data

 The parameters necessary for loading our neuroimaging data are defined. The `pipeline` and `atlas` used for preprocessing and ROIs extraction are specified. Additionally, we list all neuroimaging sites available in the dataset and select those we want to include in the analysis by the parameter `sites`.

In [12]:
# Preprocessing pipeline
pipeline = 'cpac'  

# List of ROIs brain atlas
atlases = ['rois_cc200', 'rois_aal']

# List of all available neuroimaging sites in the dataset
all_sites = [
    'caltech', 'cmu', 'kki', 'leuven_1', 'leuven_2', 'max_mun', 'nyu', 
    'ohsu', 'olin', 'pitt', 'sbl', 'sdsu', 'stanford', 'trinity', 
    'ucla_1', 'ucla_2', 'um_1', 'um_2', 'usm', 'yale'
]

# Sites include in the analysis
sites = all_sites

# Testing site
test_site = 'yale'

### Data loading function

Definition of the `load_atlas_data(pipeline, atlas, sites)` function to retrieve subject time series and diagnostic labels from each neuroimaging site in `sites`. This function reads phenotypic information from CSV files, then loads the time series data for each subject. Also handle potential issues, such as missing files or NaN values, to ensure data integrity before analysis.

In [13]:
import os
import csv
import numpy as np


def load_atlas_data(pipeline, atlases, sites):
    """
    Loads time series and diagnostic labels from neuroimaging data files for the specified sites.
    
    Parameters:
        pipeline (str): Preprocessing pipeline used for the data.
        atlas (str): Atlas defining regions of interest.
        sites (list of str): List of site names to load data from.

    Returns:
        atlas_time_series (dict): Contains time series data for each site.
        atlas_labels (dict): Contains diagnostic labels for each site.
    """

    atlas_time_series = []  # Dictionary to store time series data for each site
    atlas_labels = {}  # Dictionary to store labels for each site
    
    
    for site in sites:
        # Define path for phenotypic data for the current site
        phenotypic_path = f"data/phenotypic/{site}/phenotypic.csv"

        try:
            with open(phenotypic_path, 'r') as file:
                reader = csv.DictReader(file)
                site_labels = []  # List to store labels for each subject at the site

                for row in reader:
                    file_id = row['file_id']  # Unique subject identifier
                    dx_group = row['dx_group']  # Diagnostic group (ASD=1, Control=0)

                    # Define path for the time series data file
                    data_file_path = os.path.join(f"data/{pipeline}/{atlases[0]}/{site}", f"{file_id}_{atlases[0]}.1D")

                    # Check if the data file exists
                    if not os.path.exists(data_file_path):
                        print(f"File Not Found Error: Data file not found at path {data_file_path}")
                        continue
                    
                    data = np.loadtxt(data_file_path)

                    # Check for NaN values and add time series to the site list
                    if np.isnan(data).any():
                        print(f"Value Error: NaN value found for subject {file_id}")
                    else:            
                        site_labels.append(1 if dx_group == '1' else 0)  # Assign 1 for ASD, 0 for control

                # Store loaded data for the current site in the dictionaries
                atlas_labels[site] = np.array(site_labels)
                print(f"Loaded labels {len(site_labels)} subjects from site {site}.")
                
        except FileNotFoundError:
            print(f"File Not Found Error: Phenotypic data not found for site {site}")

    for atlas in atlases:
        atlas_data = {}
        for site in sites:
            # Define path for phenotypic data for the current site
            phenotypic_path = f"data/phenotypic/{site}/phenotypic.csv"

            try:
                with open(phenotypic_path, 'r') as file:
                    reader = csv.DictReader(file)
                    site_time_series = []  # List to store time series for each subject at the site
                
                    for row in reader:
                        file_id = row['file_id']  # Unique subject identifier
                    
                        # Define path for the time series data file
                        data_file_path = os.path.join(f"data/{pipeline}/{atlas}/{site}", f"{file_id}_{atlas}.1D")

                        # Check if the data file exists
                        if not os.path.exists(data_file_path):
                            print(f"File Not Found Error: Data file not found at path {data_file_path}")
                            continue
                        
                        data = np.loadtxt(data_file_path)

                        # Check for NaN values and add time series to the site list
                        if np.isnan(data).any():
                            print(f"Value Error: NaN value found for subject {file_id}")
                        else:
                            site_time_series.append(data)
                    
                    # Store loaded data for the current site in the dictionaries
                    atlas_data[site] = site_time_series
                    print(f"Loaded {len(site_time_series)} subjects from site {site}.")
                    
            except FileNotFoundError:
                print(f"File Not Found Error: Phenotypic data not found for site {site}")
    
        atlas_time_series.append(atlas_data)
    
    return atlas_time_series, atlas_labels

Load data to be used in the analysis based on specified parameters.

In [14]:
atlas_time_series, atlas_labels = load_atlas_data(pipeline, atlases, sites)

Loaded labels 38 subjects from site caltech.
Loaded labels 27 subjects from site cmu.
Loaded labels 55 subjects from site kki.
Loaded labels 29 subjects from site leuven_1.
Loaded labels 35 subjects from site leuven_2.
Loaded labels 57 subjects from site max_mun.
Loaded labels 184 subjects from site nyu.
Loaded labels 28 subjects from site ohsu.
Loaded labels 36 subjects from site olin.
Loaded labels 57 subjects from site pitt.
Loaded labels 30 subjects from site sbl.
Loaded labels 36 subjects from site sdsu.
Loaded labels 40 subjects from site stanford.
Loaded labels 49 subjects from site trinity.
Loaded labels 73 subjects from site ucla_1.
Loaded labels 26 subjects from site ucla_2.
Loaded labels 108 subjects from site um_1.
Loaded labels 35 subjects from site um_2.
Loaded labels 101 subjects from site usm.
Loaded labels 56 subjects from site yale.
Loaded 38 subjects from site caltech.
Loaded 27 subjects from site cmu.
Loaded 55 subjects from site kki.
Loaded 29 subjects from site le

### Tangent space embedding

This method allows the translation of connectivity matrices from fMRI data into a form that is compatible with Euclidean machine learning techniques while preserving the important geometric properties of the data. This technique is particularly useful when analyzing covariance or correlation matrices in tasks involving brain connectivity and classification of neurological conditions.

The workflow used for this notebooks involves two main steps:

**Estimate the reference tangent space**: Calculate the tangent space projection based on the mean covariance matrix of a training population. This establishes the "reference space" against which individual test subjects can later be projected.

**Project subjects onto the reference space**: Using the precomputed reference tangent space from the population, can be project the covariance matrix of a new subjects onto this space. This will yield a tangent space connectivity matrix for the subjects that aligns with those of the population.

#### Create the population data

To maintain a separate testing set, we exclude the `test_site` site data from the main population data used for Estimate the reference tangent space.

In [18]:
def create_population(time_series_data):
    # Initialize an empty list for the population data 
    population_data = []

    # Loop through the time series data
    for item in time_series_data:
        # Extend each item
        population_data.extend(item)

    print(f"Total subjects in population data: {len(population_data)}")
    return population_data

#### Estimate the reference tangent space

Calculate the tangent space based on the mean covariance matrix of a training population dataset. This creates a "reference space" that reflects the average connectivity patterns across the population.

#### Function to estimating tangent space functional connectivity

The `estimate_tangent_space(data)` function calculate the tangent space based on the geometric mean covariance matrix of a training population dataset. This creates a "reference space" that reflects the average connectivity patterns across the population.

The tangent space representation of functional connectivity is a powerful tool for analyzing brain connectivity. It allows the comparison of individual functional connectivity matrices in a standardized space, computed relative to a group average matrix.

In [19]:
from nilearn.connectome import ConnectivityMeasure

def estimate_tangent_space(data):
    """
    Estimate the tangent space functional connectivity.

    Parameters:
    -----------
    data : list or ndarray
        List or array of time series data for the training population, where each entry corresponds 
        to a subject's time series (time points x regions).

    Returns:
    --------
    ConnectivityMeasure
        Fitted ConnectivityMeasure object configured for tangent space transformation.
    """
    # Instantiate ConnectivityMeasure for tangent space, vectorizing and discarding the diagonal
    connectivity_measure = ConnectivityMeasure(kind='tangent', vectorize=True, discard_diagonal=True)

    # Fit the measure on the population data to establish a reference tangent space
    connectivity_measure.fit(data)

    return connectivity_measure


### Function for create deep neural network models 

The `build_model(input_shape)` function create DNN models with the following architecture:

Input Layer: Takes in the number of features from the input data.

Dense Layer 1: 1024 neurons, ReLU activation, with L2 regularization to reduce overfitting.

Dense Layer 2: 256 neurons, ReLU activation, L2 regularization.

Output Layer: A single neuron with sigmoid activation for binary classification.

The model is compiled with the Adam optimizer and binary cross-entropy loss, as we aim to classify subjects into two classes. We also include accuracy as a performance metric to track model performance during training and evaluation.

In [20]:
from keras import layers, models, regularizers

# Define the deep neural network model architecture
def build_model(input_shape):
    """
    Builds and compiles a deep neural network model for binary classification.

    Parameters:
    - input_shape: int, the shape of the input layer, matching the number of features in the dataset

    Returns:
    - model: compiled Keras Sequential model ready for training
    """
    
    model = models.Sequential()

    # Input layer
    model.add(layers.InputLayer(input_shape=input_shape))
    model.add(layers.Dropout(0.2))

    # Hidden layers
    model.add(layers.Dense(128, activation='relu', kernel_regularizer=regularizers.l2(0.001)))
    model.add(layers.Dropout(0.4))

    model.add(layers.Dense(64, activation='relu', kernel_regularizer=regularizers.l2(0.001)))
    model.add(layers.Dropout(0.4))

    # Output layer for binary classification (ASD vs. Healthy)
    model.add(layers.Dense(1, activation='sigmoid'))

    # Compile the model with Adam optimizer and binary cross-entropy loss
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

    return model

### Training Callbacks

To optimize training, we set up three callbacks:

**EarlyStopping:** Stops training if validation loss doesn't improve for 10 epochs, preventing overfitting and restoring the best weights.

**ReduceLROnPlateau:** Reduces the learning rate by 50% when validation loss plateaus for 5 epochs, ensuring gradual and effective model convergence.

**ModelCheckpoint:** Saves the model with the best validation loss to 'best_model.keras', allowing easy access to the optimal version of the model.

In [21]:
from keras import callbacks

# Early stopping to prevent overfitting by stopping training when validation loss stops improving
early_stopping = callbacks.EarlyStopping(
    monitor='val_loss',             # Monitor validation loss for early stopping
    patience=10,                    # Stop training if val_loss does not improve for 10 epochs
    restore_best_weights=True       # Restore the model weights from the epoch with the lowest val_loss
)

# Reduce learning rate when the validation loss plateaus
reduce_lr = callbacks.ReduceLROnPlateau(
    monitor='val_loss',             # Monitor validation loss for learning rate reduction
    factor=0.5,                     # Reduce learning rate by a factor of 0.5
    patience=5,                     # Trigger after 5 epochs without improvement in val_loss
    min_lr=1e-5                     # Set a floor on the learning rate to avoid overly small values
)

# Save the best model based on validation loss
checkpoint = callbacks.ModelCheckpoint(
    'best_model.keras',             # Filename for the best model
    monitor='val_loss',             # Monitor validation loss for checkpoint saving
    save_best_only=True             # Only save the model when it achieves a new best val_loss
)

# Callbacks list passed to the model
callbacks_list = [early_stopping, reduce_lr, checkpoint]

### Function to adjust class balance

The `adjust_class_balance(indices, labels)` function ensures equal representation of all classes by undersampling the majority class(es). This is particularly important in supervised learning, where imbalanced classes can lead to biased models. The function return shuffled list of indices representing a class-balanced subset of the dataset.

In [22]:
import numpy as np

def adjust_class_balance(indices, labels):
    """
    Adjusts the balance of classes by undersampling the majority class.

    Parameters:
    ----------
    indices : list or ndarray
        Indices of the dataset.
    labels : list or ndarray
        Class labels corresponding to the indices.

    Returns:
    -------
    balanced_indices : ndarray
        Indices of the balanced dataset.
    """
    # Class labels to consider
    CLASS_LABELS = [0, 1]

    # Separate indices by class
    class_indices = {label: [idx for idx in indices if labels[idx] == label] for label in CLASS_LABELS}

    # Determine the minimum class count
    min_class_count = min(len(indices) for indices in class_indices.values())

    # Adjust class balance by undersampling the majority class
    balanced_indices = []
    for label, class_list in class_indices.items():
        if len(class_list) > min_class_count:
            sampled_indices = np.random.choice(class_list, size=min_class_count, replace=False)
            balanced_indices.extend(sampled_indices)
        else:
            balanced_indices.extend(class_list)

    # Shuffle the indices for randomization
    np.random.shuffle(balanced_indices)
    return np.array(balanced_indices)

### Stratified cross-validation setup for model training and validation

Set up stratified 10-fold cross-validation for each site (excluding `test_site`) to evaluate model performance across multiple splits. Here’s an overview of the process:

**Stratified k-folds**: StratifiedKFold let to maintain the balance of classes (ASD vs. NC) across each fold, reducing potential bias.

**Fold processing**: For each site, 10 training and validation folds are generated, and indices are stored in the `train_indices` and `val_indices` dictionaries. To ensure class balance after combining all group folds for training data the majority class in each site is undersampling.

**Class balance checks**: For each fold, the balance of ASD and NC samples is shown to confirm each split maintains similar distributions.

In [None]:
from sklearn.model_selection import StratifiedKFold

# Number of cross-validation folds
n_folds = 10

# Dictionaries for save the training and validation indices
train_indices = {}
val_indices = {}

# Perform stratified k-fold cross-validation for each site, excluding 'test_site' for testing
for site in sites:
    if site == test_site:
        continue

    features = atlas_time_series[site]
    labels = atlas_labels[site]

    # Initialize StratifiedKFold with shuffle to ensure data randomization
    skf = StratifiedKFold(n_splits=n_folds, shuffle=True)
    
    site_train_indices = []
    site_val_indices = []

    # Loop through each fold in the stratified split
    for fold, (train_idx, val_idx) in enumerate(skf.split(features, labels)):
        print(f"Processing fold #{fold} for site `{site}`")
      
        train_idx = adjust_class_balance(train_idx, labels)
        val_idx = adjust_class_balance(val_idx, labels)

        # Append training and validation indices for each fold
        site_train_indices.append(np.array(train_idx))
        site_val_indices.append(np.array(val_idx))
        
        # Print class distribution for training and validation sets for each fold
        print(f"Balance of classes in training -> ASD: {np.count_nonzero(labels[site_train_indices[fold]] == 1)} and TC: {np.count_nonzero(labels[site_train_indices[fold]] == 0)}")
        
        print(f"Balance of classes in validation -> ASD: {np.count_nonzero(labels[site_val_indices[fold]] == 1)} and TC: {np.count_nonzero(labels[site_val_indices[fold]] == 0)}")

    # Store indices for each fold in the dictionaries
    train_indices[site] = site_train_indices
    val_indices[site] = site_val_indices


TypeError: list indices must be integers or slices, not str

### Project test set data into the tangent space and evaluating class balance

The test data is transforming by the precomputed reference tangent space and the class balance is evaluated.

In [14]:

    
site_data = []
for i in range(len(atlases)):
    # Retrieve the connectivity features for the test site
    site_feature_vectors = connectivities[i].transform(atlas_time_series[i][test_site])
    feature_vectors = []
    for value in site_feature_vectors:
        feature_vectors.append(np.array(value))
    site_feature_vectors = np.array(feature_vectors)
    site_data.append(site_feature_vectors) 

combined_list = []
for arrays in zip(site_data[0], site_data[1]):
    # Ensure each array has at least one dimension
    arrays = [np.array(arr) for arr in arrays]

    # Concatenate along the default axis (axis=0)
    combined_array = np.concatenate(arrays)
    combined_list.append(combined_array)   

# Update the dictionary with site-specific feature vectors
atlas_feature_vectors[test_site] = np.array(combined_list)
        
# Print the number of subjects processed for the current site
print(f"Feature vectors for site {test_site}: with {len(combined_list)} subjects and {len(combined_list[0])} features")

Feature vectors for site yale: with 56 subjects and 26570 features


In [15]:
X_test, y_test = atlas_feature_vectors.get(test_site), atlas_labels.get(test_site)

print(f"Shape of X_test: {X_test.shape}")
print(f"Shape of y_test: {y_test.shape}")

print(f"Balance of classes in test - ASD: {np.count_nonzero(y_test == 1)}, TC: {np.count_nonzero(y_test == 0)}")

Shape of X_test: (56, 26570)
Shape of y_test: (56,)
Balance of classes in test - ASD: 28, TC: 28


### Data Normalization

To standardize the feature values, we use StandardScaler. Each split is normalized independently to prevent data leakage, using the training data mean and variance for scaling the validation and test sets.

In [20]:
from sklearn.feature_selection import RFE
from sklearn.linear_model import LogisticRegression  # or RandomForestClassifier
from sklearn.metrics import confusion_matrix, roc_auc_score, recall_score, precision_score, mean_squared_error
from sklearn.preprocessing import StandardScaler

total_accuracy_validation = 0
total_sensitivity_validation = 0
total_specificity_validation = 0
total_auc_validation = 0

total_accuracy_test = 0
total_sensitivity_test = 0
total_specificity_test = 0
total_auc_test = 0

# Cross-validation across all splits
for split in range(n_folds):
    print(f"Split # {split + 1}")
    
    X_train, X_validation, y_train, y_validation = [], [], [], []

    # # Aggregate training and validation data from all sites
    for site in sites:
        if site == test_site:
            continue

        # Training and validation data for each site split
        X_train.extend(atlas_feature_vectors[site][train_indices[site][split]])
        X_validation.extend(atlas_feature_vectors[site][val_indices[site][split]])

        y_train.extend(atlas_labels[site][train_indices[site][split]])
        y_validation.extend(atlas_labels[site][val_indices[site][split]])

    # Convert lists to numpy arrays
    X_train, X_validation = map(np.array, [X_train, X_validation])
    y_train, y_validation = map(np.array, [y_train, y_validation])
    
    print(f"Balance of classes in training - ASD: {np.count_nonzero(y_train == 1)}, TC: {np.count_nonzero(y_train == 0)}")
    print(f"Balance of classes in validation - ASD: {np.count_nonzero(y_validation == 1)}, TC: {np.count_nonzero(y_validation == 0)}")

    print(f"Shape of X_train: {X_train.shape}")
    print(f"Shape of y_train: {y_train.shape}")

    print(f"Shape of X_validation: {X_validation.shape}")
    print(f"Shape of y_validation: {y_validation.shape}")

    # Build and train the deep neural network model
    dnn = build_model(X_train.shape[1])

    history = dnn.fit(
        X_train, y_train,
        validation_data=(X_validation, y_validation),
        batch_size=64,
        epochs=200,
        callbacks=callbacks_list
    )

     # Evaluate again in the validations set
    validation_predictions = (dnn.predict(X_validation) > 0.5).astype(int)
    cm_validation = confusion_matrix(y_validation, validation_predictions)
    tn_validation, fp_validation, fn_validation, tp_validation = cm_validation.ravel()

    # Calculate evaluation metrics
    accuracy_validation = (tp_validation + tn_validation) / (tp_validation + tn_validation + fp_validation + fn_validation)
    sensitivity_validation = tp_validation / (tp_validation + fn_validation) if tp_validation + fn_validation > 0 else 0
    specificity_validation = tn_validation / (tn_validation + fp_validation) if tn_validation + fp_validation > 0 else 0
    auc_validation = roc_auc_score(y_validation, dnn.predict(X_validation))

    # Print validation metrics
    print(f"Validation Accuracy: {accuracy_validation*100:.2f}")
    print(f"Validation Sensitivity (Recall): {sensitivity_validation*100:.2f}")
    print(f"Validation Specificity: {specificity_validation*100:.2f}")
    print(f"Validation AUC-ROC Score: {auc_validation*100:.2f}")
    print(f"Validation Confusion Matrix:\n{cm_validation}")
    
    total_accuracy_validation += accuracy_validation
    total_sensitivity_validation += sensitivity_validation
    total_specificity_validation += specificity_validation
    total_auc_validation += auc_validation

    # Print mean validation metrics
    print(f"Validation Mean Accuracy: {(total_accuracy_validation / (split + 1))*100:.2f}")
    print(f"Validation Mean Sensitivity (Recall): {(total_sensitivity_validation / (split + 1))*100:.2f}")
    print(f"Validation Mean Specificity: {(total_specificity_validation / (split + 1))*100:.2f}")
    print(f"Validation Mean AUC-ROC Score: {(total_auc_validation / (split + 1))*100:.2f}")

    # Evaluate in the test set
    test_predictions = (dnn.predict(X_test) > 0.5).astype(int)
    cm_test = confusion_matrix(y_test, test_predictions)
    tn_test, fp_test, fn_test, tp_test = cm_test.ravel()

    # Calculate evaluation metrics
    accuracy_test = (tp_test + tn_test) / (tp_test + tn_test + fp_test + fn_test)
    sensitivity_test = tp_test / (tp_test + fn_test) if tp_test + fn_test > 0 else 0
    specificity_test = tn_test / (tn_test + fp_test) if tn_test + fp_test > 0 else 0
    auc_test = roc_auc_score(y_test, dnn.predict(X_test))

    # Print test metrics
    print(f"Test Accuracy: {accuracy_test*100:.2f}")
    print(f"Test Sensitivity (Recall): {sensitivity_test*100:.2f}")
    print(f"Test Specificity: {specificity_test*100:.2f}")
    print(f"Test AUC-ROC Score: {auc_test*100:.2f}")
    print(f"Test Confusion Matrix:\n{cm_test}")
    
    total_accuracy_test += accuracy_test
    total_sensitivity_test += sensitivity_test
    total_specificity_test += specificity_test
    total_auc_test += auc_test

    # Print mean test metrics
    print(f"Test Mean Accuracy: {(total_accuracy_test / (split + 1))*100:.2f}")
    print(f"Test Mean Sensitivity (Recall): {(total_sensitivity_test / (split + 1))*100:.2f}")
    print(f"Test Mean Specificity: {(total_specificity_test / (split + 1))*100:.2f}")
    print(f"Test Mean AUC-ROC Score: {(total_auc_test / (split + 1))*100:.2f}")


Split # 1
Balance of classes in training - ASD: 417, TC: 417
Balance of classes in validation - ASD: 47, TC: 47
Shape of X_train: (834, 26570)
Shape of y_train: (834,)
Shape of X_validation: (94, 26570)
Shape of y_validation: (94,)
Epoch 1/200
Epoch 2/200
Epoch 3/200
Epoch 4/200
Epoch 5/200
Epoch 6/200
Epoch 7/200
Epoch 8/200
Epoch 9/200
Epoch 10/200
Epoch 11/200
Epoch 12/200
Epoch 13/200
Epoch 14/200
Epoch 15/200
Epoch 16/200
Epoch 17/200
Epoch 18/200
Epoch 19/200
Epoch 20/200
Epoch 21/200
Epoch 22/200
Epoch 23/200
Epoch 24/200
Epoch 25/200
Validation Accuracy: 67.02
Validation Sensitivity (Recall): 68.09
Validation Specificity: 65.96
Validation AUC-ROC Score: 72.43
Validation Confusion Matrix:
[[31 16]
 [15 32]]
Validation Mean Accuracy: 67.02
Validation Mean Sensitivity (Recall): 68.09
Validation Mean Specificity: 65.96
Validation Mean AUC-ROC Score: 72.43
Test Accuracy: 73.21
Test Sensitivity (Recall): 78.57
Test Specificity: 67.86
Test AUC-ROC Score: 76.40
Test Confusion Matrix:
[