### SDI Calculation with normalization


In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.linalg import eigh
from collections import defaultdict

sns.set(style='whitegrid')

def compute_unnormalized_laplacian(A):
    D = np.diag(np.sum(A, axis=1))
    return D - A

def preprocess_structural_connectivity(A, threshold=0.05):
    A = np.copy(A)
    
    # Step 1: Drop small values
    A[A < threshold] = 0

    # Step 2: Normalize after thresholding
    max_val = np.max(A)
    if max_val > 0:
        A = A / max_val

    return A


def compute_functional_connectivity(X):
    """Compute functional connectivity matrix from time series with robust handling"""
    # Check for constant or zero variance time series
    X_std = np.std(X, axis=0)
    valid_regions = X_std > 1e-10  # Regions with non-zero variance
    
    if not np.any(valid_regions):
        # If all regions have zero variance, return identity matrix
        return np.eye(X.shape[1])
    
    # Only use regions with valid variance
    X_valid = X[:, valid_regions]
    
    # Compute correlation matrix for valid regions
    with np.errstate(invalid='ignore', divide='ignore'):
        fc_valid = np.corrcoef(X_valid.T)
        # Replace any NaN values with 0
        fc_valid = np.nan_to_num(fc_valid, nan=0.0, posinf=0.0, neginf=0.0)
    
    # Create full FC matrix and fill in valid correlations
    fc_full = np.zeros((X.shape[1], X.shape[1]))
    valid_idx = np.where(valid_regions)[0]
    
    for i, idx_i in enumerate(valid_idx):
        for j, idx_j in enumerate(valid_idx):
            fc_full[idx_i, idx_j] = fc_valid[i, j]
    
    # Set diagonal to 1 (self-correlation)
    np.fill_diagonal(fc_full, 1.0)
    
    return fc_full

def compute_esd_and_sdi(X, A):
    T, N = X.shape
    
    # Robust standardization
    X_mean = np.mean(X, axis=0)
    X_std = np.std(X, axis=0)
    
    # Handle regions with zero variance
    X_std[X_std < 1e-10] = 1.0  # Prevent division by zero
    X = (X - X_mean) / X_std

    L = compute_unnormalized_laplacian(A)
    
    # Add small regularization to avoid numerical issues
    L_reg = L + 1e-12 * np.eye(L.shape[0])
    
    eigvals, eigvecs = eigh(L_reg)
    eigvals = np.clip(eigvals, 1e-8, None)

    X_hat = X @ eigvecs
    esd = np.sum(X_hat**2, axis=0)

    # Use median eigenvalue as cutoff
    median_cutoff = np.median(eigvals)
    
    # Calculate SDI based on median split
    low_freq_mask = eigvals <= median_cutoff
    high_freq_mask = eigvals > median_cutoff
    
    low_freq_energy = np.sum(esd[low_freq_mask])
    high_freq_energy = np.sum(esd[high_freq_mask])
    
    # SDI = ratio of low frequency to high frequency energy
    # Add safety checks to avoid division by zero
    if high_freq_energy == 0 or np.isnan(high_freq_energy) or np.isinf(high_freq_energy):
        if low_freq_energy == 0 or np.isnan(low_freq_energy) or np.isinf(low_freq_energy):
            sdi = 1.0  # Default value when both are zero
        else:
            sdi = np.inf  # High value when only denominator is zero
    else:
        sdi = low_freq_energy / high_freq_energy

    return eigvals, esd, median_cutoff, sdi

def extract_pid(fpath):
    fname = os.path.basename(fpath)
    if "sub-" in fname:
        return fname.split("sub-")[1].split("_")[0]
    else:
        return fname.split("_")[0]

def load_phenotype_csv(csv_path):
    df = pd.read_csv(csv_path)
    if 'SUB_ID' in df.columns:
        df['PatientID'] = df['SUB_ID'].astype(str)
    elif 'subject' in df.columns:
        df['PatientID'] = df['subject'].astype(str)
    return df.set_index('PatientID')

def collect_file_maps(structural_dir, functional_dir):
    struct_map, func_map = {}, {}
    for f in os.listdir(structural_dir):
        if f.endswith('.csv'):
            pid = extract_pid(f)
            struct_map[pid] = os.path.join(structural_dir, f)
    for root, _, files in os.walk(functional_dir):
        for f in files:
            if f.endswith('.csv'):
                pid = extract_pid(f)
                func_map[pid] = os.path.join(root, f)
    return struct_map, func_map

def plot_connectivity_matrices(fc_matrix, sc_matrix, pid, group_label):
    """Plot FC, SC, and difference (FC - SC) matrices in three panels"""
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5))
    
    # Plot FC matrix
    im1 = ax1.imshow(fc_matrix, cmap='RdBu_r', vmin=-1, vmax=1)
    ax1.set_title(f'Patient {pid}: fMRI Functional Connectivity')
    ax1.set_xlabel('Brain Region')
    ax1.set_ylabel('Brain Region')
    cbar1 = plt.colorbar(im1, ax=ax1, shrink=0.6)
    cbar1.set_label('Correlation', rotation=270, labelpad=15)
    
    # Plot SC matrix (no normalization)
    im2 = ax2.imshow(sc_matrix, cmap='Reds', vmin=0)
    ax2.set_title(f'Patient {pid}: Structural Connectivity')
    ax2.set_xlabel('Brain Region')
    ax2.set_ylabel('Brain Region')
    cbar2 = plt.colorbar(im2, ax=ax2, shrink=0.6)
    cbar2.set_label('Connection Strength', rotation=270, labelpad=15)
    
    # Plot difference (FC - SC)
    diff_matrix = fc_matrix - sc_matrix
    # Use symmetric colormap range for difference, avoid issues with empty matrices
    if np.any(np.isfinite(diff_matrix)):
        diff_max = np.max(np.abs(diff_matrix[np.isfinite(diff_matrix)]))
        if diff_max == 0:
            diff_max = 1.0  # Avoid zero range
    else:
        diff_max = 1.0
    
    im3 = ax3.imshow(diff_matrix, cmap='RdBu_r', vmin=-diff_max, vmax=diff_max)
    ax3.set_title(f'Patient {pid}: Difference (FC - SC)')
    ax3.set_xlabel('Brain Region')
    ax3.set_ylabel('Brain Region')
    cbar3 = plt.colorbar(im3, ax=ax3, shrink=0.6)
    cbar3.set_label('Difference', rotation=270, labelpad=15)
    
    plt.tight_layout()
    plt.savefig(f"connectivity_comparison_{pid}_{group_label}.png", dpi=150, bbox_inches='tight')
    plt.show()

def plot_group_esd(group_esds, group_name, median_cutoff):
    eigvals = group_esds['eigvals'][0]
    esd_matrix = np.vstack(group_esds['esds'])
    median_esd = np.median(esd_matrix, axis=0)
    lower_q = np.percentile(esd_matrix, 25, axis=0)
    upper_q = np.percentile(esd_matrix, 75, axis=0)

    plt.plot(eigvals, median_esd, label=group_name, lw=2)
    plt.fill_between(eigvals, lower_q, upper_q, alpha=0.25)

    # Modified: Split based on median eigenvalue
    plt.axvspan(min(eigvals), median_cutoff, color='#cceeff', alpha=0.6)
    plt.axvspan(median_cutoff, max(eigvals), color='#ffcccc', alpha=0.6)

    plt.xscale('log')
    plt.yscale('log')
    plt.xlabel(r'$\lambda$ (eigenvalue)', fontsize=14)
    plt.ylabel(r'$\xi$ (energy spectral density)', fontsize=14)
    plt.title('Energy spectral density (split at median eigenvalue)', fontsize=14)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.grid(True, which='both', ls='--', lw=0.5)

def process_site(structural_dir, functional_dir, phenotype_csv, group_esds, median_cutoffs, group_sdis):
    structural_map, functional_map = collect_file_maps(structural_dir, functional_dir)
    phenotype_df = load_phenotype_csv(phenotype_csv)
    common_pids = sorted(set(structural_map.keys()) & set(functional_map.keys()) & set(phenotype_df.index.astype(str)))

    print(f"Found {len(common_pids)} valid subjects at site.")
    
    # Plot connectivity matrices for first few subjects as examples
    plot_count = 0
    max_plots = 2  # Plot matrices for first 2 subjects per group per site

    for pid in common_pids:
        try:
            A = np.loadtxt(structural_map[pid], delimiter=',')
            X = np.loadtxt(functional_map[pid], delimiter=',')

            if A.shape[0] != A.shape[1] or X.shape[1] != A.shape[0]:
                continue

            A = preprocess_structural_connectivity(A, threshold=0.05)
            FC = compute_functional_connectivity(X)
            eigvals, esd, cutoff_value, sdi = compute_esd_and_sdi(X, A)

            group = phenotype_df.loc[pid, 'DX_GROUP']
            label = 'TD' if group == 1 else 'ASD'

            group_esds[label]['eigvals'].append(eigvals)
            group_esds[label]['esds'].append(esd)
            group_sdis[label].append(sdi)

            median_cutoffs.append(cutoff_value)
            
            # Plot connectivity matrices for first few subjects
            if plot_count < max_plots:
                plot_connectivity_matrices(FC, A, pid, label)
                plot_count += 1

        except Exception as e:
            print(f"Error processing {pid}: {e}")

