# Classification of ABIDE 2 

### To check size of connectomes and files ----> Check for each site individually by changing file paths 

In [None]:
import os
import numpy as np

sc_dir = '/Users/arnavkarnik/Documents/Classification/SC_Connectomes_ABIDE2/BNI_1_connectomes'
csv_files = [f for f in os.listdir(sc_dir) if f.endswith('.csv')]
print(f"Number of CSV files in SC directory: {len(csv_files)}\n")

for filename in csv_files:
    filepath = os.path.join(sc_dir, filename)
    try:
        matrix = np.loadtxt(filepath, delimiter=',')
        print(f"{filename}: shape = {matrix.shape}")
    except Exception as e:
        print(f"Failed to read {filename}: {e}")


In [None]:
import os
import numpy as np

fmri_dir = '/Users/arnavkarnik/Documents/Classification/Time_Series_ABIDE2/bni_time_series/schaefer_400'

csv_files = []
for root, dirs, files in os.walk(fmri_dir):
    for file in files:
        if file.endswith('.csv'):
            csv_files.append(os.path.join(root, file))

print(f"Number of CSV files in fMRI directory (including subfolders): {len(csv_files)}\n")

for filepath in csv_files:
    try:
        matrix = np.loadtxt(filepath, delimiter=',')
        print(f"{os.path.basename(filepath)}: shape = {matrix.shape}")
    except Exception as e:
        print(f"Failed to read {os.path.basename(filepath)}: {e}")


In [None]:
import os

sc_dir = '/Users/arnavkarnik/Documents/Classification/SC_Connectomes_ABIDE2/BNI_1_connectomes'
fmri_dir = '/Users/arnavkarnik/Documents/Classification/Time_Series_ABIDE2/bni_time_series/schaefer_400'

# Get patient IDs from SC files (remove extension)
sc_files = [f for f in os.listdir(sc_dir) if f.endswith('.csv')]
sc_ids = set([os.path.splitext(f)[0] for f in sc_files])

# Get patient IDs from fMRI files recursively (remove extension)
fmri_ids = set()
for root, dirs, files in os.walk(fmri_dir):
    for f in files:
        if f.endswith('.csv'):
            fmri_ids.add(os.path.splitext(f)[0])

missing_in_fmri = sc_ids - fmri_ids
missing_in_sc = fmri_ids - sc_ids

print(f"Patients missing in fMRI data: {missing_in_fmri}")
print(f"Patients missing in SC data: {missing_in_sc}")


In [None]:
import os
import re

sc_dir = '/Users/arnavkarnik/Documents/Classification/SC_Connectomes_ABIDE2/BNI_1_connectomes'
fmri_dir = '/Users/arnavkarnik/Documents/Classification/Time_Series_ABIDE2/bni_time_series/schaefer_400'

def extract_id_sc(filename):
    # Example: '29006_parcels.csv' -> '29006'
    return filename.split('_')[0]

def extract_id_fmri(filename):
    # Example: 'sub-29006_ses-1_task-rest_cleaned-1_bold.csv' -> '29006'
    match = re.search(r'sub-(\d+)', filename)
    if match:
        return match.group(1)
    else:
        return None

sc_files = [f for f in os.listdir(sc_dir) if f.endswith('.csv')]
sc_ids = set(extract_id_sc(f) for f in sc_files)

fmri_ids = set()
for root, dirs, files in os.walk(fmri_dir):
    for f in files:
        if f.endswith('.csv'):
            pid = extract_id_fmri(f)
            if pid:
                fmri_ids.add(pid)

missing_in_fmri = sc_ids - fmri_ids
missing_in_sc = fmri_ids - sc_ids

print(f"Patients missing in fMRI data: {missing_in_fmri}")
print(f"Patients missing in SC data: {missing_in_sc}")


### SDI Calculation

In [None]:
import os
import re
import numpy as np
from scipy.linalg import eigh
import csv
import matplotlib.pyplot as plt
import seaborn as sns

def compute_structural_laplacian(A):
    D = np.diag(np.sum(A, axis=1))
    with np.errstate(divide='ignore'):
        D_inv_sqrt = np.diag(1.0 / np.sqrt(np.sum(A, axis=1)))
    D_inv_sqrt[np.isinf(D_inv_sqrt)] = 0
    L = np.eye(A.shape[0]) - D_inv_sqrt @ A @ D_inv_sqrt
    return L

