## Deep neural network for ASD classification using resting-state fMRI

This notebook evaluate a deep neural network for ASD diagnosis using functionañ time series data from brain regions of interest. The used resting-state fMRI data from the ABIDE dataset were preprocessed by the **Preprocessed Connectome Project (PCP)** using four pipelines, involving 1100 subjects from multiple international sites.

### Configure the loading data

 The variables necessary for loading the neuroimaging data are defined. The `pipeline` and `atlas` used for preprocessing and ROIs extraction are specified. Additionally, list all neuroimaging sites available in the dataset and those that are to be included in the analysis are selected using the `sites` and `test_site` variables.

In [1]:
pipeline = 'fsl'  
rois = 'rois_ho'

# List of all available neuroimaging sites in the dataset
all_sites = [
    'caltech', 'cmu', 'kki', 'leuven_1', 'leuven_2', 'max_mun', 'nyu', 
    'ohsu', 'olin', 'pitt', 'sbl', 'sdsu', 'stanford', 'trinity', 
    'ucla_1', 'ucla_2', 'um_1', 'um_2', 'usm', 'yale'
]

# Sites include in the analysis
sites = all_sites

# Testing site
test_site = 'yale'

### ROIs data loading function

Definition of the `load_rois_data(pipeline, rois, sites)` function to retrieve subject time series and diagnostic labels from each neuroimaging site in `sites`. This function reads phenotypic information from CSV files, then loads the time series data for each subject. Also handle potential issues, such as missing files or NaN values, to ensure data integrity before analysis.

In [2]:
import os
import csv
import numpy as np


def load_rois_data(pipeline, rois, sites):
    """
    Loads time series and diagnostic labels from neuroimaging data files for the specified sites.
    
    Parameters:
        pipeline (str): Preprocessing pipeline used for the data.
        rois (str): Atlas defining regions of interest.
        sites (list of str): List of site names to load data from.

    Returns:
        rois_time_series (dict): Contains time series data for each site.
        rois_labels (dict): Contains diagnostic labels for each site.
    """

    rois_time_series = {}  # Dictionary to store time series data for each site
    rois_labels = {}  # Dictionary to store labels for each site

    for site in sites:
        # Define path for phenotypic data for the current site
        phenotypic_path = f"data/phenotypic_fsl/{site}/phenotypic.csv"

        try:
            with open(phenotypic_path, 'r') as file:
                reader = csv.DictReader(file)
                site_time_series = []  # List to store time series for each subject at the site
                site_labels = []  # List to store labels for each subject at the site

                for row in reader:
                    file_id = row['file_id']  # Unique subject identifier
                    dx_group = row['dx_group']  # Diagnostic group (ASD=1, Control=0)

                    # Define path for the time series data file
                    data_file_path = os.path.join(f"data/{pipeline}/{rois}/{site}/{file_id}_{rois}.1D")

                    # Check if the data file exists
                    if not os.path.exists(data_file_path):
                        print(f"File Not Found Error: Data file not found at path {data_file_path}")
                        continue
                    
                    data = np.loadtxt(data_file_path)

                    # Check for NaN values and add time series to the site list
                    if np.isnan(data).any():
                        print(f"Value Error: NaN value found for subject {file_id}")
                    else:
                        site_time_series.append(data)
                        site_labels.append(1 if dx_group == '1' else 0)  # Assign 1 for ASD, 0 for control

                # Store loaded data for the current site in the dictionaries
                rois_time_series[site] = site_time_series
                rois_labels[site] = np.array(site_labels)
                print(f"Loaded {len(site_time_series)} subjects from site {site}.")
                
        except FileNotFoundError:
            print(f"File Not Found Error: Phenotypic data not found for site {site}")

    return rois_time_series, rois_labels

Load data to be used in the analysis based on specified parameters.

In [3]:
rois_time_series, rois_labels = load_rois_data(pipeline, rois, sites)

Loaded 4 subjects from site caltech.
Loaded 3 subjects from site cmu.
Loaded 6 subjects from site kki.
Loaded 3 subjects from site leuven_1.
Loaded 4 subjects from site leuven_2.
Loaded 6 subjects from site max_mun.
Loaded 19 subjects from site nyu.
Loaded 3 subjects from site ohsu.
Loaded 4 subjects from site olin.
Loaded 6 subjects from site pitt.
Loaded 3 subjects from site sbl.
Loaded 4 subjects from site sdsu.
Loaded 4 subjects from site stanford.
Loaded 5 subjects from site trinity.
Loaded 8 subjects from site ucla_1.
Loaded 3 subjects from site ucla_2.
Loaded 10 subjects from site um_1.
Loaded 4 subjects from site um_2.
Loaded 10 subjects from site usm.
Loaded 14 subjects from site yale.