def run_all_sites():
    sites_config = [
        {
            "site": "bni",
            "functional_dir": "/Users/arnavkarnik/Documents/Classification2/Time_Series_ABIDE2/bni_time_series/schaefer_400/cleaned-1",
            "structural_dir": "/Users/arnavkarnik/Documents/Classification2/SC_Connectomes_ABIDE2/BNI_1_connectomes",
            "phenotype_csv": "/Users/arnavkarnik/Documents/Classification2/Phenotypes_ABIDE2/BNI_1_phenotypes.csv"
        },
        {
            "site": "ip",
            "functional_dir": "/Users/arnavkarnik/Documents/Classification2/Time_Series_ABIDE2/ip_time_series/schaefer_400/cleaned-1",
            "structural_dir": "/Users/arnavkarnik/Documents/Classification2/SC_Connectomes_ABIDE2/IP_1_connectomes",
            "phenotype_csv": "/Users/arnavkarnik/Documents/Classification2/Phenotypes_ABIDE2/IP_1_phenotypes.csv"
        },
        {
            "site": "nyu1",
            "functional_dir": "/Users/arnavkarnik/Documents/Classification2/Time_Series_ABIDE2/nyu1_time_series/schaefer_400/cleaned-1",
            "structural_dir": "/Users/arnavkarnik/Documents/Classification2/SC_Connectomes_ABIDE2/NYU_1_connectomes",
            "phenotype_csv": "/Users/arnavkarnik/Documents/Classification2/Phenotypes_ABIDE2/NYU_1_phenotypes.csv"
        },
        {
            "site": "nyu2",
            "functional_dir": "/Users/arnavkarnik/Documents/Classification2/Time_Series_ABIDE2/nyu2_time_series/schaefer_400/cleaned-1",
            "structural_dir": "/Users/arnavkarnik/Documents/Classification2/SC_Connectomes_ABIDE2/NYU_2_connectomes",
            "phenotype_csv": "/Users/arnavkarnik/Documents/Classification2/Phenotypes_ABIDE2/NYU_2_phenotypes.csv"
        },
        {
            "site": "sdsu",
            "functional_dir": "/Users/arnavkarnik/Documents/Classification2/Time_Series_ABIDE2/sdsu_time_series/schaefer_400/cleaned-1",
            "structural_dir": "/Users/arnavkarnik/Documents/Classification2/SC_Connectomes_ABIDE2/SDSU_1_connectomes",
            "phenotype_csv": "/Users/arnavkarnik/Documents/Classification2/Phenotypes_ABIDE2/SDSU_1_phenotypes.csv"
        }
    ]

    group_esds = defaultdict(lambda: {'eigvals': [], 'esds': []})
    group_sdis = defaultdict(list)
    median_cutoffs = []

    for config in sites_config:
        print(f"\n=== Processing site: {config['site'].upper()} ===")
        process_site(config["structural_dir"], config["functional_dir"], config["phenotype_csv"], group_esds, median_cutoffs, group_sdis)

    global_median_cutoff = np.median(median_cutoffs)
    print(f"\nGlobal median eigenvalue cutoff: {global_median_cutoff:.5f}")

    # Plot ESD comparison
    plt.figure(figsize=(10, 6))
    for label, color in zip(['TD', 'ASD'], ['blue', 'red']):
        if group_esds[label]['esds']:
            plot_group_esd(group_esds[label], label, global_median_cutoff)
    
    # Add region labels once for the entire plot
    plt.axvspan(0, global_median_cutoff, color='#cceeff', alpha=0.6, label='Low eigenvalue region')
    plt.axvspan(global_median_cutoff, plt.xlim()[1], color='#ffcccc', alpha=0.6, label='High eigenvalue region')
    
    plt.legend(fontsize=12)
    plt.tight_layout()
    plt.savefig("esd_comparison_median_split.png", dpi=300)
    plt.show()

    # Plot SDI comparison
    plt.figure(figsize=(8, 6))
    sdi_data = []
    group_labels = []
    
    for label in ['TD', 'ASD']:
        if group_sdis[label]:
            sdi_data.extend(group_sdis[label])
            group_labels.extend([label] * len(group_sdis[label]))
    
    if sdi_data:
        sdi_df = pd.DataFrame({'SDI': sdi_data, 'Group': group_labels})
        
        # Filter out infinite and NaN values for plotting
        sdi_df = sdi_df[np.isfinite(sdi_df['SDI'])]
        
        if len(sdi_df) > 0:
            # Box plot
            sns.boxplot(data=sdi_df, x='Group', y='SDI', palette=['blue', 'red'])
            plt.title('Structure-Dynamics Index (SDI) Comparison\n(Low freq energy / High freq energy)', fontsize=14)
            plt.ylabel('SDI', fontsize=12)
            plt.xlabel('Group', fontsize=12)
            
            # Add statistical info
            td_sdi = np.array([sdi for sdi, group in zip(sdi_data, group_labels) if group == 'TD'])
            asd_sdi = np.array([sdi for sdi, group in zip(sdi_data, group_labels) if group == 'ASD'])
            
            # Filter finite values for statistics
            td_sdi = td_sdi[np.isfinite(td_sdi)]
            asd_sdi = asd_sdi[np.isfinite(asd_sdi)]
            
            if len(td_sdi) > 0 and len(asd_sdi) > 0:
                from scipy.stats import ttest_ind
                t_stat, p_val = ttest_ind(td_sdi, asd_sdi)
                plt.text(0.5, plt.ylim()[1]*0.9, f'p-value: {p_val:.4f}', 
                        ha='center', fontsize=10, bbox=dict(boxstyle='round', facecolor='wheat'))
            
            plt.tight_layout()
            plt.savefig("sdi_comparison.png", dpi=300)
            plt.show()
        else:
            print("No valid SDI values for plotting")
    else:
        print("No SDI data available")
        
        # Print SDI statistics
        print(f"\nSDI Statistics:")
        if len(td_sdi) > 0:
            # Filter out infinite and NaN values for statistics
            td_sdi_clean = td_sdi[np.isfinite(td_sdi)]
            print(f"TD SDI: mean={np.mean(td_sdi_clean):.4f}, std={np.std(td_sdi_clean):.4f}, n={len(td_sdi_clean)}")
        
        if len(asd_sdi) > 0:
            asd_sdi_clean = asd_sdi[np.isfinite(asd_sdi)]
            print(f"ASD SDI: mean={np.mean(asd_sdi_clean):.4f}, std={np.std(asd_sdi_clean):.4f}, n={len(asd_sdi_clean)}")
            
            # Only perform t-test if both groups have finite values
            if len(td_sdi_clean) > 0 and len(asd_sdi_clean) > 0:
                from scipy.stats import ttest_ind
                t_stat, p_val = ttest_ind(td_sdi_clean, asd_sdi_clean)
                print(f"T-test p-value: {p_val:.4f}")
        else:
            print("No valid ASD SDI values found")

    # Optional: Print some statistics about the split
    print(f"\nSplit statistics:")
    print(f"Global median eigenvalue: {global_median_cutoff:.5f}")
    print(f"Number of subjects contributing to median: {len(median_cutoffs)}")

if __name__ == "__main__":
    run_all_sites()

### SDI calculation with no normalization

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.linalg import eigh
from collections import defaultdict

sns.set(style='whitegrid')

def compute_unnormalized_laplacian(A):
    D = np.diag(np.sum(A, axis=1))
    return D - A

def preprocess_structural_connectivity(A, threshold=0.05):
    A = np.copy(A)
    
    # Step 1: Drop small values
    A[A < threshold] = 0

    return A


def compute_functional_connectivity(X):
    """Compute functional connectivity matrix from time series with robust handling"""
    # Check for constant or zero variance time series
    X_std = np.std(X, axis=0)
    valid_regions = X_std > 1e-10  # Regions with non-zero variance
    
    if not np.any(valid_regions):
        # If all regions have zero variance, return identity matrix
        return np.eye(X.shape[1])
    
    # Only use regions with valid variance
    X_valid = X[:, valid_regions]
    
    # Compute correlation matrix for valid regions
    with np.errstate(invalid='ignore', divide='ignore'):
        fc_valid = np.corrcoef(X_valid.T)
        # Replace any NaN values with 0
        fc_valid = np.nan_to_num(fc_valid, nan=0.0, posinf=0.0, neginf=0.0)
    
    # Create full FC matrix and fill in valid correlations
    fc_full = np.zeros((X.shape[1], X.shape[1]))
    valid_idx = np.where(valid_regions)[0]
    
    for i, idx_i in enumerate(valid_idx):
        for j, idx_j in enumerate(valid_idx):
            fc_full[idx_i, idx_j] = fc_valid[i, j]
    
    # Set diagonal to 1 (self-correlation)
    np.fill_diagonal(fc_full, 1.0)
    
    return fc_full

def compute_esd_and_sdi(X, A):
    T, N = X.shape
    
    # Robust standardization
    X_mean = np.mean(X, axis=0)
    X_std = np.std(X, axis=0)
    
    # Handle regions with zero variance
    X_std[X_std < 1e-10] = 1.0  # Prevent division by zero
    X = (X - X_mean) / X_std

    L = compute_unnormalized_laplacian(A)
    
    # Add small regularization to avoid numerical issues
    L_reg = L + 1e-12 * np.eye(L.shape[0])
    
    eigvals, eigvecs = eigh(L_reg)
    eigvals = np.clip(eigvals, 1e-8, None)

    X_hat = X @ eigvecs
    esd = np.sum(X_hat**2, axis=0)

    # Use median eigenvalue as cutoff
    median_cutoff = np.median(eigvals)
    
    # Calculate SDI based on median split
    low_freq_mask = eigvals <= median_cutoff
    high_freq_mask = eigvals > median_cutoff
    
    low_freq_energy = np.sum(esd[low_freq_mask])
    high_freq_energy = np.sum(esd[high_freq_mask])
    
    # SDI = ratio of low frequency to high frequency energy
    # Add safety checks to avoid division by zero
    if high_freq_energy == 0 or np.isnan(high_freq_energy) or np.isinf(high_freq_energy):
        if low_freq_energy == 0 or np.isnan(low_freq_energy) or np.isinf(low_freq_energy):
            sdi = 1.0  # Default value when both are zero
        else:
            sdi = np.inf  # High value when only denominator is zero
    else:
        sdi = low_freq_energy / high_freq_energy

    return eigvals, esd, median_cutoff, sdi

def extract_pid(fpath):
    fname = os.path.basename(fpath)
    if "sub-" in fname:
        return fname.split("sub-")[1].split("_")[0]
    else:
        return fname.split("_")[0]

def load_phenotype_csv(csv_path):
    df = pd.read_csv(csv_path)
    if 'SUB_ID' in df.columns:
        df['PatientID'] = df['SUB_ID'].astype(str)
    elif 'subject' in df.columns:
        df['PatientID'] = df['subject'].astype(str)
    return df.set_index('PatientID')

def collect_file_maps(structural_dir, functional_dir):
    struct_map, func_map = {}, {}
    for f in os.listdir(structural_dir):
        if f.endswith('.csv'):
            pid = extract_pid(f)
            struct_map[pid] = os.path.join(structural_dir, f)
    for root, _, files in os.walk(functional_dir):
        for f in files:
            if f.endswith('.csv'):
                pid = extract_pid(f)
                func_map[pid] = os.path.join(root, f)
    return struct_map, func_map

def plot_connectivity_matrices(fc_matrix, sc_matrix, pid, group_label):
    """Plot FC, SC, and difference (FC - SC) matrices in three panels"""
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5))
    
    # Plot FC matrix
    im1 = ax1.imshow(fc_matrix, cmap='RdBu_r', vmin=-1, vmax=1)
    ax1.set_title(f'Patient {pid}: fMRI Functional Connectivity')
    ax1.set_xlabel('Brain Region')
    ax1.set_ylabel('Brain Region')
    cbar1 = plt.colorbar(im1, ax=ax1, shrink=0.6)
    cbar1.set_label('Correlation', rotation=270, labelpad=15)
    
    # Plot SC matrix (no normalization)
    im2 = ax2.imshow(sc_matrix, cmap='Reds', vmin=0)
    ax2.set_title(f'Patient {pid}: Structural Connectivity')
    ax2.set_xlabel('Brain Region')
    ax2.set_ylabel('Brain Region')
    cbar2 = plt.colorbar(im2, ax=ax2, shrink=0.6)
    cbar2.set_label('Connection Strength', rotation=270, labelpad=15)
    
    # Plot difference (FC - SC)
    diff_matrix = fc_matrix - sc_matrix
    # Use symmetric colormap range for difference, avoid issues with empty matrices
    if np.any(np.isfinite(diff_matrix)):
        diff_max = np.max(np.abs(diff_matrix[np.isfinite(diff_matrix)]))
        if diff_max == 0:
            diff_max = 1.0  # Avoid zero range
    else:
        diff_max = 1.0
    
    im3 = ax3.imshow(diff_matrix, cmap='RdBu_r', vmin=-diff_max, vmax=diff_max)
    ax3.set_title(f'Patient {pid}: Difference (FC - SC)')
    ax3.set_xlabel('Brain Region')
    ax3.set_ylabel('Brain Region')
    cbar3 = plt.colorbar(im3, ax=ax3, shrink=0.6)
    cbar3.set_label('Difference', rotation=270, labelpad=15)
    
    plt.tight_layout()
    plt.savefig(f"connectivity_comparison_{pid}_{group_label}.png", dpi=150, bbox_inches='tight')
    plt.show()

def plot_group_esd(group_esds, group_name, median_cutoff):
    eigvals = group_esds['eigvals'][0]
    esd_matrix = np.vstack(group_esds['esds'])
    median_esd = np.median(esd_matrix, axis=0)
    lower_q = np.percentile(esd_matrix, 25, axis=0)
    upper_q = np.percentile(esd_matrix, 75, axis=0)

    plt.plot(eigvals, median_esd, label=group_name, lw=2)
    plt.fill_between(eigvals, lower_q, upper_q, alpha=0.25)

    # Modified: Split based on median eigenvalue
    plt.axvspan(min(eigvals), median_cutoff, color='#cceeff', alpha=0.6)
    plt.axvspan(median_cutoff, max(eigvals), color='#ffcccc', alpha=0.6)

    plt.xscale('log')
    plt.yscale('log')
    plt.xlabel(r'$\lambda$ (eigenvalue)', fontsize=14)
    plt.ylabel(r'$\xi$ (energy spectral density)', fontsize=14)
    plt.title('Energy spectral density (split at median eigenvalue)', fontsize=14)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.grid(True, which='both', ls='--', lw=0.5)