def graph_spectral_phase_randomize(X, eigvecs, seed=None):
    if seed is not None:
        np.random.seed(seed)
    X_hat = X @ eigvecs
    T, N = X_hat.shape
    X_surr = np.zeros_like(X_hat)

    for i in range(N):
        fft_coeff = np.fft.fft(X_hat[:, i])
        mag = np.abs(fft_coeff)
        phase = np.angle(fft_coeff)
        num_phases = len(fft_coeff)

        random_phases = np.random.uniform(0, 2*np.pi, num_phases // 2 - 1)
        new_phase = np.copy(phase)
        new_phase[1:num_phases//2] = random_phases
        new_phase[-(num_phases//2)+1:] = -random_phases[::-1]

        new_fft = mag * np.exp(1j * new_phase)
        X_surr[:, i] = np.fft.ifft(new_fft).real

    return X_surr @ eigvecs.T

def compute_SDI_informed_energy_split(A, X, pid, num_surrogates=100, seed=None):
    X = X - X.mean(axis=0)
    X = X / (X.std(axis=0) + 1e-10)

    T, N = X.shape
    L = compute_structural_laplacian(A)
    eigvals, eigvecs = eigh(L)

    energy = np.sum((X @ eigvecs)**2, axis=0)
    total_energy = np.sum(energy)
    cum_energy = np.cumsum(energy)
    cutoff_index = np.searchsorted(cum_energy, 0.5 * total_energy)
    if cutoff_index <= 0 or cutoff_index >= N:
        cutoff_index = N // 2

    Vlow, Vhigh = eigvecs[:, :cutoff_index], eigvecs[:, cutoff_index:]

    N_c_surr, N_d_surr = np.empty((N, num_surrogates)), np.empty((N, num_surrogates))

    for s in range(num_surrogates):
        X_surr = graph_spectral_phase_randomize(X, eigvecs, seed=seed+s if seed is not None else None)
        X_hat = X_surr @ eigvecs

        X_c = X_hat[:, :cutoff_index] @ Vlow.T
        X_d = X_hat[:, cutoff_index:] @ Vhigh.T

        for r in range(N):
            N_c_surr[r, s] = np.linalg.norm(X_c[:, r])
            N_d_surr[r, s] = np.linalg.norm(X_d[:, r])

    SDI = N_d_surr / (N_c_surr + 1e-10)
    mean_SDI = np.mean(SDI, axis=1)
    print(f"Patient {pid}: SDI mean={mean_SDI.mean():.4f}, cutoff={cutoff_index}")
    return mean_SDI, cutoff_index

def extract_patient_id_structural(filename):
    match = re.match(r'(\d+)_parcels', filename)
    return match.group(1) if match else None

def extract_patient_id_functional(filepath):
    filename = os.path.basename(filepath)
    match = re.search(r'sub-(\d+)_', filename)
    return match.group(1) if match else None

def get_structural_files_map(structural_dir):
    files_map = {}
    for f in os.listdir(structural_dir):
        if f.endswith('.csv'):
            pid = extract_patient_id_structural(f)
            if pid:
                files_map[pid] = os.path.join(structural_dir, f)
    return files_map

def get_functional_files_map(functional_dir):
    files_map = {}
    for root, _, files in os.walk(functional_dir):
        for f in files:
            if f.endswith('.csv'):
                full_path = os.path.join(root, f)
                pid = extract_patient_id_functional(full_path)
                if pid:
                    files_map[pid] = full_path
    return files_map

def plot_fmri_sc(X, A, pid):
    fc_mat = np.corrcoef(X.T)

    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    sns.heatmap(fc_mat, ax=axes[0], cmap='coolwarm', center=0, square=True, cbar=True)
    axes[0].set_title(f'Patient {pid}: fMRI Functional Connectivity')

    sns.heatmap(A, ax=axes[1], cmap='coolwarm', center=0, square=True, cbar=True)
    axes[1].set_title(f'Patient {pid}: Structural Connectivity')

    diff = fc_mat - A
    sns.heatmap(diff, ax=axes[2], cmap='bwr', center=0, square=True, cbar=True)
    axes[2].set_title(f'Patient {pid}: Difference (FC - SC)')

    plt.tight_layout()
    plt.show()

def main(functional_dir, structural_dir, output_csv_path, num_surrogates=100, seed=42):
    func_files = get_functional_files_map(functional_dir)
    struct_files = get_structural_files_map(structural_dir)

    common_patients = set(func_files.keys()) & set(struct_files.keys())
    print(f"Found {len(common_patients)} patients with matching data.")

    results = []
    cutoff_indices = []
    expected_nodes = 400

    for pid in sorted(common_patients):
        func_path = func_files[pid]
        struct_path = struct_files[pid]

        try:
            X = np.loadtxt(func_path, delimiter=',')
            A = np.loadtxt(struct_path, delimiter=',')

            if X.size == 0 or A.size == 0:
                print(f"Skipping {pid}: empty data")
                continue

            if len(A.shape) != 2 or A.shape[0] != A.shape[1]:
                print(f"Skipping {pid}: Structural matrix not square {A.shape}")
                continue

            if len(X.shape) != 2:
                print(f"Skipping {pid}: Functional data not 2D {X.shape}")
                continue

            T, N_f = X.shape
            N_s = A.shape[0]

            if N_f != N_s:
                print(f"Skipping {pid}: functional nodes != structural nodes ({N_f} vs {N_s})")
                continue

            if N_f != expected_nodes:
                print(f"Skipping {pid}: expected {expected_nodes} nodes, got {N_f}")
                continue

            A = A / (np.max(A) + 1e-10)
            plot_fmri_sc(X, A, pid)

            sdi, cutoff_index = compute_SDI_informed_energy_split(A, X, pid, num_surrogates=num_surrogates, seed=seed)

            # Log transform and normalize to (0,1)
            sdi = np.log2(sdi + 1e-10)
            sdi_min, sdi_max = sdi.min(), sdi.max()
            sdi = (sdi - sdi_min) / (sdi_max - sdi_min + 1e-10)

            if len(sdi) != expected_nodes:
                print(f"Skipping {pid}: SDI length is {len(sdi)}, expected {expected_nodes}.")
                continue

            results.append((pid, sdi))
            cutoff_indices.append(cutoff_index)

        except Exception as e:
            print(f"Error processing {pid}: {e}")

    os.makedirs(os.path.dirname(output_csv_path), exist_ok=True)
    with open(output_csv_path, 'w', newline='') as f_out:
        writer = csv.writer(f_out)
        header = ['PatientID'] + [f'SDI_Node_{i+1}' for i in range(expected_nodes)]
        writer.writerow(header)

        for pid, sdi in results:
            writer.writerow([pid] + list(sdi))

    print(f"Saved all SDI results to {output_csv_path}")

    # Plot histogram of energy cutoff indices
    plt.figure(figsize=(8, 5))
    sns.histplot(cutoff_indices, bins=20, kde=False)
    plt.xlabel('Energy-Based Cutoff Index')
    plt.ylabel('Number of Subjects')
    plt.title('Distribution of Graph Frequency Split Points')
    plt.tight_layout()
    plt.show()

if __name__ == "__main__":
    sites_config = [
        {
            "site": "bni",
            "functional_dir": "/Users/arnavkarnik/Documents/Classification/Time_Series_ABIDE2/bni_time_series/schaefer_400/cleaned-1",
            "structural_dir": "/Users/arnavkarnik/Documents/Classification/SC_Connectomes_ABIDE2/BNI_1_connectomes",
            "output_csv_path": "/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_bni.csv"
        },
        {
            "site": "ip",
            "functional_dir": "/Users/arnavkarnik/Documents/Classification/Time_Series_ABIDE2/ip_time_series/schaefer_400/cleaned-1",
            "structural_dir": "/Users/arnavkarnik/Documents/Classification/SC_Connectomes_ABIDE2/IP_1_connectomes",
            "output_csv_path": "/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_ip.csv"
        },
        {
            "site": "nyu1",
            "functional_dir": "/Users/arnavkarnik/Documents/Classification/Time_Series_ABIDE2/nyu1_time_series/schaefer_400/cleaned-1",
            "structural_dir": "/Users/arnavkarnik/Documents/Classification/SC_Connectomes_ABIDE2/NYU_1_connectomes",
            "output_csv_path": "/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_nyu1.csv"
        },
        {
            "site": "nyu2",
            "functional_dir": "/Users/arnavkarnik/Documents/Classification/Time_Series_ABIDE2/nyu2_time_series/schaefer_400/cleaned-1",
            "structural_dir": "/Users/arnavkarnik/Documents/Classification/SC_Connectomes_ABIDE2/NYU_2_connectomes",
            "output_csv_path": "/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_nyu2.csv"
        },
        {
            "site": "sdsu",
            "functional_dir": "/Users/arnavkarnik/Documents/Classification/Time_Series_ABIDE2/sdsu_time_series/schaefer_400/cleaned-1",
            "structural_dir": "/Users/arnavkarnik/Documents/Classification/SC_Connectomes_ABIDE2/SDSU_1_connectomes",
            "output_csv_path": "/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_sdsu.csv"
        }
    ]

    for config in sites_config:
        print(f"\n=== Processing site: {config['site'].upper()} ===")
        main(
            functional_dir=config["functional_dir"],
            structural_dir=config["structural_dir"],
            output_csv_path=config["output_csv_path"],
            num_surrogates=100,
            seed=42
        )


### Visualization of SDI data for each individual site ---> Plug in file path to see 

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import csv

# Path to your SDI results CSV
csv_path = "/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_bni.csv"

# Load SDI data from CSV (skip header)
patient_ids = []
sdi_data = []

with open(csv_path, 'r') as f:
    reader = csv.reader(f)
    header = next(reader)  # Skip header
    for row in reader:
        patient_ids.append(row[0])
        sdi_vals = [float(x) if x != '' else np.nan for x in row[1:]]
        sdi_data.append(sdi_vals)

sdi_array = np.array(sdi_data)  # shape: (num_patients, num_nodes)

# Replace empty/nan with np.nan explicitly
sdi_array = np.where(np.isnan(sdi_array), np.nan, sdi_array)

# Plot 1: Histogram of all SDI values across all nodes and patients
plt.figure(figsize=(12,6))
plt.hist(sdi_array[~np.isnan(sdi_array)].flatten(), bins=90, color='c', alpha=0.7)
plt.title("Histogram of all SDI values across all nodes and patients")
plt.xlabel("SDI value")
plt.ylabel("Frequency")
plt.show()

# Plot 2: Mean ± Std Dev of SDI per brain node across patients
mean_sdi = np.nanmean(sdi_array, axis=0)
std_sdi = np.nanstd(sdi_array, axis=0)
nodes = np.arange(1, len(mean_sdi)+1)

plt.figure(figsize=(14,6))
plt.errorbar(nodes, mean_sdi, yerr=std_sdi, fmt='-o', ecolor='r', capsize=5)
plt.title("Mean ± Std Dev of SDI per brain node across patients")
plt.xlabel("Brain node")
plt.ylabel("SDI")
plt.grid(True)
plt.show()

# Plot 3: Distribution of total SDI per patient
total_sdi_per_patient = np.nansum(sdi_array, axis=1)

plt.figure(figsize=(10,5))
plt.hist(total_sdi_per_patient, bins=30, color='m', alpha=0.7)
plt.title("Distribution of total SDI scores across patients")
plt.xlabel("Total SDI")
plt.ylabel("Number of patients")
plt.show()

# Plot 4: Heatmap of SDI for first 5 patients (or less if fewer patients)
# Show only first 5 patients and first 50 nodes
num_patients_to_show = min(5, sdi_array.shape[0])
num_nodes_to_show = min(50, sdi_array.shape[1])

plt.figure(figsize=(12, 4))
sns.heatmap(
    sdi_array[:num_patients_to_show, :num_nodes_to_show],
    cmap='viridis',
    xticklabels=np.arange(1, num_nodes_to_show + 1),
    yticklabels=patient_ids[:num_patients_to_show]
)
plt.title("Zoomed-in SDI heatmap (first 5 patients × first 50 brain nodes)")
plt.xlabel("Brain node")
plt.ylabel("Patient ID")
plt.tight_layout()
plt.show()


start_node = 50
end_node = 150
num_patients_to_show = min(5, sdi_array.shape[0])

plt.figure(figsize=(12, 4))
sns.heatmap(
    sdi_array[:num_patients_to_show, start_node:end_node],
    cmap='viridis',
    xticklabels=np.arange(start_node + 1, end_node + 1),
    yticklabels=patient_ids[:num_patients_to_show]
)
plt.title(f"Zoomed-in SDI heatmap (first 5 patients × nodes {start_node+1}-{end_node})")
plt.xlabel("Brain node")
plt.ylabel("Patient ID")
plt.tight_layout()
plt.show()



### Visualization of SDI data for all sites together 

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import csv
import os

# List of all CSV paths
csv_paths = [
    "/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_bni.csv",
    "/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_ip.csv",
    "/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_nyu1.csv",
    "/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_nyu2.csv",
    "/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_sdsu.csv"
]

all_patient_ids = []
all_sdi_data = []

# Load data from all CSVs
for path in csv_paths:
    if not os.path.exists(path):
        print(f"File not found: {path}")
        continue
    with open(path, 'r') as f:
        reader = csv.reader(f)
        header = next(reader)
        for row in reader:
            all_patient_ids.append(row[0])
            sdi_vals = [float(x) if x != '' else np.nan for x in row[1:]]
            all_sdi_data.append(sdi_vals)

sdi_array = np.array(all_sdi_data)  # shape: (num_patients, num_nodes)

# Replace nan explicitly
sdi_array = np.where(np.isnan(sdi_array), np.nan, sdi_array)

# --- Plot 1: Histogram of all SDI values ---
plt.figure(figsize=(12,6))
plt.hist(sdi_array[~np.isnan(sdi_array)].flatten(), bins=90, color='c', alpha=0.7)
plt.title("Histogram of all SDI values across all ABIDE II patients and nodes")
plt.xlabel("SDI value")
plt.ylabel("Frequency")
plt.tight_layout()
plt.show()

# --- Plot 2: Mean ± Std Dev per brain node ---
mean_sdi = np.nanmean(sdi_array, axis=0)
std_sdi = np.nanstd(sdi_array, axis=0)
nodes = np.arange(1, len(mean_sdi)+1)

plt.figure(figsize=(14,6))
plt.errorbar(nodes, mean_sdi, yerr=std_sdi, fmt='-o', ecolor='r', capsize=5)
plt.title("Mean ± Std Dev of SDI per brain node across ABIDE II patients")
plt.xlabel("Brain node")
plt.ylabel("SDI")
plt.grid(True)
plt.tight_layout()
plt.show()

# --- Plot 3: Distribution of total SDI per patient ---
total_sdi_per_patient = np.nansum(sdi_array, axis=1)

plt.figure(figsize=(10,5))
plt.hist(total_sdi_per_patient, bins=30, color='m', alpha=0.7)
plt.title("Distribution of total SDI scores across ABIDE II patients")
plt.xlabel("Total SDI")
plt.ylabel("Number of patients")
plt.tight_layout()
plt.show()

# --- Plot 4a: Zoomed-in heatmap (first 5 patients × first 50 brain nodes) ---
num_patients_to_show = min(5, sdi_array.shape[0])
num_nodes_to_show = min(50, sdi_array.shape[1])

plt.figure(figsize=(12, 4))
sns.heatmap(
    sdi_array[:num_patients_to_show, :num_nodes_to_show],
    cmap='viridis',
    xticklabels=np.arange(1, num_nodes_to_show + 1),
    yticklabels=all_patient_ids[:num_patients_to_show]
)
plt.title("Zoomed-in SDI heatmap (first 5 patients × first 50 brain nodes)")
plt.xlabel("Brain node")
plt.ylabel("Patient ID")
plt.tight_layout()
plt.show()

# --- Plot 4b: Zoomed-in heatmap (first 5 patients × nodes 51–150) ---
start_node = 50
end_node = 150

plt.figure(figsize=(12, 4))
sns.heatmap(
    sdi_array[:num_patients_to_show, start_node:end_node],
    cmap='viridis',
    xticklabels=np.arange(start_node + 1, end_node + 1),
    yticklabels=all_patient_ids[:num_patients_to_show]
)
plt.title(f"Zoomed-in SDI heatmap (first 5 patients × nodes {start_node+1}-{end_node})")
plt.xlabel("Brain node")
plt.ylabel("Patient ID")
plt.tight_layout()
plt.show()


### Visualization of Brain Atlas for each patient in individual site using Schaffer 400 ----> Keep changing patient number and site 

In [None]:
import pandas as pd
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from nilearn import plotting, datasets

# -----------------------------
# Step 1: Load SDI from CSV
# -----------------------------
csv_path = "/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_bni.csv"
patient_id = "29006"

df = pd.read_csv(csv_path)
df["PatientID"] = df["PatientID"].astype(str)
row_match = df[df["PatientID"] == patient_id]

if row_match.empty:
    raise ValueError(f"Patient ID {patient_id} not found in CSV.")

sdi_row = row_match.iloc[0]
sdi_values = sdi_row.iloc[1:401].to_numpy(dtype=float)  # 400 parcels

# -----------------------------
# Step 2: Load Schaefer 400 atlas volume
# -----------------------------
atlas_path = "/Users/arnavkarnik/Documents/Classification/Schaefer2018_400Parcels_7Networks_order_FSLMNI152_1mm.nii"
atlas_img = nib.load(atlas_path)
atlas_data = atlas_img.get_fdata()

# -----------------------------
# Step 3: Build SDI volume
# -----------------------------
sdi_volume = np.zeros_like(atlas_data)

for i in range(400):
    region_label = i + 1  # Labels are from 1 to 400
    sdi_volume[atlas_data == region_label] = sdi_values[i]

sdi_img = nib.Nifti1Image(sdi_volume, affine=atlas_img.affine)

# -----------------------------
# Step 4: Plot on MNI template
# -----------------------------
template = datasets.load_mni152_template()

display = plotting.plot_stat_map(
    sdi_img,
    bg_img=template,
    title=f"SDI Map - Patient {patient_id}",
    display_mode="ortho",
    threshold=np.percentile(sdi_values, 20),  # show top 80%
    cmap="viridis",
)
plotting.show()


In [None]:
import pandas as pd
import numpy as np
import nibabel as nib

# -------- Step 1: Load SDI values from CSV --------
csv_path = "/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_bni.csv"
patient_id = "29006"

df = pd.read_csv(csv_path)
df["PatientID"] = df["PatientID"].astype(str)
row = df[df["PatientID"] == patient_id]

if row.empty:
    raise ValueError(f"Patient {patient_id} not found in CSV")

sdi_values = row.iloc[0, 1:401].to_numpy(dtype=float)  # 400 regions

# -------- Step 2: Load the Schaefer400 atlas --------
atlas_path = "/Users/arnavkarnik/Documents/Classification/Schaefer2018_400Parcels_7Networks_order_FSLMNI152_1mm.nii"
atlas_img = nib.load(atlas_path)
atlas_data = atlas_img.get_fdata().astype(int)

# -------- Step 3: Identify top N coupled/decoupled regions --------
N = 10

# Indices of highest/lowest SDI values
top_indices = np.argsort(sdi_values)[-N:][::-1]     # top N
bottom_indices = np.argsort(sdi_values)[:N]         # bottom N

print("Top Coupled Regions (High SDI):")
for i in top_indices:
    print(f"Region {i+1} — SDI: {sdi_values[i]:.4f}")

print("\nTop Decoupled Regions (Low SDI):")
for i in bottom_indices:
    print(f"Region {i+1} — SDI: {sdi_values[i]:.4f}")


In [None]:
from nilearn import datasets

# Fetch the Schaefer 2018 atlas with 400 regions and 7-network solution
schaefer = datasets.fetch_atlas_schaefer_2018(n_rois=400, yeo_networks=7, resolution_mm=1)

# Extract the labels
region_labels = schaefer['labels']  # This is a list of region names


In [None]:
import pandas as pd
import numpy as np
import nibabel as nib

# Build Schaefer400 parcel names dictionary by network and parcel number
network_names = [
    ("Vis", 75),         # Visual
    ("SomMot", 60),      # Somatomotor
    ("DorsAttn", 52),    # Dorsal Attention
    ("SalVentAttn", 50), # Ventral Attention
    ("Limbic", 25),      # Limbic
    ("Cont", 58),        # Frontoparietal Control
    ("Default", 80)      # Default Mode
]

parcel_names = {}
current_index = 1
for network, count in network_names:
    for i in range(1, count + 1):
        parcel_names[current_index] = f"{network}_{i}"
        current_index += 1

# === New code to verify and save all 400 parcel names ===
print(f"Total parcels: {len(parcel_names)}")

for i in range(1, 401):
    print(f"Region {i}: {parcel_names[i]}")

with open("schaefer400_parcel_names.txt", "w") as f:
    for i in range(1, 401):
        f.write(f"Region {i}: {parcel_names[i]}\n")

print("Saved all parcel names to schaefer400_parcel_names.txt")

# Load SDI values
csv_path = "/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_bni.csv"
patient_id = "29006"

df = pd.read_csv(csv_path)
df["PatientID"] = df["PatientID"].astype(str)
row = df[df["PatientID"] == patient_id]

if row.empty:
    raise ValueError(f"Patient {patient_id} not found in CSV")

sdi_values = row.iloc[0, 1:401].to_numpy(dtype=float)  # 400 regions

# Load Schaefer400 atlas (not used directly here but useful if needed)
atlas_path = "/Users/arnavkarnik/Documents/Classification/Schaefer2018_400Parcels_7Networks_order_FSLMNI152_1mm.nii"
atlas_img = nib.load(atlas_path)
atlas_data = atlas_img.get_fdata().astype(int)

N = 10
top_indices = np.argsort(sdi_values)[-N:][::-1]  # Top 10
bottom_indices = np.argsort(sdi_values)[:N]      # Bottom 10

print("Top Coupled Regions (High SDI):")
for i in top_indices:
    region_index = i + 1  # Convert zero-based to 1-based indexing
    region_name = parcel_names.get(region_index, f"Region_{region_index}")
    print(f"{region_name} (Region {region_index}) — SDI: {sdi_values[i]:.4f}")

print("\nTop Decoupled Regions (Low SDI):")
for i in bottom_indices:
    region_index = i + 1
    region_name = parcel_names.get(region_index, f"Region_{region_index}")
    print(f"{region_name} (Region {region_index}) — SDI: {sdi_values[i]:.4f}")


In [None]:
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from nilearn import plotting
import matplotlib.patches as mpatches

# Paths
atlas_path = "/Users/arnavkarnik/Documents/Classification/Schaefer2018_400Parcels_7Networks_order_FSLMNI152_1mm.nii"

# Your top coupled/decoupled region indices (1-based)
top_coupled = [90, 70, 373, 79, 133, 60, 255, 239, 107, 154]
top_decoupled = [15, 168, 27, 399, 254, 249, 252, 217, 128, 18]

# Load atlas
atlas_img = nib.load(atlas_path)
atlas_data = atlas_img.get_fdata()

# Create masks for top coupled and decoupled parcels
top_coupled_mask = np.isin(atlas_data, top_coupled).astype(int)
top_decoupled_mask = np.isin(atlas_data, top_decoupled).astype(int)

# Combine masks: 0=background, 1=top coupled, 2=top decoupled
combined_mask = top_coupled_mask + (top_decoupled_mask * 2)

# Convert to uint8 to avoid nibabel dtype error
combined_img = nib.Nifti1Image(combined_mask.astype(np.uint8), affine=atlas_img.affine)

# Plotting combined mask
display = plotting.plot_roi(
    combined_img,
    title="Top Coupled (red) & Decoupled (blue) Regions",
    cmap=plt.cm.get_cmap('coolwarm', 3),
    alpha=0.7
)

# Add legend for color meaning
red_patch = mpatches.Patch(color='red', label='Top Coupled')
blue_patch = mpatches.Patch(color='blue', label='Top Decoupled')
plt.legend(handles=[red_patch, blue_patch], loc='lower left')

plotting.show()


### Visualization of Brain Atlas for all patients in all sites using Schaffer 400 

In [None]:
import pandas as pd
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from nilearn import plotting, datasets
import os

# -----------------------------
# Step 1: File Paths
# -----------------------------
atlas_path = "/Users/arnavkarnik/Documents/Classification/Schaefer2018_400Parcels_7Networks_order_FSLMNI152_1mm.nii"
template = datasets.load_mni152_template()

site_csvs = {
    "BNI": "/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_bni.csv",
    "IP": "/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_ip.csv",
    "NYU1": "/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_nyu1.csv",
    "NYU2": "/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_nyu2.csv",
    "SDSU": "/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_sdsu.csv",
}

# -----------------------------
# Step 2: Aggregate SDI from All Sites
# -----------------------------
all_sdi = []

for site_name, csv_path in site_csvs.items():
    df = pd.read_csv(csv_path)
    if df.empty:
        continue
    sdi_values = df.iloc[:, 1:401].to_numpy(dtype=float)
    all_sdi.append(sdi_values)

# Concatenate all patients
all_sdi_array = np.vstack(all_sdi)  # shape: (total_patients, 400)
mean_sdi = np.nanmean(all_sdi_array, axis=0)  # shape: (400,)

# -----------------------------
# Step 3: Load Atlas and Create Volume
# -----------------------------
atlas_img = nib.load(atlas_path)
atlas_data = atlas_img.get_fdata()
sdi_volume = np.zeros_like(atlas_data)

for i in range(400):
    region_label = i + 1
    sdi_volume[atlas_data == region_label] = mean_sdi[i]

sdi_img = nib.Nifti1Image(sdi_volume, affine=atlas_img.affine)

# -----------------------------
# Step 4: Plot the Common SDI Atlas
# -----------------------------
plotting.plot_stat_map(
    sdi_img,
    bg_img=template,
    title="Mean SDI Atlas (All ABIDE II Sites)",
    display_mode="ortho",
    threshold=np.percentile(mean_sdi, 20),
    cmap="viridis"
)
plotting.show()


In [None]:
import pandas as pd
import numpy as np

# --- List all your ABIDE II SDI CSVs across sites here ---
csv_files = [
    "/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_ip.csv",
    "/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_nyu1.csv",
    "/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_nyu2.csv",
    "/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_sdsu.csv",
    "/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_bni.csv"
]

all_sdi = []

# --- Aggregate SDI data from all patients ---
for csv_path in csv_files:
    df = pd.read_csv(csv_path)
    sdi_data = df.iloc[:, 1:401].to_numpy(dtype=float)  # Only SDI columns
    all_sdi.append(sdi_data)

# Combine all patients across all sites
sdi_matrix = np.vstack(all_sdi)  # Shape: (total_patients, 400)

# --- Compute average SDI across patients for each region ---
mean_sdi_per_region = np.nanmean(sdi_matrix, axis=0)  # shape (400,)

# --- Identify top/bottom N regions ---
N = 10
top_indices = np.argsort(mean_sdi_per_region)[-N:][::-1]
bottom_indices = np.argsort(mean_sdi_per_region)[:N]

# --- Print results ---
print(f"\nTop {N} Coupled Regions (Highest Mean SDI):")
for i in top_indices:
    print(f"Region {i+1} — Mean SDI: {mean_sdi_per_region[i]:.4f}")

print(f"\nTop {N} Decoupled Regions (Lowest Mean SDI):")
for i in bottom_indices:
    print(f"Region {i+1} — Mean SDI: {mean_sdi_per_region[i]:.4f}")


In [None]:
from nilearn import datasets

# Fetch the Schaefer 2018 atlas with 400 regions and 7-network solution
schaefer = datasets.fetch_atlas_schaefer_2018(n_rois=400, yeo_networks=7, resolution_mm=1)

# Extract the labels
region_labels = schaefer['labels']  # This is a list of region names


In [None]:
import pandas as pd
import numpy as np
import os

# -------- Define Schaefer-400 Parcel Names --------
network_names = [
    ("Vis", 75),
    ("SomMot", 60),
    ("DorsAttn", 52),
    ("SalVentAttn", 50),
    ("Limbic", 25),
    ("Cont", 58),
    ("Default", 80)
]

parcel_names = {}
idx = 1
for net, count in network_names:
    for i in range(1, count + 1):
        parcel_names[idx] = f"{net}_{i}"
        idx += 1

# -------- All CSV Files Across Sites --------
csv_paths = [
    "/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_ip.csv",
    "/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_nyu1.csv",
    "/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_nyu2.csv",
    "/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_sdsu.csv",
    "/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_bni.csv"
]

all_sdi = []

# -------- Load and Stack SDI Data from All Sites --------
for path in csv_paths:
    df = pd.read_csv(path)
    sdi_vals = df.iloc[:, 1:401].to_numpy(dtype=float)
    all_sdi.append(sdi_vals)

sdi_matrix = np.vstack(all_sdi)  # shape: (total_patients, 400)

# -------- Compute Mean SDI Per Region Across Patients --------
mean_sdi = np.nanmean(sdi_matrix, axis=0)

# -------- Identify Top Coupled and Decoupled Regions --------
N = 10
top_indices = np.argsort(mean_sdi)[-N:][::-1]
bottom_indices = np.argsort(mean_sdi)[:N]

# -------- Print Results with Parcel Names --------
print(f"Top {N} Coupled Regions (Highest Mean SDI):")
for i in top_indices:
    parcel_id = i + 1
    name = parcel_names.get(parcel_id, f"Region_{parcel_id}")
    print(f"{name} (Region {parcel_id}) — Mean SDI: {mean_sdi[i]:.4f}")

print(f"\nTop {N} Decoupled Regions (Lowest Mean SDI):")
for i in bottom_indices:
    parcel_id = i + 1
    name = parcel_names.get(parcel_id, f"Region_{parcel_id}")
    print(f"{name} (Region {parcel_id}) — Mean SDI: {mean_sdi[i]:.4f}")


In [None]:
import numpy as np
import pandas as pd
import nibabel as nib
import matplotlib.pyplot as plt
from nilearn import plotting
import matplotlib.patches as mpatches

# --- Step 1: Define Paths ---
atlas_path = "/Users/arnavkarnik/Documents/Classification/Schaefer2018_400Parcels_7Networks_order_FSLMNI152_1mm.nii"
csv_paths = [
    "/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_ip.csv",
    "/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_nyu1.csv",
    "/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_nyu2.csv",
    "/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_sdsu.csv",
    "/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_bni.csv"
]

# --- Step 2: Load and Combine SDI Data ---
all_sdi = []
for path in csv_paths:
    df = pd.read_csv(path)
    sdi_vals = df.iloc[:, 1:401].to_numpy(dtype=float)
    all_sdi.append(sdi_vals)

sdi_matrix = np.vstack(all_sdi)  # (total_patients, 400 regions)

# --- Step 3: Compute Mean SDI and Identify Top/Bottom Regions ---
mean_sdi = np.nanmean(sdi_matrix, axis=0)
top_coupled = np.argsort(mean_sdi)[-10:][::-1] + 1     # Convert to 1-based
top_decoupled = np.argsort(mean_sdi)[:10] + 1

# --- Step 4: Load Atlas and Generate Region Masks ---
atlas_img = nib.load(atlas_path)
atlas_data = atlas_img.get_fdata()

top_coupled_mask = np.isin(atlas_data, top_coupled).astype(int)
top_decoupled_mask = np.isin(atlas_data, top_decoupled).astype(int)

# 0: background, 1: coupled (red), 2: decoupled (blue)
combined_mask = top_coupled_mask + (top_decoupled_mask * 2)
combined_img = nib.Nifti1Image(combined_mask.astype(np.uint8), affine=atlas_img.affine)

# --- Step 5: Plotting ---
display = plotting.plot_roi(
    combined_img,
    title="Top Coupled (red) & Decoupled (blue) Regions (All Sites)",
    cmap=plt.cm.get_cmap('coolwarm', 3),
    alpha=0.7
)

# Add legend
red_patch = mpatches.Patch(color='red', label='Top Coupled')
blue_patch = mpatches.Patch(color='blue', label='Top Decoupled')
plt.legend(handles=[red_patch, blue_patch], loc='lower left')

plotting.show()


In [None]:
import numpy as np
import pandas as pd
import nibabel as nib
from nilearn import plotting, image, datasets
import matplotlib.pyplot as plt

# Define sites
sites = {
    'IP': {
        'sdi': '/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_ip.csv',
        'phenotype': '/Users/arnavkarnik/Documents/Classification/Phenotypes_ABIDE2/IP_1_phenotypes.csv'
    },
    'BNI': {
        'sdi': '/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_bni.csv',
        'phenotype': '/Users/arnavkarnik/Documents/Classification/Phenotypes_ABIDE2/BNI_1_phenotypes.csv'
    },
    'NYU1': {
        'sdi': '/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_nyu1.csv',
        'phenotype': '/Users/arnavkarnik/Documents/Classification/Phenotypes_ABIDE2/NYU_1_phenotypes.csv'
    },
    'NYU2': {
        'sdi': '/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_nyu2.csv',
        'phenotype': '/Users/arnavkarnik/Documents/Classification/Phenotypes_ABIDE2/NYU_2_phenotypes.csv'
    },
    'SDSU': {
        'sdi': '/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_sdsu.csv',
        'phenotype': '/Users/arnavkarnik/Documents/Classification/Phenotypes_ABIDE2/SDSU_1_phenotypes.csv'
    }
}

# --- Step 1: Load atlas and compute coordinates (do this once) ---
atlas_path = "/Users/arnavkarnik/Documents/Classification/Schaefer2018_400Parcels_7Networks_order_FSLMNI152_1mm.nii"
atlas_img = nib.load(atlas_path)
atlas_data = atlas_img.get_fdata()
affine = atlas_img.affine

# Compute center-of-mass (coordinates) for each region
coords = []
for label in range(0, 400):  # 0-based indexing (0-399)
    mask = atlas_data == (label + 1)  # Atlas labels are still 1-based (1-400)
    if np.any(mask):
        indices = np.argwhere(mask)
        center_voxel = np.mean(indices, axis=0)
        center_mm = nib.affines.apply_affine(affine, center_voxel)
        coords.append(center_mm)
coords = np.array(coords)  # shape (400, 3)

# --- Step 2: Process each site ---
for site_name, paths in sites.items():
    print(f"Processing site: {site_name}")
    
    # Load SDI values for this site
    df = pd.read_csv(paths['sdi'])
    sdi_matrix = df.iloc[:, 1:401].to_numpy(dtype=float)  # (subjects, 400 ROIs)
    mean_sdi = np.nanmean(sdi_matrix, axis=0)
    
    # Normalize SDI values to 0-1 range
    sdi_normalized = (mean_sdi - np.min(mean_sdi)) / (np.max(mean_sdi) - np.min(mean_sdi))
    sdi_scaled = 20 + sdi_normalized * 80  # marker size from 20 to 100
    
    # Create visualization for this site
    fig = plotting.plot_markers(
        node_coords=coords,
        node_values=sdi_normalized,
        node_size=sdi_scaled,
        display_mode='lyrz',
        title=f"Schaefer-400 ROIs – Mean SDI for {site_name}"
    )
    
    plt.show()
    
    # Optional: Print some statistics
    print(f"  Mean SDI range: {np.min(mean_sdi):.4f} to {np.max(mean_sdi):.4f}")
    print(f"  Number of subjects: {sdi_matrix.shape[0]}")
    print(f"  Number of ROIs with valid data: {np.sum(~np.isnan(mean_sdi))}")
    print()

### Classification of Healthy vs Autistic patients in ABIDE 2

In [None]:
import pandas as pd
import numpy as np
import os
from sklearn.svm import SVC
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.model_selection import (train_test_split, cross_val_score, 
                                   StratifiedKFold, GridSearchCV)
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter

# File paths - Configure to include/exclude sites as needed
sites = {
    'IP': {
        'sdi': '/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_ip.csv',
        'phenotype': '/Users/arnavkarnik/Documents/Classification/Phenotypes_ABIDE2/IP_1_phenotypes.csv'
    },
    'BNI': {
        'sdi': '/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_bni.csv',
        'phenotype': '/Users/arnavkarnik/Documents/Classification/Phenotypes_ABIDE2/BNI_1_phenotypes.csv'
    },
    'NYU1': {
        'sdi': '/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_nyu1.csv',
        'phenotype': '/Users/arnavkarnik/Documents/Classification/Phenotypes_ABIDE2/NYU_1_phenotypes.csv'
    },
    'NYU2': {
        'sdi': '/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_nyu2.csv',
        'phenotype': '/Users/arnavkarnik/Documents/Classification/Phenotypes_ABIDE2/NYU_2_phenotypes.csv'
    },
    'SDSU': {
        'sdi': '/Users/arnavkarnik/Documents/Classification/results_ABIDE2SC/sdi_informed_energy_normalized_sdsu.csv',
        'phenotype': '/Users/arnavkarnik/Documents/Classification/Phenotypes_ABIDE2/SDSU_1_phenotypes.csv'
    }
}

def load_site_data(sites, exclude_sites=None):
    """Load data from each site separately."""
    if exclude_sites is None:
        exclude_sites = []
    
    site_data = {}
    all_features = []
    all_labels = []
    
    print("Loading data from sites:")
    print("-" * 40)
    
    for site, paths in sites.items():
        if site in exclude_sites:
            print(f"{site}: EXCLUDED")
            continue
            
        try:
            sdi_df = pd.read_csv(paths['sdi'])
            phen_df = pd.read_csv(paths['phenotype'])
            
            # Normalize patient IDs
            sdi_df['PatientID'] = sdi_df['PatientID'].astype(str)
            phen_df['SUB_ID'] = phen_df['SUB_ID'].astype(str)
            
            # Merge on patient ID
            merged = pd.merge(sdi_df, phen_df, left_on='PatientID', right_on='SUB_ID')
            
            # Extract features and labels
            features = merged.filter(like='SDI_Node').values
            labels = merged['DX_GROUP'].values  # 1 = TD, 2 = ASD
            
            site_data[site] = {
                'features': features,
                'labels': labels
            }
            
            # For combined analysis
            all_features.append(features)
            all_labels.append(labels)
            
            # Print basic stats
            label_counts = Counter(labels)
            print(f"{site}: {len(labels)} subjects (TD: {label_counts.get(1, 0)}, ASD: {label_counts.get(2, 0)}) - {features.shape[1]} features")
            
        except Exception as e:
            print(f"{site}: Failed to load - {e}")
    
    # Combined dataset
    if all_features:
        X_combined = np.vstack(all_features)
        y_combined = np.concatenate(all_labels)
        
        print(f"\nCombined dataset: {X_combined.shape}")
        print(f"Label distribution: TD={np.sum(y_combined==1)}, ASD={np.sum(y_combined==2)}")
    else:
        X_combined, y_combined = None, None
    
    return site_data, X_combined, y_combined

def basic_cross_validation(X, y, n_splits=5):
    """Perform basic stratified cross-validation."""
    print(f"\nBasic {n_splits}-Fold Cross-Validation:")
    print("-" * 40)
    
    clf = SVC(kernel='rbf', C=1, gamma='scale', random_state=42)
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
    scores = cross_val_score(clf, X, y, cv=skf)
    
    print(f"Cross-validated Accuracy Scores: {scores}")
    print(f"Mean Accuracy: {np.mean(scores):.4f} ± {np.std(scores):.4f}")
    
    return scores

def train_test_evaluation(X, y, test_size=0.2):
    """Train-test split evaluation with confusion matrix."""
    print(f"\nTrain-Test Split Evaluation (test_size={test_size}):")
    print("-" * 40)
    
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_size, stratify=y, random_state=42
    )
    
    clf = SVC(kernel='rbf', C=1, gamma='scale', random_state=42)
    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)
    
    final_accuracy = accuracy_score(y_test, y_pred)
    
    print(f"Training set: {len(y_train)} samples")
    print(f"Test set: {len(y_test)} samples")
    print(f"\n=== FINAL TEST ACCURACY: {final_accuracy:.4f} ({final_accuracy*100:.2f}%) ===")
    
    print("\nClassification Report:")
    print(classification_report(y_test, y_pred, target_names=['TD', 'ASD']))
    
    print("\nConfusion Matrix:")
    cm = confusion_matrix(y_test, y_pred)
    print(cm)
    
    return y_test, y_pred, final_accuracy

def leave_one_site_out_cv(site_data):
    """Perform Leave-One-Site-Out Cross-Validation."""
    
    site_names = list(site_data.keys())
    if len(site_names) < 2:
        print("Need at least 2 sites for LOSO-CV")
        return None
    
    results = []
    
    print(f"\nLeave-One-Site-Out CV ({len(site_names)} folds):")
    print("-" * 50)
    
    for test_site in site_names:
        # Get training sites
        train_sites = [s for s in site_names if s != test_site]
        
        # Combine training data
        train_features = []
        train_labels = []
        for train_site in train_sites:
            train_features.append(site_data[train_site]['features'])
            train_labels.append(site_data[train_site]['labels'])
        
        X_train = np.vstack(train_features)
        y_train = np.concatenate(train_labels)
        
        # Test data
        X_test = site_data[test_site]['features']
        y_test = site_data[test_site]['labels']
        
        # Preprocessing
        scaler = StandardScaler()
        X_train_scaled = scaler.fit_transform(X_train)
        X_test_scaled = scaler.transform(X_test)
        
        # Hyperparameter tuning
        param_grid = {
            'C': [0.1, 1, 10],
            'gamma': ['scale', 0.01, 0.1],
            'kernel': ['rbf', 'linear']
        }
        
        svm = SVC(probability=True, random_state=42)
        cv = StratifiedKFold(n_splits=3, shuffle=True, random_state=42)
        grid_search = GridSearchCV(svm, param_grid, cv=cv, scoring='accuracy')
        grid_search.fit(X_train_scaled, y_train)
        
        # Train best model
        best_model = grid_search.best_estimator_
        
        # Predict
        y_pred = best_model.predict(X_test_scaled)
        accuracy = accuracy_score(y_test, y_pred)
        
        # Store results
        results.append({
            'test_site': test_site,
            'train_sites': train_sites,
            'accuracy': accuracy,
            'n_train': len(y_train),
            'n_test': len(y_test),
            'y_true': y_test,
            'y_pred': y_pred,
            'best_params': grid_search.best_params_
        })
        
        print(f"Test site: {test_site:5} | Accuracy: {accuracy:.3f} | Train: {len(y_train)} | Test: {len(y_test)} | Best params: {grid_search.best_params_}")
    
    return results

def summarize_loso_results(results):
    """Summarize LOSO-CV results."""
    
    accuracies = [r['accuracy'] for r in results]
    mean_acc = np.mean(accuracies)
    std_acc = np.std(accuracies)
    
    print(f"\nLOSO-CV Summary:")
    print("-" * 30)
    print(f"Mean accuracy: {mean_acc:.3f} ± {std_acc:.3f}")
    print(f"Range: {np.min(accuracies):.3f} - {np.max(accuracies):.3f}")
    
    # Per-site results
    print(f"\nPer-site results:")
    for r in results:
        print(f"{r['test_site']}: {r['accuracy']:.3f}")
    
    # Overall confusion matrix
    all_true = np.concatenate([r['y_true'] for r in results])
    all_pred = np.concatenate([r['y_pred'] for r in results])
    overall_accuracy = accuracy_score(all_true, all_pred)
    
    print(f"\nOverall LOSO accuracy: {overall_accuracy:.3f}")
    print("\nOverall LOSO classification report:")
    print(classification_report(all_true, all_pred, target_names=['TD', 'ASD']))
    
    # Interpretation
    if mean_acc >= 0.75:
        print(f"\n✓ Good cross-site generalization (accuracy: {mean_acc:.1%})")
    elif mean_acc >= 0.65:
        print(f"\n~ Moderate cross-site generalization (accuracy: {mean_acc:.1%})")
        print("  Consider site harmonization techniques")
    else:
        print(f"\n✗ Poor cross-site generalization (accuracy: {mean_acc:.1%})")
        print("  Strong site effects detected")
    
    return all_true, all_pred, mean_acc

def create_visualizations(basic_scores=None, y_test=None, y_pred=None, 
                         loso_results=None, all_true_loso=None, all_pred_loso=None):
    """Create comprehensive visualization plots."""
    
    # Determine number of subplots needed
    n_plots = 0
    if basic_scores is not None:
        n_plots += 1
    if y_test is not None and y_pred is not None:
        n_plots += 1
    if loso_results is not None:
        n_plots += 2  # accuracy per site + sample sizes
    if all_true_loso is not None and all_pred_loso is not None:
        n_plots += 1
    
    if n_plots == 0:
        print("No data available for visualization")
        return
    
    # Create subplots
    cols = min(3, n_plots)
    rows = (n_plots + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(5*cols, 4*rows))
    if n_plots == 1:
        axes = [axes]
    elif rows == 1:
        axes = axes.flatten()
    else:
        axes = axes.flatten()
    
    plot_idx = 0
    
    # 1. Basic CV scores
    if basic_scores is not None:
        ax = axes[plot_idx]
        ax.bar(range(len(basic_scores)), basic_scores, alpha=0.7)
        ax.set_title('Basic Cross-Validation Scores')
        ax.set_xlabel('Fold')
        ax.set_ylabel('Accuracy')
        ax.set_ylim([0, 1])
        ax.axhline(y=np.mean(basic_scores), color='red', linestyle='--', 
                  label=f'Mean: {np.mean(basic_scores):.3f}')
        ax.legend()
        plot_idx += 1
    
    # 2. Train-test confusion matrix
    if y_test is not None and y_pred is not None:
        ax = axes[plot_idx]
        cm = confusion_matrix(y_test, y_pred)
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                    xticklabels=['TD', 'ASD'], yticklabels=['TD', 'ASD'], ax=ax)
        ax.set_title('Train-Test Confusion Matrix')
        ax.set_xlabel('Predicted')
        ax.set_ylabel('Actual')
        plot_idx += 1
    
    # 3. LOSO accuracy per site
    if loso_results is not None:
        ax = axes[plot_idx]
        sites = [r['test_site'] for r in loso_results]
        accuracies = [r['accuracy'] for r in loso_results]
        
        ax.bar(sites, accuracies, alpha=0.7)
        ax.set_title('LOSO-CV: Accuracy per Test Site')
        ax.set_ylabel('Accuracy')
        ax.set_ylim([0, 1])
        ax.axhline(y=np.mean(accuracies), color='red', linestyle='--', 
                  label=f'Mean: {np.mean(accuracies):.3f}')
        ax.legend()
        
        # Add value labels
        for i, acc in enumerate(accuracies):
            ax.text(i, acc + 0.02, f'{acc:.3f}', ha='center', va='bottom')
        plot_idx += 1
        
        # 4. LOSO sample sizes
        if plot_idx < len(axes):
            ax = axes[plot_idx]
            n_train = [r['n_train'] for r in loso_results]
            n_test = [r['n_test'] for r in loso_results]
            
            x = np.arange(len(sites))
            width = 0.35
            
            ax.bar(x - width/2, n_train, width, label='Train', alpha=0.7)
            ax.bar(x + width/2, n_test, width, label='Test', alpha=0.7)
            ax.set_title('LOSO-CV: Sample Sizes')
            ax.set_ylabel('Number of Subjects')
            ax.set_xticks(x)
            ax.set_xticklabels(sites)
            ax.legend()
            plot_idx += 1
    
    # 5. LOSO overall confusion matrix
    if all_true_loso is not None and all_pred_loso is not None and plot_idx < len(axes):
        ax = axes[plot_idx]
        cm = confusion_matrix(all_true_loso, all_pred_loso)
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                    xticklabels=['TD', 'ASD'], yticklabels=['TD', 'ASD'], ax=ax)
        ax.set_title('LOSO-CV: Overall Confusion Matrix')
        ax.set_xlabel('Predicted')
        ax.set_ylabel('Actual')
        plot_idx += 1
    
    # Hide unused subplots
    for i in range(plot_idx, len(axes)):
        axes[i].set_visible(False)
    
    plt.tight_layout()
    plt.show()

def main():
    """Main execution function."""
    print("ASD Classification Pipeline")
    print("=" * 50)
    
    # Configuration
    EXCLUDE_SITES = []  # Add site names here to exclude them, e.g., ['NYU2']
    RUN_BASIC_CV = False
    RUN_TRAIN_TEST = True
    RUN_LOSO_CV = False
    CREATE_PLOTS = False
    
    # Load data
    site_data, X_combined, y_combined = load_site_data(sites, exclude_sites=EXCLUDE_SITES)
    
    if X_combined is None or len(site_data) == 0:
        print("No data loaded successfully. Check file paths.")
        return
    
    # Initialize variables for plotting
    basic_scores = None
    y_test, y_pred = None, None
    loso_results = None
    all_true_loso, all_pred_loso = None, None
    
    # 2. Train-Test Split Evaluation only
    if RUN_TRAIN_TEST:
        y_test, y_pred, final_accuracy = train_test_evaluation(X_combined, y_combined)
    
    print("\nAnalysis complete!")
    if EXCLUDE_SITES:
        print(f"Excluded sites: {EXCLUDE_SITES}")
    print(f"Included sites: {list(site_data.keys())}")


if __name__ == "__main__":
    main()