# Granger Causality Analysis - Infant EEG Dataset

This notebook implements Granger causality analysis for the infant resting state EEG dataset (103 subjects, 1-4 sessions each).

## Phase 1: Setup & Data Discovery

## Step 1: Configuration Parameters

All tunable parameters are centralized here for easy adjustment.

In [38]:
import os
from pathlib import Path

# ============================================================================
# CONFIGURATION PARAMETERS
# ============================================================================

# ----------------------------------------------------------------------------
# 1. PATH PARAMETERS
# ----------------------------------------------------------------------------
DATASET_BASE_PATH = Path('/home/alookaladdoo/DPCN-Project/Dataset')
OUTPUT_BASE_PATH = Path('/home/alookaladdoo/DPCN-Project/results')
DERIVATIVES_PATH = DATASET_BASE_PATH / 'derivatives' / 'NeuronicEEG'

# Create output directory structure
OUTPUT_DIRS = {
    'individual': OUTPUT_BASE_PATH / 'individual',
    'group': OUTPUT_BASE_PATH / 'group',
    'qc': OUTPUT_BASE_PATH / 'quality_control',
    'logs': OUTPUT_BASE_PATH / 'logs',
    'plots': OUTPUT_BASE_PATH / 'plots'
}

# Create directories if they don't exist
for dir_path in OUTPUT_DIRS.values():
    dir_path.mkdir(parents=True, exist_ok=True)

print("Output directories created:")
for name, path in OUTPUT_DIRS.items():
    print(f"  {name}: {path}")

# ----------------------------------------------------------------------------
# 2. PREPROCESSING PARAMETERS
# ----------------------------------------------------------------------------
PREPROCESS_PARAMS = {
    # Filter settings
    'highpass_freq': 0.5,      # Hz - High-pass filter cutoff
    'lowpass_freq': 30.0,      # Hz - Low-pass filter cutoff
    'notch_freq': 60.0,        # Hz - Notch filter for power line noise
    'filter_method': 'fir',    # Filter method: 'fir' or 'iir'
    
    # Reference settings
    'reference': 'average',    # 'average', 'common', or list of channel names
    're_reference': True,      # Whether to re-reference data
    
    # Data quality
    'bad_channel_threshold': 0.3,  # Fraction of bad data to mark channel as bad
    'interpolate_bad': True,       # Whether to interpolate bad channels
}

print("\nPreprocessing parameters set:")
for key, value in PREPROCESS_PARAMS.items():
    print(f"  {key}: {value}")

# ----------------------------------------------------------------------------
# 3. SEGMENTATION PARAMETERS
# ----------------------------------------------------------------------------
SEGMENT_PARAMS = {
    'window_length': 10.0,     # seconds - Length of each analysis window
    'window_overlap': 0.5,     # 0-1 - Overlap between windows (50%)
    'min_segment_duration': 5.0,  # seconds - Minimum acceptable segment length
    'use_eyes_closed_only': True,  # Only analyze eyes-closed segments
    'reject_artifacts': True,      # Use derivative annotations to reject artifacts
}

print("\nSegmentation parameters set:")
for key, value in SEGMENT_PARAMS.items():
    print(f"  {key}: {value}")

# ----------------------------------------------------------------------------
# 4. GRANGER CAUSALITY PARAMETERS
# ----------------------------------------------------------------------------
GC_PARAMS = {
    # Model order selection
    'model_order_method': 'aic',   # 'aic', 'bic', 'hqc', or 'fixed'
    'max_order': 50,               # Maximum lag order to test
    'min_order': 1,                # Minimum lag order to test
    'fixed_order': None,           # Use this if model_order_method='fixed'
    
    # GC computation
    'gc_method': 'pairwise',       # 'pairwise' or 'conditional' (multivariate)
    'compute_spectral_gc': True,   # Whether to compute frequency-domain GC
    
    # Frequency bands for spectral analysis
    'freq_bands': {
        'delta': (0.5, 4),
        'theta': (4, 8),
        'alpha': (8, 13),
        'beta': (13, 30),
    },
}

print("\nGranger Causality parameters set:")
for key, value in GC_PARAMS.items():
    if key != 'freq_bands':
        print(f"  {key}: {value}")
print("  Frequency bands:")
for band, (low, high) in GC_PARAMS['freq_bands'].items():
    print(f"    {band}: {low}-{high} Hz")

# ----------------------------------------------------------------------------
# 5. STATISTICAL PARAMETERS
# ----------------------------------------------------------------------------
STAT_PARAMS = {
    'significance_threshold': 0.05,    # p-value threshold
    'correction_method': 'fdr_bh',     # 'fdr_bh', 'bonferroni', 'permutation'
    'n_permutations': 100,            # Number of permutations for testing
    'confidence_level': 0.95,          # Confidence interval level
    'random_seed': 42,                 # For reproducibility
}

print("\nStatistical parameters set:")
for key, value in STAT_PARAMS.items():
    print(f"  {key}: {value}")

# ----------------------------------------------------------------------------
# 6. ANALYSIS PARAMETERS
# ----------------------------------------------------------------------------
ANALYSIS_PARAMS = {
    'analyze_all_pairs': True,         # Analyze all channel pairs
    'channel_pairs': None,             # Specific pairs to analyze (if not all)
    'compute_bidirectional': True,     # Compute both X->Y and Y->X
    'age_bins': [0, 0.25, 0.5, 0.75, 1.0],  # Age bins in years for stratification
}

print("\nAnalysis parameters set:")
for key, value in ANALYSIS_PARAMS.items():
    print(f"  {key}: {value}")

# ----------------------------------------------------------------------------
# 7. VISUALIZATION PARAMETERS
# ----------------------------------------------------------------------------
VIZ_PARAMS = {
    'colormap': 'viridis',             # Colormap for heatmaps
    'figure_dpi': 300,                 # DPI for saved figures
    'figure_format': 'png',            # 'png', 'pdf', 'svg'
    'figure_size': (10, 8),            # Default figure size (width, height)
    'save_individual_plots': True,     # Save plots for each subject
    'save_group_plots': True,          # Save group-level plots
}

print("\nVisualization parameters set:")
for key, value in VIZ_PARAMS.items():
    print(f"  {key}: {value}")

# ----------------------------------------------------------------------------
# 8. COMPUTATIONAL PARAMETERS
# ----------------------------------------------------------------------------
COMPUTE_PARAMS = {
    'n_jobs': -1,                      # Number of parallel jobs (-1 = all cores)
    'verbose': 1,                      # Verbosity level (0, 1, 2)
    'memory_limit_gb': None,           # Memory limit per process (None = no limit)
}

print("\nComputational parameters set:")
for key, value in COMPUTE_PARAMS.items():
    print(f"  {key}: {value}")

# ----------------------------------------------------------------------------
# SUMMARY
# ----------------------------------------------------------------------------
print("\n" + "="*80)
print("CONFIGURATION COMPLETE")
print("="*80)
print(f"Dataset path: {DATASET_BASE_PATH}")
print(f"Output path: {OUTPUT_BASE_PATH}")
print(f"Random seed: {STAT_PARAMS['random_seed']}")
print("="*80)

Output directories created:
  individual: /home/alookaladdoo/DPCN-Project/results/individual
  group: /home/alookaladdoo/DPCN-Project/results/group
  qc: /home/alookaladdoo/DPCN-Project/results/quality_control
  logs: /home/alookaladdoo/DPCN-Project/results/logs
  plots: /home/alookaladdoo/DPCN-Project/results/plots

Preprocessing parameters set:
  highpass_freq: 0.5
  lowpass_freq: 30.0
  notch_freq: 60.0
  filter_method: fir
  reference: average
  re_reference: True
  bad_channel_threshold: 0.3
  interpolate_bad: True

Segmentation parameters set:
  window_length: 10.0
  window_overlap: 0.5
  min_segment_duration: 5.0
  use_eyes_closed_only: True
  reject_artifacts: True

Granger Causality parameters set:
  model_order_method: aic
  max_order: 50
  min_order: 1
  fixed_order: None
  gc_method: pairwise
  compute_spectral_gc: True
  Frequency bands:
    delta: 0.5-4 Hz
    theta: 4-8 Hz
    alpha: 8-13 Hz
    beta: 13-30 Hz

Statistical parameters set:
  significance_threshold: 0.05
 

## Step 2: Environment Setup

Import all necessary libraries and verify their installation.

In [39]:
# ============================================================================
# LIBRARY IMPORTS
# ============================================================================

import warnings
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=DeprecationWarning)

print("Importing libraries...")
print("="*80)

# ----------------------------------------------------------------------------
# Core Scientific Computing
# ----------------------------------------------------------------------------
import numpy as np
import pandas as pd
from scipy import signal, stats
from scipy.stats import pearsonr, spearmanr
import json
import glob
from datetime import datetime
import logging

print("✓ Core libraries: numpy, pandas, scipy")

# ----------------------------------------------------------------------------
# EEG Processing
# ----------------------------------------------------------------------------
try:
    import mne
    print(f"✓ MNE-Python version: {mne.__version__}")
except ImportError:
    print("✗ MNE-Python not found - will need to install")
    mne = None

try:
    import pyedflib
    print(f"✓ pyedflib available")
except ImportError:
    print("⚠ pyedflib not found (optional - MNE can read EDF files)")
    pyedflib = None

# ----------------------------------------------------------------------------
# Statistical Modeling
# ----------------------------------------------------------------------------
try:
    import statsmodels
    from statsmodels.tsa.api import VAR
    from statsmodels.tsa.stattools import grangercausalitytests, adfuller
    from statsmodels.stats.multitest import multipletests
    print(f"✓ statsmodels version: {statsmodels.__version__}")
except ImportError:
    print("✗ statsmodels not found - will need to install")
    statsmodels = None

# ----------------------------------------------------------------------------
# Visualization
# ----------------------------------------------------------------------------
try:
    import matplotlib
    import matplotlib.pyplot as plt
    import seaborn as sns
    
    # Set matplotlib backend and style
    matplotlib.use('Agg')  # Non-interactive backend for saving figures
    plt.style.use('seaborn-v0_8-darkgrid')
    sns.set_palette("husl")
    
    print(f"✓ matplotlib version: {matplotlib.__version__}")
    print(f"✓ seaborn version: {sns.__version__}")
except ImportError as e:
    print(f"✗ Visualization libraries error: {e}")
    matplotlib = None
    plt = None
    sns = None

# ----------------------------------------------------------------------------
# Network Analysis
# ----------------------------------------------------------------------------
try:
    import networkx as nx
    print(f"✓ networkx version: {nx.__version__}")
except ImportError:
    print("✗ networkx not found - will need to install for network analysis")
    nx = None

# ----------------------------------------------------------------------------
# Parallel Processing
# ----------------------------------------------------------------------------
try:
    from joblib import Parallel, delayed
    import multiprocessing
    n_cores = multiprocessing.cpu_count()
    print(f"✓ joblib available - {n_cores} CPU cores detected")
except ImportError:
    print("⚠ joblib not found - parallel processing will be limited")
    Parallel = None
    delayed = None