def process_site(structural_dir, functional_dir, phenotype_csv, group_esds, median_cutoffs, group_sdis):
    structural_map, functional_map = collect_file_maps(structural_dir, functional_dir)
    phenotype_df = load_phenotype_csv(phenotype_csv)
    common_pids = sorted(set(structural_map.keys()) & set(functional_map.keys()) & set(phenotype_df.index.astype(str)))

    print(f"Found {len(common_pids)} valid subjects at site.")
    
    # Plot connectivity matrices for first few subjects as examples
    plot_count = 0
    max_plots = 2  # Plot matrices for first 2 subjects per group per site

    for pid in common_pids:
        try:
            A = np.loadtxt(structural_map[pid], delimiter=',')
            X = np.loadtxt(functional_map[pid], delimiter=',')

            if A.shape[0] != A.shape[1] or X.shape[1] != A.shape[0]:
                continue

            A = preprocess_structural_connectivity(A, threshold=0.05)
            FC = compute_functional_connectivity(X)
            eigvals, esd, cutoff_value, sdi = compute_esd_and_sdi(X, A)

            group = phenotype_df.loc[pid, 'DX_GROUP']
            label = 'TD' if group == 1 else 'ASD'

            group_esds[label]['eigvals'].append(eigvals)
            group_esds[label]['esds'].append(esd)
            group_sdis[label].append(sdi)

            median_cutoffs.append(cutoff_value)
            
            # Plot connectivity matrices for first few subjects
            if plot_count < max_plots:
                plot_connectivity_matrices(FC, A, pid, label)
                plot_count += 1

        except Exception as e:
            print(f"Error processing {pid}: {e}")

def run_all_sites():
    sites_config = [
        {
            "site": "bni",
            "functional_dir": "/Users/arnavkarnik/Documents/Classification2/Time_Series_ABIDE2/bni_time_series/schaefer_400/cleaned-1",
            "structural_dir": "/Users/arnavkarnik/Documents/Classification2/SC_Connectomes_ABIDE2/BNI_1_connectomes",
            "phenotype_csv": "/Users/arnavkarnik/Documents/Classification2/Phenotypes_ABIDE2/BNI_1_phenotypes.csv"
        },
        {
            "site": "ip",
            "functional_dir": "/Users/arnavkarnik/Documents/Classification2/Time_Series_ABIDE2/ip_time_series/schaefer_400/cleaned-1",
            "structural_dir": "/Users/arnavkarnik/Documents/Classification2/SC_Connectomes_ABIDE2/IP_1_connectomes",
            "phenotype_csv": "/Users/arnavkarnik/Documents/Classification2/Phenotypes_ABIDE2/IP_1_phenotypes.csv"
        },
        {
            "site": "nyu1",
            "functional_dir": "/Users/arnavkarnik/Documents/Classification2/Time_Series_ABIDE2/nyu1_time_series/schaefer_400/cleaned-1",
            "structural_dir": "/Users/arnavkarnik/Documents/Classification2/SC_Connectomes_ABIDE2/NYU_1_connectomes",
            "phenotype_csv": "/Users/arnavkarnik/Documents/Classification2/Phenotypes_ABIDE2/NYU_1_phenotypes.csv"
        },
        {
            "site": "nyu2",
            "functional_dir": "/Users/arnavkarnik/Documents/Classification2/Time_Series_ABIDE2/nyu2_time_series/schaefer_400/cleaned-1",
            "structural_dir": "/Users/arnavkarnik/Documents/Classification2/SC_Connectomes_ABIDE2/NYU_2_connectomes",
            "phenotype_csv": "/Users/arnavkarnik/Documents/Classification2/Phenotypes_ABIDE2/NYU_2_phenotypes.csv"
        },
        {
            "site": "sdsu",
            "functional_dir": "/Users/arnavkarnik/Documents/Classification2/Time_Series_ABIDE2/sdsu_time_series/schaefer_400/cleaned-1",
            "structural_dir": "/Users/arnavkarnik/Documents/Classification2/SC_Connectomes_ABIDE2/SDSU_1_connectomes",
            "phenotype_csv": "/Users/arnavkarnik/Documents/Classification2/Phenotypes_ABIDE2/SDSU_1_phenotypes.csv"
        }
    ]

    group_esds = defaultdict(lambda: {'eigvals': [], 'esds': []})
    group_sdis = defaultdict(list)
    median_cutoffs = []

    for config in sites_config:
        print(f"\n=== Processing site: {config['site'].upper()} ===")
        process_site(config["structural_dir"], config["functional_dir"], config["phenotype_csv"], group_esds, median_cutoffs, group_sdis)

    global_median_cutoff = np.median(median_cutoffs)
    print(f"\nGlobal median eigenvalue cutoff: {global_median_cutoff:.5f}")

    # Plot ESD comparison
    plt.figure(figsize=(10, 6))
    for label, color in zip(['TD', 'ASD'], ['blue', 'red']):
        if group_esds[label]['esds']:
            plot_group_esd(group_esds[label], label, global_median_cutoff)
    
    # Add region labels once for the entire plot
    plt.axvspan(0, global_median_cutoff, color='#cceeff', alpha=0.6, label='Low eigenvalue region')
    plt.axvspan(global_median_cutoff, plt.xlim()[1], color='#ffcccc', alpha=0.6, label='High eigenvalue region')
    
    plt.legend(fontsize=12)
    plt.tight_layout()
    plt.savefig("esd_comparison_median_split.png", dpi=300)
    plt.show()

    # Plot SDI comparison
    plt.figure(figsize=(8, 6))
    sdi_data = []
    group_labels = []
    
    for label in ['TD', 'ASD']:
        if group_sdis[label]:
            sdi_data.extend(group_sdis[label])
            group_labels.extend([label] * len(group_sdis[label]))
    
    if sdi_data:
        sdi_df = pd.DataFrame({'SDI': sdi_data, 'Group': group_labels})
        
        # Filter out infinite and NaN values for plotting
        sdi_df = sdi_df[np.isfinite(sdi_df['SDI'])]
        
        if len(sdi_df) > 0:
            # Box plot
            sns.boxplot(data=sdi_df, x='Group', y='SDI', palette=['blue', 'red'])
            plt.title('Structure-Dynamics Index (SDI) Comparison\n(Low freq energy / High freq energy)', fontsize=14)
            plt.ylabel('SDI', fontsize=12)
            plt.xlabel('Group', fontsize=12)
            
            # Add statistical info
            td_sdi = np.array([sdi for sdi, group in zip(sdi_data, group_labels) if group == 'TD'])
            asd_sdi = np.array([sdi for sdi, group in zip(sdi_data, group_labels) if group == 'ASD'])
            
            # Filter finite values for statistics
            td_sdi = td_sdi[np.isfinite(td_sdi)]
            asd_sdi = asd_sdi[np.isfinite(asd_sdi)]
            
            if len(td_sdi) > 0 and len(asd_sdi) > 0:
                from scipy.stats import ttest_ind
                t_stat, p_val = ttest_ind(td_sdi, asd_sdi)
                plt.text(0.5, plt.ylim()[1]*0.9, f'p-value: {p_val:.4f}', 
                        ha='center', fontsize=10, bbox=dict(boxstyle='round', facecolor='wheat'))
            
            plt.tight_layout()
            plt.savefig("sdi_comparison.png", dpi=300)
            plt.show()
        else:
            print("No valid SDI values for plotting")
    else:
        print("No SDI data available")
        
        # Print SDI statistics
        print(f"\nSDI Statistics:")
        if len(td_sdi) > 0:
            # Filter out infinite and NaN values for statistics
            td_sdi_clean = td_sdi[np.isfinite(td_sdi)]
            print(f"TD SDI: mean={np.mean(td_sdi_clean):.4f}, std={np.std(td_sdi_clean):.4f}, n={len(td_sdi_clean)}")
        
        if len(asd_sdi) > 0:
            asd_sdi_clean = asd_sdi[np.isfinite(asd_sdi)]
            print(f"ASD SDI: mean={np.mean(asd_sdi_clean):.4f}, std={np.std(asd_sdi_clean):.4f}, n={len(asd_sdi_clean)}")
            
            # Only perform t-test if both groups have finite values
            if len(td_sdi_clean) > 0 and len(asd_sdi_clean) > 0:
                from scipy.stats import ttest_ind
                t_stat, p_val = ttest_ind(td_sdi_clean, asd_sdi_clean)
                print(f"T-test p-value: {p_val:.4f}")
        else:
            print("No valid ASD SDI values found")

    # Optional: Print some statistics about the split
    print(f"\nSplit statistics:")
    print(f"Global median eigenvalue: {global_median_cutoff:.5f}")
    print(f"Number of subjects contributing to median: {len(median_cutoffs)}")

if __name__ == "__main__":
    run_all_sites()

### SDI calculation with normalization after dropping the first eigen value -> Normalized

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.linalg import eigh
from collections import defaultdict

sns.set(style='whitegrid')

def compute_unnormalized_laplacian(A):
    D = np.diag(np.sum(A, axis=1))
    return D - A

def preprocess_structural_connectivity(A, threshold=0.05):
    A = np.copy(A)
    
    # Step 1: Drop small values
    A[A < threshold] = 0

    # Step 2: Normalize after thresholding
    max_val = np.max(A)
    if max_val > 0:
        A = A / max_val

    return A


def compute_functional_connectivity(X):
    """Compute functional connectivity matrix from time series with robust handling"""
    # Check for constant or zero variance time series
    X_std = np.std(X, axis=0)
    valid_regions = X_std > 1e-10  # Regions with non-zero variance
    
    if not np.any(valid_regions):
        # If all regions have zero variance, return identity matrix
        return np.eye(X.shape[1])
    
    # Only use regions with valid variance
    X_valid = X[:, valid_regions]
    
    # Compute correlation matrix for valid regions
    with np.errstate(invalid='ignore', divide='ignore'):
        fc_valid = np.corrcoef(X_valid.T)
        # Replace any NaN values with 0
        fc_valid = np.nan_to_num(fc_valid, nan=0.0, posinf=0.0, neginf=0.0)
    
    # Create full FC matrix and fill in valid correlations
    fc_full = np.zeros((X.shape[1], X.shape[1]))
    valid_idx = np.where(valid_regions)[0]
    
    for i, idx_i in enumerate(valid_idx):
        for j, idx_j in enumerate(valid_idx):
            fc_full[idx_i, idx_j] = fc_valid[i, j]
    
    # Set diagonal to 1 (self-correlation)
    np.fill_diagonal(fc_full, 1.0)
    
    return fc_full

def compute_esd_and_sdi(X, A):
    T, N = X.shape

    # Robust standardization
    X_mean = np.mean(X, axis=0)
    X_std = np.std(X, axis=0)
    
    # Handle regions with zero variance
    X_std[X_std < 1e-10] = 1.0  # Prevent division by zero
    X = (X - X_mean) / X_std

    L = compute_unnormalized_laplacian(A)
    
    # Add small regularization to avoid numerical issues
    L_reg = L + 1e-12 * np.eye(L.shape[0])
    
    # Eigen decomposition
    eigvals, eigvecs = eigh(L_reg)
    
    # Discard the first eigenpair (usually corresponds to eigenvalue ≈ 0)
    eigvals = eigvals[1:]
    eigvecs = eigvecs[:, 1:]

    # Project time series onto graph spectral domain
    X_hat = X @ eigvecs  # (T x N) @ (N x N-1) = T x (N-1)
    esd = np.sum(X_hat**2, axis=0)  # Energy per eigenmode

    # Use median eigenvalue as cutoff
    median_cutoff = np.median(eigvals)
    
    # Calculate SDI based on median split
    low_freq_mask = eigvals <= median_cutoff
    high_freq_mask = eigvals > median_cutoff

    low_freq_energy = np.sum(esd[low_freq_mask])
    high_freq_energy = np.sum(esd[high_freq_mask])

    # SDI = ratio of low frequency to high frequency energy
    if high_freq_energy == 0 or np.isnan(high_freq_energy) or np.isinf(high_freq_energy):
        if low_freq_energy == 0 or np.isnan(low_freq_energy) or np.isinf(low_freq_energy):
            sdi = 1.0
        else:
            sdi = np.inf
    else:
        sdi = low_freq_energy / high_freq_energy

    return eigvals, esd, median_cutoff, sdi


def extract_pid(fpath):
    fname = os.path.basename(fpath)
    if "sub-" in fname:
        return fname.split("sub-")[1].split("_")[0]
    else:
        return fname.split("_")[0]

def load_phenotype_csv(csv_path):
    df = pd.read_csv(csv_path)
    if 'SUB_ID' in df.columns:
        df['PatientID'] = df['SUB_ID'].astype(str)
    elif 'subject' in df.columns:
        df['PatientID'] = df['subject'].astype(str)
    return df.set_index('PatientID')