### Tangent space embedding

This method allows the translation of connectivity matrices from fMRI data into a form that is compatible with Euclidean machine learning techniques while preserving the important geometric properties of the data. Is particularly useful when analyzing covariance or correlation matrices in tasks involving brain connectivity and classification of neurological conditions.

The workflow involves two main steps:

**Estimate the reference tangent space**: Calculate the tangent space projection based on the mean covariance matrix of a training population. This establishes the "reference space" against which individual test subjects can later be projected.

**Project subjects onto the reference space**: Using the precomputed reference tangent space from the population, can be project the covariance matrix of a new subjects onto this space. This will yield a tangent space connectivity matrix for the subjects that aligns with those of the population.

#### Create the training population

To maintain a separate testing set, we exclude the `test_site` site data from the main population data used for Estimate the reference tangent space.

In [4]:
def create_population(time_series_data):
    # Initialize an empty list for the population data 
    population_data = []

    # Loop through the time series data
    for item in time_series_data:
        # Extend each item
        population_data.extend(item)

    print(f"Total subjects in population data: {len(population_data)}")
    return population_data

#### Function to estimating tangent space functional connectivity

The `estimate_tangent_space(data)` function calculate the tangent space based on the geometric mean covariance matrix of a training population dataset. This creates a "reference space" that reflects the average connectivity patterns across the population.

The tangent space representation of functional connectivity is a powerful tool for analyzing brain connectivity. It allows the comparison of individual functional connectivity matrices in a standardized space, computed relative to a group average matrix.

In [5]:
from nilearn.connectome import ConnectivityMeasure

def estimate_tangent_space(data):
    """
    Estimate the tangent space functional connectivity.

    Parameters:
    -----------
    data : list or ndarray
        List or array of time series data for the training population, where each entry corresponds 
        to a subject's time series (time points x regions).

    Returns:
    --------
    ConnectivityMeasure
        Fitted ConnectivityMeasure object configured for tangent space transformation.
    """
    # Instantiate ConnectivityMeasure for tangent space, vectorizing and discarding the diagonal
    connectivity_measure = ConnectivityMeasure(kind='tangent', vectorize=True, discard_diagonal=True)

    # Fit the measure on the population data to establish a reference tangent space
    connectivity_measure.fit(data)

    return connectivity_measure


### Function for create deep neural network models 

The `build_model(input_shape)` function create DNN models with the following architecture:

Input Layer: Takes in the number of features from the input data.

Dense Layer 1: 128 neurons, ReLU activation, with L2 regularization to reduce overfitting.

Dense Layer 2: 64 neurons, ReLU activation, L2 regularization.

Output Layer: A single neuron with sigmoid activation for binary classification.

The model is compiled with the Adam optimizer and binary cross-entropy loss, as we aim to classify subjects into two classes. We also include accuracy as a performance metric to track model performance during training and evaluation.

In [6]:
from keras import layers, models, regularizers

# Define the deep neural network model architecture
def build_model(input_shape):
    """
    Builds and compiles a deep neural network model for binary classification.

    Parameters:
    - input_shape: int, the shape of the input layer, matching the number of features in the dataset

    Returns:
    - model: compiled Keras Sequential model ready for training
    """
    
    model = models.Sequential()

    # Input layer
    model.add(layers.InputLayer(input_shape=input_shape))
    model.add(layers.Dropout(0.2))

    # Hidden layers
    model.add(layers.Dense(128, activation='relu', kernel_regularizer=regularizers.l2(0.001)))
    model.add(layers.Dropout(0.4))

    model.add(layers.Dense(64, activation='relu', kernel_regularizer=regularizers.l2(0.001)))
    model.add(layers.Dropout(0.4))

    # Output layer for binary classification (ASD vs. Healthy)
    model.add(layers.Dense(1, activation='sigmoid'))

    # Compile the model with Adam optimizer and binary cross-entropy loss
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

    return model

### Training Callbacks

To optimize training, we set up three callbacks:

**EarlyStopping:** Stops training if validation loss doesn't improve for 10 epochs, preventing overfitting and restoring the best weights.

**ReduceLROnPlateau:** Reduces the learning rate by 50% when validation loss plateaus for 5 epochs, ensuring gradual and effective model convergence.

**ModelCheckpoint:** Saves the model with the best validation loss to 'best_model.keras', allowing easy access to the optimal version of the model.

