In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.signal import welch
import os
import sys

BASE_DIR = os.getcwd()
if BASE_DIR not in sys.path:
    sys.path.append(BASE_DIR)

from utils.logger import LoggerHelper

try:
    logger = LoggerHelper.get_logger()
except:
    import logging
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger("EDA")

logger.info("Data visualization started...")

2025-12-17 01:09:47 - INFO - eeg_pipeline - Data visualization started...


In [2]:
csv_path = os.path.join(BASE_DIR, "pediatric", "adhdata.csv")
if not os.path.exists(csv_path):
    csv_path = os.path.join(BASE_DIR, "adhdata.csv")

if os.path.exists(csv_path):
    logger.info("Loading data...")
    df = pd.read_csv(csv_path)
    df.columns = [c.strip() for c in df.columns]
    logger.info(f"Data loaded successfully. Data shape: {df.shape}")
else:
    logger.error("CSV file not found.")

2025-12-17 01:10:19 - INFO - eeg_pipeline - Loading data...
2025-12-17 01:10:21 - INFO - eeg_pipeline - Data loaded successfully. Data shape: (2166383, 21)


In [13]:
def plot_raw_data(df, save_dir):
    plt.figure(figsize=(15, 6))
    channel = 'Fz'
    
    adhd_id = df[df['Class'] == 'ADHD']['ID'].unique()[0]
    control_id = df[df['Class'] == 'Control']['ID'].unique()[0]
    
    adhd_signal = df[df['ID'] == adhd_id][channel].values[:500]
    control_signal = df[df['ID'] == control_id][channel].values[:500]
    
    plt.subplot(1, 2, 1)
    plt.plot(control_signal, color='blue', alpha=0.8)
    plt.title(f"Healthy (Control) - ID: {control_id}\nChannel: {channel}")
    plt.xlabel("Time (Samples)")
    plt.ylabel("Amplitude (uV)")
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 2, 2)
    plt.plot(adhd_signal, color='hotpink', alpha=0.8)
    plt.title(f"ADHD - ID: {adhd_id}\nChannel: {channel}")
    plt.xlabel("Time (Samples)")
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, "eda_raw_signal_comparison.png"))
    plt.close()

def plot_average_psd(df, save_dir):
    fs = 128
    channel = 'Fz'
    
    adhd_data = df[df['Class'] == 'ADHD'][channel].values
    control_data = df[df['Class'] == 'Control'][channel].values
    
    np.random.seed(42)
    adhd_sample = np.random.choice(adhd_data, size=min(len(adhd_data), 100000), replace=False)
    control_sample = np.random.choice(control_data, size=min(len(control_data), 100000), replace=False)
    
    f_adhd, p_adhd = welch(adhd_sample, fs=fs, nperseg=256)
    f_cont, p_cont = welch(control_sample, fs=fs, nperseg=256)
    
    plt.figure(figsize=(10, 6))
    plt.semilogy(f_cont, p_cont, label='Healthy (Control) Mean', color='blue', linewidth=2)
    plt.semilogy(f_adhd, p_adhd, label='ADHD Mean', color='hotpink', linewidth=2, linestyle='--')
    
    plt.title("Group Comparison: Frequency Power Spectrum (PSD)", fontsize=14)
    plt.xlabel("Frequency (Hz) - [Delta, Theta, Alpha, Beta]")
    plt.ylabel("Power Density (Log Scale)")
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.savefig(os.path.join(save_dir, "eda_average_psd.png"))
    plt.close()

def plot_correlation_heatmap(df, save_dir):
    channels = [c for c in df.columns if c not in ['Class', 'ID']]
    
    corr_adhd = df[df['Class'] == 'ADHD'][channels].corr()
    corr_control = df[df['Class'] == 'Control'][channels].corr()
    
    plt.figure(figsize=(20, 8))
    
    plt.subplot(1, 2, 1)
    sns.heatmap(corr_control, cmap='Purples', vmin=-1, vmax=1)
    plt.title("Healthy Brain Connectivity (Correlation)")
    
    plt.subplot(1, 2, 2)
    sns.heatmap(corr_adhd, cmap='Purples', vmin=-1, vmax=1)
    plt.title("ADHD Brain Connectivity (Correlation)")
    
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, "eda_correlation_heatmap.png"))
    plt.close()

In [14]:
save_dir = os.path.join(BASE_DIR, "results", "eda")
os.makedirs(save_dir, exist_ok=True)

if 'df' in locals():
    logger.info("1. Plotting Raw Time Series...")
    plot_raw_data(df, save_dir)
    
    logger.info("2. Plotting Frequency Analysis (PSD)...")
    plot_average_psd(df, save_dir)
    
    logger.info("3. Plotting Correlation Heatmap...")
    plot_correlation_heatmap(df, save_dir)
    
    logger.info("All graphs saved to 'results/eda' folder.")

2025-12-17 01:37:32 - INFO - eeg_pipeline - 1. Plotting Raw Time Series...
2025-12-17 01:37:33 - INFO - eeg_pipeline - 2. Plotting Frequency Analysis (PSD)...
2025-12-17 01:37:33 - INFO - eeg_pipeline - 3. Plotting Correlation Heatmap...
2025-12-17 01:37:35 - INFO - eeg_pipeline - All graphs saved to 'results/eda' folder.