def collect_file_maps(structural_dir, functional_dir):
    struct_map, func_map = {}, {}
    for f in os.listdir(structural_dir):
        if f.endswith('.csv'):
            pid = extract_pid(f)
            struct_map[pid] = os.path.join(structural_dir, f)
    for root, _, files in os.walk(functional_dir):
        for f in files:
            if f.endswith('.csv'):
                pid = extract_pid(f)
                func_map[pid] = os.path.join(root, f)
    return struct_map, func_map

def plot_connectivity_matrices(fc_matrix, sc_matrix, pid, group_label):
    """Plot FC, SC, and difference (FC - SC) matrices in three panels"""
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5))
    
    # Plot FC matrix
    im1 = ax1.imshow(fc_matrix, cmap='RdBu_r', vmin=-1, vmax=1)
    ax1.set_title(f'Patient {pid}: fMRI Functional Connectivity')
    ax1.set_xlabel('Brain Region')
    ax1.set_ylabel('Brain Region')
    cbar1 = plt.colorbar(im1, ax=ax1, shrink=0.6)
    cbar1.set_label('Correlation', rotation=270, labelpad=15)
    
    # Plot SC matrix (no normalization)
    im2 = ax2.imshow(sc_matrix, cmap='Reds', vmin=0)
    ax2.set_title(f'Patient {pid}: Structural Connectivity')
    ax2.set_xlabel('Brain Region')
    ax2.set_ylabel('Brain Region')
    cbar2 = plt.colorbar(im2, ax=ax2, shrink=0.6)
    cbar2.set_label('Connection Strength', rotation=270, labelpad=15)
    
    # Plot difference (FC - SC)
    diff_matrix = fc_matrix - sc_matrix
    # Use symmetric colormap range for difference, avoid issues with empty matrices
    if np.any(np.isfinite(diff_matrix)):
        diff_max = np.max(np.abs(diff_matrix[np.isfinite(diff_matrix)]))
        if diff_max == 0:
            diff_max = 1.0  # Avoid zero range
    else:
        diff_max = 1.0
    
    im3 = ax3.imshow(diff_matrix, cmap='RdBu_r', vmin=-diff_max, vmax=diff_max)
    ax3.set_title(f'Patient {pid}: Difference (FC - SC)')
    ax3.set_xlabel('Brain Region')
    ax3.set_ylabel('Brain Region')
    cbar3 = plt.colorbar(im3, ax=ax3, shrink=0.6)
    cbar3.set_label('Difference', rotation=270, labelpad=15)
    
    plt.tight_layout()
    plt.savefig(f"connectivity_comparison_{pid}_{group_label}.png", dpi=150, bbox_inches='tight')
    plt.show()

def plot_group_esd(group_esds, group_name, median_cutoff):
    # Get one subject's eigenvalues (already trimmed to exclude the first)
    eigvals = group_esds['eigvals'][0]

    # Stack ESDs and remove first entry (already excluded if using updated compute_esd_and_sdi)
    esd_matrix = np.vstack(group_esds['esds'])  # Shape: (num_subjects, N-1)

    # Compute summary statistics
    median_esd = np.median(esd_matrix, axis=0)
    lower_q = np.percentile(esd_matrix, 25, axis=0)
    upper_q = np.percentile(esd_matrix, 75, axis=0)

    # Plot
    plt.plot(eigvals, median_esd, label=group_name, lw=2)
    plt.fill_between(eigvals, lower_q, upper_q, alpha=0.25)

    # Shade low and high frequency regions
    plt.axvspan(min(eigvals), median_cutoff, color='#cceeff', alpha=0.6)
    plt.axvspan(median_cutoff, max(eigvals), color='#ffcccc', alpha=0.6)

    plt.xscale('log')
    plt.yscale('log')
    plt.xlabel(r'$\lambda$ (eigenvalue)', fontsize=14)
    plt.ylabel(r'$\xi$ (energy spectral density)', fontsize=14)
    plt.title('Energy Spectral Density (excluding 1st eigenvalue)', fontsize=14)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.grid(True, which='both', ls='--', lw=0.5)


def process_site(structural_dir, functional_dir, phenotype_csv, group_esds, median_cutoffs, group_sdis):
    structural_map, functional_map = collect_file_maps(structural_dir, functional_dir)
    phenotype_df = load_phenotype_csv(phenotype_csv)
    common_pids = sorted(set(structural_map.keys()) & set(functional_map.keys()) & set(phenotype_df.index.astype(str)))

    print(f"Found {len(common_pids)} valid subjects at site.")
    
    # Plot connectivity matrices for first few subjects as examples
    plot_count = 0
    max_plots = 2  # Plot matrices for first 2 subjects per group per site

    for pid in common_pids:
        try:
            A = np.loadtxt(structural_map[pid], delimiter=',')
            X = np.loadtxt(functional_map[pid], delimiter=',')

            if A.shape[0] != A.shape[1] or X.shape[1] != A.shape[0]:
                continue

            A = preprocess_structural_connectivity(A, threshold=0.05)
            FC = compute_functional_connectivity(X)
            eigvals, esd, cutoff_value, sdi = compute_esd_and_sdi(X, A)

            group = phenotype_df.loc[pid, 'DX_GROUP']
            label = 'TD' if group == 1 else 'ASD'

            group_esds[label]['eigvals'].append(eigvals)
            group_esds[label]['esds'].append(esd)
            group_sdis[label].append(sdi)

            median_cutoffs.append(cutoff_value)
            
            # Plot connectivity matrices for first few subjects
            if plot_count < max_plots:
                plot_connectivity_matrices(FC, A, pid, label)
                plot_count += 1

        except Exception as e:
            print(f"Error processing {pid}: {e}")

def run_all_sites():
    sites_config = [
        {
            "site": "bni",
            "functional_dir": "/Users/arnavkarnik/Documents/Classification2/Time_Series_ABIDE2/bni_time_series/schaefer_400/cleaned-1",
            "structural_dir": "/Users/arnavkarnik/Documents/Classification2/SC_Connectomes_ABIDE2/BNI_1_connectomes",
            "phenotype_csv": "/Users/arnavkarnik/Documents/Classification2/Phenotypes_ABIDE2/BNI_1_phenotypes.csv"
        },
        {
            "site": "ip",
            "functional_dir": "/Users/arnavkarnik/Documents/Classification2/Time_Series_ABIDE2/ip_time_series/schaefer_400/cleaned-1",
            "structural_dir": "/Users/arnavkarnik/Documents/Classification2/SC_Connectomes_ABIDE2/IP_1_connectomes",
            "phenotype_csv": "/Users/arnavkarnik/Documents/Classification2/Phenotypes_ABIDE2/IP_1_phenotypes.csv"
        },
        {
            "site": "nyu1",
            "functional_dir": "/Users/arnavkarnik/Documents/Classification2/Time_Series_ABIDE2/nyu1_time_series/schaefer_400/cleaned-1",
            "structural_dir": "/Users/arnavkarnik/Documents/Classification2/SC_Connectomes_ABIDE2/NYU_1_connectomes",
            "phenotype_csv": "/Users/arnavkarnik/Documents/Classification2/Phenotypes_ABIDE2/NYU_1_phenotypes.csv"
        },
        {
            "site": "nyu2",
            "functional_dir": "/Users/arnavkarnik/Documents/Classification2/Time_Series_ABIDE2/nyu2_time_series/schaefer_400/cleaned-1",
            "structural_dir": "/Users/arnavkarnik/Documents/Classification2/SC_Connectomes_ABIDE2/NYU_2_connectomes",
            "phenotype_csv": "/Users/arnavkarnik/Documents/Classification2/Phenotypes_ABIDE2/NYU_2_phenotypes.csv"
        },
        {
            "site": "sdsu",
            "functional_dir": "/Users/arnavkarnik/Documents/Classification2/Time_Series_ABIDE2/sdsu_time_series/schaefer_400/cleaned-1",
            "structural_dir": "/Users/arnavkarnik/Documents/Classification2/SC_Connectomes_ABIDE2/SDSU_1_connectomes",
            "phenotype_csv": "/Users/arnavkarnik/Documents/Classification2/Phenotypes_ABIDE2/SDSU_1_phenotypes.csv"
        }
    ]

    group_esds = defaultdict(lambda: {'eigvals': [], 'esds': []})
    group_sdis = defaultdict(list)
    median_cutoffs = []

    for config in sites_config:
        print(f"\n=== Processing site: {config['site'].upper()} ===")
        process_site(config["structural_dir"], config["functional_dir"], config["phenotype_csv"], group_esds, median_cutoffs, group_sdis)

    global_median_cutoff = np.median(median_cutoffs)
    print(f"\nGlobal median eigenvalue cutoff: {global_median_cutoff:.5f}")

    # Plot ESD comparison
    plt.figure(figsize=(10, 6))
    for label, color in zip(['TD', 'ASD'], ['blue', 'red']):
        if group_esds[label]['esds']:
            plot_group_esd(group_esds[label], label, global_median_cutoff)
    
    # Add region labels once for the entire plot
    plt.axvspan(0, global_median_cutoff, color='#cceeff', alpha=0.6, label='Low eigenvalue region')
    plt.axvspan(global_median_cutoff, plt.xlim()[1], color='#ffcccc', alpha=0.6, label='High eigenvalue region')
    
    plt.legend(fontsize=12)
    plt.tight_layout()
    plt.savefig("esd_comparison_median_split.png", dpi=300)
    plt.show()

    # Plot SDI comparison
    plt.figure(figsize=(8, 6))
    sdi_data = []
    group_labels = []
    
    for label in ['TD', 'ASD']:
        if group_sdis[label]:
            sdi_data.extend(group_sdis[label])
            group_labels.extend([label] * len(group_sdis[label]))
    
    if sdi_data:
        sdi_df = pd.DataFrame({'SDI': sdi_data, 'Group': group_labels})
        
        # Filter out infinite and NaN values for plotting
        sdi_df = sdi_df[np.isfinite(sdi_df['SDI'])]
        
        if len(sdi_df) > 0:
            # Box plot
            sns.boxplot(data=sdi_df, x='Group', y='SDI', palette=['blue', 'red'])
            plt.title('Structure-Dynamics Index (SDI) Comparison\n(Low freq energy / High freq energy)', fontsize=14)
            plt.ylabel('SDI', fontsize=12)
            plt.xlabel('Group', fontsize=12)
            
            # Add statistical info
            td_sdi = np.array([sdi for sdi, group in zip(sdi_data, group_labels) if group == 'TD'])
            asd_sdi = np.array([sdi for sdi, group in zip(sdi_data, group_labels) if group == 'ASD'])
            
            # Filter finite values for statistics
            td_sdi = td_sdi[np.isfinite(td_sdi)]
            asd_sdi = asd_sdi[np.isfinite(asd_sdi)]
            
            if len(td_sdi) > 0 and len(asd_sdi) > 0:
                from scipy.stats import ttest_ind
                t_stat, p_val = ttest_ind(td_sdi, asd_sdi)
                plt.text(0.5, plt.ylim()[1]*0.9, f'p-value: {p_val:.4f}', 
                        ha='center', fontsize=10, bbox=dict(boxstyle='round', facecolor='wheat'))
            
            plt.tight_layout()
            plt.savefig("sdi_comparison.png", dpi=300)
            plt.show()
        else:
            print("No valid SDI values for plotting")
    else:
        print("No SDI data available")
        
        # Print SDI statistics
        print(f"\nSDI Statistics:")
        if len(td_sdi) > 0:
            # Filter out infinite and NaN values for statistics
            td_sdi_clean = td_sdi[np.isfinite(td_sdi)]
            print(f"TD SDI: mean={np.mean(td_sdi_clean):.4f}, std={np.std(td_sdi_clean):.4f}, n={len(td_sdi_clean)}")
        
        if len(asd_sdi) > 0:
            asd_sdi_clean = asd_sdi[np.isfinite(asd_sdi)]
            print(f"ASD SDI: mean={np.mean(asd_sdi_clean):.4f}, std={np.std(asd_sdi_clean):.4f}, n={len(asd_sdi_clean)}")
            
            # Only perform t-test if both groups have finite values
            if len(td_sdi_clean) > 0 and len(asd_sdi_clean) > 0:
                from scipy.stats import ttest_ind
                t_stat, p_val = ttest_ind(td_sdi_clean, asd_sdi_clean)
                print(f"T-test p-value: {p_val:.4f}")
        else:
            print("No valid ASD SDI values found")

    # Optional: Print some statistics about the split
    print(f"\nSplit statistics:")
    print(f"Global median eigenvalue: {global_median_cutoff:.5f}")
    print(f"Number of subjects contributing to median: {len(median_cutoffs)}")