In [7]:
from keras import callbacks

# Early stopping to prevent overfitting by stopping training when validation loss stops improving
early_stopping = callbacks.EarlyStopping(
    monitor='val_loss',             # Monitor validation loss for early stopping
    patience=10,                    # Stop training if val_loss does not improve for 10 epochs
    restore_best_weights=True       # Restore the model weights from the epoch with the lowest val_loss
)

# Reduce learning rate when the validation loss plateaus
reduce_lr = callbacks.ReduceLROnPlateau(
    monitor='val_loss',             # Monitor validation loss for learning rate reduction
    factor=0.5,                     # Reduce learning rate by a factor of 0.5
    patience=5,                     # Trigger after 5 epochs without improvement in val_loss
    min_lr=1e-5                     # Set a floor on the learning rate to avoid overly small values
)

# Save the best model based on validation loss
checkpoint = callbacks.ModelCheckpoint(
    'best_model.keras',             # Filename for the best model
    monitor='val_loss',             # Monitor validation loss for checkpoint saving
    save_best_only=True             # Only save the model when it achieves a new best val_loss
)

# Callbacks list passed to the model
callbacks_list = [early_stopping, reduce_lr, checkpoint]

### Function to adjust class balance

The `adjust_class_balance(indices, labels)` function ensures equal representation of all classes by undersampling the majority class(es). This is particularly important in supervised learning, where imbalanced classes can lead to biased models. The function return shuffled list of indices representing a class-balanced subset of the dataset.

In [8]:
import numpy as np

def adjust_class_balance(indices, labels):
    """
    Adjusts the balance of classes by undersampling the majority class.

    Parameters:
    ----------
    indices : list or ndarray
        Indices of the dataset.
    labels : list or ndarray
        Class labels corresponding to the indices.

    Returns:
    -------
    balanced_indices : ndarray
        Indices of the balanced dataset.
    """
    # Class labels to consider
    CLASS_LABELS = [0, 1]

    # Separate indices by class
    class_indices = {label: [idx for idx in indices if labels[idx] == label] for label in CLASS_LABELS}

    # Determine the minimum class count
    min_class_count = min(len(indices) for indices in class_indices.values())

    # Adjust class balance by undersampling the majority class
    balanced_indices = []
    for label, class_list in class_indices.items():
        if len(class_list) > min_class_count:
            sampled_indices = np.random.choice(class_list, size=min_class_count, replace=False)
            balanced_indices.extend(sampled_indices)
        else:
            balanced_indices.extend(class_list)

    # Shuffle the indices for randomization
    np.random.shuffle(balanced_indices)
    return np.array(balanced_indices)

### Stratified cross-validation setup for model training and validation

Set up stratified 10-fold cross-validation for each site (excluding `test_site`) to evaluate model performance across multiple splits. Here’s an overview of the process:

**Stratified k-folds**: StratifiedKFold let to maintain the balance of classes (ASD vs. NC) across each fold, reducing potential bias.

**Fold processing**: For each site, 10 training and validation folds are generated, and indices are stored in the `train_indices` and `val_indices` dictionaries. To ensure class balance after combining all group folds for training data the majority class in each site is undersampling.

**Class balance checks**: For each fold, the balance of ASD and NC samples is shown to confirm each split maintains similar distributions.

In [9]:
from sklearn.model_selection import StratifiedKFold

# Number of cross-validation folds
n_folds = 10

# Dictionaries for save the training and validation indices
train_indices = {}
val_indices = {}

# Perform stratified k-fold cross-validation for each site, excluding 'test_site' for testing
for site in sites:
    if site == test_site:
        continue

    features = rois_time_series[site]
    labels = rois_labels[site]

    # Initialize StratifiedKFold with shuffle to ensure data randomization
    skf = StratifiedKFold(n_splits=n_folds, shuffle=True)
    
    site_train_indices = []
    site_val_indices = []

    # Loop through each fold in the stratified split
    for fold, (train_idx, val_idx) in enumerate(skf.split(features, labels)):
        print(f"Processing fold #{fold} for site `{site}`")
      
        train_idx = adjust_class_balance(train_idx, labels)
        val_idx = adjust_class_balance(val_idx, labels)

        # Append training and validation indices for each fold
        site_train_indices.append(np.array(train_idx))
        site_val_indices.append(np.array(val_idx))
        
        # Print class distribution for training and validation sets for each fold
        print(f"Balance of classes in training -> ASD: {np.count_nonzero(labels[site_train_indices[fold]] == 1)} and TC: {np.count_nonzero(labels[site_train_indices[fold]] == 0)}")
        
        print(f"Balance of classes in validation -> ASD: {np.count_nonzero(labels[site_val_indices[fold]] == 1)} and TC: {np.count_nonzero(labels[site_val_indices[fold]] == 0)}")

    # Store indices for each fold in the dictionaries
    train_indices[site] = site_train_indices
    val_indices[site] = site_val_indices