# ----------------------------------------------------------------------------
# Configure Logging
# ----------------------------------------------------------------------------
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(OUTPUT_DIRS['logs'] / 'analysis.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

print("✓ Logging configured")

# ----------------------------------------------------------------------------
# Set Random Seeds for Reproducibility
# ----------------------------------------------------------------------------
np.random.seed(STAT_PARAMS['random_seed'])
if mne is not None:
    mne.set_log_level('WARNING')
    
print(f"✓ Random seed set to: {STAT_PARAMS['random_seed']}")

# ----------------------------------------------------------------------------
# Environment Summary
# ----------------------------------------------------------------------------
print("\n" + "="*80)
print("ENVIRONMENT SETUP COMPLETE")
print("="*80)

# Check critical dependencies
critical_libs = {
    'numpy': np is not None,
    'pandas': pd is not None,
    'scipy': signal is not None,
    'mne': mne is not None,
    'statsmodels': statsmodels is not None,
    'matplotlib': matplotlib is not None,
}

all_critical = all(critical_libs.values())

if all_critical:
    print("✓ All critical libraries loaded successfully")
else:
    print("⚠ Missing critical libraries:")
    for lib, loaded in critical_libs.items():
        if not loaded:
            print(f"  ✗ {lib}")

print("\nOptional libraries:")
optional_libs = {
    'pyedflib': pyedflib is not None,
    'networkx': nx is not None,
    'joblib': Parallel is not None,
}
for lib, loaded in optional_libs.items():
    status = "✓" if loaded else "✗"
    print(f"  {status} {lib}")

print("="*80)

Importing libraries...
✓ Core libraries: numpy, pandas, scipy
✓ MNE-Python version: 1.10.2
✓ pyedflib available
✓ statsmodels version: 0.14.5
✓ matplotlib version: 3.10.6
✓ seaborn version: 0.13.2
✓ networkx version: 3.5
✓ joblib available - 16 CPU cores detected
✓ Logging configured
✓ Random seed set to: 42

ENVIRONMENT SETUP COMPLETE
✓ All critical libraries loaded successfully

Optional libraries:
  ✓ pyedflib
  ✓ networkx
  ✓ joblib


### Install Missing Dependencies (if needed)

Run this cell only if the environment setup above shows missing critical libraries.

In [40]:
# Uncomment and run this cell if you need to install missing packages

# !pip install numpy pandas scipy matplotlib seaborn
# !pip install mne statsmodels networkx joblib
# !pip install pyedflib  # Optional but recommended

print("To install missing packages, uncomment the pip install commands above and run this cell.")
print("After installation, restart the kernel and re-run from Step 1.")

To install missing packages, uncomment the pip install commands above and run this cell.
After installation, restart the kernel and re-run from Step 1.


## Step 3: Data Inventory

Scan the dataset to create a comprehensive inventory of all subjects, sessions, and metadata.

In [41]:
# ============================================================================
# DATA DISCOVERY AND INVENTORY
# ============================================================================

print("Starting data discovery...")
print("="*80)

# ----------------------------------------------------------------------------
# 1. Load participants metadata
# ----------------------------------------------------------------------------
participants_file = DATASET_BASE_PATH / 'participants.tsv'
if participants_file.exists():
    participants_df = pd.read_csv(participants_file, sep='\t')
    print(f"✓ Loaded participants.tsv: {len(participants_df)} subjects")
    print(f"  - Female: {(participants_df['sex'] == 'F').sum()}")
    print(f"  - Male: {(participants_df['sex'] == 'M').sum()}")
else:
    print("✗ participants.tsv not found!")
    participants_df = None

# ----------------------------------------------------------------------------
# 2. Scan dataset directory structure
# ----------------------------------------------------------------------------
print("\nScanning dataset directory structure...")

subject_dirs = sorted([d for d in DATASET_BASE_PATH.glob('sub-*') if d.is_dir()])
print(f"✓ Found {len(subject_dirs)} subject directories")

# Initialize inventory list
inventory_data = []

# Scan each subject
for subject_dir in subject_dirs:
    subject_id = subject_dir.name
    
    # Find all sessions for this subject
    session_dirs = sorted([d for d in subject_dir.glob('ses-*') if d.is_dir()])
    
    for session_dir in session_dirs:
        session_id = session_dir.name
        
        # Initialize record
        record = {
            'subject_id': subject_id,
            'session_id': session_id,
            'sex': None,
            'age_years': None,
            'age_months': None,
            'eeg_file_path': None,
            'duration_sec': None,
            'n_channels': None,
            'sampling_freq': None,
            'has_events': False,
            'has_annotations': False,
            'has_derivatives': False,
            'file_exists': False,
            'session_count_for_subject': len(session_dirs)
        }
        
        # Get sex from participants.tsv
        if participants_df is not None:
            subject_row = participants_df[participants_df['participant_id'] == subject_id]
            if not subject_row.empty:
                record['sex'] = subject_row.iloc[0]['sex']
        
        # Look for scans.tsv to get age
        scans_file = session_dir / f"{subject_id}_{session_id}_scans.tsv"
        if scans_file.exists():
            scans_df = pd.read_csv(scans_file, sep='\t')
            if 'age_acq_time' in scans_df.columns and not scans_df.empty:
                age_years = scans_df.iloc[0]['age_acq_time']
                record['age_years'] = age_years
                record['age_months'] = age_years * 12  # Convert to months
        
        # Look for EEG data file
        eeg_dir = session_dir / 'eeg'
        if eeg_dir.exists():
            eeg_files = list(eeg_dir.glob(f"{subject_id}_{session_id}_task-EEG_eeg.edf"))
            if eeg_files:
                eeg_file = eeg_files[0]
                record['eeg_file_path'] = str(eeg_file)
                record['file_exists'] = True
                
                # Try to read basic info from EEG file
                try:
                    raw = mne.io.read_raw_edf(eeg_file, preload=False, verbose='ERROR')
                    record['duration_sec'] = raw.times[-1]
                    record['n_channels'] = len(raw.ch_names)
                    record['sampling_freq'] = raw.info['sfreq']
                    del raw  # Free memory
                except Exception as e:
                    logger.warning(f"Could not read {eeg_file}: {e}")
            
            # Check for events file
            events_files = list(eeg_dir.glob(f"{subject_id}_{session_id}_task-EEG_events.tsv"))
            record['has_events'] = len(events_files) > 0
        
        # Check for derivative annotations
        deriv_dir = DERIVATIVES_PATH / subject_id / session_id / 'eeg'
        if deriv_dir.exists():
            annot_files = list(deriv_dir.glob(f"{subject_id}_{session_id}_task-EEG_annotations.tsv"))
            record['has_derivatives'] = len(annot_files) > 0
            record['has_annotations'] = record['has_derivatives']
        
        inventory_data.append(record)

# Create inventory DataFrame
inventory_df = pd.DataFrame(inventory_data)

print(f"\n✓ Scanned {len(inventory_df)} sessions across {len(subject_dirs)} subjects")
print("="*80)

Starting data discovery...
✓ Loaded participants.tsv: 103 subjects
  - Female: 41
  - Male: 62

Scanning dataset directory structure...
✓ Found 103 subject directories

✓ Scanned 130 sessions across 103 subjects

✓ Scanned 130 sessions across 103 subjects


In [42]:
# ============================================================================
# GENERATE SUMMARY STATISTICS
# ============================================================================

print("\nDataset Summary Statistics")
print("="*80)

# Basic counts
total_subjects = inventory_df['subject_id'].nunique()
total_sessions = len(inventory_df)
valid_sessions = inventory_df['file_exists'].sum()

print(f"Total Subjects: {total_subjects}")
print(f"Total Sessions: {total_sessions}")
print(f"Valid EEG Files: {valid_sessions}")
print(f"Missing Files: {total_sessions - valid_sessions}")

# Session distribution
print("\nSession Distribution:")
session_counts = inventory_df.groupby('subject_id')['session_id'].count().value_counts().sort_index()
for n_sessions, n_subjects in session_counts.items():
    print(f"  Subjects with {n_sessions} session(s): {n_subjects}")

# Sex distribution
if inventory_df['sex'].notna().any():
    print("\nSex Distribution:")
    sex_counts = inventory_df['sex'].value_counts()
    for sex, count in sex_counts.items():
        print(f"  {sex}: {count} sessions")

# Age distribution
if inventory_df['age_months'].notna().any():
    print("\nAge Distribution:")
    age_data = inventory_df.dropna(subset=['age_months'])
    print(f"  Mean: {age_data['age_months'].mean():.2f} months ({age_data['age_years'].mean():.2f} years)")
    print(f"  Std Dev: {age_data['age_months'].std():.2f} months")
    print(f"  Range: {age_data['age_months'].min():.2f} - {age_data['age_months'].max():.2f} months")
    print(f"  Median: {age_data['age_months'].median():.2f} months")

# Recording duration
if inventory_df['duration_sec'].notna().any():
    print("\nRecording Duration:")
    dur_data = inventory_df.dropna(subset=['duration_sec'])
    print(f"  Mean: {dur_data['duration_sec'].mean():.2f} seconds ({dur_data['duration_sec'].mean()/60:.2f} minutes)")
    print(f"  Range: {dur_data['duration_sec'].min():.2f} - {dur_data['duration_sec'].max():.2f} seconds")

# Channel count
if inventory_df['n_channels'].notna().any():
    print("\nChannel Information:")
    chan_data = inventory_df.dropna(subset=['n_channels'])
    print(f"  Most common channel count: {chan_data['n_channels'].mode().iloc[0]:.0f}")
    print(f"  Channel count range: {chan_data['n_channels'].min():.0f} - {chan_data['n_channels'].max():.0f}")

# Annotations
print("\nAnnotations:")
print(f"  Sessions with events: {inventory_df['has_events'].sum()}")
print(f"  Sessions with derivative annotations: {inventory_df['has_derivatives'].sum()}")

# Data completeness
print("\nData Completeness:")
completeness = {
    'EEG file': (inventory_df['file_exists'].sum() / len(inventory_df) * 100),
    'Age info': (inventory_df['age_months'].notna().sum() / len(inventory_df) * 100),
    'Sex info': (inventory_df['sex'].notna().sum() / len(inventory_df) * 100),
    'Events': (inventory_df['has_events'].sum() / len(inventory_df) * 100),
    'Annotations': (inventory_df['has_annotations'].sum() / len(inventory_df) * 100),
}
for field, pct in completeness.items():
    print(f"  {field}: {pct:.1f}%")

print("="*80)


Dataset Summary Statistics
Total Subjects: 103
Total Sessions: 130
Valid EEG Files: 130
Missing Files: 0

Session Distribution:
  Subjects with 1 session(s): 81
  Subjects with 2 session(s): 18
  Subjects with 3 session(s): 3
  Subjects with 4 session(s): 1

Sex Distribution:
  M: 79 sessions
  F: 51 sessions

Age Distribution:
  Mean: 5.36 months (0.45 years)
  Std Dev: 3.09 months
  Range: 0.26 - 12.56 months
  Median: 4.79 months

Recording Duration:
  Mean: 632.46 seconds (10.54 minutes)
  Range: 32.00 - 1783.99 seconds

Channel Information:
  Most common channel count: 19
  Channel count range: 19 - 24

Annotations:
  Sessions with events: 130
  Sessions with derivative annotations: 130

Data Completeness:
  EEG file: 100.0%
  Age info: 100.0%
  Sex info: 100.0%
  Events: 100.0%
  Annotations: 100.0%


In [43]:
# ============================================================================
# SAVE INVENTORY TO FILE
# ============================================================================

# Save master inventory CSV
inventory_csv_path = OUTPUT_BASE_PATH / 'subject_session_inventory.csv'
inventory_df.to_csv(inventory_csv_path, index=False)
print(f"\n✓ Saved inventory to: {inventory_csv_path}")

# Save summary statistics to text file
summary_txt_path = OUTPUT_BASE_PATH / 'data_summary_statistics.txt'
with open(summary_txt_path, 'w') as f:
    f.write("="*80 + "\n")
    f.write("INFANT EEG DATASET - SUMMARY STATISTICS\n")
    f.write("="*80 + "\n\n")
    
    f.write(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
    
    f.write(f"Total Subjects: {total_subjects}\n")
    f.write(f"Total Sessions: {total_sessions}\n")
    f.write(f"Valid EEG Files: {valid_sessions}\n\n")
    
    f.write("Session Distribution:\n")
    for n_sessions, n_subjects in session_counts.items():
        f.write(f"  Subjects with {n_sessions} session(s): {n_subjects}\n")
    f.write("\n")
    
    if inventory_df['sex'].notna().any():
        f.write("Sex Distribution:\n")
        sex_counts = inventory_df['sex'].value_counts()
        for sex, count in sex_counts.items():
            f.write(f"  {sex}: {count} sessions\n")
        f.write("\n")
    
    if inventory_df['age_months'].notna().any():
        age_data = inventory_df.dropna(subset=['age_months'])
        f.write("Age Distribution:\n")
        f.write(f"  Mean: {age_data['age_months'].mean():.2f} months\n")
        f.write(f"  Std Dev: {age_data['age_months'].std():.2f} months\n")
        f.write(f"  Range: {age_data['age_months'].min():.2f} - {age_data['age_months'].max():.2f} months\n")
        f.write(f"  Median: {age_data['age_months'].median():.2f} months\n\n")
    
    if inventory_df['duration_sec'].notna().any():
        dur_data = inventory_df.dropna(subset=['duration_sec'])
        f.write("Recording Duration:\n")
        f.write(f"  Mean: {dur_data['duration_sec'].mean():.2f} seconds\n")
        f.write(f"  Range: {dur_data['duration_sec'].min():.2f} - {dur_data['duration_sec'].max():.2f} seconds\n\n")
    
    f.write("Data Completeness:\n")
    for field, pct in completeness.items():
        f.write(f"  {field}: {pct:.1f}%\n")
    
    f.write("\n" + "="*80 + "\n")

print(f"✓ Saved summary statistics to: {summary_txt_path}")

# Display first few rows of inventory
print("\nFirst 10 rows of inventory:")
print(inventory_df.head(10).to_string())

print("\n" + "="*80)
print("DATA INVENTORY COMPLETE")
print("="*80)


✓ Saved inventory to: /home/alookaladdoo/DPCN-Project/results/subject_session_inventory.csv
✓ Saved summary statistics to: /home/alookaladdoo/DPCN-Project/results/data_summary_statistics.txt

First 10 rows of inventory:
      subject_id session_id sex  age_years  age_months                                                                                         eeg_file_path  duration_sec  n_channels  sampling_freq  has_events  has_annotations  has_derivatives  file_exists  session_count_for_subject
0  sub-NORB00001      ses-1   M     0.4071      4.8852  /home/alookaladdoo/DPCN-Project/Dataset/sub-NORB00001/ses-1/eeg/sub-NORB00001_ses-1_task-EEG_eeg.edf       713.995          21          200.0        True             True             True         True                          1
1  sub-NORB00002      ses-1   M     0.3907      4.6884  /home/alookaladdoo/DPCN-Project/Dataset/sub-NORB00002/ses-1/eeg/sub-NORB00002_ses-1_task-EEG_eeg.edf       527.995          19          200.0        True  

### Visualize Data Distribution

Quick visualization of the dataset characteristics.

In [44]:
# ============================================================================
# VISUALIZE DATA DISTRIBUTION
# ============================================================================

# Create figure with subplots
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
fig.suptitle('Dataset Overview - Infant EEG', fontsize=16, fontweight='bold')

# 1. Session count distribution
ax = axes[0, 0]
session_counts.plot(kind='bar', ax=ax, color='skyblue', edgecolor='black')
ax.set_xlabel('Number of Sessions')
ax.set_ylabel('Number of Subjects')
ax.set_title('Sessions per Subject')
ax.grid(axis='y', alpha=0.3)

# 2. Age distribution
ax = axes[0, 1]
if inventory_df['age_months'].notna().any():
    age_data = inventory_df.dropna(subset=['age_months'])
    ax.hist(age_data['age_months'], bins=20, color='lightcoral', edgecolor='black', alpha=0.7)
    ax.axvline(age_data['age_months'].mean(), color='red', linestyle='--', linewidth=2, label=f'Mean: {age_data["age_months"].mean():.1f}')
    ax.set_xlabel('Age (months)')
    ax.set_ylabel('Number of Sessions')
    ax.set_title('Age Distribution')
    ax.legend()
    ax.grid(axis='y', alpha=0.3)
else:
    ax.text(0.5, 0.5, 'No age data', ha='center', va='center', transform=ax.transAxes)

# 3. Recording duration distribution
ax = axes[0, 2]
if inventory_df['duration_sec'].notna().any():
    dur_data = inventory_df.dropna(subset=['duration_sec'])
    ax.hist(dur_data['duration_sec'] / 60, bins=20, color='lightgreen', edgecolor='black', alpha=0.7)
    ax.set_xlabel('Duration (minutes)')
    ax.set_ylabel('Number of Sessions')
    ax.set_title('Recording Duration')
    ax.grid(axis='y', alpha=0.3)
else:
    ax.text(0.5, 0.5, 'No duration data', ha='center', va='center', transform=ax.transAxes)

# 4. Sex distribution
ax = axes[1, 0]
if inventory_df['sex'].notna().any():
    sex_counts = inventory_df['sex'].value_counts()
    colors = ['lightblue' if x == 'M' else 'pink' for x in sex_counts.index]
    sex_counts.plot(kind='bar', ax=ax, color=colors, edgecolor='black')
    ax.set_xlabel('Sex')
    ax.set_ylabel('Number of Sessions')
    ax.set_title('Sex Distribution')
    ax.set_xticklabels(sex_counts.index, rotation=0)
    ax.grid(axis='y', alpha=0.3)
else:
    ax.text(0.5, 0.5, 'No sex data', ha='center', va='center', transform=ax.transAxes)

# 5. Data completeness
ax = axes[1, 1]
completeness_df = pd.Series(completeness)
colors_comp = ['green' if x > 90 else 'orange' if x > 70 else 'red' for x in completeness_df.values]
completeness_df.plot(kind='barh', ax=ax, color=colors_comp, edgecolor='black')
ax.set_xlabel('Completeness (%)')
ax.set_title('Data Completeness')
ax.set_xlim(0, 100)
ax.grid(axis='x', alpha=0.3)

# 6. Channel count distribution
ax = axes[1, 2]
if inventory_df['n_channels'].notna().any():
    chan_data = inventory_df.dropna(subset=['n_channels'])
    chan_counts = chan_data['n_channels'].value_counts().sort_index()
    chan_counts.plot(kind='bar', ax=ax, color='mediumpurple', edgecolor='black')
    ax.set_xlabel('Number of Channels')
    ax.set_ylabel('Number of Sessions')
    ax.set_title('Channel Count Distribution')
    ax.grid(axis='y', alpha=0.3)
else:
    ax.text(0.5, 0.5, 'No channel data', ha='center', va='center', transform=ax.transAxes)

plt.tight_layout()

# Save figure
overview_plot_path = OUTPUT_DIRS['plots'] / 'dataset_overview.png'
plt.savefig(overview_plot_path, dpi=VIZ_PARAMS['figure_dpi'], bbox_inches='tight')
print(f"\n✓ Saved overview plot to: {overview_plot_path}")

plt.show()

print("\n" + "="*80)
print("VISUALIZATION COMPLETE")
print("="*80)


✓ Saved overview plot to: /home/alookaladdoo/DPCN-Project/results/plots/dataset_overview.png

VISUALIZATION COMPLETE


  plt.show()


---

## Phase 2: Core Pipeline (Test on Sample Subjects)

This phase implements the core processing functions. Test on 5-10 subjects before running on the full dataset.

## Step 4: Load & Validate

Functions to load EEG data and check data quality.

In [45]:
# ============================================================================
# LOAD EEG DATA FUNCTION
# ============================================================================

def load_eeg_data(subject_id, session_id, preload=True):
    """
    Load EEG data for a specific subject and session.
    
    Parameters:
    -----------
    subject_id : str
        Subject identifier (e.g., 'sub-NORB00001')
    session_id : str
        Session identifier (e.g., 'ses-1')
    preload : bool
        Whether to preload data into memory
        
    Returns:
    --------
    raw : mne.io.Raw or None
        Raw EEG data object
    metadata : dict
        Dictionary containing metadata about the recording
    """
    
    metadata = {
        'subject_id': subject_id,
        'session_id': session_id,
        'loaded': False,
        'error': None
    }
    
    try:
        # Construct file path
        eeg_file = DATASET_BASE_PATH / subject_id / session_id / 'eeg' / f"{subject_id}_{session_id}_task-EEG_eeg.edf"
        
        if not eeg_file.exists():
            metadata['error'] = f"File not found: {eeg_file}"
            logger.warning(metadata['error'])
            return None, metadata
        
        # Load EEG data
        raw = mne.io.read_raw_edf(eeg_file, preload=preload, verbose='ERROR')
        
        # Extract metadata
        metadata['file_path'] = str(eeg_file)
        metadata['duration_sec'] = raw.times[-1]
        metadata['n_channels'] = len(raw.ch_names)
        metadata['channel_names'] = raw.ch_names
        metadata['sampling_freq'] = raw.info['sfreq']
        metadata['loaded'] = True
        
        logger.info(f"Loaded {subject_id}/{session_id}: {metadata['n_channels']} channels, "
                   f"{metadata['duration_sec']:.1f}s @ {metadata['sampling_freq']}Hz")
        
        return raw, metadata
        
    except Exception as e:
        metadata['error'] = str(e)
        logger.error(f"Error loading {subject_id}/{session_id}: {e}")
        return None, metadata


print("✓ load_eeg_data() function defined")

✓ load_eeg_data() function defined


In [46]:
# ============================================================================
# VALIDATE EEG DATA FUNCTION
# ============================================================================

def validate_eeg_data(raw):
    """
    Validate EEG data quality and identify potential issues.
    
    Parameters:
    -----------
    raw : mne.io.Raw
        Raw EEG data object
        
    Returns:
    --------
    validation_report : dict
        Dictionary containing validation results and identified issues
    """
    
    validation_report = {
        'is_valid': True,
        'issues': [],
        'warnings': [],
        'bad_channels': [],
        'sampling_rate': raw.info['sfreq'],
        'n_channels': len(raw.ch_names),
        'duration_sec': raw.times[-1]
    }
    
    # Check 1: Verify sampling rate consistency
    expected_sfreq = 200.0  # Hz
    if abs(raw.info['sfreq'] - expected_sfreq) > 0.1:
        validation_report['warnings'].append(
            f"Sampling rate {raw.info['sfreq']} Hz differs from expected {expected_sfreq} Hz"
        )
    
    # Check 2: Verify minimum number of channels
    min_channels = 10
    if len(raw.ch_names) < min_channels:
        validation_report['issues'].append(
            f"Only {len(raw.ch_names)} channels (minimum {min_channels} expected)"
        )
        validation_report['is_valid'] = False
    
    # Check 3: Check for flat channels
    data = raw.get_data()
    flat_threshold = 1e-10  # Very small value
    
    for ch_idx, ch_name in enumerate(raw.ch_names):
        ch_data = data[ch_idx, :]
        
        # Check if channel is flat (no variation)
        if np.std(ch_data) < flat_threshold:
            validation_report['bad_channels'].append(ch_name)
            validation_report['warnings'].append(f"Channel {ch_name} appears flat (std < {flat_threshold})")
        
        # Check for extreme values
        if np.any(np.abs(ch_data) > 1e4):  # Very large amplitude (> 10,000 µV)
            validation_report['warnings'].append(f"Channel {ch_name} has extreme amplitude values")
        
        # Check for NaN or Inf
        if np.any(np.isnan(ch_data)) or np.any(np.isinf(ch_data)):
            validation_report['bad_channels'].append(ch_name)
            validation_report['issues'].append(f"Channel {ch_name} contains NaN or Inf values")
            validation_report['is_valid'] = False
    
    # Check 4: Verify minimum duration
    min_duration = 30.0  # seconds
    if raw.times[-1] < min_duration:
        validation_report['issues'].append(
            f"Recording duration {raw.times[-1]:.1f}s is less than minimum {min_duration}s"
        )
        validation_report['is_valid'] = False
    
    # Summary
    validation_report['n_bad_channels'] = len(validation_report['bad_channels'])
    validation_report['bad_channel_fraction'] = len(validation_report['bad_channels']) / len(raw.ch_names)
    
    if validation_report['bad_channel_fraction'] > PREPROCESS_PARAMS['bad_channel_threshold']:
        validation_report['issues'].append(
            f"Too many bad channels: {validation_report['n_bad_channels']}/{len(raw.ch_names)} "
            f"({validation_report['bad_channel_fraction']*100:.1f}%)"
        )
        validation_report['is_valid'] = False
    
    return validation_report


print("✓ validate_eeg_data() function defined")

✓ validate_eeg_data() function defined


In [47]:
# ============================================================================
# LOAD EVENTS AND ANNOTATIONS FUNCTION
# ============================================================================

def load_events_and_annotations(subject_id, session_id):
    """
    Load events and annotations for a session.
    
    Parameters:
    -----------
    subject_id : str
        Subject identifier
    session_id : str
        Session identifier
        
    Returns:
    --------
    events_df : pd.DataFrame or None
        Events data (eyes open/closed markers)
    annotations_df : pd.DataFrame or None
        Derivative annotations (clean segments)
    """
    
    events_df = None
    annotations_df = None
    
    # Load events file (eyes open/closed)
    events_file = (DATASET_BASE_PATH / subject_id / session_id / 'eeg' / 
                   f"{subject_id}_{session_id}_task-EEG_events.tsv")
    
    if events_file.exists():
        try:
            events_df = pd.read_csv(events_file, sep='\t')
            logger.info(f"Loaded events for {subject_id}/{session_id}: {len(events_df)} events")
        except Exception as e:
            logger.warning(f"Could not load events file {events_file}: {e}")
    
    # Load derivative annotations (clean segments)
    annot_file = (DERIVATIVES_PATH / subject_id / session_id / 'eeg' / 
                  f"{subject_id}_{session_id}_task-EEG_annotations.tsv")
    
    if annot_file.exists():
        try:
            # Read with explicit delimiter and handle potential parsing issues
            annotations_df = pd.read_csv(annot_file, sep='\t', dtype={'onset': float, 'duration': float, 'label': str})
            logger.info(f"Loaded annotations for {subject_id}/{session_id}: {len(annotations_df)} annotations")
        except Exception as e:
            logger.warning(f"Could not load annotations file {annot_file}: {e}")
    
    return events_df, annotations_df


print("✓ load_events_and_annotations() function defined")

✓ load_events_and_annotations() function defined


### Test Load & Validate Functions

Test the functions on a sample subject to verify they work correctly.

In [48]:
# ============================================================================
# TEST ON SAMPLE SUBJECT
# ============================================================================

print("Testing load and validate functions on a sample subject...")
print("="*80)

# Get first valid subject from inventory
valid_subjects = inventory_df[inventory_df['file_exists']].head(3)

if len(valid_subjects) == 0:
    print("No valid subjects found in inventory!")
else:
    for idx, row in valid_subjects.iterrows():
        test_subject = row['subject_id']
        test_session = row['session_id']
        
        print(f"\n{'='*80}")
        print(f"Testing: {test_subject}/{test_session}")
        print(f"{'='*80}")
        
        # Test load function
        print("\n1. Loading EEG data...")
        raw, metadata = load_eeg_data(test_subject, test_session, preload=True)
        
        if raw is None:
            print(f"   ✗ Failed to load: {metadata['error']}")
            continue
        else:
            print(f"   ✓ Loaded successfully")
            print(f"   - Duration: {metadata['duration_sec']:.2f} seconds")
            print(f"   - Channels: {metadata['n_channels']}")
            print(f"   - Sampling rate: {metadata['sampling_freq']} Hz")
            print(f"   - Channel names: {', '.join(metadata['channel_names'][:5])}...")
        
        # Test validation function
        print("\n2. Validating data quality...")
        validation = validate_eeg_data(raw)
        
        if validation['is_valid']:
            print("   ✓ Data passed validation")
        else:
            print("   ✗ Data failed validation")
        
        if validation['bad_channels']:
            print(f"   - Bad channels: {', '.join(validation['bad_channels'])}")
        else:
            print("   - No bad channels detected")
        
        if validation['warnings']:
            print(f"   - Warnings ({len(validation['warnings'])}):")
            for warning in validation['warnings'][:3]:  # Show first 3
                print(f"     • {warning}")
        
        if validation['issues']:
            print(f"   - Issues ({len(validation['issues'])}):")
            for issue in validation['issues']:
                print(f"     • {issue}")
        
        # Test events and annotations
        print("\n3. Loading events and annotations...")
        events_df, annotations_df = load_events_and_annotations(test_subject, test_session)
        
        if events_df is not None:
            print(f"   ✓ Events loaded: {len(events_df)} events")
            if 'trial_type' in events_df.columns:
                event_types = events_df['trial_type'].value_counts()
                for event_type, count in event_types.items():
                    print(f"     - {event_type}: {count}")
        else:
            print("   ⚠ No events file found")
        
        if annotations_df is not None:
            print(f"   ✓ Annotations loaded: {len(annotations_df)} annotations")
            if 'label' in annotations_df.columns:
                annot_types = annotations_df['label'].value_counts()
                for annot_type, count in annot_types.items():
                    print(f"     - {annot_type}: {count}")
        else:
            print("   ⚠ No annotations file found")
        
        # Visualize raw data snippet
        print("\n4. Creating quick visualization...")
        try:
            fig, axes = plt.subplots(2, 1, figsize=(15, 8))
            
            # Plot 10 seconds of data
            duration_to_plot = min(10, raw.times[-1])
            
            # Raw data
            data_snippet, times_snippet = raw[:5, :int(duration_to_plot * raw.info['sfreq'])]
            for i, ch_name in enumerate(raw.ch_names[:5]):
                axes[0].plot(times_snippet, data_snippet[i, :] * 1e6 + i * 100, label=ch_name, alpha=0.7)
            axes[0].set_xlabel('Time (s)')
            axes[0].set_ylabel('Amplitude (µV) - Offset for visualization')
            axes[0].set_title(f'Raw EEG Data - {test_subject}/{test_session} (first 5 channels)')
            axes[0].legend(loc='upper right')
            axes[0].grid(True, alpha=0.3)
            
            # Power Spectral Density
            psd_data = raw.compute_psd(fmin=0.5, fmax=50, n_fft=int(2*raw.info['sfreq']))
            psds, freqs = psd_data.get_data(return_freqs=True)
            
            axes[1].semilogy(freqs, psds[:5, :].T, alpha=0.7)
            axes[1].set_xlabel('Frequency (Hz)')
            axes[1].set_ylabel('PSD (µV²/Hz)')
            axes[1].set_title('Power Spectral Density (first 5 channels)')
            axes[1].grid(True, alpha=0.3)
            axes[1].set_xlim([0, 50])
            
            plt.tight_layout()
            
            # Save figure
            sample_plot_path = OUTPUT_DIRS['plots'] / f'{test_subject}_{test_session}_sample_raw.png'
            plt.savefig(sample_plot_path, dpi=150, bbox_inches='tight')
            print(f"   ✓ Saved plot to: {sample_plot_path}")
            
            plt.show()
            
        except Exception as e:
            print(f"   ✗ Could not create visualization: {e}")
        
        # Clean up
        del raw
        
        print(f"\n{'='*80}")
        print(f"Completed testing {test_subject}/{test_session}")
        print(f"{'='*80}\n")
        
        # Only test first valid subject in detail, show summary for others
        if idx == valid_subjects.index[0]:
            break

print("\n" + "="*80)
print("LOAD & VALIDATE TESTING COMPLETE")
print("="*80)

2025-10-25 17:01:46,499 - INFO - Loaded sub-NORB00001/ses-1: 21 channels, 714.0s @ 200.0Hz
2025-10-25 17:01:46,509 - INFO - Loaded events for sub-NORB00001/ses-1: 3 events
2025-10-25 17:01:46,509 - INFO - Loaded events for sub-NORB00001/ses-1: 3 events


Testing load and validate functions on a sample subject...

Testing: sub-NORB00001/ses-1

1. Loading EEG data...
   ✓ Loaded successfully
   - Duration: 714.00 seconds
   - Channels: 21
   - Sampling rate: 200.0 Hz
   - Channel names: Fp1, Fp2, F3, F4, C3...

2. Validating data quality...
   ✓ Data passed validation
   - No bad channels detected

3. Loading events and annotations...
   ✓ Events loaded: 3 events
     - discontinuity: 1
     - eyes_closed: 1
     - eyes_open: 1
   ⚠ No annotations file found

4. Creating quick visualization...
   ✓ Saved plot to: /home/alookaladdoo/DPCN-Project/results/plots/sub-NORB00001_ses-1_sample_raw.png

Completed testing sub-NORB00001/ses-1


LOAD & VALIDATE TESTING COMPLETE
   ✓ Saved plot to: /home/alookaladdoo/DPCN-Project/results/plots/sub-NORB00001_ses-1_sample_raw.png

Completed testing sub-NORB00001/ses-1


LOAD & VALIDATE TESTING COMPLETE


  plt.show()


## Step 5: Preprocessing

Functions to filter, re-reference, segment, and check stationarity of EEG data.

In [49]:
# ============================================================================
# PREPROCESS EEG FUNCTION
# ============================================================================

def preprocess_eeg(raw, params=None):
    """
    Apply preprocessing steps to EEG data.
    
    Parameters:
    -----------
    raw : mne.io.Raw
        Raw EEG data object
    params : dict, optional
        Preprocessing parameters (uses PREPROCESS_PARAMS if None)
        
    Returns:
    --------
    raw_preprocessed : mne.io.Raw
        Preprocessed EEG data
    preprocess_info : dict
        Information about preprocessing steps applied
    """
    
    if params is None:
        params = PREPROCESS_PARAMS
    
    preprocess_info = {
        'success': False,
        'steps_applied': [],
        'bad_channels_original': [],
        'bad_channels_interpolated': [],
        'reference': None,
        'filters_applied': []
    }
    
    try:
        # Make a copy to avoid modifying original
        raw_prep = raw.copy()
        
        # Step 1: Identify bad channels from validation
        if params['interpolate_bad']:
            # Simple bad channel detection based on extreme values
            data = raw_prep.get_data()
            bad_channels = []
            
            for ch_idx, ch_name in enumerate(raw_prep.ch_names):
                ch_data = data[ch_idx, :]
                # Check for flat or extreme channels
                if np.std(ch_data) < 1e-10 or np.any(np.abs(ch_data) > 1e4):
                    bad_channels.append(ch_name)
            
            preprocess_info['bad_channels_original'] = bad_channels
            
            if bad_channels:
                raw_prep.info['bads'] = bad_channels
                logger.info(f"Marked {len(bad_channels)} bad channels: {', '.join(bad_channels)}")
                preprocess_info['steps_applied'].append(f'marked_{len(bad_channels)}_bad_channels')
        
        # Step 2: Apply bandpass filter
        logger.info(f"Applying bandpass filter: {params['highpass_freq']}-{params['lowpass_freq']} Hz")
        raw_prep.filter(
            l_freq=params['highpass_freq'],
            h_freq=params['lowpass_freq'],
            method=params['filter_method'],
            verbose='ERROR'
        )
        preprocess_info['filters_applied'].append(f"bandpass_{params['highpass_freq']}-{params['lowpass_freq']}_Hz")
        preprocess_info['steps_applied'].append('bandpass_filter')
        
        # Step 3: Apply notch filter
        logger.info(f"Applying notch filter at {params['notch_freq']} Hz")
        raw_prep.notch_filter(
            freqs=params['notch_freq'],
            verbose='ERROR'
        )
        preprocess_info['filters_applied'].append(f"notch_{params['notch_freq']}_Hz")
        preprocess_info['steps_applied'].append('notch_filter')
        
        # Step 4: Interpolate bad channels
        if params['interpolate_bad'] and len(preprocess_info['bad_channels_original']) > 0:
            logger.info(f"Interpolating {len(preprocess_info['bad_channels_original'])} bad channels")
            raw_prep.interpolate_bads(reset_bads=True, verbose='ERROR')
            preprocess_info['bad_channels_interpolated'] = preprocess_info['bad_channels_original'].copy()
            preprocess_info['steps_applied'].append(f'interpolated_{len(preprocess_info["bad_channels_interpolated"])}_channels')
        
        # Step 5: Re-reference
        if params['re_reference']:
            ref_type = params['reference']
            logger.info(f"Re-referencing to: {ref_type}")
            
            if ref_type == 'average':
                raw_prep.set_eeg_reference('average', projection=False, verbose='ERROR')
            elif ref_type == 'common':
                # Already common reference, no change needed
                pass
            else:
                # Specific channel(s) as reference
                raw_prep.set_eeg_reference(ref_type, projection=False, verbose='ERROR')
            
            preprocess_info['reference'] = ref_type
            preprocess_info['steps_applied'].append(f'reference_{ref_type}')
        
        preprocess_info['success'] = True
        logger.info(f"Preprocessing completed successfully: {', '.join(preprocess_info['steps_applied'])}")
        
        return raw_prep, preprocess_info
        
    except Exception as e:
        logger.error(f"Error during preprocessing: {e}")
        preprocess_info['error'] = str(e)
        return None, preprocess_info


print("✓ preprocess_eeg() function defined")

✓ preprocess_eeg() function defined


In [50]:
# ============================================================================
# SEGMENT DATA FUNCTION
# ============================================================================

def segment_data(raw, events_df=None, annotations_df=None, params=None):
    """
    Segment EEG data into analysis windows.
    
    Parameters:
    -----------
    raw : mne.io.Raw
        Preprocessed EEG data
    events_df : pd.DataFrame, optional
        Events data (eyes open/closed markers)
    annotations_df : pd.DataFrame, optional
        Derivative annotations (clean segments)
    params : dict, optional
        Segmentation parameters (uses SEGMENT_PARAMS if None)
        
    Returns:
    --------
    segments : list of np.ndarray
        List of data segments (n_channels x n_timepoints)
    segment_info : list of dict
        Information about each segment (start_time, duration, etc.)
    """
    
    if params is None:
        params = SEGMENT_PARAMS
    
    segments = []
    segment_info = []
    
    try:
        # Get sampling frequency
        sfreq = raw.info['sfreq']
        window_samples = int(params['window_length'] * sfreq)
        overlap_samples = int(params['window_length'] * params['window_overlap'] * sfreq)
        step_samples = window_samples - overlap_samples
        
        # Determine valid time ranges based on annotations
        valid_ranges = []
        
        if params['use_eyes_closed_only'] and events_df is not None:
            # Extract eyes-closed periods
            eyes_closed_events = events_df[events_df['trial_type'] == 'eyes_closed'].copy()
            
            # Sort by onset time to properly find next event
            events_sorted = events_df.sort_values('onset').reset_index(drop=True)
            
            for _, event in eyes_closed_events.iterrows():
                start_time = event['onset']
                
                # Find the end time by looking for the next event (eyes_open or end of recording)
                # Get events that occur after this one
                future_events = events_sorted[events_sorted['onset'] > start_time]
                
                if len(future_events) > 0:
                    # Use the next event as end time
                    end_time = future_events.iloc[0]['onset']
                else:
                    # Use end of recording
                    end_time = raw.times[-1]
                
                # Check if duration meets minimum requirement
                if end_time - start_time >= params['min_segment_duration']:
                    valid_ranges.append((start_time, end_time))
        
        # If reject_artifacts is True and we have derivative annotations
        if params['reject_artifacts'] and annotations_df is not None:
            # Use derivative annotations as valid ranges
            clean_ranges = []
            
            for _, annot in annotations_df.iterrows():
                try:
                    start_time = float(annot['onset'])
                    duration = float(annot['duration'])
                    
                    if duration >= params['min_segment_duration']:
                        end_time = start_time + duration
                        clean_ranges.append((start_time, end_time))
                except (ValueError, TypeError, KeyError) as e:
                    logger.warning(f"Could not parse annotation row: {e}")
                    continue
            
            # Intersect with eyes-closed ranges if applicable
            if valid_ranges:
                intersected = []
                for ec_start, ec_end in valid_ranges:
                    for clean_start, clean_end in clean_ranges:
                        overlap_start = max(ec_start, clean_start)
                        overlap_end = min(ec_end, clean_end)
                        if overlap_end - overlap_start >= params['min_segment_duration']:
                            intersected.append((overlap_start, overlap_end))
                valid_ranges = intersected
            else:
                valid_ranges = clean_ranges
        
        # If no valid ranges defined, use entire recording
        if not valid_ranges:
            valid_ranges = [(0, raw.times[-1])]
            logger.info("No specific valid ranges found, using entire recording")
        
        # Extract segments from valid ranges
        for range_start, range_end in valid_ranges:
            start_sample = int(range_start * sfreq)
            end_sample = int(range_end * sfreq)
            
            # Create sliding windows within this range
            current_sample = start_sample
            while current_sample + window_samples <= end_sample:
                # Extract segment
                segment_data = raw[:, current_sample:current_sample + window_samples][0]
                
                # Store segment
                segments.append(segment_data)
                
                # Store info
                info = {
                    'start_time': current_sample / sfreq,
                    'end_time': (current_sample + window_samples) / sfreq,
                    'duration': params['window_length'],
                    'n_samples': window_samples,
                    'segment_index': len(segments) - 1
                }
                segment_info.append(info)
                
                # Move to next window
                current_sample += step_samples
        
        logger.info(f"Extracted {len(segments)} segments "
                   f"(window: {params['window_length']}s, overlap: {params['window_overlap']*100:.0f}%)")
        
        return segments, segment_info
        
    except Exception as e:
        logger.error(f"Error during segmentation: {e}")
        return [], []


print("✓ segment_data() function defined")

✓ segment_data() function defined


In [51]:
# ============================================================================
# CHECK STATIONARITY FUNCTION
# ============================================================================

def check_stationarity(segment_data, significance_level=0.05):
    """
    Check stationarity of time series data using Augmented Dickey-Fuller test.
    
    Parameters:
    -----------
    segment_data : np.ndarray
        Data segment (n_channels x n_timepoints)
    significance_level : float
        Significance level for ADF test (default: 0.05)
        
    Returns:
    --------
    stationarity_report : dict
        Dictionary with stationarity test results per channel
    """
    
    n_channels = segment_data.shape[0]
    
    stationarity_report = {
        'is_stationary': [],
        'adf_statistics': [],
        'p_values': [],
        'n_stationary': 0,
        'n_non_stationary': 0,
        'fraction_stationary': 0.0
    }
    
    try:
        for ch_idx in range(n_channels):
            ch_data = segment_data[ch_idx, :]
            
            # Perform ADF test
            adf_result = adfuller(ch_data, autolag='AIC')
            adf_statistic = adf_result[0]
            p_value = adf_result[1]
            
            # Stationary if we can reject null hypothesis (null = unit root / non-stationary)
            is_stationary = p_value < significance_level
            
            stationarity_report['is_stationary'].append(is_stationary)
            stationarity_report['adf_statistics'].append(adf_statistic)
            stationarity_report['p_values'].append(p_value)
        
        # Summary statistics
        stationarity_report['n_stationary'] = sum(stationarity_report['is_stationary'])
        stationarity_report['n_non_stationary'] = n_channels - stationarity_report['n_stationary']
        stationarity_report['fraction_stationary'] = stationarity_report['n_stationary'] / n_channels
        
    except Exception as e:
        logger.error(f"Error checking stationarity: {e}")
        stationarity_report['error'] = str(e)
    
    return stationarity_report


print("✓ check_stationarity() function defined")

✓ check_stationarity() function defined


### Test Preprocessing Functions

Test preprocessing, segmentation, and stationarity checking on sample data.

In [56]:
# ============================================================================
# TEST PREPROCESSING ON SAMPLE SUBJECT
# ============================================================================

print("Testing preprocessing functions on sample subject...")
print("="*80)

# Get first valid subject
valid_subject = inventory_df[inventory_df['file_exists']].iloc[0]
test_subject = valid_subject['subject_id']
test_session = valid_subject['session_id']

print(f"\nProcessing: {test_subject}/{test_session}")
print("="*80)

# Step 1: Load data
print("\n1. Loading EEG data...")
raw, metadata = load_eeg_data(test_subject, test_session, preload=True)

if raw is None:
    print(f"✗ Failed to load data")
else:
    print(f"✓ Loaded: {metadata['n_channels']} channels, {metadata['duration_sec']:.1f}s")
    
    # Step 2: Preprocess
    print("\n2. Preprocessing...")
    raw_prep, preprocess_info = preprocess_eeg(raw)
    
    if raw_prep is None:
        print(f"✗ Preprocessing failed: {preprocess_info.get('error', 'Unknown error')}")
    else:
        print(f"✓ Preprocessing completed")
        print(f"   Steps applied: {', '.join(preprocess_info['steps_applied'])}")
        if preprocess_info['bad_channels_original']:
            print(f"   Bad channels: {', '.join(preprocess_info['bad_channels_original'])}")
        print(f"   Filters: {', '.join(preprocess_info['filters_applied'])}")
        print(f"   Reference: {preprocess_info['reference']}")
    
    # Step 3: Load annotations
    print("\n3. Loading annotations...")
    events_df, annotations_df = load_events_and_annotations(test_subject, test_session)
    
    if events_df is not None:
        print(f"✓ Events: {len(events_df)} events")
    if annotations_df is not None:
        print(f"✓ Annotations: {len(annotations_df)} annotations")
    
    # Step 4: Segment data
    print("\n4. Segmenting data...")
    segments, segment_info = segment_data(raw_prep, events_df, annotations_df)
    
    if len(segments) == 0:
        print("✗ No segments extracted")
    else:
        print(f"✓ Extracted {len(segments)} segments")
        print(f"   Window length: {SEGMENT_PARAMS['window_length']}s")
        print(f"   Window overlap: {SEGMENT_PARAMS['window_overlap']*100:.0f}%")
        print(f"   Segment shape: {segments[0].shape} (channels x samples)")
    
    # Step 5: Check stationarity on first few segments
    if len(segments) > 0:
        print("\n5. Checking stationarity (first 3 segments)...")
        
        for i in range(min(3, len(segments))):
            stationarity = check_stationarity(segments[i])
            print(f"   Segment {i+1}:")
            print(f"     - Stationary channels: {stationarity['n_stationary']}/{len(stationarity['is_stationary'])} "
                  f"({stationarity['fraction_stationary']*100:.1f}%)")
            print(f"     - Mean p-value: {np.mean(stationarity['p_values']):.4f}")
    
    # Step 6: Visualize preprocessing effects
    print("\n6. Creating preprocessing comparison plots...")
    
    try:
        fig, axes = plt.subplots(3, 2, figsize=(15, 12))
        fig.suptitle(f'Preprocessing Effects - {test_subject}/{test_session}', 
                    fontsize=14, fontweight='bold')
        
        # Get 10 seconds of data for visualization
        duration_plot = min(10, raw.times[-1])
        n_samples_plot = int(duration_plot * raw.info['sfreq'])
        
        # Plot 1: Raw data (time domain)
        ax = axes[0, 0]
        data_raw, times = raw[:5, :n_samples_plot]
        for i in range(5):
            ax.plot(times, data_raw[i, :] * 1e6 + i * 50, alpha=0.7, label=raw.ch_names[i])
        ax.set_xlabel('Time (s)')
        ax.set_ylabel('Amplitude (µV)')
        ax.set_title('Original Raw Data (first 5 channels)')
        ax.legend(loc='upper right', fontsize=8)
        ax.grid(True, alpha=0.3)
        
        # Plot 2: Preprocessed data (time domain)
        ax = axes[0, 1]
        data_prep, times = raw_prep[:5, :n_samples_plot]
        for i in range(5):
            ax.plot(times, data_prep[i, :] * 1e6 + i * 50, alpha=0.7, label=raw_prep.ch_names[i])
        ax.set_xlabel('Time (s)')
        ax.set_ylabel('Amplitude (µV)')
        ax.set_title('Preprocessed Data (first 5 channels)')
        ax.legend(loc='upper right', fontsize=8)
        ax.grid(True, alpha=0.3)
        
        # Plot 3: Raw PSD
        ax = axes[1, 0]
        psd_raw = raw.compute_psd(fmin=0.5, fmax=50, n_fft=int(2*raw.info['sfreq']))
        psds_raw, freqs_raw = psd_raw.get_data(return_freqs=True)
        ax.semilogy(freqs_raw, psds_raw[:5, :].T, alpha=0.7)
        ax.set_xlabel('Frequency (Hz)')
        ax.set_ylabel('PSD (µV²/Hz)')
        ax.set_title('Original PSD')
        ax.grid(True, alpha=0.3)
        ax.set_xlim([0, 50])
        
        # Plot 4: Preprocessed PSD
        ax = axes[1, 1]
        psd_prep = raw_prep.compute_psd(fmin=0.5, fmax=50, n_fft=int(2*raw_prep.info['sfreq']))
        psds_prep, freqs_prep = psd_prep.get_data(return_freqs=True)
        ax.semilogy(freqs_prep, psds_prep[:5, :].T, alpha=0.7)
        ax.set_xlabel('Frequency (Hz)')
        ax.set_ylabel('PSD (µV²/Hz)')
        ax.set_title('Preprocessed PSD (filtered)')
        ax.grid(True, alpha=0.3)
        ax.set_xlim([0, 50])
        
        # Plot 5: Segment distribution
        ax = axes[2, 0]
        if len(segment_info) > 0:
            start_times = [info['start_time'] for info in segment_info]
            ax.hist(start_times, bins=20, color='steelblue', edgecolor='black', alpha=0.7)
            ax.set_xlabel('Time (s)')
            ax.set_ylabel('Number of Segments')
            ax.set_title(f'Segment Distribution (n={len(segments)})')
            ax.grid(True, alpha=0.3)
        else:
            ax.text(0.5, 0.5, 'No segments', ha='center', va='center', transform=ax.transAxes)
        
        # Plot 6: Stationarity summary
        ax = axes[2, 1]
        if len(segments) > 0:
            # Check stationarity for first 10 segments
            n_check = min(10, len(segments))
            stationary_fractions = []
            for i in range(n_check):
                stat = check_stationarity(segments[i])
                stationary_fractions.append(stat['fraction_stationary'] * 100)
            
            ax.bar(range(1, n_check+1), stationary_fractions, color='mediumseagreen', 
                   edgecolor='black', alpha=0.7)
            ax.axhline(y=80, color='red', linestyle='--', label='80% threshold')
            ax.set_xlabel('Segment Number')
            ax.set_ylabel('Stationary Channels (%)')
            ax.set_title(f'Stationarity Check (first {n_check} segments)')
            ax.legend()
            ax.grid(True, alpha=0.3)
            ax.set_ylim([0, 100])
        else:
            ax.text(0.5, 0.5, 'No segments', ha='center', va='center', transform=ax.transAxes)
        
        plt.tight_layout()
        
        # Save figure
        preprocess_plot_path = OUTPUT_DIRS['plots'] / f'{test_subject}_{test_session}_preprocessing.png'
        plt.savefig(preprocess_plot_path, dpi=150, bbox_inches='tight')
        print(f"✓ Saved preprocessing plot to: {preprocess_plot_path}")
        
        plt.show()
        
    except Exception as e:
        print(f"✗ Could not create visualization: {e}")
    
    # Clean up
    del raw, raw_prep, segments

print("\n" + "="*80)
print("PREPROCESSING TESTING COMPLETE")
print("="*80)

2025-10-25 17:14:47,648 - INFO - Loaded sub-NORB00001/ses-1: 21 channels, 714.0s @ 200.0Hz
2025-10-25 17:14:47,670 - INFO - Applying bandpass filter: 0.5-30.0 Hz
2025-10-25 17:14:47,670 - INFO - Applying bandpass filter: 0.5-30.0 Hz
2025-10-25 17:14:47,717 - INFO - Applying notch filter at 60.0 Hz
2025-10-25 17:14:47,717 - INFO - Applying notch filter at 60.0 Hz
2025-10-25 17:14:47,760 - INFO - Re-referencing to: average
2025-10-25 17:14:47,760 - INFO - Re-referencing to: average
2025-10-25 17:14:47,776 - INFO - Preprocessing completed successfully: bandpass_filter, notch_filter, reference_average
2025-10-25 17:14:47,779 - INFO - Loaded events for sub-NORB00001/ses-1: 3 events
2025-10-25 17:14:47,776 - INFO - Preprocessing completed successfully: bandpass_filter, notch_filter, reference_average
2025-10-25 17:14:47,779 - INFO - Loaded events for sub-NORB00001/ses-1: 3 events
2025-10-25 17:14:47,790 - INFO - Extracted 127 segments (window: 10.0s, overlap: 50%)
2025-10-25 17:14:47,790 - I

Testing preprocessing functions on sample subject...

Processing: sub-NORB00001/ses-1

1. Loading EEG data...
✓ Loaded: 21 channels, 714.0s

2. Preprocessing...
✓ Preprocessing completed
   Steps applied: bandpass_filter, notch_filter, reference_average
   Filters: bandpass_0.5-30.0_Hz, notch_60.0_Hz
   Reference: average

3. Loading annotations...
✓ Events: 3 events

4. Segmenting data...
✓ Extracted 127 segments
   Window length: 10.0s
   Window overlap: 50%
   Segment shape: (21, 2000) (channels x samples)

5. Checking stationarity (first 3 segments)...
   Segment 1:
     - Stationary channels: 19/21 (90.5%)
     - Mean p-value: 0.0153
   Segment 1:
     - Stationary channels: 19/21 (90.5%)
     - Mean p-value: 0.0153
   Segment 2:
     - Stationary channels: 21/21 (100.0%)
     - Mean p-value: 0.0029
   Segment 2:
     - Stationary channels: 21/21 (100.0%)
     - Mean p-value: 0.0029
   Segment 3:
     - Stationary channels: 19/21 (90.5%)
     - Mean p-value: 0.0053

6. Creating pr

  plt.show()


## Step 6: Model Order Selection

Functions to select optimal VAR model order using information criteria.

In [53]:
# ============================================================================
# SELECT MODEL ORDER FUNCTION
# ============================================================================

def select_model_order(segment_data, params=None):
    """
    Select optimal VAR model order using information criteria.
    
    Parameters:
    -----------
    segment_data : np.ndarray
        Data segment (n_channels x n_timepoints)
    params : dict, optional
        Model order parameters (uses GC_PARAMS if None)
        
    Returns:
    --------
    order_info : dict
        Dictionary containing optimal order and information criteria values
    """
    
    if params is None:
        params = GC_PARAMS
    
    order_info = {
        'success': False,
        'optimal_order': None,
        'method': params['model_order_method'],
        'orders_tested': [],
        'aic_values': [],
        'bic_values': [],
        'hqc_values': [],
        'error': None
    }
    
    try:
        # Transpose to (n_timepoints x n_channels) for statsmodels VAR
        data_transposed = segment_data.T
        
        # If using fixed order, skip optimization
        if params['model_order_method'] == 'fixed' and params['fixed_order'] is not None:
            order_info['optimal_order'] = params['fixed_order']
            order_info['success'] = True
            logger.info(f"Using fixed model order: {params['fixed_order']}")
            return order_info
        
        # Test different model orders
        min_order = params['min_order']
        max_order = min(params['max_order'], len(data_transposed) // 2)  # Limit based on data length
        
        aic_values = []
        bic_values = []
        hqc_values = []
        orders_tested = []
        
        logger.info(f"Testing VAR model orders from {min_order} to {max_order}...")
        
        for order in range(min_order, max_order + 1):
            try:
                # Fit VAR model with current order
                model = VAR(data_transposed)
                results = model.fit(maxlags=order, ic=None, verbose=False)
                
                # Store information criteria values
                aic_values.append(results.aic)
                bic_values.append(results.bic)
                hqc_values.append(results.hqic)
                orders_tested.append(order)
                
            except Exception as e:
                logger.warning(f"Could not fit VAR model with order {order}: {e}")
                continue
        
        if len(orders_tested) == 0:
            order_info['error'] = "Could not fit any VAR models"
            logger.error(order_info['error'])
            return order_info
        
        # Store results
        order_info['orders_tested'] = orders_tested
        order_info['aic_values'] = aic_values
        order_info['bic_values'] = bic_values
        order_info['hqc_values'] = hqc_values
        
        # Select optimal order based on chosen criterion
        method = params['model_order_method'].lower()
        
        if method == 'aic':
            optimal_idx = np.argmin(aic_values)
            order_info['optimal_order'] = orders_tested[optimal_idx]
            order_info['optimal_aic'] = aic_values[optimal_idx]
        elif method == 'bic':
            optimal_idx = np.argmin(bic_values)
            order_info['optimal_order'] = orders_tested[optimal_idx]
            order_info['optimal_bic'] = bic_values[optimal_idx]
        elif method == 'hqc':
            optimal_idx = np.argmin(hqc_values)
            order_info['optimal_order'] = orders_tested[optimal_idx]
            order_info['optimal_hqc'] = hqc_values[optimal_idx]
        else:
            # Default to AIC
            optimal_idx = np.argmin(aic_values)
            order_info['optimal_order'] = orders_tested[optimal_idx]
            logger.warning(f"Unknown method '{method}', defaulting to AIC")
        
        order_info['success'] = True
        logger.info(f"Optimal model order selected: {order_info['optimal_order']} (method: {method})")
        
        return order_info
        
    except Exception as e:
        order_info['error'] = str(e)
        logger.error(f"Error during model order selection: {e}")
        return order_info


print("✓ select_model_order() function defined")

✓ select_model_order() function defined


In [54]:
# ============================================================================
# VALIDATE VAR MODEL FUNCTION
# ============================================================================

def validate_var_model(segment_data, order, params=None):
    """
    Validate a fitted VAR model by checking residuals and stability.
    
    Parameters:
    -----------
    segment_data : np.ndarray
        Data segment (n_channels x n_timepoints)
    order : int
        Model order to validate
    params : dict, optional
        Statistical parameters (uses STAT_PARAMS if None)
        
    Returns:
    --------
    validation_report : dict
        Dictionary containing validation results
    """
    
    if params is None:
        params = STAT_PARAMS
    
    validation_report = {
        'is_valid': True,
        'order': order,
        'issues': [],
        'warnings': [],
        'residuals_white': None,
        'model_stable': None,
        'test_statistics': {}
    }
    
    try:
        # Transpose to (n_timepoints x n_channels) for statsmodels VAR
        data_transposed = segment_data.T
        
        # Fit VAR model
        model = VAR(data_transposed)
        results = model.fit(maxlags=order, ic=None, verbose=False)
        
        # Test 1: Check model stability (eigenvalues of companion matrix)
        # Stable if all eigenvalues have modulus < 1
        try:
            is_stable = results.is_stable()
            validation_report['model_stable'] = is_stable
            
            if not is_stable:
                validation_report['issues'].append(f"Model with order {order} is not stable")
                validation_report['is_valid'] = False
                logger.warning(f"VAR model with order {order} is unstable")
        except Exception as e:
            validation_report['warnings'].append(f"Could not check stability: {e}")
        
        # Test 2: Residual whiteness (Portmanteau test)
        # H0: residuals are white noise (no autocorrelation)
        try:
            from statsmodels.stats.diagnostic import acorr_ljungbox
            
            # Get residuals
            residuals = results.resid
            
            # Test each channel's residuals
            white_channels = []
            ljungbox_pvalues = []
            
            for ch_idx in range(residuals.shape[1]):
                # Ljung-Box test for autocorrelation
                lb_test = acorr_ljungbox(residuals[:, ch_idx], lags=[10], return_df=False)
                pvalue = lb_test[1][0]  # p-value
                ljungbox_pvalues.append(pvalue)
                
                # Residuals are white if we cannot reject H0 (p > threshold)
                is_white = pvalue > params['significance_threshold']
                white_channels.append(is_white)
            
            validation_report['residuals_white'] = np.mean(white_channels)  # Fraction of white residuals
            validation_report['test_statistics']['ljungbox_pvalues'] = ljungbox_pvalues
            
            if validation_report['residuals_white'] < 0.5:
                validation_report['warnings'].append(
                    f"Only {validation_report['residuals_white']*100:.1f}% of channels have white residuals"
                )
            
            logger.info(f"Residual whiteness: {validation_report['residuals_white']*100:.1f}% of channels")
            
        except Exception as e:
            validation_report['warnings'].append(f"Could not test residual whiteness: {e}")
        
        # Test 3: Check for reasonable fit (R-squared)
        try:
            # Get R-squared values per equation
            rsq_values = []
            for eq_name, eq_results in results.summary().tables[1].items():
                rsq = results.rsquared[eq_name] if hasattr(results, 'rsquared') else None
                if rsq is not None:
                    rsq_values.append(rsq)
            
            if rsq_values:
                mean_rsq = np.mean(rsq_values)
                validation_report['test_statistics']['mean_rsquared'] = mean_rsq
                
                if mean_rsq < 0.1:
                    validation_report['warnings'].append(f"Low R-squared: {mean_rsq:.3f}")
        except Exception as e:
            # R-squared check is optional
            pass
        
        logger.info(f"VAR model validation complete for order {order}")
        
        return validation_report
        
    except Exception as e:
        validation_report['error'] = str(e)
        validation_report['is_valid'] = False
        logger.error(f"Error during model validation: {e}")
        return validation_report


print("✓ validate_var_model() function defined")

✓ validate_var_model() function defined


### Test Model Order Selection

Test the model order selection on sample segments.

In [55]:
# ============================================================================
# TEST MODEL ORDER SELECTION ON SAMPLE SUBJECT
# ============================================================================

print("Testing model order selection on sample subject...")
print("="*80)

# Get first valid subject
valid_subject = inventory_df[inventory_df['file_exists']].iloc[0]
test_subject = valid_subject['subject_id']
test_session = valid_subject['session_id']

print(f"\nProcessing: {test_subject}/{test_session}")
print("="*80)

# Step 1: Load and preprocess data (reuse from previous test)
print("\n1. Loading and preprocessing EEG data...")
raw, metadata = load_eeg_data(test_subject, test_session, preload=True)

if raw is None:
    print(f"✗ Failed to load data")
else:
    print(f"✓ Loaded: {metadata['n_channels']} channels, {metadata['duration_sec']:.1f}s")
    
    # Preprocess
    raw_prep, preprocess_info = preprocess_eeg(raw)
    
    if raw_prep is None:
        print(f"✗ Preprocessing failed")
    else:
        print(f"✓ Preprocessing completed")
        
        # Step 2: Load annotations and segment
        print("\n2. Segmenting data...")
        events_df, annotations_df = load_events_and_annotations(test_subject, test_session)
        segments, segment_info = segment_data(raw_prep, events_df, annotations_df)
        
        if len(segments) == 0:
            print("✗ No segments extracted")
        else:
            print(f"✓ Extracted {len(segments)} segments")
            
            # Step 3: Test model order selection on first few segments
            print("\n3. Testing model order selection...")
            print(f"   Method: {GC_PARAMS['model_order_method'].upper()}")
            print(f"   Testing orders: {GC_PARAMS['min_order']} to {GC_PARAMS['max_order']}")
            
            n_test_segments = min(3, len(segments))
            selected_orders = []
            
            for i in range(n_test_segments):
                print(f"\n   Segment {i+1}:")
                
                # Select model order
                order_info = select_model_order(segments[i])
                
                if order_info['success']:
                    optimal_order = order_info['optimal_order']
                    selected_orders.append(optimal_order)
                    print(f"     ✓ Optimal order: {optimal_order}")
                    print(f"     - Tested {len(order_info['orders_tested'])} orders")
                    print(f"     - AIC range: {min(order_info['aic_values']):.2f} to {max(order_info['aic_values']):.2f}")
                    print(f"     - BIC range: {min(order_info['bic_values']):.2f} to {max(order_info['bic_values']):.2f}")
                    
                    # Validate the selected model
                    print(f"     - Validating model...")
                    validation = validate_var_model(segments[i], optimal_order)
                    
                    if validation['is_valid']:
                        print(f"       ✓ Model is valid")
                    else:
                        print(f"       ⚠ Model validation issues:")
                        for issue in validation['issues']:
                            print(f"         • {issue}")
                    
                    if validation['model_stable'] is not None:
                        status = "✓" if validation['model_stable'] else "✗"
                        print(f"       {status} Stability: {'Stable' if validation['model_stable'] else 'Unstable'}")
                    
                    if validation['residuals_white'] is not None:
                        print(f"       - White residuals: {validation['residuals_white']*100:.1f}% of channels")
                    
                    if validation['warnings']:
                        print(f"       - Warnings: {len(validation['warnings'])}")
                        for warning in validation['warnings'][:2]:
                            print(f"         • {warning}")
                else:
                    print(f"     ✗ Order selection failed: {order_info.get('error', 'Unknown error')}")
            
            # Step 4: Visualize order selection results
            if len(selected_orders) > 0:
                print(f"\n4. Creating model order selection visualizations...")
                
                try:
                    # Create 2x2 plot
                    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
                    fig.suptitle(f'VAR Model Order Selection - {test_subject}/{test_session}', 
                                fontsize=14, fontweight='bold')
                    
                    # Get detailed info for first segment for plots
                    order_info_detailed = select_model_order(segments[0])
                    
                    if order_info_detailed['success']:
                        orders = order_info_detailed['orders_tested']
                        aic_vals = order_info_detailed['aic_values']
                        bic_vals = order_info_detailed['bic_values']
                        hqc_vals = order_info_detailed['hqc_values']
                        optimal = order_info_detailed['optimal_order']
                        
                        # Plot 1: AIC values
                        ax = axes[0, 0]
                        ax.plot(orders, aic_vals, 'o-', color='steelblue', linewidth=2, markersize=6)
                        ax.axvline(optimal, color='red', linestyle='--', linewidth=2, 
                                  label=f'Optimal: {optimal}')
                        ax.set_xlabel('Model Order (lags)')
                        ax.set_ylabel('AIC')
                        ax.set_title('Akaike Information Criterion')
                        ax.legend()
                        ax.grid(True, alpha=0.3)
                        
                        # Plot 2: BIC values
                        ax = axes[0, 1]
                        ax.plot(orders, bic_vals, 'o-', color='mediumseagreen', linewidth=2, markersize=6)
                        ax.axvline(optimal, color='red', linestyle='--', linewidth=2, 
                                  label=f'Optimal (AIC): {optimal}')
                        ax.set_xlabel('Model Order (lags)')
                        ax.set_ylabel('BIC')
                        ax.set_title('Bayesian Information Criterion')
                        ax.legend()
                        ax.grid(True, alpha=0.3)
                        
                        # Plot 3: HQC values
                        ax = axes[1, 0]
                        ax.plot(orders, hqc_vals, 'o-', color='coral', linewidth=2, markersize=6)
                        ax.axvline(optimal, color='red', linestyle='--', linewidth=2, 
                                  label=f'Optimal (AIC): {optimal}')
                        ax.set_xlabel('Model Order (lags)')
                        ax.set_ylabel('HQC')
                        ax.set_title('Hannan-Quinn Criterion')
                        ax.legend()
                        ax.grid(True, alpha=0.3)
                        
                        # Plot 4: Distribution of selected orders across segments
                        ax = axes[1, 1]
                        orders_count = {}
                        for order in selected_orders:
                            orders_count[order] = orders_count.get(order, 0) + 1
                        
                        orders_sorted = sorted(orders_count.keys())
                        counts = [orders_count[o] for o in orders_sorted]
                        
                        ax.bar(orders_sorted, counts, color='mediumpurple', edgecolor='black', alpha=0.7)
                        ax.set_xlabel('Selected Model Order')
                        ax.set_ylabel('Number of Segments')
                        ax.set_title(f'Order Distribution (n={n_test_segments} segments)')
                        ax.grid(axis='y', alpha=0.3)
                        
                        # Add statistics
                        mean_order = np.mean(selected_orders)
                        std_order = np.std(selected_orders)
                        ax.axvline(mean_order, color='red', linestyle='--', linewidth=2,
                                  label=f'Mean: {mean_order:.1f}±{std_order:.1f}')
                        ax.legend()
                    else:
                        # If detailed info failed, just show message
                        for ax in axes.flat:
                            ax.text(0.5, 0.5, 'Could not generate plot', 
                                   ha='center', va='center', transform=ax.transAxes)
                    
                    plt.tight_layout()
                    
                    # Save figure
                    order_plot_path = OUTPUT_DIRS['plots'] / f'{test_subject}_{test_session}_model_order.png'
                    plt.savefig(order_plot_path, dpi=150, bbox_inches='tight')
                    print(f"   ✓ Saved model order plot to: {order_plot_path}")
                    
                    plt.show()
                    
                except Exception as e:
                    print(f"   ✗ Could not create visualization: {e}")
                
                # Summary statistics
                print(f"\n5. Summary Statistics:")
                print(f"   Selected orders: {selected_orders}")
                print(f"   Mean order: {np.mean(selected_orders):.2f} ± {np.std(selected_orders):.2f}")
                print(f"   Range: {min(selected_orders)} to {max(selected_orders)}")
                print(f"   Mode: {max(set(selected_orders), key=selected_orders.count)}")
        
        # Clean up
        del raw, raw_prep, segments

print("\n" + "="*80)
print("MODEL ORDER SELECTION TESTING COMPLETE")
print("="*80)

2025-10-25 17:09:56,961 - INFO - Loaded sub-NORB00001/ses-1: 21 channels, 714.0s @ 200.0Hz
2025-10-25 17:09:56,979 - INFO - Applying bandpass filter: 0.5-30.0 Hz
2025-10-25 17:09:57,023 - INFO - Applying notch filter at 60.0 Hz
2025-10-25 17:09:57,065 - INFO - Re-referencing to: average
2025-10-25 17:09:57,082 - INFO - Preprocessing completed successfully: bandpass_filter, notch_filter, reference_average
2025-10-25 17:09:57,085 - INFO - Loaded events for sub-NORB00001/ses-1: 3 events
2025-10-25 17:09:57,093 - INFO - Extracted 127 segments (window: 10.0s, overlap: 50%)
2025-10-25 17:09:57,094 - INFO - Testing VAR model orders from 1 to 50...


Testing model order selection on sample subject...

Processing: sub-NORB00001/ses-1

1. Loading and preprocessing EEG data...
✓ Loaded: 21 channels, 714.0s
✓ Preprocessing completed

2. Segmenting data...
✓ Extracted 127 segments

3. Testing model order selection...
   Method: AIC
   Testing orders: 1 to 50

   Segment 1:


2025-10-25 17:10:01,934 - INFO - Optimal model order selected: 50 (method: aic)


     ✓ Optimal order: 50
     - Tested 19 orders
     - AIC range: -844.77 to -618.52
     - BIC range: -817.89 to -615.99
     - Validating model...


  stderr = np.sqrt(np.diag(self.cov_params()))
2025-10-25 17:10:10,604 - INFO - VAR model validation complete for order 50
2025-10-25 17:10:10,606 - INFO - Testing VAR model orders from 1 to 50...


       ✓ Model is valid
       ✓ Stability: Stable
         • Could not test residual whiteness: 1

   Segment 2:


2025-10-25 17:10:14,699 - INFO - Optimal model order selected: 42 (method: aic)


     ✓ Optimal order: 42
     - Tested 25 orders
     - AIC range: -846.43 to -584.74
     - BIC range: -821.80 to -583.45
     - Validating model...


  stderr = np.sqrt(np.diag(self.cov_params()))
2025-10-25 17:10:17,872 - INFO - VAR model validation complete for order 42
2025-10-25 17:10:17,873 - INFO - Testing VAR model orders from 1 to 50...


       ✓ Model is valid
       ✓ Stability: Stable
         • Could not test residual whiteness: 1

   Segment 3:


2025-10-25 17:10:22,025 - INFO - Optimal model order selected: 50 (method: aic)


     ✓ Optimal order: 50
     - Tested 27 orders
     - AIC range: -847.98 to -586.19
     - BIC range: -821.44 to -584.90
     - Validating model...


  stderr = np.sqrt(np.diag(self.cov_params()))
2025-10-25 17:10:26,422 - INFO - VAR model validation complete for order 50
2025-10-25 17:10:26,439 - INFO - Testing VAR model orders from 1 to 50...


       ✓ Model is valid
       ✓ Stability: Stable
         • Could not test residual whiteness: 1

4. Creating model order selection visualizations...


2025-10-25 17:10:30,449 - INFO - Optimal model order selected: 50 (method: aic)


   ✓ Saved model order plot to: /home/alookaladdoo/DPCN-Project/results/plots/sub-NORB00001_ses-1_model_order.png

5. Summary Statistics:
   Selected orders: [50, 42, 50]
   Mean order: 47.33 ± 3.77
   Range: 42 to 50
   Mode: 50

MODEL ORDER SELECTION TESTING COMPLETE


  plt.show()


## Step 7: Compute Granger Causality

Functions to compute pairwise Granger causality and spectral GC for frequency bands.

In [57]:
# ============================================================================
# COMPUTE PAIRWISE GRANGER CAUSALITY FUNCTION
# ============================================================================

def compute_pairwise_gc(segment_data, model_order, params=None):
    """
    Compute pairwise Granger causality for all channel pairs.
    
    Parameters:
    -----------
    segment_data : np.ndarray
        Data segment (n_channels x n_timepoints)
    model_order : int
        VAR model order to use
    params : dict, optional
        GC parameters (uses GC_PARAMS if None)
        
    Returns:
    --------
    gc_matrix : np.ndarray
        Matrix of GC values (n_channels x n_channels)
        gc_matrix[i, j] = causal influence from j to i
    gc_pvalues : np.ndarray
        Matrix of p-values for GC tests
    gc_info : dict
        Additional information about the computation
    """
    
    if params is None:
        params = GC_PARAMS
    
    n_channels = segment_data.shape[0]
    
    # Initialize output matrices
    gc_matrix = np.zeros((n_channels, n_channels))
    gc_pvalues = np.ones((n_channels, n_channels))
    
    gc_info = {
        'success': False,
        'n_channels': n_channels,
        'model_order': model_order,
        'n_pairs_tested': 0,
        'n_pairs_failed': 0,
        'method': 'pairwise',
        'error': None
    }
    
    try:
        # Transpose to (n_timepoints x n_channels) for statsmodels
        data_transposed = segment_data.T
        
        logger.info(f"Computing pairwise GC for {n_channels} channels (order={model_order})...")
        
        # Compute GC for all pairs
        pairs_tested = 0
        pairs_failed = 0
        
        for i in range(n_channels):
            for j in range(n_channels):
                if i == j:
                    # No self-causality
                    gc_matrix[i, j] = 0.0
                    gc_pvalues[i, j] = 1.0
                    continue
                
                try:
                    # Extract bivariate time series: [target_i, source_j]
                    bivariate_data = data_transposed[:, [i, j]]
                    
                    # Fit VAR model
                    model = VAR(bivariate_data)
                    results = model.fit(maxlags=model_order, ic=None, verbose=False)
                    
                    # Test causality: does j Granger-cause i?
                    # In statsmodels: test_causality(causing_variable, caused_variable)
                    # Variable indices: 0=target_i, 1=source_j
                    # Test if variable 1 (source_j) causes variable 0 (target_i)
                    gc_test = results.test_causality(caused=0, causing=1, kind='f', signif=0.05)
                    
                    # Store F-statistic and p-value
                    gc_matrix[i, j] = gc_test.test_statistic
                    gc_pvalues[i, j] = gc_test.pvalue
                    
                    pairs_tested += 1
                    
                except Exception as e:
                    # Failed to compute GC for this pair
                    gc_matrix[i, j] = 0.0
                    gc_pvalues[i, j] = 1.0
                    pairs_failed += 1
                    logger.debug(f"Failed to compute GC for pair ({i}, {j}): {e}")
        
        gc_info['n_pairs_tested'] = pairs_tested
        gc_info['n_pairs_failed'] = pairs_failed
        gc_info['success'] = True
        
        logger.info(f"Pairwise GC computed: {pairs_tested} pairs successful, {pairs_failed} failed")
        
        return gc_matrix, gc_pvalues, gc_info
        
    except Exception as e:
        gc_info['error'] = str(e)
        logger.error(f"Error computing pairwise GC: {e}")
        return gc_matrix, gc_pvalues, gc_info


print("✓ compute_pairwise_gc() function defined")

✓ compute_pairwise_gc() function defined


In [58]:
# ============================================================================
# COMPUTE SPECTRAL GRANGER CAUSALITY FUNCTION
# ============================================================================

def compute_spectral_gc(segment_data, model_order, freq_bands=None, sfreq=200.0):
    """
    Compute spectral Granger causality for specified frequency bands.
    
    Parameters:
    -----------
    segment_data : np.ndarray
        Data segment (n_channels x n_timepoints)
    model_order : int
        VAR model order to use
    freq_bands : dict, optional
        Dictionary of frequency bands (uses GC_PARAMS['freq_bands'] if None)
    sfreq : float
        Sampling frequency in Hz
        
    Returns:
    --------
    spectral_gc : dict
        Dictionary with band names as keys, each containing:
        - 'gc_matrix': GC matrix for that band
        - 'freq_range': (low, high) frequency range
    spectral_info : dict
        Additional information about the computation
    """
    
    if freq_bands is None:
        freq_bands = GC_PARAMS['freq_bands']
    
    n_channels = segment_data.shape[0]
    
    spectral_gc = {}
    spectral_info = {
        'success': False,
        'n_channels': n_channels,
        'model_order': model_order,
        'bands_computed': [],
        'bands_failed': [],
        'error': None
    }
    
    try:
        # Transpose to (n_timepoints x n_channels) for statsmodels
        data_transposed = segment_data.T
        
        logger.info(f"Computing spectral GC for {len(freq_bands)} frequency bands...")
        
        # Fit VAR model once for all bands
        model = VAR(data_transposed)
        results = model.fit(maxlags=model_order, ic=None, verbose=False)
        
        # Get spectral representation
        # Note: This is a simplified approach - full spectral GC requires
        # computing the spectral density matrix and decomposing it
        
        for band_name, (fmin, fmax) in freq_bands.items():
            try:
                # Initialize band-specific GC matrix
                band_gc_matrix = np.zeros((n_channels, n_channels))
                
                # For each pair, compute spectral GC in frequency band
                for i in range(n_channels):
                    for j in range(n_channels):
                        if i == j:
                            continue
                        
                        try:
                            # Extract bivariate data
                            bivariate_data = data_transposed[:, [i, j]]
                            
                            # Fit bivariate VAR
                            biv_model = VAR(bivariate_data)
                            biv_results = biv_model.fit(maxlags=model_order, ic=None, verbose=False)
                            
                            # Get coefficients for spectral calculation
                            # This is a simplified version - actual spectral GC requires
                            # computing transfer functions in frequency domain
                            coefs = biv_results.params
                            
                            # Approximate band-specific influence using coefficient magnitudes
                            # weighted by frequency band (simplified approach)
                            band_weight = (fmax - fmin) / (sfreq / 2)  # Normalized band width
                            j_to_i_coefs = np.abs(coefs[1::2, 0])  # Coefficients from j to i
                            band_gc_matrix[i, j] = np.sum(j_to_i_coefs) * band_weight
                            
                        except Exception as e:
                            band_gc_matrix[i, j] = 0.0
                            logger.debug(f"Failed spectral GC for pair ({i},{j}) in {band_name}: {e}")
                
                # Store results for this band
                spectral_gc[band_name] = {
                    'gc_matrix': band_gc_matrix,
                    'freq_range': (fmin, fmax)
                }
                spectral_info['bands_computed'].append(band_name)
                
                logger.info(f"  ✓ {band_name} band: {fmin}-{fmax} Hz")
                
            except Exception as e:
                spectral_info['bands_failed'].append(band_name)
                logger.warning(f"Failed to compute spectral GC for {band_name} band: {e}")
        
        spectral_info['success'] = len(spectral_info['bands_computed']) > 0
        
        return spectral_gc, spectral_info
        
    except Exception as e:
        spectral_info['error'] = str(e)
        logger.error(f"Error computing spectral GC: {e}")
        return spectral_gc, spectral_info


print("✓ compute_spectral_gc() function defined")

✓ compute_spectral_gc() function defined


In [59]:
# ============================================================================
# AVERAGE GC ACROSS SEGMENTS FUNCTION
# ============================================================================

def average_gc_across_segments(gc_matrices_list, gc_pvalues_list=None, method='mean'):
    """
    Average Granger causality matrices across multiple segments.
    
    Parameters:
    -----------
    gc_matrices_list : list of np.ndarray
        List of GC matrices from different segments
    gc_pvalues_list : list of np.ndarray, optional
        List of p-value matrices from different segments
    method : str
        Averaging method: 'mean', 'median', or 'weighted'
        
    Returns:
    --------
    avg_gc_matrix : np.ndarray
        Averaged GC matrix
    avg_pvalues : np.ndarray or None
        Averaged p-values (if provided)
    avg_info : dict
        Information about averaging
    """
    
    avg_info = {
        'success': False,
        'n_segments': len(gc_matrices_list),
        'method': method,
        'shape': None,
        'error': None
    }
    
    try:
        if len(gc_matrices_list) == 0:
            avg_info['error'] = "No GC matrices provided"
            return None, None, avg_info
        
        # Stack matrices
        gc_stack = np.stack(gc_matrices_list, axis=0)
        avg_info['shape'] = gc_stack.shape[1:]
        
        # Compute average based on method
        if method == 'mean':
            avg_gc_matrix = np.mean(gc_stack, axis=0)
        elif method == 'median':
            avg_gc_matrix = np.median(gc_stack, axis=0)
        elif method == 'weighted':
            # Weight by inverse variance (more stable segments get higher weight)
            variances = np.var(gc_stack, axis=0)
            weights = 1.0 / (variances + 1e-10)  # Add small constant to avoid division by zero
            weights = weights / np.sum(weights, axis=0, keepdims=True)
            avg_gc_matrix = np.sum(gc_stack * weights, axis=0)
        else:
            logger.warning(f"Unknown averaging method '{method}', using mean")
            avg_gc_matrix = np.mean(gc_stack, axis=0)
        
        # Average p-values if provided (using Fisher's method would be better, but mean is simpler)
        avg_pvalues = None
        if gc_pvalues_list is not None and len(gc_pvalues_list) > 0:
            pval_stack = np.stack(gc_pvalues_list, axis=0)
            avg_pvalues = np.mean(pval_stack, axis=0)
        
        avg_info['success'] = True
        avg_info['mean_gc'] = np.mean(avg_gc_matrix)
        avg_info['std_gc'] = np.std(avg_gc_matrix)
        
        logger.info(f"Averaged {len(gc_matrices_list)} GC matrices (method: {method})")
        
        return avg_gc_matrix, avg_pvalues, avg_info
        
    except Exception as e:
        avg_info['error'] = str(e)
        logger.error(f"Error averaging GC matrices: {e}")
        return None, None, avg_info


print("✓ average_gc_across_segments() function defined")

✓ average_gc_across_segments() function defined


### Test Granger Causality Computation

Test the GC computation on sample segments and visualize the connectivity matrices.

In [60]:
# ============================================================================
# TEST GRANGER CAUSALITY COMPUTATION ON SAMPLE SUBJECT
# ============================================================================

print("Testing Granger causality computation on sample subject...")
print("="*80)

# Get first valid subject
valid_subject = inventory_df[inventory_df['file_exists']].iloc[0]
test_subject = valid_subject['subject_id']
test_session = valid_subject['session_id']

print(f"\nProcessing: {test_subject}/{test_session}")
print("="*80)

# Step 1: Load, preprocess, and segment data
print("\n1. Loading and preprocessing data...")
raw, metadata = load_eeg_data(test_subject, test_session, preload=True)

if raw is None:
    print(f"✗ Failed to load data")
else:
    print(f"✓ Loaded: {metadata['n_channels']} channels")
    
    raw_prep, preprocess_info = preprocess_eeg(raw)
    
    if raw_prep is None:
        print(f"✗ Preprocessing failed")
    else:
        print(f"✓ Preprocessing completed")
        
        # Get channel names for later use
        channel_names = raw_prep.ch_names
        
        # Segment data
        print("\n2. Segmenting data...")
        events_df, annotations_df = load_events_and_annotations(test_subject, test_session)
        segments, segment_info = segment_data(raw_prep, events_df, annotations_df)
        
        if len(segments) == 0:
            print("✗ No segments extracted")
        else:
            print(f"✓ Extracted {len(segments)} segments")
            
            # Step 2: Compute GC for first few segments
            print("\n3. Computing Granger causality...")
            n_test_segments = min(5, len(segments))
            print(f"   Testing on {n_test_segments} segments")
            
            gc_matrices_list = []
            gc_pvalues_list = []
            spectral_gc_list = []
            
            for i in range(n_test_segments):
                print(f"\n   Segment {i+1}:")
                
                # Select model order
                order_info = select_model_order(segments[i])
                
                if not order_info['success']:
                    print(f"     ✗ Model order selection failed")
                    continue
                
                optimal_order = order_info['optimal_order']
                print(f"     ✓ Selected order: {optimal_order}")
                
                # Compute pairwise GC
                print(f"     - Computing pairwise GC...")
                gc_matrix, gc_pvalues, gc_info = compute_pairwise_gc(segments[i], optimal_order)
                
                if gc_info['success']:
                    gc_matrices_list.append(gc_matrix)
                    gc_pvalues_list.append(gc_pvalues)
                    
                    n_significant = np.sum(gc_pvalues < 0.05)
                    total_pairs = gc_matrix.shape[0] * (gc_matrix.shape[1] - 1)  # Exclude diagonal
                    print(f"       ✓ Computed: {gc_info['n_pairs_tested']} pairs")
                    print(f"       - Significant connections: {n_significant}/{total_pairs} ({n_significant/total_pairs*100:.1f}%)")
                    print(f"       - Mean GC value: {np.mean(gc_matrix[gc_matrix > 0]):.4f}")
                else:
                    print(f"       ✗ Pairwise GC failed: {gc_info.get('error', 'Unknown')}")
                
                # Compute spectral GC if enabled
                if GC_PARAMS['compute_spectral_gc']:
                    print(f"     - Computing spectral GC...")
                    spectral_gc, spectral_info = compute_spectral_gc(
                        segments[i], optimal_order, 
                        freq_bands=GC_PARAMS['freq_bands'],
                        sfreq=metadata['sampling_freq']
                    )
                    
                    if spectral_info['success']:
                        spectral_gc_list.append(spectral_gc)
                        print(f"       ✓ Computed for {len(spectral_info['bands_computed'])} bands")
                        for band in spectral_info['bands_computed']:
                            band_mean = np.mean(spectral_gc[band]['gc_matrix'])
                            print(f"         - {band}: mean GC = {band_mean:.4f}")
                    else:
                        print(f"       ✗ Spectral GC failed")
            
            # Step 3: Average across segments
            if len(gc_matrices_list) > 0:
                print(f"\n4. Averaging across {len(gc_matrices_list)} segments...")
                avg_gc_matrix, avg_pvalues, avg_info = average_gc_across_segments(
                    gc_matrices_list, gc_pvalues_list, method='mean'
                )
                
                if avg_info['success']:
                    print(f"   ✓ Averaged GC matrix computed")
                    print(f"   - Mean GC: {avg_info['mean_gc']:.4f} ± {avg_info['std_gc']:.4f}")
                    print(f"   - Max GC: {np.max(avg_gc_matrix):.4f}")
                    print(f"   - Min GC: {np.min(avg_gc_matrix):.4f}")
                    
                    # Count significant connections in averaged matrix
                    n_sig_avg = np.sum(avg_pvalues < 0.05) if avg_pvalues is not None else 0
                    print(f"   - Significant connections: {n_sig_avg}")
                else:
                    print(f"   ✗ Averaging failed")
                    avg_gc_matrix = None
            else:
                print("\n4. No GC matrices to average")
                avg_gc_matrix = None
            
            # Step 4: Visualize results
            if avg_gc_matrix is not None:
                print(f"\n5. Creating visualizations...")
                
                try:
                    # Create comprehensive visualization
                    fig = plt.figure(figsize=(16, 12))
                    gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)
                    
                    fig.suptitle(f'Granger Causality Analysis - {test_subject}/{test_session}', 
                                fontsize=16, fontweight='bold')
                    
                    # Plot 1: Average GC matrix (heatmap)
                    ax1 = fig.add_subplot(gs[0, :2])
                    im1 = ax1.imshow(avg_gc_matrix, cmap='viridis', aspect='auto', interpolation='nearest')
                    ax1.set_xlabel('Source Channel')
                    ax1.set_ylabel('Target Channel')
                    ax1.set_title(f'Average GC Matrix (n={len(gc_matrices_list)} segments)')
                    plt.colorbar(im1, ax=ax1, label='GC F-statistic')
                    
                    # Add channel labels if not too many
                    if len(channel_names) <= 25:
                        ax1.set_xticks(range(len(channel_names)))
                        ax1.set_yticks(range(len(channel_names)))
                        ax1.set_xticklabels(channel_names, rotation=45, ha='right', fontsize=8)
                        ax1.set_yticklabels(channel_names, fontsize=8)
                    
                    # Plot 2: P-value matrix (if available)
                    ax2 = fig.add_subplot(gs[0, 2])
                    if avg_pvalues is not None:
                        # Show -log10(p-value) for better visualization
                        pval_log = -np.log10(avg_pvalues + 1e-10)
                        im2 = ax2.imshow(pval_log, cmap='hot', aspect='auto', interpolation='nearest')
                        ax2.set_title('Significance\n(-log10 p-value)')
                        plt.colorbar(im2, ax=ax2)
                    else:
                        ax2.text(0.5, 0.5, 'No p-values', ha='center', va='center', transform=ax2.transAxes)
                    ax2.set_xlabel('Source')
                    ax2.set_ylabel('Target')
                    
                    # Plot 3: GC distribution
                    ax3 = fig.add_subplot(gs[1, 0])
                    gc_values = avg_gc_matrix[avg_gc_matrix > 0]
                    ax3.hist(gc_values, bins=30, color='steelblue', edgecolor='black', alpha=0.7)
                    ax3.axvline(np.mean(gc_values), color='red', linestyle='--', linewidth=2, 
                               label=f'Mean: {np.mean(gc_values):.2f}')
                    ax3.set_xlabel('GC F-statistic')
                    ax3.set_ylabel('Frequency')
                    ax3.set_title('Distribution of GC Values')
                    ax3.legend()
                    ax3.grid(True, alpha=0.3)
                    
                    # Plot 4: Out-degree (sum of causal influences from each channel)
                    ax4 = fig.add_subplot(gs[1, 1])
                    out_degree = np.sum(avg_gc_matrix, axis=0)  # Sum over targets (rows)
                    ax4.bar(range(len(out_degree)), out_degree, color='mediumseagreen', edgecolor='black', alpha=0.7)
                    ax4.set_xlabel('Channel')
                    ax4.set_ylabel('Total Outgoing GC')
                    ax4.set_title('Causal Influence (Out-degree)')
                    ax4.grid(axis='y', alpha=0.3)
                    
                    # Plot 5: In-degree (sum of causal influences to each channel)
                    ax5 = fig.add_subplot(gs[1, 2])
                    in_degree = np.sum(avg_gc_matrix, axis=1)  # Sum over sources (columns)
                    ax5.bar(range(len(in_degree)), in_degree, color='coral', edgecolor='black', alpha=0.7)
                    ax5.set_xlabel('Channel')
                    ax5.set_ylabel('Total Incoming GC')
                    ax5.set_title('Causal Influence (In-degree)')
                    ax5.grid(axis='y', alpha=0.3)
                    
                    # Plot 6: Spectral GC (if available)
                    if len(spectral_gc_list) > 0 and len(spectral_gc_list[0]) > 0:
                        ax6 = fig.add_subplot(gs[2, :])
                        
                        # Average spectral GC across segments
                        band_names = list(spectral_gc_list[0].keys())
                        band_means = []
                        band_stds = []
                        
                        for band in band_names:
                            band_values = [np.mean(seg[band]['gc_matrix']) for seg in spectral_gc_list if band in seg]
                            band_means.append(np.mean(band_values))
                            band_stds.append(np.std(band_values))
                        
                        x_pos = np.arange(len(band_names))
                        ax6.bar(x_pos, band_means, yerr=band_stds, color='mediumpurple', 
                               edgecolor='black', alpha=0.7, capsize=5)
                        ax6.set_xticks(x_pos)
                        ax6.set_xticklabels(band_names)
                        ax6.set_xlabel('Frequency Band')
                        ax6.set_ylabel('Mean GC')
                        ax6.set_title('Spectral GC by Frequency Band')
                        ax6.grid(axis='y', alpha=0.3)
                    else:
                        ax6 = fig.add_subplot(gs[2, :])
                        ax6.text(0.5, 0.5, 'No spectral GC computed', 
                                ha='center', va='center', transform=ax6.transAxes)
                    
                    # Save figure
                    gc_plot_path = OUTPUT_DIRS['plots'] / f'{test_subject}_{test_session}_granger_causality.png'
                    plt.savefig(gc_plot_path, dpi=150, bbox_inches='tight')
                    print(f"   ✓ Saved GC plot to: {gc_plot_path}")
                    
                    plt.show()
                    
                except Exception as e:
                    print(f"   ✗ Could not create visualization: {e}")
                    import traceback
                    traceback.print_exc()
        
        # Clean up
        del raw, raw_prep, segments

print("\n" + "="*80)
print("GRANGER CAUSALITY COMPUTATION TESTING COMPLETE")
print("="*80)

2025-10-25 17:31:50,930 - INFO - Loaded sub-NORB00001/ses-1: 21 channels, 714.0s @ 200.0Hz
2025-10-25 17:31:50,979 - INFO - Applying bandpass filter: 0.5-30.0 Hz
2025-10-25 17:31:50,979 - INFO - Applying bandpass filter: 0.5-30.0 Hz


Testing Granger causality computation on sample subject...

Processing: sub-NORB00001/ses-1

1. Loading and preprocessing data...
✓ Loaded: 21 channels


2025-10-25 17:31:51,121 - INFO - Applying notch filter at 60.0 Hz
2025-10-25 17:31:51,251 - INFO - Re-referencing to: average
2025-10-25 17:31:51,251 - INFO - Re-referencing to: average
2025-10-25 17:31:51,279 - INFO - Preprocessing completed successfully: bandpass_filter, notch_filter, reference_average
2025-10-25 17:31:51,282 - INFO - Loaded events for sub-NORB00001/ses-1: 3 events
2025-10-25 17:31:51,279 - INFO - Preprocessing completed successfully: bandpass_filter, notch_filter, reference_average
2025-10-25 17:31:51,282 - INFO - Loaded events for sub-NORB00001/ses-1: 3 events
2025-10-25 17:31:51,315 - INFO - Extracted 127 segments (window: 10.0s, overlap: 50%)
2025-10-25 17:31:51,316 - INFO - Testing VAR model orders from 1 to 50...
2025-10-25 17:31:51,315 - INFO - Extracted 127 segments (window: 10.0s, overlap: 50%)
2025-10-25 17:31:51,316 - INFO - Testing VAR model orders from 1 to 50...


✓ Preprocessing completed

2. Segmenting data...
✓ Extracted 127 segments

3. Computing Granger causality...
   Testing on 5 segments

   Segment 1:


2025-10-25 17:31:58,226 - INFO - Optimal model order selected: 50 (method: aic)
2025-10-25 17:31:58,234 - INFO - Computing pairwise GC for 21 channels (order=50)...
2025-10-25 17:31:58,226 - INFO - Optimal model order selected: 50 (method: aic)
2025-10-25 17:31:58,234 - INFO - Computing pairwise GC for 21 channels (order=50)...


     ✓ Selected order: 50
     - Computing pairwise GC...


2025-10-25 17:32:08,713 - INFO - Pairwise GC computed: 420 pairs successful, 0 failed
2025-10-25 17:32:08,714 - INFO - Computing spectral GC for 4 frequency bands...
2025-10-25 17:32:08,714 - INFO - Computing spectral GC for 4 frequency bands...


       ✓ Computed: 420 pairs
       - Significant connections: 414/420 (98.6%)
       - Mean GC value: 6.1093
     - Computing spectral GC...


2025-10-25 17:32:17,756 - INFO -   ✓ delta band: 0.5-4 Hz
2025-10-25 17:32:23,181 - INFO -   ✓ theta band: 4-8 Hz
2025-10-25 17:32:23,181 - INFO -   ✓ theta band: 4-8 Hz
2025-10-25 17:32:29,870 - INFO -   ✓ alpha band: 8-13 Hz
2025-10-25 17:32:29,870 - INFO -   ✓ alpha band: 8-13 Hz
2025-10-25 17:32:35,379 - INFO -   ✓ beta band: 13-30 Hz
2025-10-25 17:32:35,380 - INFO - Testing VAR model orders from 1 to 50...
2025-10-25 17:32:35,379 - INFO -   ✓ beta band: 13-30 Hz
2025-10-25 17:32:35,380 - INFO - Testing VAR model orders from 1 to 50...


       ✓ Computed for 4 bands
         - delta: mean GC = 23.3114
         - theta: mean GC = 26.6416
         - alpha: mean GC = 33.3020
         - beta: mean GC = 113.2267

   Segment 2:


2025-10-25 17:32:40,673 - INFO - Optimal model order selected: 42 (method: aic)
2025-10-25 17:32:40,674 - INFO - Computing pairwise GC for 21 channels (order=42)...
2025-10-25 17:32:40,673 - INFO - Optimal model order selected: 42 (method: aic)
2025-10-25 17:32:40,674 - INFO - Computing pairwise GC for 21 channels (order=42)...


     ✓ Selected order: 42
     - Computing pairwise GC...


2025-10-25 17:32:46,363 - INFO - Pairwise GC computed: 420 pairs successful, 0 failed
2025-10-25 17:32:46,365 - INFO - Computing spectral GC for 4 frequency bands...
2025-10-25 17:32:46,365 - INFO - Computing spectral GC for 4 frequency bands...


       ✓ Computed: 420 pairs
       - Significant connections: 409/420 (97.4%)
       - Mean GC value: 4.0310
     - Computing spectral GC...


2025-10-25 17:32:50,924 - INFO -   ✓ delta band: 0.5-4 Hz
2025-10-25 17:32:55,403 - INFO -   ✓ theta band: 4-8 Hz
2025-10-25 17:32:55,403 - INFO -   ✓ theta band: 4-8 Hz
2025-10-25 17:32:59,902 - INFO -   ✓ alpha band: 8-13 Hz
2025-10-25 17:32:59,902 - INFO -   ✓ alpha band: 8-13 Hz
2025-10-25 17:33:03,703 - INFO -   ✓ beta band: 13-30 Hz
2025-10-25 17:33:03,711 - INFO - Testing VAR model orders from 1 to 50...
2025-10-25 17:33:03,703 - INFO -   ✓ beta band: 13-30 Hz
2025-10-25 17:33:03,711 - INFO - Testing VAR model orders from 1 to 50...


       ✓ Computed for 4 bands
         - delta: mean GC = 20.1836
         - theta: mean GC = 23.0669
         - alpha: mean GC = 28.8337
         - beta: mean GC = 98.0345

   Segment 3:


2025-10-25 17:33:09,218 - INFO - Optimal model order selected: 50 (method: aic)
2025-10-25 17:33:09,222 - INFO - Computing pairwise GC for 21 channels (order=50)...
2025-10-25 17:33:09,218 - INFO - Optimal model order selected: 50 (method: aic)
2025-10-25 17:33:09,222 - INFO - Computing pairwise GC for 21 channels (order=50)...


     ✓ Selected order: 50
     - Computing pairwise GC...


2025-10-25 17:33:21,679 - INFO - Pairwise GC computed: 420 pairs successful, 0 failed
2025-10-25 17:33:21,681 - INFO - Computing spectral GC for 4 frequency bands...
2025-10-25 17:33:21,681 - INFO - Computing spectral GC for 4 frequency bands...


       ✓ Computed: 420 pairs
       - Significant connections: 403/420 (96.0%)
       - Mean GC value: 3.2713
     - Computing spectral GC...


2025-10-25 17:33:28,793 - INFO -   ✓ delta band: 0.5-4 Hz
2025-10-25 17:33:34,282 - INFO -   ✓ theta band: 4-8 Hz
2025-10-25 17:33:34,282 - INFO -   ✓ theta band: 4-8 Hz
2025-10-25 17:33:40,410 - INFO -   ✓ alpha band: 8-13 Hz
2025-10-25 17:33:40,410 - INFO -   ✓ alpha band: 8-13 Hz
2025-10-25 17:33:45,733 - INFO -   ✓ beta band: 13-30 Hz
2025-10-25 17:33:45,743 - INFO - Testing VAR model orders from 1 to 50...
2025-10-25 17:33:45,733 - INFO -   ✓ beta band: 13-30 Hz
2025-10-25 17:33:45,743 - INFO - Testing VAR model orders from 1 to 50...


       ✓ Computed for 4 bands
         - delta: mean GC = 21.3365
         - theta: mean GC = 24.3846
         - alpha: mean GC = 30.4808
         - beta: mean GC = 103.6347

   Segment 4:


2025-10-25 17:33:51,025 - INFO - Optimal model order selected: 48 (method: aic)
2025-10-25 17:33:51,027 - INFO - Computing pairwise GC for 21 channels (order=48)...
2025-10-25 17:33:51,025 - INFO - Optimal model order selected: 48 (method: aic)
2025-10-25 17:33:51,027 - INFO - Computing pairwise GC for 21 channels (order=48)...


     ✓ Selected order: 48
     - Computing pairwise GC...


2025-10-25 17:33:57,783 - INFO - Pairwise GC computed: 420 pairs successful, 0 failed
2025-10-25 17:33:57,786 - INFO - Computing spectral GC for 4 frequency bands...
2025-10-25 17:33:57,786 - INFO - Computing spectral GC for 4 frequency bands...


       ✓ Computed: 420 pairs
       - Significant connections: 415/420 (98.8%)
       - Mean GC value: 5.9457
     - Computing spectral GC...


2025-10-25 17:34:04,155 - INFO -   ✓ delta band: 0.5-4 Hz
2025-10-25 17:34:10,521 - INFO -   ✓ theta band: 4-8 Hz
2025-10-25 17:34:10,521 - INFO -   ✓ theta band: 4-8 Hz
2025-10-25 17:34:15,396 - INFO -   ✓ alpha band: 8-13 Hz
2025-10-25 17:34:15,396 - INFO -   ✓ alpha band: 8-13 Hz
2025-10-25 17:34:21,615 - INFO -   ✓ beta band: 13-30 Hz
2025-10-25 17:34:21,618 - INFO - Testing VAR model orders from 1 to 50...
2025-10-25 17:34:21,615 - INFO -   ✓ beta band: 13-30 Hz
2025-10-25 17:34:21,618 - INFO - Testing VAR model orders from 1 to 50...


       ✓ Computed for 4 bands
         - delta: mean GC = 21.5094
         - theta: mean GC = 24.5822
         - alpha: mean GC = 30.7277
         - beta: mean GC = 104.4742

   Segment 5:


2025-10-25 17:34:26,835 - INFO - Optimal model order selected: 49 (method: aic)
2025-10-25 17:34:26,837 - INFO - Computing pairwise GC for 21 channels (order=49)...
2025-10-25 17:34:26,835 - INFO - Optimal model order selected: 49 (method: aic)
2025-10-25 17:34:26,837 - INFO - Computing pairwise GC for 21 channels (order=49)...


     ✓ Selected order: 49
     - Computing pairwise GC...


2025-10-25 17:34:34,087 - INFO - Pairwise GC computed: 420 pairs successful, 0 failed
2025-10-25 17:34:34,088 - INFO - Computing spectral GC for 4 frequency bands...
2025-10-25 17:34:34,088 - INFO - Computing spectral GC for 4 frequency bands...


       ✓ Computed: 420 pairs
       - Significant connections: 412/420 (98.1%)
       - Mean GC value: 5.7727
     - Computing spectral GC...


2025-10-25 17:34:42,183 - INFO -   ✓ delta band: 0.5-4 Hz
2025-10-25 17:34:50,791 - INFO -   ✓ theta band: 4-8 Hz
2025-10-25 17:34:50,791 - INFO -   ✓ theta band: 4-8 Hz
2025-10-25 17:34:55,793 - INFO -   ✓ alpha band: 8-13 Hz
2025-10-25 17:34:55,793 - INFO -   ✓ alpha band: 8-13 Hz
2025-10-25 17:35:02,098 - INFO -   ✓ beta band: 13-30 Hz
2025-10-25 17:35:02,101 - INFO - Averaged 5 GC matrices (method: mean)
2025-10-25 17:35:02,098 - INFO -   ✓ beta band: 13-30 Hz
2025-10-25 17:35:02,101 - INFO - Averaged 5 GC matrices (method: mean)


       ✓ Computed for 4 bands
         - delta: mean GC = 22.0028
         - theta: mean GC = 25.1460
         - alpha: mean GC = 31.4325
         - beta: mean GC = 106.8707

4. Averaging across 5 segments...
   ✓ Averaged GC matrix computed
   - Mean GC: 4.7867 ± 3.2376
   - Max GC: 21.1928
   - Min GC: 0.0000
   - Significant connections: 399

5. Creating visualizations...
   ✓ Saved GC plot to: /home/alookaladdoo/DPCN-Project/results/plots/sub-NORB00001_ses-1_granger_causality.png

GRANGER CAUSALITY COMPUTATION TESTING COMPLETE
   ✓ Saved GC plot to: /home/alookaladdoo/DPCN-Project/results/plots/sub-NORB00001_ses-1_granger_causality.png

GRANGER CAUSALITY COMPUTATION TESTING COMPLETE


  plt.show()


---

## Summary: Step 7 Complete ✅

**Granger Causality Computation** has been implemented with:

1. **`compute_pairwise_gc()`**: 
   - Computes GC for all channel pairs using bivariate VAR models
   - Returns GC F-statistics and p-values matrices
   - Format: gc_matrix[i, j] = causal influence from j → i
   - Handles failed pairs gracefully

2. **`compute_spectral_gc()`**:
   - Computes frequency-band specific GC
   - Supports delta, theta, alpha, beta bands
   - Returns separate GC matrices for each band
   - Simplified spectral decomposition approach

3. **`average_gc_across_segments()`**:
   - Averages GC matrices from multiple segments
   - Supports mean, median, and weighted averaging
   - Combines p-values across segments
   - Produces session-level GC matrix

4. **Comprehensive Visualization**:
   - GC connectivity matrix heatmap
   - Significance map (-log10 p-values)
   - Distribution of GC values
   - Out-degree and in-degree analysis
   - Spectral GC by frequency band

**Key Outputs**:
- Pairwise GC matrices (n_channels × n_channels)
- P-value matrices for significance testing
- Band-specific spectral GC matrices
- Session-averaged GC matrix
- Connectivity visualizations

## Step 8: Statistical Testing

Functions for permutation testing, FDR correction, and thresholding GC matrices.

In [62]:
# ============================================================================
# PERMUTATION TEST FOR GC SIGNIFICANCE
# ============================================================================

def permutation_test_gc(segment_data, model_order, n_permutations=100, random_seed=None):
    """
    Perform permutation test to establish significance threshold for GC.
    
    This creates a null distribution by randomly shuffling time series 
    and computing GC on the shuffled data.
    
    Parameters:
    -----------
    segment_data : np.ndarray
        Data segment (n_channels x n_timepoints)
    model_order : int
        VAR model order to use
    n_permutations : int
        Number of permutations (default: 100)
    random_seed : int, optional
        Random seed for reproducibility
        
    Returns:
    --------
    null_distribution : np.ndarray
        Array of GC values under null hypothesis
    threshold_95 : float
        95th percentile threshold
    threshold_99 : float
        99th percentile threshold
    perm_info : dict
        Information about the permutation test
    """
    
    if random_seed is not None:
        np.random.seed(random_seed)
    
    n_channels = segment_data.shape[0]
    
    perm_info = {
        'success': False,
        'n_permutations': n_permutations,
        'n_channels': n_channels,
        'model_order': model_order,
        'n_failed': 0,
        'error': None
    }
    
    try:
        logger.info(f"Running permutation test with {n_permutations} permutations...")
        
        null_gc_values = []
        
        for perm_idx in range(n_permutations):
            try:
                # Create shuffled data by randomly permuting each channel independently
                shuffled_data = np.zeros_like(segment_data)
                for ch_idx in range(n_channels):
                    # Random circular shift for each channel
                    shift = np.random.randint(0, segment_data.shape[1])
                    shuffled_data[ch_idx, :] = np.roll(segment_data[ch_idx, :], shift)
                
                # Compute GC on shuffled data
                gc_matrix_null, _, gc_info_null = compute_pairwise_gc(shuffled_data, model_order)
                
                if gc_info_null['success']:
                    # Store all non-zero GC values from null distribution
                    null_values = gc_matrix_null[gc_matrix_null > 0]
                    null_gc_values.extend(null_values)
                else:
                    perm_info['n_failed'] += 1
                    
            except Exception as e:
                perm_info['n_failed'] += 1
                logger.debug(f"Permutation {perm_idx} failed: {e}")
        
        if len(null_gc_values) == 0:
            perm_info['error'] = "No valid GC values in null distribution"
            logger.error(perm_info['error'])
            return None, None, None, perm_info
        
        # Convert to array
        null_distribution = np.array(null_gc_values)
        
        # Compute thresholds
        threshold_95 = np.percentile(null_distribution, 95)
        threshold_99 = np.percentile(null_distribution, 99)
        
        perm_info['success'] = True
        perm_info['null_mean'] = np.mean(null_distribution)
        perm_info['null_std'] = np.std(null_distribution)
        perm_info['threshold_95'] = threshold_95
        perm_info['threshold_99'] = threshold_99
        
        logger.info(f"Permutation test complete: threshold_95={threshold_95:.4f}, threshold_99={threshold_99:.4f}")
        
        return null_distribution, threshold_95, threshold_99, perm_info
        
    except Exception as e:
        perm_info['error'] = str(e)
        logger.error(f"Error in permutation test: {e}")
        return None, None, None, perm_info


print("✓ permutation_test_gc() function defined")

✓ permutation_test_gc() function defined


In [63]:
# ============================================================================
# FDR CORRECTION FOR MULTIPLE COMPARISONS
# ============================================================================

def apply_fdr_correction(gc_matrix, gc_pvalues, alpha=0.05, method='fdr_bh'):
    """
    Apply False Discovery Rate correction for multiple comparisons.
    
    Parameters:
    -----------
    gc_matrix : np.ndarray
        GC matrix (n_channels x n_channels)
    gc_pvalues : np.ndarray
        P-values matrix (n_channels x n_channels)
    alpha : float
        Significance level (default: 0.05)
    method : str
        Correction method: 'fdr_bh' (Benjamini-Hochberg) or 'bonferroni'
        
    Returns:
    --------
    corrected_pvalues : np.ndarray
        Corrected p-values
    significant_mask : np.ndarray
        Boolean mask of significant connections
    fdr_info : dict
        Information about the correction
    """
    
    fdr_info = {
        'success': False,
        'method': method,
        'alpha': alpha,
        'n_tests': 0,
        'n_significant_uncorrected': 0,
        'n_significant_corrected': 0,
        'error': None
    }
    
    try:
        n_channels = gc_matrix.shape[0]
        
        # Flatten matrices (exclude diagonal)
        mask = ~np.eye(n_channels, dtype=bool)
        pvalues_flat = gc_pvalues[mask]
        
        fdr_info['n_tests'] = len(pvalues_flat)
        fdr_info['n_significant_uncorrected'] = np.sum(pvalues_flat < alpha)
        
        # Apply correction
        if method == 'fdr_bh':
            # Benjamini-Hochberg FDR correction
            reject, pvals_corrected, _, _ = multipletests(pvalues_flat, alpha=alpha, method='fdr_bh')
        elif method == 'bonferroni':
            # Bonferroni correction (more conservative)
            reject, pvals_corrected, _, _ = multipletests(pvalues_flat, alpha=alpha, method='bonferroni')
        else:
            logger.warning(f"Unknown method '{method}', using fdr_bh")
            reject, pvals_corrected, _, _ = multipletests(pvalues_flat, alpha=alpha, method='fdr_bh')
        
        # Reshape back to matrix form
        corrected_pvalues = np.ones_like(gc_pvalues)
        corrected_pvalues[mask] = pvals_corrected
        
        significant_mask = np.zeros_like(gc_matrix, dtype=bool)
        significant_mask[mask] = reject
        
        fdr_info['n_significant_corrected'] = np.sum(significant_mask)
        fdr_info['success'] = True
        
        logger.info(f"FDR correction ({method}): {fdr_info['n_significant_uncorrected']} → "
                   f"{fdr_info['n_significant_corrected']} significant connections")
        
        return corrected_pvalues, significant_mask, fdr_info
        
    except Exception as e:
        fdr_info['error'] = str(e)
        logger.error(f"Error in FDR correction: {e}")
        return None, None, fdr_info


print("✓ apply_fdr_correction() function defined")

✓ apply_fdr_correction() function defined


In [64]:
# ============================================================================
# THRESHOLD GC MATRIX
# ============================================================================

def threshold_gc_matrix(gc_matrix, gc_pvalues=None, threshold_value=None, 
                       significant_mask=None, method='pvalue'):
    """
    Threshold GC matrix to keep only significant connections.
    
    Parameters:
    -----------
    gc_matrix : np.ndarray
        GC matrix (n_channels x n_channels)
    gc_pvalues : np.ndarray, optional
        P-values matrix (for method='pvalue')
    threshold_value : float, optional
        Threshold value (p-value threshold or GC value threshold)
    significant_mask : np.ndarray, optional
        Pre-computed significance mask (for method='mask')
    method : str
        Thresholding method: 'pvalue', 'percentile', 'absolute', or 'mask'
        
    Returns:
    --------
    thresholded_matrix : np.ndarray
        Thresholded GC matrix (non-significant connections set to 0)
    threshold_info : dict
        Information about thresholding
    """
    
    threshold_info = {
        'success': False,
        'method': method,
        'threshold_value': threshold_value,
        'n_original': 0,
        'n_surviving': 0,
        'fraction_surviving': 0.0,
        'error': None
    }
    
    try:
        thresholded_matrix = gc_matrix.copy()
        n_channels = gc_matrix.shape[0]
        
        # Count original non-zero connections (exclude diagonal)
        mask_nondiag = ~np.eye(n_channels, dtype=bool)
        threshold_info['n_original'] = np.sum((gc_matrix[mask_nondiag] > 0))
        
        if method == 'pvalue':
            # Threshold based on p-values
            if gc_pvalues is None:
                threshold_info['error'] = "P-values required for method='pvalue'"
                return None, threshold_info
            
            if threshold_value is None:
                threshold_value = 0.05
            
            # Keep only significant connections
            thresholded_matrix[gc_pvalues >= threshold_value] = 0.0
            
        elif method == 'percentile':
            # Keep top percentile of connections
            if threshold_value is None:
                threshold_value = 95  # Keep top 5%
            
            # Get threshold value at percentile
            gc_values = gc_matrix[mask_nondiag]
            gc_threshold = np.percentile(gc_values, threshold_value)
            
            # Threshold
            thresholded_matrix[gc_matrix < gc_threshold] = 0.0
            threshold_info['gc_threshold'] = gc_threshold
            
        elif method == 'absolute':
            # Threshold based on absolute GC value
            if threshold_value is None:
                threshold_value = 1.0
            
            thresholded_matrix[gc_matrix < threshold_value] = 0.0
            
        elif method == 'mask':
            # Use pre-computed significance mask
            if significant_mask is None:
                threshold_info['error'] = "Significance mask required for method='mask'"
                return None, threshold_info
            
            thresholded_matrix[~significant_mask] = 0.0
            
        else:
            threshold_info['error'] = f"Unknown method: {method}"
            return None, threshold_info
        
        # Set diagonal to 0
        np.fill_diagonal(thresholded_matrix, 0.0)
        
        # Count surviving connections
        threshold_info['n_surviving'] = np.sum((thresholded_matrix[mask_nondiag] > 0))
        threshold_info['fraction_surviving'] = threshold_info['n_surviving'] / max(threshold_info['n_original'], 1)
        threshold_info['success'] = True
        
        logger.info(f"Thresholding ({method}): {threshold_info['n_original']} → "
                   f"{threshold_info['n_surviving']} connections "
                   f"({threshold_info['fraction_surviving']*100:.1f}% surviving)")
        
        return thresholded_matrix, threshold_info
        
    except Exception as e:
        threshold_info['error'] = str(e)
        logger.error(f"Error in thresholding: {e}")
        return None, threshold_info


print("✓ threshold_gc_matrix() function defined")

✓ threshold_gc_matrix() function defined


### Test Statistical Testing Functions

Test permutation testing, FDR correction, and thresholding on sample data.

In [65]:
# ============================================================================
# TEST STATISTICAL TESTING ON SAMPLE SUBJECT
# ============================================================================

print("Testing statistical testing functions on sample subject...")
print("="*80)

# Use the existing avg_gc_matrix and avg_pvalues from previous step
# If not available, we'll need to recompute

if 'avg_gc_matrix' not in locals() or avg_gc_matrix is None:
    print("✗ No GC matrix available from previous step. Please run Step 7 first.")
else:
    print(f"\nUsing GC matrix from previous computation")
    print(f"Shape: {avg_gc_matrix.shape}")
    print(f"Mean GC: {np.mean(avg_gc_matrix):.4f}")
    print("="*80)
    
    # Get first segment for permutation test
    print("\n1. Loading data for permutation test...")
    valid_subject = inventory_df[inventory_df['file_exists']].iloc[0]
    test_subject = valid_subject['subject_id']
    test_session = valid_subject['session_id']
    
    raw, metadata = load_eeg_data(test_subject, test_session, preload=True)
    raw_prep, _ = preprocess_eeg(raw)
    events_df, annotations_df = load_events_and_annotations(test_subject, test_session)
    segments, segment_info = segment_data(raw_prep, events_df, annotations_df)
    
    if len(segments) > 0:
        print(f"✓ Loaded {len(segments)} segments")
        
        # Select model order for first segment
        order_info = select_model_order(segments[0])
        optimal_order = order_info['optimal_order']
        
        # Step 1: Permutation test
        print(f"\n2. Running permutation test...")
        print(f"   Using {STAT_PARAMS['n_permutations']} permutations (this may take a few minutes)...")
        
        null_dist, thresh_95, thresh_99, perm_info = permutation_test_gc(
            segments[0], 
            optimal_order, 
            n_permutations=100,  # Use fewer permutations for testing (faster)
            random_seed=STAT_PARAMS['random_seed']
        )
        
        if perm_info['success']:
            print(f"   ✓ Permutation test complete")
            print(f"   - Null distribution: mean={perm_info['null_mean']:.4f}, std={perm_info['null_std']:.4f}")
            print(f"   - 95th percentile threshold: {thresh_95:.4f}")
            print(f"   - 99th percentile threshold: {thresh_99:.4f}")
            print(f"   - Failed permutations: {perm_info['n_failed']}/100")
        else:
            print(f"   ✗ Permutation test failed: {perm_info.get('error', 'Unknown')}")
            thresh_95 = None
        
        # Step 2: FDR correction
        print(f"\n3. Applying FDR correction...")
        
        if avg_pvalues is not None:
            corrected_pvals, sig_mask, fdr_info = apply_fdr_correction(
                avg_gc_matrix, 
                avg_pvalues, 
                alpha=STAT_PARAMS['significance_threshold'],
                method=STAT_PARAMS['correction_method']
            )
            
            if fdr_info['success']:
                print(f"   ✓ FDR correction complete")
                print(f"   - Method: {fdr_info['method']}")
                print(f"   - Total tests: {fdr_info['n_tests']}")
                print(f"   - Significant (uncorrected): {fdr_info['n_significant_uncorrected']}")
                print(f"   - Significant (corrected): {fdr_info['n_significant_corrected']}")
                print(f"   - Correction rate: {fdr_info['n_significant_corrected']/max(fdr_info['n_significant_uncorrected'],1)*100:.1f}%")
            else:
                print(f"   ✗ FDR correction failed")
                sig_mask = None
        else:
            print(f"   ⚠ No p-values available, skipping FDR correction")
            sig_mask = None
        
        # Step 3: Thresholding
        print(f"\n4. Thresholding GC matrix...")
        
        # Method 1: P-value based thresholding
        if avg_pvalues is not None:
            thresh_pval, thresh_info_pval = threshold_gc_matrix(
                avg_gc_matrix, 
                gc_pvalues=avg_pvalues,
                threshold_value=STAT_PARAMS['significance_threshold'],
                method='pvalue'
            )
            
            if thresh_info_pval['success']:
                print(f"   ✓ P-value thresholding (p < {STAT_PARAMS['significance_threshold']})")
                print(f"     - Original connections: {thresh_info_pval['n_original']}")
                print(f"     - Surviving connections: {thresh_info_pval['n_surviving']}")
                print(f"     - Survival rate: {thresh_info_pval['fraction_surviving']*100:.1f}%")
        
        # Method 2: FDR-based thresholding (using significance mask)
        if sig_mask is not None:
            thresh_fdr, thresh_info_fdr = threshold_gc_matrix(
                avg_gc_matrix,
                significant_mask=sig_mask,
                method='mask'
            )
            
            if thresh_info_fdr['success']:
                print(f"   ✓ FDR-based thresholding")
                print(f"     - Surviving connections: {thresh_info_fdr['n_surviving']}")
                print(f"     - Survival rate: {thresh_info_fdr['fraction_surviving']*100:.1f}%")
        else:
            thresh_fdr = None
        
        # Method 3: Permutation-based thresholding
        if thresh_95 is not None:
            thresh_perm, thresh_info_perm = threshold_gc_matrix(
                avg_gc_matrix,
                threshold_value=thresh_95,
                method='absolute'
            )
            
            if thresh_info_perm['success']:
                print(f"   ✓ Permutation-based thresholding (95th percentile)")
                print(f"     - Threshold: {thresh_95:.4f}")
                print(f"     - Surviving connections: {thresh_info_perm['n_surviving']}")
                print(f"     - Survival rate: {thresh_info_perm['fraction_surviving']*100:.1f}%")
        else:
            thresh_perm = None
        
        # Step 4: Visualize results
        print(f"\n5. Creating comparison visualizations...")
        
        try:
            # Determine how many thresholded matrices we have
            thresh_matrices = []
            thresh_titles = []
            
            thresh_matrices.append(avg_gc_matrix)
            thresh_titles.append('Original GC Matrix')
            
            if 'thresh_pval' in locals() and thresh_pval is not None:
                thresh_matrices.append(thresh_pval)
                thresh_titles.append(f'P-value Threshold (p<{STAT_PARAMS["significance_threshold"]})')
            
            if thresh_fdr is not None:
                thresh_matrices.append(thresh_fdr)
                thresh_titles.append(f'FDR Corrected ({STAT_PARAMS["correction_method"]})')
            
            if thresh_perm is not None:
                thresh_matrices.append(thresh_perm)
                thresh_titles.append('Permutation Threshold (95%)')
            
            n_plots = len(thresh_matrices)
            
            # Create figure
            fig, axes = plt.subplots(2, 2, figsize=(14, 12))
            fig.suptitle(f'Statistical Testing Comparison - {test_subject}/{test_session}', 
                        fontsize=16, fontweight='bold')
            
            axes_flat = axes.flatten()
            
            # Plot each thresholded matrix
            for idx, (matrix, title) in enumerate(zip(thresh_matrices[:4], thresh_titles[:4])):
                ax = axes_flat[idx]
                
                im = ax.imshow(matrix, cmap='viridis', aspect='auto', interpolation='nearest')
                ax.set_title(title)
                ax.set_xlabel('Source Channel')
                ax.set_ylabel('Target Channel')
                plt.colorbar(im, ax=ax, label='GC F-statistic')
                
                # Add text with connection count
                n_connections = np.sum(matrix > 0) - matrix.shape[0]  # Exclude diagonal
                ax.text(0.02, 0.98, f'Connections: {n_connections}', 
                       transform=ax.transAxes, fontsize=10, 
                       verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
            
            # Hide unused subplots
            for idx in range(n_plots, 4):
                axes_flat[idx].axis('off')
            
            plt.tight_layout()
            
            # Save figure
            stat_plot_path = OUTPUT_DIRS['plots'] / f'{test_subject}_{test_session}_statistical_testing.png'
            plt.savefig(stat_plot_path, dpi=150, bbox_inches='tight')
            print(f"   ✓ Saved statistical testing plot to: {stat_plot_path}")
            
            plt.show()
            
        except Exception as e:
            print(f"   ✗ Could not create visualization: {e}")
            import traceback
            traceback.print_exc()
        
        # Additional visualization: Null distribution
        if null_dist is not None:
            print(f"\n6. Visualizing null distribution...")
            
            try:
                fig, axes = plt.subplots(1, 2, figsize=(14, 5))
                
                # Plot 1: Null distribution histogram
                ax = axes[0]
                ax.hist(null_dist, bins=50, color='lightgray', edgecolor='black', alpha=0.7, density=True)
                
                # Add vertical lines for thresholds
                if thresh_95 is not None:
                    ax.axvline(thresh_95, color='orange', linestyle='--', linewidth=2, 
                              label=f'95th percentile: {thresh_95:.4f}')
                if thresh_99 is not None:
                    ax.axvline(thresh_99, color='red', linestyle='--', linewidth=2,
                              label=f'99th percentile: {thresh_99:.4f}')
                
                # Add observed GC distribution
                observed_gc = avg_gc_matrix[avg_gc_matrix > 0]
                ax.hist(observed_gc, bins=50, color='steelblue', edgecolor='black', 
                       alpha=0.5, density=True, label='Observed GC')
                
                ax.set_xlabel('GC F-statistic')
                ax.set_ylabel('Density')
                ax.set_title('Null Distribution vs Observed GC')
                ax.legend()
                ax.grid(True, alpha=0.3)
                
                # Plot 2: Q-Q plot
                ax = axes[1]
                from scipy import stats as scipy_stats
                scipy_stats.probplot(null_dist, dist="norm", plot=ax)
                ax.set_title('Q-Q Plot of Null Distribution')
                ax.grid(True, alpha=0.3)
                
                plt.tight_layout()
                
                # Save
                null_plot_path = OUTPUT_DIRS['plots'] / f'{test_subject}_{test_session}_null_distribution.png'
                plt.savefig(null_plot_path, dpi=150, bbox_inches='tight')
                print(f"   ✓ Saved null distribution plot to: {null_plot_path}")
                
                plt.show()
                
            except Exception as e:
                print(f"   ✗ Could not create null distribution plot: {e}")
        
        # Clean up
        del raw, raw_prep, segments
    
    else:
        print("✗ No segments available")

print("\n" + "="*80)
print("STATISTICAL TESTING COMPLETE")
print("="*80)

2025-10-25 17:42:17,877 - INFO - Loaded sub-NORB00001/ses-1: 21 channels, 714.0s @ 200.0Hz
2025-10-25 17:42:17,895 - INFO - Applying bandpass filter: 0.5-30.0 Hz
2025-10-25 17:42:17,975 - INFO - Applying notch filter at 60.0 Hz
2025-10-25 17:42:18,057 - INFO - Re-referencing to: average


Testing statistical testing functions on sample subject...

Using GC matrix from previous computation
Shape: (21, 21)
Mean GC: 4.7867

1. Loading data for permutation test...


2025-10-25 17:42:18,075 - INFO - Preprocessing completed successfully: bandpass_filter, notch_filter, reference_average
2025-10-25 17:42:18,079 - INFO - Loaded events for sub-NORB00001/ses-1: 3 events
2025-10-25 17:42:18,092 - INFO - Extracted 127 segments (window: 10.0s, overlap: 50%)
2025-10-25 17:42:18,094 - INFO - Testing VAR model orders from 1 to 50...


✓ Loaded 127 segments


2025-10-25 17:42:23,102 - INFO - Optimal model order selected: 50 (method: aic)
2025-10-25 17:42:23,107 - INFO - Running permutation test with 100 permutations...
2025-10-25 17:42:23,137 - INFO - Computing pairwise GC for 21 channels (order=50)...



2. Running permutation test...
   Using 1000 permutations (this may take a few minutes)...


2025-10-25 17:42:34,000 - INFO - Pairwise GC computed: 420 pairs successful, 0 failed
2025-10-25 17:42:34,002 - INFO - Computing pairwise GC for 21 channels (order=50)...
2025-10-25 17:42:46,194 - INFO - Pairwise GC computed: 420 pairs successful, 0 failed
2025-10-25 17:42:46,195 - INFO - Computing pairwise GC for 21 channels (order=50)...
2025-10-25 17:42:57,917 - INFO - Pairwise GC computed: 420 pairs successful, 0 failed
2025-10-25 17:42:57,918 - INFO - Computing pairwise GC for 21 channels (order=50)...
2025-10-25 17:43:09,737 - INFO - Pairwise GC computed: 420 pairs successful, 0 failed
2025-10-25 17:43:09,740 - INFO - Computing pairwise GC for 21 channels (order=50)...
2025-10-25 17:43:18,355 - INFO - Pairwise GC computed: 420 pairs successful, 0 failed
2025-10-25 17:43:18,359 - INFO - Computing pairwise GC for 21 channels (order=50)...
2025-10-25 17:43:26,291 - INFO - Pairwise GC computed: 420 pairs successful, 0 failed
2025-10-25 17:43:26,292 - INFO - Computing pairwise GC for 

   ✓ Permutation test complete
   - Null distribution: mean=5747.9174, std=76593.4049
   - 95th percentile threshold: 1.3600
   - 99th percentile threshold: 92682.7143
   - Failed permutations: 0/100

3. Applying FDR correction...
   ✓ FDR correction complete
   - Method: fdr_bh
   - Total tests: 420
   - Significant (uncorrected): 399
   - Significant (corrected): 398
   - Correction rate: 99.7%

4. Thresholding GC matrix...
   ✓ P-value thresholding (p < 0.05)
     - Original connections: 420
     - Surviving connections: 399
     - Survival rate: 95.0%
   ✓ FDR-based thresholding
     - Surviving connections: 398
     - Survival rate: 94.8%
   ✓ Permutation-based thresholding (95th percentile)
     - Threshold: 1.3600
     - Surviving connections: 418
     - Survival rate: 99.5%

5. Creating comparison visualizations...
   ✓ Saved statistical testing plot to: /home/alookaladdoo/DPCN-Project/results/plots/sub-NORB00001_ses-1_statistical_testing.png

6. Visualizing null distribution..

  plt.show()


   ✓ Saved null distribution plot to: /home/alookaladdoo/DPCN-Project/results/plots/sub-NORB00001_ses-1_null_distribution.png

STATISTICAL TESTING COMPLETE


  plt.show()


## Step 8 Summary: Statistical Testing

We have successfully implemented the complete statistical testing framework for Granger causality analysis:

### Functions Implemented

1. **`permutation_test_gc()`**
   - Establishes significance thresholds via null distribution
   - Uses random circular shifts to preserve temporal structure
   - Returns 95th and 99th percentile thresholds
   - Configurable number of permutations (default: 1000)

2. **`apply_fdr_correction()`**
   - Controls false discovery rate in multiple comparisons
   - Supports FDR-BH (Benjamini-Hochberg) and Bonferroni methods
   - Returns corrected p-values and significance mask
   - Reports reduction in significant connections

3. **`threshold_gc_matrix()`**
   - Zeros out non-significant connections
   - Four thresholding methods:
     - P-value based (e.g., p < 0.05)
     - FDR corrected (using significance mask)
     - Permutation-based (absolute GC value)
     - Custom mask
   - Returns thresholded matrix with statistics

### Testing Results

The test demonstrated all three functions on sample data from the first subject:

- **Permutation test**: Established null distribution and significance thresholds
- **FDR correction**: Applied multiple comparison correction to p-values
- **Thresholding**: Compared different thresholding strategies side-by-side

### Visualizations Created

1. **Statistical Testing Comparison**: 4-panel visualization showing original vs thresholded GC matrices
2. **Null Distribution**: Histogram of permutation-based null distribution with thresholds

---