if __name__ == "__main__":
    run_all_sites()

### SDI calculation with normalization after dropping the first eigen value -> Not  Normalized

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.linalg import eigh
from collections import defaultdict

sns.set(style='whitegrid')

def compute_unnormalized_laplacian(A):
    D = np.diag(np.sum(A, axis=1))
    return D - A

def preprocess_structural_connectivity(A, threshold=0.05):
    A = np.copy(A)
    return A


def compute_functional_connectivity(X):
    """Compute functional connectivity matrix from time series with robust handling"""
    # Check for constant or zero variance time series
    X_std = np.std(X, axis=0)
    valid_regions = X_std > 1e-10  # Regions with non-zero variance
    
    if not np.any(valid_regions):
        # If all regions have zero variance, return identity matrix
        return np.eye(X.shape[1])
    
    # Only use regions with valid variance
    X_valid = X[:, valid_regions]
    
    # Compute correlation matrix for valid regions
    with np.errstate(invalid='ignore', divide='ignore'):
        fc_valid = np.corrcoef(X_valid.T)
        # Replace any NaN values with 0
        fc_valid = np.nan_to_num(fc_valid, nan=0.0, posinf=0.0, neginf=0.0)
    
    # Create full FC matrix and fill in valid correlations
    fc_full = np.zeros((X.shape[1], X.shape[1]))
    valid_idx = np.where(valid_regions)[0]
    
    for i, idx_i in enumerate(valid_idx):
        for j, idx_j in enumerate(valid_idx):
            fc_full[idx_i, idx_j] = fc_valid[i, j]
    
    # Set diagonal to 1 (self-correlation)
    np.fill_diagonal(fc_full, 1.0)
    
    return fc_full

def compute_esd_and_sdi(X, A):
    T, N = X.shape

    # Robust standardization
    X_mean = np.mean(X, axis=0)
    X_std = np.std(X, axis=0)
    
    # Handle regions with zero variance
    X_std[X_std < 1e-10] = 1.0  # Prevent division by zero
    X = (X - X_mean) / X_std

    L = compute_unnormalized_laplacian(A)
    
    # Add small regularization to avoid numerical issues
    L_reg = L + 1e-12 * np.eye(L.shape[0])
    
    # Eigen decomposition
    eigvals, eigvecs = eigh(L_reg)
    
    # Discard the first eigenpair (usually corresponds to eigenvalue ≈ 0)
    eigvals = eigvals[1:]
    eigvecs = eigvecs[:, 1:]

    # Project time series onto graph spectral domain
    X_hat = X @ eigvecs  # (T x N) @ (N x N-1) = T x (N-1)
    esd = np.sum(X_hat**2, axis=0)  # Energy per eigenmode

    # Use median eigenvalue as cutoff
    median_cutoff = np.median(eigvals)
    
    # Calculate SDI based on median split
    low_freq_mask = eigvals <= median_cutoff
    high_freq_mask = eigvals > median_cutoff

    low_freq_energy = np.sum(esd[low_freq_mask])
    high_freq_energy = np.sum(esd[high_freq_mask])

    # SDI = ratio of low frequency to high frequency energy
    if high_freq_energy == 0 or np.isnan(high_freq_energy) or np.isinf(high_freq_energy):
        if low_freq_energy == 0 or np.isnan(low_freq_energy) or np.isinf(low_freq_energy):
            sdi = 1.0
        else:
            sdi = np.inf
    else:
        sdi = low_freq_energy / high_freq_energy

    return eigvals, esd, median_cutoff, sdi


def extract_pid(fpath):
    fname = os.path.basename(fpath)
    if "sub-" in fname:
        return fname.split("sub-")[1].split("_")[0]
    else:
        return fname.split("_")[0]

def load_phenotype_csv(csv_path):
    df = pd.read_csv(csv_path)
    if 'SUB_ID' in df.columns:
        df['PatientID'] = df['SUB_ID'].astype(str)
    elif 'subject' in df.columns:
        df['PatientID'] = df['subject'].astype(str)
    return df.set_index('PatientID')

def collect_file_maps(structural_dir, functional_dir):
    struct_map, func_map = {}, {}
    for f in os.listdir(structural_dir):
        if f.endswith('.csv'):
            pid = extract_pid(f)
            struct_map[pid] = os.path.join(structural_dir, f)
    for root, _, files in os.walk(functional_dir):
        for f in files:
            if f.endswith('.csv'):
                pid = extract_pid(f)
                func_map[pid] = os.path.join(root, f)
    return struct_map, func_map

def plot_connectivity_matrices(fc_matrix, sc_matrix, pid, group_label):
    """Plot FC, SC, and difference (FC - SC) matrices in three panels"""
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5))
    
    # Plot FC matrix
    im1 = ax1.imshow(fc_matrix, cmap='RdBu_r', vmin=-1, vmax=1)
    ax1.set_title(f'Patient {pid}: fMRI Functional Connectivity')
    ax1.set_xlabel('Brain Region')
    ax1.set_ylabel('Brain Region')
    cbar1 = plt.colorbar(im1, ax=ax1, shrink=0.6)
    cbar1.set_label('Correlation', rotation=270, labelpad=15)
    
    # Plot SC matrix (no normalization)
    im2 = ax2.imshow(sc_matrix, cmap='Reds', vmin=0)
    ax2.set_title(f'Patient {pid}: Structural Connectivity')
    ax2.set_xlabel('Brain Region')
    ax2.set_ylabel('Brain Region')
    cbar2 = plt.colorbar(im2, ax=ax2, shrink=0.6)
    cbar2.set_label('Connection Strength', rotation=270, labelpad=15)
    
    # Plot difference (FC - SC)
    diff_matrix = fc_matrix - sc_matrix
    # Use symmetric colormap range for difference, avoid issues with empty matrices
    if np.any(np.isfinite(diff_matrix)):
        diff_max = np.max(np.abs(diff_matrix[np.isfinite(diff_matrix)]))
        if diff_max == 0:
            diff_max = 1.0  # Avoid zero range
    else:
        diff_max = 1.0
    
    im3 = ax3.imshow(diff_matrix, cmap='RdBu_r', vmin=-diff_max, vmax=diff_max)
    ax3.set_title(f'Patient {pid}: Difference (FC - SC)')
    ax3.set_xlabel('Brain Region')
    ax3.set_ylabel('Brain Region')
    cbar3 = plt.colorbar(im3, ax=ax3, shrink=0.6)
    cbar3.set_label('Difference', rotation=270, labelpad=15)
    
    plt.tight_layout()
    plt.savefig(f"connectivity_comparison_{pid}_{group_label}.png", dpi=150, bbox_inches='tight')
    plt.show()

def plot_group_esd(group_esds, group_name, median_cutoff):
    # Get one subject's eigenvalues (already trimmed to exclude the first)
    eigvals = group_esds['eigvals'][0]

    # Stack ESDs and remove first entry (already excluded if using updated compute_esd_and_sdi)
    esd_matrix = np.vstack(group_esds['esds'])  # Shape: (num_subjects, N-1)

    # Compute summary statistics
    median_esd = np.median(esd_matrix, axis=0)
    lower_q = np.percentile(esd_matrix, 25, axis=0)
    upper_q = np.percentile(esd_matrix, 75, axis=0)

    # Plot
    plt.plot(eigvals, median_esd, label=group_name, lw=2)
    plt.fill_between(eigvals, lower_q, upper_q, alpha=0.25)

    # Shade low and high frequency regions
    plt.axvspan(min(eigvals), median_cutoff, color='#cceeff', alpha=0.6)
    plt.axvspan(median_cutoff, max(eigvals), color='#ffcccc', alpha=0.6)

    plt.xscale('log')
    plt.yscale('log')
    plt.xlabel(r'$\lambda$ (eigenvalue)', fontsize=14)
    plt.ylabel(r'$\xi$ (energy spectral density)', fontsize=14)
    plt.title('Energy Spectral Density (excluding 1st eigenvalue)', fontsize=14)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.grid(True, which='both', ls='--', lw=0.5)


def process_site(structural_dir, functional_dir, phenotype_csv, group_esds, median_cutoffs, group_sdis):
    structural_map, functional_map = collect_file_maps(structural_dir, functional_dir)
    phenotype_df = load_phenotype_csv(phenotype_csv)
    common_pids = sorted(set(structural_map.keys()) & set(functional_map.keys()) & set(phenotype_df.index.astype(str)))

    print(f"Found {len(common_pids)} valid subjects at site.")
    
    # Plot connectivity matrices for first few subjects as examples
    plot_count = 0
    max_plots = 2  # Plot matrices for first 2 subjects per group per site

    for pid in common_pids:
        try:
            A = np.loadtxt(structural_map[pid], delimiter=',')
            X = np.loadtxt(functional_map[pid], delimiter=',')

            if A.shape[0] != A.shape[1] or X.shape[1] != A.shape[0]:
                continue

            A = preprocess_structural_connectivity(A, threshold=0.05)
            FC = compute_functional_connectivity(X)
            eigvals, esd, cutoff_value, sdi = compute_esd_and_sdi(X, A)

            group = phenotype_df.loc[pid, 'DX_GROUP']
            label = 'TD' if group == 1 else 'ASD'

            group_esds[label]['eigvals'].append(eigvals)
            group_esds[label]['esds'].append(esd)
            group_sdis[label].append(sdi)

            median_cutoffs.append(cutoff_value)
            
            # Plot connectivity matrices for first few subjects
            if plot_count < max_plots:
                plot_connectivity_matrices(FC, A, pid, label)
                plot_count += 1

        except Exception as e:
            print(f"Error processing {pid}: {e}")