ValueError: Cannot have number of splits n_splits=10 greater than the number of samples: n_samples=4.

###  Function to calculate evaluation metrics

In [10]:
from sklearn.metrics import confusion_matrix, roc_auc_score

def calculate_metrics(y_true, y_pred, y_pred_prob):
    cm = confusion_matrix(y_true, y_pred)
    tn, fp, fn, tp = cm.ravel()
    
    accuracy = (tp + tn) / (tp + tn + fp + fn)
    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    auc = roc_auc_score(y_true, y_pred_prob)
    
    return accuracy, sensitivity, precision, specificity, auc, cm

### Function to print metrics

In [11]:
def print_metrics(split, dataset_type, accuracy, sensitivity, precision, specificity, auc, cm):
    print(f"{dataset_type.capitalize()} Metrics for split {split + 1}:")
    print(f"  Accuracy: {accuracy * 100:.2f}%")
    print(f"  Sensitivity (Recall): {sensitivity * 100:.2f}%")
    print(f"  Precision: {precision * 100:.2f}%")
    print(f"  Specificity: {specificity * 100:.2f}%")
    print(f"  AUC-ROC Score: {auc * 100:.2f}%")
    print(f"  Confusion Matrix:\n{cm}")

### Saving tangent spaces

In [12]:
connectivity_list = []

# Cross-validation across all splits
for split in range(n_folds):
    print(f"\n--- Split {split + 1} ---")

    # Aggregate training and validation data across all sites
    X_train_time_series, X_val_time_series = [], []
    y_train, y_val = [], []

    for site in sites:
        if site == test_site:
            continue
        X_train_time_series.extend([rois_time_series[site][idx] for idx in train_indices[site][split]])
        y_train.extend([rois_labels[site][idx] for idx in train_indices[site][split]])
        X_val_time_series.extend([rois_time_series[site][idx] for idx in val_indices[site][split]])
        y_val.extend([rois_labels[site][idx] for idx in val_indices[site][split]])

    # Prepare tangent space for feature extraction
    connectivity_m = estimate_tangent_space(X_train_time_series)
    connectivity_list.append(connectivity_m)


--- Split 1 ---


KeyError: 'caltech'

### Testing the model

In [13]:
# Initialize accumulators for metrics
metrics = {
    "validation": {"accuracy": 0, "sensitivity": 0, "precision": 0, "specificity": 0, "auc": 0},
    "test": {"accuracy": 0, "sensitivity": 0, "precision": 0, "specificity": 0, "auc": 0}
}

# Cross-validation across all splits
for split in range(n_folds):
    print(f"\n--- Split {split + 1} ---")

    # Aggregate training and validation data across all sites
    X_train_time_series, X_val_time_series = [], []
    y_train, y_val = [], []

    for site in sites:
        if site == test_site:
            continue
        X_train_time_series.extend([rois_time_series[site][idx] for idx in train_indices[site][split]])
        y_train.extend([rois_labels[site][idx] for idx in train_indices[site][split]])
        X_val_time_series.extend([rois_time_series[site][idx] for idx in val_indices[site][split]])
        y_val.extend([rois_labels[site][idx] for idx in val_indices[site][split]])

    # Prepare tangent space for feature extraction
    X_train = connectivity_list[split].transform(X_train_time_series)
    X_val = connectivity_list[split].transform(X_val_time_series)
    X_test = connectivity_list[split].transform(rois_time_series[test_site])
    y_test = rois_labels[test_site]

    X_train, X_val, X_test = map(np.array, [X_train, X_val, X_test])
    y_train, y_val, y_test = map(np.array, [y_train, y_val, y_test])

    # Print dataset statistics
    print(f"Training set shape: {X_train.shape}, class balance: ASD={np.sum(y_train == 1)}, TC={np.sum(y_train == 0)}")
    print(f"Validation set shape: {X_val.shape}, class balance: ASD={np.sum(y_val == 1)}, TC={np.sum(y_val == 0)}")
    print(f"Test set shape: {X_test.shape}, class balance: ASD={np.sum(y_test == 1)}, TC={np.sum(y_test == 0)}")

    # Build and train the model
    dnn = build_model(X_train.shape[1])
    history = dnn.fit(X_train, y_train, validation_data=(X_val, y_val), batch_size=32, epochs=100, callbacks=callbacks_list)

    # Evaluate on validation set
    validation_pred_prob = dnn.predict(X_val).ravel()
    validation_pred = (validation_pred_prob > 0.5).astype(int)
    acc, sens, prec, spec, auc, cm = calculate_metrics(y_val, validation_pred, validation_pred_prob)
    metrics["validation"]["accuracy"] += acc
    metrics["validation"]["sensitivity"] += sens
    metrics["validation"]["precision"] += prec
    metrics["validation"]["specificity"] += spec
    metrics["validation"]["auc"] += auc
    print_metrics(split, "validation", acc, sens, prec, spec, auc, cm)

    # Evaluate on test set
    test_pred_prob = dnn.predict(X_test).ravel()
    test_pred = (test_pred_prob > 0.5).astype(int)
    acc, sens, prec, spec, auc, cm = calculate_metrics(y_test, test_pred, test_pred_prob)
    metrics["test"]["accuracy"] += acc
    metrics["test"]["sensitivity"] += sens
    metrics["test"]["precision"] += prec
    metrics["test"]["specificity"] += spec
    metrics["test"]["auc"] += auc
    print_metrics(split, "test", acc, sens, prec, spec, auc, cm)