def run_all_sites():
    sites_config = [
        {
            "site": "bni",
            "functional_dir": "/Users/arnavkarnik/Documents/Classification2/Time_Series_ABIDE2/bni_time_series/schaefer_400/cleaned-1",
            "structural_dir": "/Users/arnavkarnik/Documents/Classification2/SC_Connectomes_ABIDE2/BNI_1_connectomes",
            "phenotype_csv": "/Users/arnavkarnik/Documents/Classification2/Phenotypes_ABIDE2/BNI_1_phenotypes.csv"
        },
        {
            "site": "ip",
            "functional_dir": "/Users/arnavkarnik/Documents/Classification2/Time_Series_ABIDE2/ip_time_series/schaefer_400/cleaned-1",
            "structural_dir": "/Users/arnavkarnik/Documents/Classification2/SC_Connectomes_ABIDE2/IP_1_connectomes",
            "phenotype_csv": "/Users/arnavkarnik/Documents/Classification2/Phenotypes_ABIDE2/IP_1_phenotypes.csv"
        },
        {
            "site": "nyu1",
            "functional_dir": "/Users/arnavkarnik/Documents/Classification2/Time_Series_ABIDE2/nyu1_time_series/schaefer_400/cleaned-1",
            "structural_dir": "/Users/arnavkarnik/Documents/Classification2/SC_Connectomes_ABIDE2/NYU_1_connectomes",
            "phenotype_csv": "/Users/arnavkarnik/Documents/Classification2/Phenotypes_ABIDE2/NYU_1_phenotypes.csv"
        },
        {
            "site": "nyu2",
            "functional_dir": "/Users/arnavkarnik/Documents/Classification2/Time_Series_ABIDE2/nyu2_time_series/schaefer_400/cleaned-1",
            "structural_dir": "/Users/arnavkarnik/Documents/Classification2/SC_Connectomes_ABIDE2/NYU_2_connectomes",
            "phenotype_csv": "/Users/arnavkarnik/Documents/Classification2/Phenotypes_ABIDE2/NYU_2_phenotypes.csv"
        },
        {
            "site": "sdsu",
            "functional_dir": "/Users/arnavkarnik/Documents/Classification2/Time_Series_ABIDE2/sdsu_time_series/schaefer_400/cleaned-1",
            "structural_dir": "/Users/arnavkarnik/Documents/Classification2/SC_Connectomes_ABIDE2/SDSU_1_connectomes",
            "phenotype_csv": "/Users/arnavkarnik/Documents/Classification2/Phenotypes_ABIDE2/SDSU_1_phenotypes.csv"
        }
    ]

    group_esds = defaultdict(lambda: {'eigvals': [], 'esds': []})
    group_sdis = defaultdict(list)
    median_cutoffs = []

    for config in sites_config:
        print(f"\n=== Processing site: {config['site'].upper()} ===")
        process_site(config["structural_dir"], config["functional_dir"], config["phenotype_csv"], group_esds, median_cutoffs, group_sdis)

    global_median_cutoff = np.median(median_cutoffs)
    print(f"\nGlobal median eigenvalue cutoff: {global_median_cutoff:.5f}")

    # Plot ESD comparison
    plt.figure(figsize=(10, 6))
    for label, color in zip(['TD', 'ASD'], ['blue', 'red']):
        if group_esds[label]['esds']:
            plot_group_esd(group_esds[label], label, global_median_cutoff)
    
    # Add region labels once for the entire plot
    plt.axvspan(0, global_median_cutoff, color='#cceeff', alpha=0.6, label='Low eigenvalue region')
    plt.axvspan(global_median_cutoff, plt.xlim()[1], color='#ffcccc', alpha=0.6, label='High eigenvalue region')
    
    plt.legend(fontsize=12)
    plt.tight_layout()
    plt.savefig("esd_comparison_median_split.png", dpi=300)
    plt.show()

    # Plot SDI comparison
    plt.figure(figsize=(8, 6))
    sdi_data = []
    group_labels = []
    
    for label in ['TD', 'ASD']:
        if group_sdis[label]:
            sdi_data.extend(group_sdis[label])
            group_labels.extend([label] * len(group_sdis[label]))
    
    if sdi_data:
        sdi_df = pd.DataFrame({'SDI': sdi_data, 'Group': group_labels})
        
        # Filter out infinite and NaN values for plotting
        sdi_df = sdi_df[np.isfinite(sdi_df['SDI'])]
        
        if len(sdi_df) > 0:
            # Box plot
            sns.boxplot(data=sdi_df, x='Group', y='SDI', palette=['blue', 'red'])
            plt.title('Structure-Dynamics Index (SDI) Comparison\n(Low freq energy / High freq energy)', fontsize=14)
            plt.ylabel('SDI', fontsize=12)
            plt.xlabel('Group', fontsize=12)
            
            # Add statistical info
            td_sdi = np.array([sdi for sdi, group in zip(sdi_data, group_labels) if group == 'TD'])
            asd_sdi = np.array([sdi for sdi, group in zip(sdi_data, group_labels) if group == 'ASD'])
            
            # Filter finite values for statistics
            td_sdi = td_sdi[np.isfinite(td_sdi)]
            asd_sdi = asd_sdi[np.isfinite(asd_sdi)]
            
            if len(td_sdi) > 0 and len(asd_sdi) > 0:
                from scipy.stats import ttest_ind
                t_stat, p_val = ttest_ind(td_sdi, asd_sdi)
                plt.text(0.5, plt.ylim()[1]*0.9, f'p-value: {p_val:.4f}', 
                        ha='center', fontsize=10, bbox=dict(boxstyle='round', facecolor='wheat'))
            
            plt.tight_layout()
            plt.savefig("sdi_comparison.png", dpi=300)
            plt.show()
        else:
            print("No valid SDI values for plotting")
    else:
        print("No SDI data available")
        
        # Print SDI statistics
        print(f"\nSDI Statistics:")
        if len(td_sdi) > 0:
            # Filter out infinite and NaN values for statistics
            td_sdi_clean = td_sdi[np.isfinite(td_sdi)]
            print(f"TD SDI: mean={np.mean(td_sdi_clean):.4f}, std={np.std(td_sdi_clean):.4f}, n={len(td_sdi_clean)}")
        
        if len(asd_sdi) > 0:
            asd_sdi_clean = asd_sdi[np.isfinite(asd_sdi)]
            print(f"ASD SDI: mean={np.mean(asd_sdi_clean):.4f}, std={np.std(asd_sdi_clean):.4f}, n={len(asd_sdi_clean)}")
            
            # Only perform t-test if both groups have finite values
            if len(td_sdi_clean) > 0 and len(asd_sdi_clean) > 0:
                from scipy.stats import ttest_ind
                t_stat, p_val = ttest_ind(td_sdi_clean, asd_sdi_clean)
                print(f"T-test p-value: {p_val:.4f}")
        else:
            print("No valid ASD SDI values found")

    # Optional: Print some statistics about the split
    print(f"\nSplit statistics:")
    print(f"Global median eigenvalue: {global_median_cutoff:.5f}")
    print(f"Number of subjects contributing to median: {len(median_cutoffs)}")

if __name__ == "__main__":
    run_all_sites()

### Split based on energy and not eigenvalue -> Normalized

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.linalg import eigh
from collections import defaultdict

sns.set(style='whitegrid')

def compute_unnormalized_laplacian(A):
    D = np.diag(np.sum(A, axis=1))
    return D - A


def preprocess_structural_connectivity(A, threshold=0.05):
    A = np.copy(A)
    
    # Step 1: Drop small values
    A[A < threshold] = 0

    # Step 2: Normalize after thresholding
    max_val = np.max(A)
    if max_val > 0:
        A = A / max_val

    return A

def compute_functional_connectivity(X):
    X_std = np.std(X, axis=0)
    valid_regions = X_std > 1e-10
    if not np.any(valid_regions):
        return np.eye(X.shape[1])
    X_valid = X[:, valid_regions]
    with np.errstate(invalid='ignore', divide='ignore'):
        fc_valid = np.corrcoef(X_valid.T)
        fc_valid = np.nan_to_num(fc_valid, nan=0.0, posinf=0.0, neginf=0.0)
    fc_full = np.zeros((X.shape[1], X.shape[1]))
    valid_idx = np.where(valid_regions)[0]
    for i, idx_i in enumerate(valid_idx):
        for j, idx_j in enumerate(valid_idx):
            fc_full[idx_i, idx_j] = fc_valid[i, j]
    np.fill_diagonal(fc_full, 1.0)
    return fc_full

def compute_esd_and_sdi(X, A):
    T, N = X.shape
    X_mean = np.mean(X, axis=0)
    X_std = np.std(X, axis=0)
    X_std[X_std < 1e-10] = 1.0
    X = (X - X_mean) / X_std
    L = compute_unnormalized_laplacian(A)
    L_reg = L + 1e-12 * np.eye(N)
    eigvals, eigvecs = eigh(L_reg)
    eigvals = eigvals[1:]
    eigvecs = eigvecs[:, 1:]
    X_hat = X @ eigvecs
    esd = np.sum(X_hat ** 2, axis=0)
    cumulative_energy = np.cumsum(esd)
    total_energy = cumulative_energy[-1]
    norm_cum_energy = cumulative_energy / total_energy
    cutoff_idx = np.searchsorted(norm_cum_energy, 0.5)
    cutoff_value = eigvals[cutoff_idx]
    low_freq_mask = np.arange(len(eigvals)) <= cutoff_idx
    high_freq_mask = np.arange(len(eigvals)) > cutoff_idx
    low_freq_energy = np.sum(esd[low_freq_mask])
    high_freq_energy = np.sum(esd[high_freq_mask])
    if high_freq_energy == 0 or not np.isfinite(high_freq_energy):
        sdi = 1.0 if low_freq_energy == 0 or not np.isfinite(low_freq_energy) else np.inf
    else:
        sdi = low_freq_energy / high_freq_energy
    return eigvals, esd, cutoff_value, sdi

def extract_pid(fpath):
    fname = os.path.basename(fpath)
    return fname.split("sub-")[1].split("_")[0] if "sub-" in fname else fname.split("_")[0]

def load_phenotype_csv(csv_path):
    df = pd.read_csv(csv_path)
    if 'SUB_ID' in df.columns:
        df['PatientID'] = df['SUB_ID'].astype(str)
    elif 'subject' in df.columns:
        df['PatientID'] = df['subject'].astype(str)
    return df.set_index('PatientID')

def collect_file_maps(structural_dir, functional_dir):
    struct_map, func_map = {}, {}
    for f in os.listdir(structural_dir):
        if f.endswith('.csv'):
            pid = extract_pid(f)
            struct_map[pid] = os.path.join(structural_dir, f)
    for root, _, files in os.walk(functional_dir):
        for f in files:
            if f.endswith('.csv'):
                pid = extract_pid(f)
                func_map[pid] = os.path.join(root, f)
    return struct_map, func_map

def plot_connectivity_matrices(fc_matrix, sc_matrix, pid, group_label):
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5))
    im1 = ax1.imshow(fc_matrix, cmap='RdBu_r', vmin=-1, vmax=1)
    ax1.set_title(f'Patient {pid}: fMRI Functional Connectivity')
    cbar1 = plt.colorbar(im1, ax=ax1, shrink=0.6)
    cbar1.set_label('Correlation', rotation=270, labelpad=15)
    im2 = ax2.imshow(sc_matrix, cmap='Reds', vmin=0)
    ax2.set_title(f'Patient {pid}: Structural Connectivity')
    cbar2 = plt.colorbar(im2, ax=ax2, shrink=0.6)
    cbar2.set_label('Connection Strength', rotation=270, labelpad=15)
    diff_matrix = fc_matrix - sc_matrix
    diff_max = np.max(np.abs(diff_matrix[np.isfinite(diff_matrix)])) if np.any(np.isfinite(diff_matrix)) else 1.0
    diff_max = diff_max if diff_max > 0 else 1.0
    im3 = ax3.imshow(diff_matrix, cmap='RdBu_r', vmin=-diff_max, vmax=diff_max)
    ax3.set_title(f'Patient {pid}: Difference (FC - SC)')
    cbar3 = plt.colorbar(im3, ax=ax3, shrink=0.6)
    cbar3.set_label('Difference', rotation=270, labelpad=15)
    plt.tight_layout()
    plt.savefig(f"connectivity_comparison_{pid}_{group_label}.png", dpi=150)
    plt.show()

def plot_group_esd(group_esds, group_name, energy_split_cutoff):
    eigvals = group_esds['eigvals'][0]
    esd_matrix = np.vstack(group_esds['esds'])
    median_esd = np.median(esd_matrix, axis=0)
    lower_q = np.percentile(esd_matrix, 25, axis=0)
    upper_q = np.percentile(esd_matrix, 75, axis=0)
    plt.plot(eigvals, median_esd, label=group_name, lw=2)
    plt.fill_between(eigvals, lower_q, upper_q, alpha=0.25)
    plt.axvspan(min(eigvals), energy_split_cutoff, color='#cceeff', alpha=0.6)
    plt.axvspan(energy_split_cutoff, max(eigvals), color='#ffcccc', alpha=0.6)
    plt.xscale('log')
    plt.yscale('log')
    plt.xlabel(r'$\lambda$ (eigenvalue)', fontsize=14)
    plt.ylabel(r'$\xi$ (energy spectral density)', fontsize=14)
    plt.title('Energy Spectral Density (50% ESD Split)', fontsize=14)
    plt.grid(True, which='both', ls='--', lw=0.5)

def process_site(structural_dir, functional_dir, phenotype_csv, group_esds, energy_cutoffs, group_sdis):
    structural_map, functional_map = collect_file_maps(structural_dir, functional_dir)
    phenotype_df = load_phenotype_csv(phenotype_csv)
    common_pids = sorted(set(structural_map.keys()) & set(functional_map.keys()) & set(phenotype_df.index.astype(str)))
    print(f"Found {len(common_pids)} valid subjects at site.")
    plot_count = 0
    for pid in common_pids:
        try:
            A = np.loadtxt(structural_map[pid], delimiter=',')
            X = np.loadtxt(functional_map[pid], delimiter=',')
            if A.shape[0] != A.shape[1] or X.shape[1] != A.shape[0]:
                continue
            A = preprocess_structural_connectivity(A, threshold=0.05)
            FC = compute_functional_connectivity(X)
            eigvals, esd, cutoff_value, sdi = compute_esd_and_sdi(X, A)
            group = phenotype_df.loc[pid, 'DX_GROUP']
            label = 'TD' if group == 1 else 'ASD'
            group_esds[label]['eigvals'].append(eigvals)
            group_esds[label]['esds'].append(esd)
            group_sdis[label].append(sdi)
            energy_cutoffs.append(cutoff_value)
            if plot_count < 2:
                plot_connectivity_matrices(FC, A, pid, label)
                plot_count += 1
        except Exception as e:
            print(f"Error processing {pid}: {e}")