# Print mean metrics
for dataset in metrics:
    print(f"\n--- Mean {dataset.capitalize()} Metrics Across All Splits ---")
    for metric, value in metrics[dataset].items():
        print(f"{metric.capitalize()}: {(value / n_folds) * 100:.2f}%")


--- Split 1 ---


KeyError: 'caltech'

### Traing with the whole dataset

In [14]:
from sklearn.model_selection import train_test_split

# Aggregate all training data
X_train_time_series, y_data = [], []

for site in sites:
    if site == test_site:
        continue

    X_train_time_series.extend(rois_time_series[site])
    y_data.extend(rois_labels[site])

# Prepare tangent space for feature extraction
connectivity_m = estimate_tangent_space(X_train_time_series)
print("Tangent space estimated.")

# Transform data into feature vectors
X_data = connectivity_m.transform(X_train_time_series)
X_test = connectivity_m.transform(rois_time_series[test_site])
y_test = rois_labels[test_site]

Tangent space estimated.


In [15]:
# Split the data into training and validation sets
X_train, X_val, y_train, y_val = train_test_split(X_data, y_data, stratify=y_data, test_size=0.2)

# Convert labels to numpy arrays
X_train, X_val, X_test = map(np.array, [X_train, X_val, X_test])
y_train, y_val, y_test = map(np.array, [y_train, y_val, y_test])

# Print dataset statistics
print(f"Training set shape: {X_train.shape}, class balance: ASD={np.sum(y_train == 1)}, TC={np.sum(y_train == 0)}")
print(f"Validation set shape: {X_val.shape}, class balance: ASD={np.sum(y_val == 1)}, TC={np.sum(y_val == 0)}")
print(f"Test set shape: {X_test.shape}, class balance: ASD={np.sum(y_test == 1)}, TC={np.sum(y_test == 0)}")

# Build and train the model
dnn = build_model(X_train.shape[1])
history = dnn.fit(X_train, y_train, validation_data=(X_val, y_val), batch_size=32, epochs=100, callbacks=callbacks_list)

# Evaluate on training set
train_pred_prob = dnn.predict(X_train).ravel()
train_pred = (train_pred_prob > 0.5).astype(int)
acc, sens, prec, spec, auc, cm = calculate_metrics(y_train, train_pred, train_pred_prob)
print_metrics(1, "training", acc, sens, prec, spec, auc, cm)

# Evaluate on validation set
val_pred_prob = dnn.predict(X_val).ravel()
val_pred = (val_pred_prob > 0.5).astype(int)
acc, sens, prec, spec, auc, cm = calculate_metrics(y_val, val_pred, val_pred_prob)
print_metrics(1, "validation", acc, sens, prec, spec, auc, cm)

# Evaluate on test set
test_pred_prob = dnn.predict(X_test).ravel()
test_pred = (test_pred_prob > 0.5).astype(int)
acc, sens, prec, spec, auc, cm = calculate_metrics(y_test, test_pred, test_pred_prob)
print_metrics(1, "test", acc, sens, prec, spec, auc, cm)

Training set shape: (87, 6105), class balance: ASD=44, TC=43
Validation set shape: (22, 6105), class balance: ASD=11, TC=11
Test set shape: (14, 6105), class balance: ASD=7, TC=7
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/10