def run_all_sites():
    sites_config = [
        {
            "site": "bni",
            "functional_dir": "/Users/arnavkarnik/Documents/Classification2/Time_Series_ABIDE2/bni_time_series/schaefer_400/cleaned-1",
            "structural_dir": "/Users/arnavkarnik/Documents/Classification2/SC_Connectomes_ABIDE2/BNI_1_connectomes",
            "phenotype_csv": "/Users/arnavkarnik/Documents/Classification2/Phenotypes_ABIDE2/BNI_1_phenotypes.csv"
        },
        {
            "site": "ip",
            "functional_dir": "/Users/arnavkarnik/Documents/Classification2/Time_Series_ABIDE2/ip_time_series/schaefer_400/cleaned-1",
            "structural_dir": "/Users/arnavkarnik/Documents/Classification2/SC_Connectomes_ABIDE2/IP_1_connectomes",
            "phenotype_csv": "/Users/arnavkarnik/Documents/Classification2/Phenotypes_ABIDE2/IP_1_phenotypes.csv"
        },
        {
            "site": "nyu1",
            "functional_dir": "/Users/arnavkarnik/Documents/Classification2/Time_Series_ABIDE2/nyu1_time_series/schaefer_400/cleaned-1",
            "structural_dir": "/Users/arnavkarnik/Documents/Classification2/SC_Connectomes_ABIDE2/NYU_1_connectomes",
            "phenotype_csv": "/Users/arnavkarnik/Documents/Classification2/Phenotypes_ABIDE2/NYU_1_phenotypes.csv"
        },
        {
            "site": "nyu2",
            "functional_dir": "/Users/arnavkarnik/Documents/Classification2/Time_Series_ABIDE2/nyu2_time_series/schaefer_400/cleaned-1",
            "structural_dir": "/Users/arnavkarnik/Documents/Classification2/SC_Connectomes_ABIDE2/NYU_2_connectomes",
            "phenotype_csv": "/Users/arnavkarnik/Documents/Classification2/Phenotypes_ABIDE2/NYU_2_phenotypes.csv"
        },
        {
            "site": "sdsu",
            "functional_dir": "/Users/arnavkarnik/Documents/Classification2/Time_Series_ABIDE2/sdsu_time_series/schaefer_400/cleaned-1",
            "structural_dir": "/Users/arnavkarnik/Documents/Classification2/SC_Connectomes_ABIDE2/SDSU_1_connectomes",
            "phenotype_csv": "/Users/arnavkarnik/Documents/Classification2/Phenotypes_ABIDE2/SDSU_1_phenotypes.csv"
        }
    ]

    group_esds = defaultdict(lambda: {'eigvals': [], 'esds': []})
    group_sdis = defaultdict(list)
    energy_cutoffs = []

    for config in sites_config:
        print(f"\n=== Processing site: {config['site'].upper()} ===")
        process_site(config["structural_dir"], config["functional_dir"], config["phenotype_csv"],
                     group_esds, energy_cutoffs, group_sdis)

    global_energy_cutoff = np.median(energy_cutoffs)
    print(f"\nGlobal energy-based eigenvalue cutoff: {global_energy_cutoff:.5f}")

    # ESD comparison plot
    plt.figure(figsize=(10, 6))
    for label in ['TD', 'ASD']:
        if group_esds[label]['esds']:
            plot_group_esd(group_esds[label], label, global_energy_cutoff)
    plt.axvspan(0, global_energy_cutoff, color='#cceeff', alpha=0.6, label='Low eigenvalue region')
    plt.axvspan(global_energy_cutoff, plt.xlim()[1], color='#ffcccc', alpha=0.6, label='High eigenvalue region')
    plt.legend(fontsize=12)
    plt.tight_layout()
    plt.savefig("esd_comparison_median_split.png", dpi=300)
    plt.show()

    # SDI comparison boxplot
    plt.figure(figsize=(8, 6))
    sdi_data, group_labels = [], []
    for label in ['TD', 'ASD']:
        sdi_data.extend(group_sdis[label])
        group_labels.extend([label] * len(group_sdis[label]))
    sdi_df = pd.DataFrame({'SDI': sdi_data, 'Group': group_labels})
    sdi_df = sdi_df[np.isfinite(sdi_df['SDI'])]
    if not sdi_df.empty:
        sns.boxplot(data=sdi_df, x='Group', y='SDI', palette=['blue', 'red'])
        plt.title('SDI Comparison (Low freq / High freq)', fontsize=14)
        td_sdi = sdi_df[sdi_df['Group'] == 'TD']['SDI'].values
        asd_sdi = sdi_df[sdi_df['Group'] == 'ASD']['SDI'].values
        from scipy.stats import ttest_ind
        if len(td_sdi) > 0 and len(asd_sdi) > 0:
            _, p_val = ttest_ind(td_sdi, asd_sdi)
            plt.text(0.5, plt.ylim()[1]*0.9, f'p-value: {p_val:.4f}', ha='center', fontsize=10,
                     bbox=dict(boxstyle='round', facecolor='wheat'))
        plt.ylabel('SDI', fontsize=12)
        plt.xlabel('Group', fontsize=12)
        plt.tight_layout()
        plt.savefig("sdi_comparison.png", dpi=300)
        plt.show()

    print(f"\nSplit statistics:")
    print(f"Global energy cutoff eigenvalue: {global_energy_cutoff:.5f}")
    print(f"Number of subjects contributing to energy split: {len(energy_cutoffs)}")


    group_esds = defaultdict(lambda: {'eigvals': [], 'esds': []})
    group_sdis = defaultdict(list)
    energy_cutoffs = []

    for config in sites_config:
        print(f"\n=== Processing site: {config['site'].upper()} ===")
        process_site(config["structural_dir"], config["functional_dir"], config["phenotype_csv"], group_esds, energy_cutoffs, group_sdis)

    global_energy_cutoff = np.median(energy_cutoffs)
    print(f"\nGlobal energy-based eigenvalue cutoff: {global_energy_cutoff:.5f}")

    plt.figure(figsize=(10, 6))
    for label in ['TD', 'ASD']:
        if group_esds[label]['esds']:
            plot_group_esd(group_esds[label], label, global_energy_cutoff)
    plt.axvspan(0, global_energy_cutoff, color='#cceeff', alpha=0.6, label='Low eigenvalue region')
    plt.axvspan(global_energy_cutoff, plt.xlim()[1], color='#ffcccc', alpha=0.6, label='High eigenvalue region')
    plt.legend(fontsize=12)
    plt.tight_layout()
    plt.savefig("esd_comparison_median_split.png", dpi=300)
    plt.show()

    plt.figure(figsize=(8, 6))
    sdi_data, group_labels = [], []
    for label in ['TD', 'ASD']:
        sdi_data.extend(group_sdis[label])
        group_labels.extend([label] * len(group_sdis[label]))
    sdi_df = pd.DataFrame({'SDI': sdi_data, 'Group': group_labels})
    sdi_df = sdi_df[np.isfinite(sdi_df['SDI'])]
    if not sdi_df.empty:
        sns.boxplot(data=sdi_df, x='Group', y='SDI', palette=['blue', 'red'])
        plt.title('SDI Comparison (Low freq / High freq)', fontsize=14)
        td_sdi = sdi_df[sdi_df['Group'] == 'TD']['SDI'].values
        asd_sdi = sdi_df[sdi_df['Group'] == 'ASD']['SDI'].values
        from scipy.stats import ttest_ind
        if len(td_sdi) > 0 and len(asd_sdi) > 0:
            _, p_val = ttest_ind(td_sdi, asd_sdi)
            plt.text(0.5, plt.ylim()[1]*0.9, f'p-value: {p_val:.4f}', ha='center', fontsize=10,
                     bbox=dict(boxstyle='round', facecolor='wheat'))
        plt.ylabel('SDI', fontsize=12)
        plt.xlabel('Group', fontsize=12)
        plt.tight_layout()
        plt.savefig("sdi_comparison.png", dpi=300)
        plt.show()

    print(f"\nSplit statistics:")
    print(f"Global energy cutoff eigenvalue: {global_energy_cutoff:.5f}")
    print(f"Subjects contributing to energy split: {len(energy_cutoffs)}")

if __name__ == "__main__":
    run_all_sites()


In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.linalg import eigh
from collections import defaultdict

sns.set(style='whitegrid')

def compute_unnormalized_laplacian(A):
    D = np.diag(np.sum(A, axis=1))
    return D - A

def preprocess_structural_connectivity(A, threshold=0.05):
    A = np.copy(A)
    return A  # Thresholding skipped intentionally

def compute_functional_connectivity(X):
    X_std = np.std(X, axis=0)
    valid_regions = X_std > 1e-10
    if not np.any(valid_regions):
        return np.eye(X.shape[1])
    X_valid = X[:, valid_regions]
    with np.errstate(invalid='ignore', divide='ignore'):
        fc_valid = np.corrcoef(X_valid.T)
        fc_valid = np.nan_to_num(fc_valid, nan=0.0, posinf=0.0, neginf=0.0)
    fc_full = np.zeros((X.shape[1], X.shape[1]))
    valid_idx = np.where(valid_regions)[0]
    for i, idx_i in enumerate(valid_idx):
        for j, idx_j in enumerate(valid_idx):
            fc_full[idx_i, idx_j] = fc_valid[i, j]
    np.fill_diagonal(fc_full, 1.0)
    return fc_full

def compute_esd_and_sdi(X, A):
    T, N = X.shape
    X_mean = np.mean(X, axis=0)
    X_std = np.std(X, axis=0)
    X_std[X_std < 1e-10] = 1.0
    X = (X - X_mean) / X_std
    L = compute_unnormalized_laplacian(A)
    L_reg = L + 1e-12 * np.eye(N)
    eigvals, eigvecs = eigh(L_reg)
    eigvals = eigvals[1:]
    eigvecs = eigvecs[:, 1:]
    X_hat = X @ eigvecs
    esd = np.sum(X_hat ** 2, axis=0)
    cumulative_energy = np.cumsum(esd)
    total_energy = cumulative_energy[-1]
    norm_cum_energy = cumulative_energy / total_energy
    cutoff_idx = np.searchsorted(norm_cum_energy, 0.5)
    cutoff_value = eigvals[cutoff_idx]
    low_freq_mask = np.arange(len(eigvals)) <= cutoff_idx
    high_freq_mask = np.arange(len(eigvals)) > cutoff_idx
    low_freq_energy = np.sum(esd[low_freq_mask])
    high_freq_energy = np.sum(esd[high_freq_mask])
    if high_freq_energy == 0 or not np.isfinite(high_freq_energy):
        sdi = 1.0 if low_freq_energy == 0 or not np.isfinite(low_freq_energy) else np.inf
    else:
        sdi = low_freq_energy / high_freq_energy
    return eigvals, esd, cutoff_value, sdi

def extract_pid(fpath):
    fname = os.path.basename(fpath)
    return fname.split("sub-")[1].split("_")[0] if "sub-" in fname else fname.split("_")[0]

def load_phenotype_csv(csv_path):
    df = pd.read_csv(csv_path)
    if 'SUB_ID' in df.columns:
        df['PatientID'] = df['SUB_ID'].astype(str)
    elif 'subject' in df.columns:
        df['PatientID'] = df['subject'].astype(str)
    return df.set_index('PatientID')

def collect_file_maps(structural_dir, functional_dir):
    struct_map, func_map = {}, {}
    for f in os.listdir(structural_dir):
        if f.endswith('.csv'):
            pid = extract_pid(f)
            struct_map[pid] = os.path.join(structural_dir, f)
    for root, _, files in os.walk(functional_dir):
        for f in files:
            if f.endswith('.csv'):
                pid = extract_pid(f)
                func_map[pid] = os.path.join(root, f)
    return struct_map, func_map

def plot_connectivity_matrices(fc_matrix, sc_matrix, pid, group_label):
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5))
    im1 = ax1.imshow(fc_matrix, cmap='RdBu_r', vmin=-1, vmax=1)
    ax1.set_title(f'Patient {pid}: fMRI Functional Connectivity')
    cbar1 = plt.colorbar(im1, ax=ax1, shrink=0.6)
    cbar1.set_label('Correlation', rotation=270, labelpad=15)
    im2 = ax2.imshow(sc_matrix, cmap='Reds', vmin=0)
    ax2.set_title(f'Patient {pid}: Structural Connectivity')
    cbar2 = plt.colorbar(im2, ax=ax2, shrink=0.6)
    cbar2.set_label('Connection Strength', rotation=270, labelpad=15)
    diff_matrix = fc_matrix - sc_matrix
    diff_max = np.max(np.abs(diff_matrix[np.isfinite(diff_matrix)])) if np.any(np.isfinite(diff_matrix)) else 1.0
    diff_max = diff_max if diff_max > 0 else 1.0
    im3 = ax3.imshow(diff_matrix, cmap='RdBu_r', vmin=-diff_max, vmax=diff_max)
    ax3.set_title(f'Patient {pid}: Difference (FC - SC)')
    cbar3 = plt.colorbar(im3, ax=ax3, shrink=0.6)
    cbar3.set_label('Difference', rotation=270, labelpad=15)
    plt.tight_layout()
    plt.savefig(f"connectivity_comparison_{pid}_{group_label}.png", dpi=150)
    plt.show()

def plot_group_esd(group_esds, group_name, energy_split_cutoff):
    eigvals = group_esds['eigvals'][0]
    esd_matrix = np.vstack(group_esds['esds'])
    median_esd = np.median(esd_matrix, axis=0)
    lower_q = np.percentile(esd_matrix, 25, axis=0)
    upper_q = np.percentile(esd_matrix, 75, axis=0)
    plt.plot(eigvals, median_esd, label=group_name, lw=2)
    plt.fill_between(eigvals, lower_q, upper_q, alpha=0.25)
    plt.axvspan(min(eigvals), energy_split_cutoff, color='#cceeff', alpha=0.6)
    plt.axvspan(energy_split_cutoff, max(eigvals), color='#ffcccc', alpha=0.6)
    plt.xscale('log')
    plt.yscale('log')
    plt.xlabel(r'$\lambda$ (eigenvalue)', fontsize=14)
    plt.ylabel(r'$\xi$ (energy spectral density)', fontsize=14)
    plt.title('Energy Spectral Density (50% ESD Split)', fontsize=14)
    plt.grid(True, which='both', ls='--', lw=0.5)

def process_site(structural_dir, functional_dir, phenotype_csv, group_esds, energy_cutoffs, group_sdis):
    structural_map, functional_map = collect_file_maps(structural_dir, functional_dir)
    phenotype_df = load_phenotype_csv(phenotype_csv)
    common_pids = sorted(set(structural_map.keys()) & set(functional_map.keys()) & set(phenotype_df.index.astype(str)))
    print(f"Found {len(common_pids)} valid subjects at site.")
    plot_count = 0
    for pid in common_pids:
        try:
            A = np.loadtxt(structural_map[pid], delimiter=',')
            X = np.loadtxt(functional_map[pid], delimiter=',')
            if A.shape[0] != A.shape[1] or X.shape[1] != A.shape[0]:
                continue
            A = preprocess_structural_connectivity(A, threshold=0.05)
            FC = compute_functional_connectivity(X)
            eigvals, esd, cutoff_value, sdi = compute_esd_and_sdi(X, A)
            group = phenotype_df.loc[pid, 'DX_GROUP']
            label = 'TD' if group == 1 else 'ASD'
            group_esds[label]['eigvals'].append(eigvals)
            group_esds[label]['esds'].append(esd)
            group_sdis[label].append(sdi)
            energy_cutoffs.append(cutoff_value)
            if plot_count < 2:
                plot_connectivity_matrices(FC, A, pid, label)
                plot_count += 1
        except Exception as e:
            print(f"Error processing {pid}: {e}")

def run_all_sites():
    sites_config = [
        {
            "site": "bni",
            "functional_dir": "/Users/arnavkarnik/Documents/Classification2/Time_Series_ABIDE2/bni_time_series/schaefer_400/cleaned-1",
            "structural_dir": "/Users/arnavkarnik/Documents/Classification2/SC_Connectomes_ABIDE2/BNI_1_connectomes",
            "phenotype_csv": "/Users/arnavkarnik/Documents/Classification2/Phenotypes_ABIDE2/BNI_1_phenotypes.csv"
        },
        {
            "site": "ip",
            "functional_dir": "/Users/arnavkarnik/Documents/Classification2/Time_Series_ABIDE2/ip_time_series/schaefer_400/cleaned-1",
            "structural_dir": "/Users/arnavkarnik/Documents/Classification2/SC_Connectomes_ABIDE2/IP_1_connectomes",
            "phenotype_csv": "/Users/arnavkarnik/Documents/Classification2/Phenotypes_ABIDE2/IP_1_phenotypes.csv"
        },
        {
            "site": "nyu1",
            "functional_dir": "/Users/arnavkarnik/Documents/Classification2/Time_Series_ABIDE2/nyu1_time_series/schaefer_400/cleaned-1",
            "structural_dir": "/Users/arnavkarnik/Documents/Classification2/SC_Connectomes_ABIDE2/NYU_1_connectomes",
            "phenotype_csv": "/Users/arnavkarnik/Documents/Classification2/Phenotypes_ABIDE2/NYU_1_phenotypes.csv"
        },
        {
            "site": "nyu2",
            "functional_dir": "/Users/arnavkarnik/Documents/Classification2/Time_Series_ABIDE2/nyu2_time_series/schaefer_400/cleaned-1",
            "structural_dir": "/Users/arnavkarnik/Documents/Classification2/SC_Connectomes_ABIDE2/NYU_2_connectomes",
            "phenotype_csv": "/Users/arnavkarnik/Documents/Classification2/Phenotypes_ABIDE2/NYU_2_phenotypes.csv"
        },
        {
            "site": "sdsu",
            "functional_dir": "/Users/arnavkarnik/Documents/Classification2/Time_Series_ABIDE2/sdsu_time_series/schaefer_400/cleaned-1",
            "structural_dir": "/Users/arnavkarnik/Documents/Classification2/SC_Connectomes_ABIDE2/SDSU_1_connectomes",
            "phenotype_csv": "/Users/arnavkarnik/Documents/Classification2/Phenotypes_ABIDE2/SDSU_1_phenotypes.csv"
        }
    ]

    group_esds = defaultdict(lambda: {'eigvals': [], 'esds': []})
    group_sdis = defaultdict(list)
    energy_cutoffs = []

    for config in sites_config:
        print(f"\n=== Processing site: {config['site'].upper()} ===")
        process_site(config["structural_dir"], config["functional_dir"], config["phenotype_csv"],
                     group_esds, energy_cutoffs, group_sdis)

    global_energy_cutoff = np.median(energy_cutoffs)
    print(f"\nGlobal energy-based eigenvalue cutoff: {global_energy_cutoff:.5f}")

    # ESD comparison plot
    plt.figure(figsize=(10, 6))
    for label in ['TD', 'ASD']:
        if group_esds[label]['esds']:
            plot_group_esd(group_esds[label], label, global_energy_cutoff)
    plt.axvspan(0, global_energy_cutoff, color='#cceeff', alpha=0.6, label='Low eigenvalue region')
    plt.axvspan(global_energy_cutoff, plt.xlim()[1], color='#ffcccc', alpha=0.6, label='High eigenvalue region')
    plt.legend(fontsize=12)
    plt.tight_layout()
    plt.savefig("esd_comparison_median_split.png", dpi=300)
    plt.show()

    # SDI comparison boxplot
    plt.figure(figsize=(8, 6))
    sdi_data, group_labels = [], []
    for label in ['TD', 'ASD']:
        sdi_data.extend(group_sdis[label])
        group_labels.extend([label] * len(group_sdis[label]))
    sdi_df = pd.DataFrame({'SDI': sdi_data, 'Group': group_labels})
    sdi_df = sdi_df[np.isfinite(sdi_df['SDI'])]
    if not sdi_df.empty:
        sns.boxplot(data=sdi_df, x='Group', y='SDI', palette=['blue', 'red'])
        plt.title('SDI Comparison (Low freq / High freq)', fontsize=14)
        td_sdi = sdi_df[sdi_df['Group'] == 'TD']['SDI'].values
        asd_sdi = sdi_df[sdi_df['Group'] == 'ASD']['SDI'].values
        from scipy.stats import ttest_ind
        if len(td_sdi) > 0 and len(asd_sdi) > 0:
            _, p_val = ttest_ind(td_sdi, asd_sdi)
            plt.text(0.5, plt.ylim()[1]*0.9, f'p-value: {p_val:.4f}', ha='center', fontsize=10,
                     bbox=dict(boxstyle='round', facecolor='wheat'))
        plt.ylabel('SDI', fontsize=12)
        plt.xlabel('Group', fontsize=12)
        plt.tight_layout()
        plt.savefig("sdi_comparison.png", dpi=300)
        plt.show()

    print(f"\nSplit statistics:")
    print(f"Global energy cutoff eigenvalue: {global_energy_cutoff:.5f}")
    print(f"Number of subjects contributing to energy split: {len(energy_cutoffs)}")


    group_esds = defaultdict(lambda: {'eigvals': [], 'esds': []})
    group_sdis = defaultdict(list)
    energy_cutoffs = []

    for config in sites_config:
        print(f"\n=== Processing site: {config['site'].upper()} ===")
        process_site(config["structural_dir"], config["functional_dir"], config["phenotype_csv"], group_esds, energy_cutoffs, group_sdis)

    global_energy_cutoff = np.median(energy_cutoffs)
    print(f"\nGlobal energy-based eigenvalue cutoff: {global_energy_cutoff:.5f}")

    plt.figure(figsize=(10, 6))
    for label in ['TD', 'ASD']:
        if group_esds[label]['esds']:
            plot_group_esd(group_esds[label], label, global_energy_cutoff)
    plt.axvspan(0, global_energy_cutoff, color='#cceeff', alpha=0.6, label='Low eigenvalue region')
    plt.axvspan(global_energy_cutoff, plt.xlim()[1], color='#ffcccc', alpha=0.6, label='High eigenvalue region')
    plt.legend(fontsize=12)
    plt.tight_layout()
    plt.savefig("esd_comparison_median_split.png", dpi=300)
    plt.show()

    plt.figure(figsize=(8, 6))
    sdi_data, group_labels = [], []
    for label in ['TD', 'ASD']:
        sdi_data.extend(group_sdis[label])
        group_labels.extend([label] * len(group_sdis[label]))
    sdi_df = pd.DataFrame({'SDI': sdi_data, 'Group': group_labels})
    sdi_df = sdi_df[np.isfinite(sdi_df['SDI'])]
    if not sdi_df.empty:
        sns.boxplot(data=sdi_df, x='Group', y='SDI', palette=['blue', 'red'])
        plt.title('SDI Comparison (Low freq / High freq)', fontsize=14)
        td_sdi = sdi_df[sdi_df['Group'] == 'TD']['SDI'].values
        asd_sdi = sdi_df[sdi_df['Group'] == 'ASD']['SDI'].values
        from scipy.stats import ttest_ind
        if len(td_sdi) > 0 and len(asd_sdi) > 0:
            _, p_val = ttest_ind(td_sdi, asd_sdi)
            plt.text(0.5, plt.ylim()[1]*0.9, f'p-value: {p_val:.4f}', ha='center', fontsize=10,
                     bbox=dict(boxstyle='round', facecolor='wheat'))
        plt.ylabel('SDI', fontsize=12)
        plt.xlabel('Group', fontsize=12)
        plt.tight_layout()
        plt.savefig("sdi_comparison.png", dpi=300)
        plt.show()

    print(f"\nSplit statistics:")
    print(f"Global energy cutoff eigenvalue: {global_energy_cutoff:.5f}")
    print(f"Subjects contributing to energy split: {len(energy_cutoffs)}")

if __name__ == "__main__":
    run_all_sites()
