In [None]:
# static_cca.py
from sklearn.cross_decomposition import CCA
from datetime import datetime, timedelta
from config_loader import load_config
from collections import Counter
from logger import logger
import pandas as pd
import numpy as np
import mne
import os
import gc

# Parameters
config = load_config()

DATA_FOLDER = config.data.data_dir # "data/apples"
OUTPUT_FOLDER = config.static_cca_params.output_dir # "data/static_cca"
EEG_CHANNELS = config.data.eeg_channels # ['C3_M2', 'C4_M1', 'O1_M2', 'O2_M1']
EOG_CHANNELS = config.data.eog_channels # ['LOC', 'ROC']
SLEEP_STAGES = config.data.sleep_stages # ['W', 'N1', 'N2', 'N3', 'R']
DOWNSAMPLING_FACTOR = config.static_cca_params.downsampling_factor # 1
fmt = "%H:%M:%S"

# Initialize results list
summary_results = []

# Iterate through .edf/.annot files
file_pairs = [(f, f.replace(".edf", ".annot")) for f in os.listdir(DATA_FOLDER) if f.endswith(".edf")]

for edf_file, annot_file in file_pairs:
    edf_path = os.path.join(DATA_FOLDER, edf_file)
    annot_path = os.path.join(DATA_FOLDER, annot_file)

    if not os.path.exists(annot_path):
        continue

    try:
        raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
        sfreq = int(raw.info['sfreq'])
        start_datetime = raw.info['meas_date'].replace(tzinfo=None)
        #logger.info(f"EDF {edf_path} Start:", start_datetime)

        # Read annotations
        with open(annot_path, "r", encoding="utf-8", errors="ignore") as f:
            lines = f.readlines()[1:]
        annotations = []
        starts = []
        for line in lines:
            parts = line.strip().split("\t")
            if len(parts) == 6:
                stage, _, _, start, stop, _ = parts
                annotations.append((stage, start, stop))

        parsed_epochs = []
        for stage, start_str, stop_str in annotations:
            if stage in SLEEP_STAGES:
                try:
                    start_clock = datetime.strptime(start_str, fmt).time()
                    stop_clock = datetime.strptime(stop_str, fmt).time()

                    start_time = datetime.combine(start_datetime.date(), start_clock)
                    stop_time = datetime.combine(start_datetime.date(), stop_clock)

                    # Handle midnight crossing
                    if stop_time < start_time:
                        stop_time += timedelta(days=1)
                    if start_time < start_datetime:
                        start_time += timedelta(days=1)
                        stop_time += timedelta(days=1)

                    start_sec = (start_time - start_datetime).total_seconds()
                    stop_sec = (stop_time - start_datetime).total_seconds()

                    parsed_epochs.append((stage, start_sec, stop_sec))
                except Exception as e:
                    logger.error(f"Annotation parse error: {e}")


        logger.info('Epochs parsed: ', parsed_epochs[:5])
        stage_counts = Counter([stage for stage, _, _ in parsed_epochs])
        logger.info(f'Stage count for subject {edf_file}: {stage_counts}')

        # Organize EEG and EOG segments
        eeg_data = {stage: [] for stage in SLEEP_STAGES}
        eog_data = {stage: [] for stage in SLEEP_STAGES}
        for s, start, stop in parsed_epochs:
            start_sample = round(start * sfreq)
            stop_sample = min(round(stop * sfreq), raw.n_times)

            if start_sample < 0 or start_sample >= stop_sample:
                logger.warning(f"Invalid sample range for stage {s}: start={start_sample}, stop={stop_sample}, total={raw.n_times}")
                continue

            try:
                eeg = raw.get_data(picks=EEG_CHANNELS, start=start_sample, stop=stop_sample)
                eog = raw.get_data(picks=EOG_CHANNELS, start=start_sample, stop=stop_sample)
                eeg_data[s].append(eeg)
                eog_data[s].append(eog)
            except Exception as e:
                logger.error(f"Failed to extract data: stage={s}, start={start_sample}, stop={stop_sample}, error: {e}")
                continue

        # Set a downsampling factor
        target_fs = DOWNSAMPLING_FACTOR
        factor = int(sfreq / target_fs)

        # Perform CCA
        for stage in SLEEP_STAGES:
            if not eeg_data[stage] or not eog_data[stage]:
                logger.info(f"Skipping stage {stage}: no data available.")
                continue
            eeg_agg = np.hstack(eeg_data[stage])
            eog_agg = np.hstack(eog_data[stage])

            try:
                # Transpose to (n_samples, n_features)
                X = eeg_agg.T
                Y = eog_agg.T

                # Ensure same number of samples
                min_len = min(len(X), len(Y))
                X = X[:min_len]
                Y = Y[:min_len]

                cca = CCA(n_components=2)
                X_c, Y_c = cca.fit_transform(X, Y)

                # Save downsampled projections
                if len(X_c) > factor:
                    X_c_ds = X_c[::factor, :]
                    Y_c_ds = Y_c[::factor, :]
                else:
                    X_c_ds = X_c
                    Y_c_ds = Y_c

                file_prefix = f"{edf_file.replace('.edf','')}_{stage}"

                pd.DataFrame(X_c_ds, columns=["Xc_1", "Xc_2"]).to_csv(
                    os.path.join(OUTPUT_FOLDER, f"{file_prefix}_Xc_downsampled.csv"), index=False
                )
                pd.DataFrame(Y_c_ds, columns=["Yc_1", "Yc_2"]).to_csv(
                    os.path.join(OUTPUT_FOLDER, f"{file_prefix}_Yc_downsampled.csv"), index=False
                )

                # Compute canonical correlation coefficients
                corr_coeffs = [np.corrcoef(X_c[:, i], Y_c[:, i])[0, 1] for i in range(X_c.shape[1])]
                
                # Compute summary statistics
                summary = {
                    "subject": edf_file,
                    "stage": stage,
                    "cca_corr1": corr_coeffs[0],
                    "cca_corr2": corr_coeffs[1] if len(corr_coeffs) > 1 else np.nan
                }

                for idx in range(X_c.shape[1]):
                    x_vals = X_c[:, idx]
                    summary[f"Xc{idx+1}_mean"] = np.mean(x_vals)
                    summary[f"Xc{idx+1}_std"] = np.std(x_vals)
                    summary[f"Xc{idx+1}_25p"] = np.percentile(x_vals, 25)
                    summary[f"Xc{idx+1}_median"] = np.median(x_vals)
                    summary[f"Xc{idx+1}_75p"] = np.percentile(x_vals, 75)

                for idx in range(Y_c.shape[1]):
                    y_vals = Y_c[:, idx]
                    summary[f"Yc{idx+1}_mean"] = np.mean(y_vals)
                    summary[f"Yc{idx+1}_std"] = np.std(y_vals)
                    summary[f"Yc{idx+1}_25p"] = np.percentile(y_vals, 25)
                    summary[f"Yc{idx+1}_median"] = np.median(y_vals)
                    summary[f"Yc{idx+1}_75p"] = np.percentile(y_vals, 75)

                summary_results.append(summary)
                logger.info(f'Summary resulst for edf {edf_file} for stage {stage} written')

            except Exception as e:
                logger.error(f"CCA failed for {edf_file}, stage {stage}: {e}")
    except Exception as e:
        logger.error(f"Failed on {edf_file}: {e}")
    #
        raw._data = None  # Detach memory-mapped data if present
        raw.annotations.delete(0, len(raw.annotations))  # Clear MNE annotations    
        del raw, eeg_data, eog_data
        gc.collect() # Clean up memory

# Save results
results_df = pd.DataFrame(summary_results)
results_csv_path = os.path.join(OUTPUT_FOLDER, "eeg_eog_cca_summary_stats.csv")
results_df.to_csv(results_csv_path, index=False)

logger.info(f'The summary statistics saved to {results_csv_path}')

In [None]:
# static_cca_analyze_canonical_projections.py
from scipy.stats import skew, kurtosis, f_oneway
from config_loader import load_config
import matplotlib.pyplot as plt
from logger import logger
import seaborn as sns
import pandas as pd
import numpy as np
import glob
import os

# Parameters
config = load_config()

OUTPUT_FOLDER = config.static_cca_params.output_dir # "data/static_cca"
RESULTS_FOLDER = config.static_cca_params.results_dir # "data/static_cca_analysis"

# File patterns
pattern_Xc = os.path.join(OUTPUT_FOLDER, "*_Xc_downsampled.csv")
pattern_Yc = os.path.join(OUTPUT_FOLDER, "*_Yc_downsampled.csv")

# Projection loader
def load_projection_files(pattern, prefix):
    rows = []
    for filepath in glob.glob(pattern):
        filename = os.path.basename(filepath)
        
        name_parts = filename.replace("_Xc_downsampled.csv", "") \
                             .replace("_Yc_downsampled.csv", "").split("_")
        
        if len(name_parts) != 2:
            logger.warning(f"Skipping unexpected filename: {filename}")
            continue
        
        subject, stage = name_parts

        df = pd.read_csv(filepath)

        # Unpivoting the dataframe
        df_long = df.melt(var_name="projection", value_name="value")
        df_long["subject"] = subject
        df_long["stage"] = stage
        df_long["type"] = prefix

        rows.append(df_long)

    if rows:
        return pd.concat(rows, ignore_index=True)
    else:
        return pd.DataFrame(columns=["subject", "stage", "projection", "value", "type"])

# Load all downsampled files
df_Xc = load_projection_files(pattern_Xc, "Xc")
df_Yc = load_projection_files(pattern_Yc, "Yc")

all_data = pd.concat([df_Xc, df_Yc], ignore_index=True)
logger.info(f"Loaded downsampled projections: {all_data.shape}")

if all_data.empty:
    logger.warning("No data loaded! Check file paths and patterns.")
    exit()

# Save combined file
all_data.to_csv(os.path.join(OUTPUT_FOLDER, "all_downsampled_projections.csv"), index=False)
logger.info("Saved combined downsampled data.")

# Plot distributions per stage and projection
for proj_type in ["Xc", "Yc"]:
    for comp in ["1", "2"]:
        projection_name = f"{proj_type}_{comp}"

        df_plot = all_data[
            (all_data["type"] == proj_type) &
            (all_data["projection"] == projection_name)
        ]

        if df_plot.empty:
            logger.warning(f"No data for {projection_name}")
            continue

        # KDE Plot
        plt.figure(figsize=(10,5))
        sns.kdeplot(
            data=df_plot,
            x="value",
            hue="stage",
            common_norm=False,
            fill=True,
            alpha=0.3,
            linewidth=1.5
        )
        plt.title(f"Density of {projection_name} across stages")
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.savefig(os.path.join(RESULTS_FOLDER, f"{projection_name}_kde.png"))
        plt.close()
        logger.info(f"Saved density plot for {projection_name}")

        # Boxplot
        plt.figure(figsize=(8,5))
        sns.boxplot(
            x="stage",
            y="value",
            data=df_plot,
            showmeans=True
        )
        plt.title(f"Boxplot of {projection_name} across stages")
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.savefig(os.path.join(RESULTS_FOLDER, f"{projection_name}_boxplot.png"))
        plt.close()
        logger.info(f"Saved boxplot for {projection_name}")

# Compute summary statistics by stage and projection
summary_rows = []
for proj_type in ["Xc", "Yc"]:
    for comp in ["1", "2"]:
        projection_name = f"{proj_type}_{comp}"
        df_proj = all_data[
            (all_data["type"] == proj_type) &
            (all_data["projection"] == projection_name)
        ]
        if df_proj.empty:
            continue

        grouped = df_proj.groupby("stage")["value"]
        for stage, vals in grouped:
            summary_rows.append({
                "projection": projection_name,
                "stage": stage,
                "mean": np.mean(vals),
                "std": np.std(vals),
                "skewness": skew(vals),
                "kurtosis": kurtosis(vals),
                "count": len(vals)
            })

summary_df = pd.DataFrame(summary_rows)
summary_csv = os.path.join(RESULTS_FOLDER, "canonical_projection_summary_by_stage.csv")
summary_df.to_csv(summary_csv, index=False)
logger.info(f"Saved summary stats to {summary_csv}")

# Run ANOVA across stages
for proj_type in ["Xc", "Yc"]:
    for comp in ["1", "2"]:
        projection_name = f"{proj_type}_{comp}"
        df_proj = all_data[
            (all_data["type"] == proj_type) &
            (all_data["projection"] == projection_name)
        ]
        if df_proj.empty:
            continue

        groups = [g["value"].values for _, g in df_proj.groupby("stage")]
        if len(groups) > 1:
            fval, pval = f_oneway(*groups)
            logger.info(f"ANOVA for {projection_name}: F = {fval:.3f}, p = {pval:.3e}")


In [None]:
# stataic_cca_analyze_summary_stats.py
from config_loader import load_config
from scipy.stats import f_oneway
import matplotlib.pyplot as plt
from logger import logger
import seaborn as sns
import pandas as pd
import os

# Parameters
config = load_config()

OUTPUT_FOLDER = config.static_cca_params.output_dir # "data/static_cca"
RESULTS_FOLDER = config.static_cca_params.results_dir # "data/static_cca_analysis"

# Load summary CSV
summary_path = os.path.join(OUTPUT_FOLDER, "eeg_eog_cca_summary_stats.csv")
summary_df = pd.read_csv(summary_path)

logger.info("Summary stats loaded:", summary_df.shape)

# Plot distributions of CCA correlations
for var in ["cca_corr1", "cca_corr2"]:
    plt.figure(figsize=(8,5))
    sns.boxplot(x="stage", y=var, data=summary_df, showmeans=True)
    plt.title(f"Boxplot of {var} across sleep stages")
    plt.ylabel(var)
    plt.grid(alpha=0.3)
    plt.tight_layout()
    plot_path = os.path.join(RESULTS_FOLDER, f"{var}_boxplot_by_stage.png")
    plt.savefig(plot_path)
    plt.close()
    logger.info(f"Saved boxplot for {var}")

# Run ANOVA on CCA correlations
for var in ["cca_corr1", "cca_corr2"]:
    groups = []
    for stage in summary_df["stage"].unique():
        vals = summary_df.loc[summary_df["stage"]==stage, var].dropna().values
        if len(vals) > 1:
            groups.append(vals)
    if len(groups) > 1:
        stat, pval = f_oneway(*groups)
        logger.info(f"ANOVA for {var}: F={stat:.3f}, p={pval:.3e}")
    else:
        logger.warning(f"Not enough groups for ANOVA on {var}")

# Aggregate mean +- std for CCA correlations
agg = summary_df.groupby("stage")[["cca_corr1", "cca_corr2"]].agg(["mean", "std", "count"])
agg.columns = ['_'.join(col) for col in agg.columns]
agg.reset_index(inplace=True)
agg.to_csv(os.path.join(RESULTS_FOLDER, "cca_correlation_summary.csv"), index=False)
logger.info("Saved aggregated summary CSV for cca_corr1 and cca_corr2.")

# Analyze canonical projection means
projection_cols = [c for c in summary_df.columns if ("Xc" in c or "Yc" in c) and "_mean" in c]

for var in projection_cols:
    plt.figure(figsize=(8,5))
    sns.boxplot(x="stage", y=var, data=summary_df, showmeans=True)
    plt.title(f"Boxplot of {var} across sleep stages")
    plt.ylabel(var)
    plt.grid(alpha=0.3)
    plt.tight_layout()
    plot_path = os.path.join(RESULTS_FOLDER, f"{var}_boxplot_by_stage.png")
    plt.savefig(plot_path)
    plt.close()
    logger.info(f"Saved boxplot for {var}")

In [None]:
# static_cca_explained_variance.py
from config_loader import load_config
from logger import logger
import pandas as pd
import numpy as np
import os

# Parameters
config = load_config()

DATA_FOLDER = config.static_cca_params.output_dir  # "data/static_cca"
OUTPUT_PATH = os.path.join(DATA_FOLDER, "explained_variance_by_stage.csv")

# Find files
xc_files = [f for f in os.listdir(DATA_FOLDER) if f.endswith("_Xc_downsampled.csv")]
yc_files = [f for f in os.listdir(DATA_FOLDER) if f.endswith("_Yc_downsampled.csv")]

# Get stage and subject
def parse_metadata(filename):
    parts = filename.replace("_Xc_downsampled.csv", "").replace("_Yc_downsampled.csv", "").split("_")
    if len(parts) == 2:
        return parts[0], parts[1]
    return "unknown", "unknown"

# Initialize
records = []

for xc_file in xc_files:
    base = xc_file.replace("_Xc_downsampled.csv", "")
    yc_file = base + "_Yc_downsampled.csv"
    xc_path = os.path.join(DATA_FOLDER, xc_file)
    yc_path = os.path.join(DATA_FOLDER, yc_file)

    if not os.path.exists(yc_path):
        continue

    subject, stage = parse_metadata(xc_file)
    try:
        Xc = pd.read_csv(xc_path).values
        Yc = pd.read_csv(yc_path).values

        if Xc.shape[0] != Yc.shape[0]:
            continue

        var_Xc = np.var(Xc, axis=0)
        var_Yc = np.var(Yc, axis=0)

        total_var_X = var_Xc.sum()
        total_var_Y = var_Yc.sum()

        for i in range(2):
            records.append({
                "subject": subject,
                "stage": stage,
                "component": i + 1,
                "explained_variance_Xc": var_Xc[i] / total_var_X if total_var_X else np.nan,
                "explained_variance_Yc": var_Yc[i] / total_var_Y if total_var_Y else np.nan
            })
    except Exception as e:
        logger.error(f"Error processing {xc_file}: {e}")
        continue

# Save results
explained_var_df = pd.DataFrame(records)
explained_var_df.to_csv(OUTPUT_PATH, index=False)

logger.info(f"Explained variance results saved to {OUTPUT_PATH}")

In [None]:
# static_cca_visualize_explained_variance.py
from config_loader import load_config
import matplotlib.pyplot as plt
from logger import logger
import seaborn as sns
import pandas as pd
import os

# Parameters
config = load_config()

DATA_FOLDER = config.static_cca_params.output_dir  # "data/static_cca"
CSV_PATH = os.path.join(DATA_FOLDER, "explained_variance_by_stage.csv")
REPORT_FIGURES_FOLDER = config.report.figures_folder  # "report/figs"

df = pd.read_csv(CSV_PATH)

# Set plot style
sns.set(style="whitegrid")

# Boxplot for explained variance of EEG canonical projections (Xc)
plt.figure(figsize=(10, 6))
sns.boxplot(data=df, x="stage", y="explained_variance_Xc", hue="component", showmeans=True)
plt.title("Explained Variance of EEG Canonical Projections (Xc)")
plt.ylabel("Explained Variance Ratio")
plt.xlabel("Sleep Stage")
plt.legend(title="Component")
plt.ylim(0, 1)
plt.tight_layout()
fig_xc_path = os.path.join(REPORT_FIGURES_FOLDER, "figure5a_explained_variance_Xc.png")
plt.savefig(fig_xc_path)
plt.close()

# Boxplot for explained variance of EOG canonical projections (Yc)
plt.figure(figsize=(10, 6))
sns.boxplot(data=df, x="stage", y="explained_variance_Yc", hue="component", showmeans=True)
plt.title("Explained Variance of EOG Canonical Projections (Yc)")
plt.ylabel("Explained Variance Ratio")
plt.xlabel("Sleep Stage")
plt.legend(title="Component")
plt.ylim(0, 1)
plt.tight_layout()
fig_yc_path = os.path.join(REPORT_FIGURES_FOLDER, "figure5b_explained_variance_Yc.png")
plt.savefig(fig_yc_path)
plt.close()

logger.info(f"Saved explained variance plots to {REPORT_FIGURES_FOLDER}")

In [None]:
# time_resolved_cca.py
from sklearn.cross_decomposition import CCA
from datetime import datetime, timedelta
from config_loader import load_config
from logger import logger
import pandas as pd
import numpy as np
import mne
import os
import gc

# Parameters
config = load_config()

DATA_FOLDER = config.data.data_dir # "data/apples"
OUTPUT_FOLDER = config.time_cca_params.output_dir # "data/time_resolved_cca"
EEG_CHANNELS = config.data.eeg_channels # ['C3_M2', 'C4_M1', 'O1_M2', 'O2_M1']
EOG_CHANNELS = config.data.eog_channels # ['LOC', 'ROC']
SLEEP_STAGES = config.data.sleep_stages # ['W', 'N1', 'N2', 'N3', 'R']
WINDOW_LENGTH = config.time_cca_params.window_length # 30 seconds
STEP_LENGTH = config.time_cca_params.step_length # 15 seconds
fmt = "%H:%M:%S"

# Iterate through .edf/.annot files
file_pairs = [(f, f.replace(".edf", ".annot")) for f in os.listdir(DATA_FOLDER) if f.endswith(".edf")]

for edf_file, annot_file in file_pairs:
    edf_path = os.path.join(DATA_FOLDER, edf_file)
    annot_path = os.path.join(DATA_FOLDER, annot_file)

    if not os.path.exists(annot_path):
        continue

    try:
        raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
        sfreq = int(raw.info['sfreq'])
        start_datetime = raw.info['meas_date'].replace(tzinfo=None)
        logger.info(f"EDF {edf_path} Start:", start_datetime)

        # Read annotations
        with open(annot_path, "r", encoding="utf-8", errors="ignore") as f:
            lines = f.readlines()[1:]
        annotations = []
        for line in lines:
            parts = line.strip().split("\t")
            if len(parts) == 6:
                stage, _, _, start, stop, _ = parts
                annotations.append((stage, start, stop))

        parsed_epochs = []
        for stage, start_str, stop_str in annotations:
            if stage in SLEEP_STAGES:
                try:
                    start_clock = datetime.strptime(start_str, fmt).time()
                    stop_clock = datetime.strptime(stop_str, fmt).time()

                    start_time = datetime.combine(start_datetime.date(), start_clock)
                    stop_time = datetime.combine(start_datetime.date(), stop_clock)

                    # Handle midnight crossing
                    if stop_time < start_time:
                        stop_time += timedelta(days=1)
                    if start_time < start_datetime:
                        start_time += timedelta(days=1)
                        stop_time += timedelta(days=1)

                    start_sec = (start_time - start_datetime).total_seconds()
                    stop_sec = (stop_time - start_datetime).total_seconds()

                    parsed_epochs.append((stage, start_sec, stop_sec))
                except Exception as e:
                    logger.error(f"Annotation parse error: {e}")

        # Perform time-resolved CCA
        stage_to_results = {stage: [] for stage in SLEEP_STAGES}
        for stage, start, stop in parsed_epochs:
            t = start
            while t + WINDOW_LENGTH <= stop:
                start_sample = round(t * sfreq)
                stop_sample = round((t + WINDOW_LENGTH) * sfreq)

                if start_sample < 0 or stop_sample > raw.n_times:
                    t += STEP_LENGTH
                    continue

                try:
                    eeg = raw.get_data(picks=EEG_CHANNELS, start=start_sample, stop=stop_sample)
                    eog = raw.get_data(picks=EOG_CHANNELS, start=start_sample, stop=stop_sample)
                    X = eeg.T
                    Y = eog.T

                    min_len = min(len(X), len(Y))
                    X = X[:min_len]
                    Y = Y[:min_len]

                    cca = CCA(n_components=2)
                    X_c, Y_c = cca.fit_transform(X, Y)

                    # Compute canonical correlation coefficients
                    corr1 = np.corrcoef(X_c[:, 0], Y_c[:, 0])[0, 1]
                    corr2 = np.corrcoef(X_c[:, 1], Y_c[:, 1])[0, 1]

                    # Compute summary statistics
                    stage_to_results[stage].append({
                        "time_sec": t,
                        "cca_corr1": corr1,
                        "cca_corr2": corr2,
                        "subject": edf_file.replace(".edf", ""),
                        "stage": stage
                    })

                except Exception as e:
                    logger.error(f"CCA failed in {edf_file} stage {stage} at t={t}: {e}")
                t += STEP_LENGTH

            # Save results
            for stage, results in stage_to_results.items():
                if results:
                    df_out = pd.DataFrame(results)
                    file_prefix = f"{edf_file.replace('.edf','')}_{stage}_cca_timeseries.csv"
                    df_out.to_csv(os.path.join(OUTPUT_FOLDER, file_prefix), index=False)
                    logger.info(f"Saved CCA timeseries for {edf_file} stage {stage}")

    except Exception as e:
        logger.error(f"Failed on {edf_file}: {e}")

    finally:
        raw._data = None  # Detach memory-mapped data if present
        raw.annotations.delete(0, len(raw.annotations))  # Clear MNE annotations    
        del raw
        gc.collect() # Clean up memory

In [None]:
# time_resolved_cca_check_stationarity.py
from statsmodels.tsa.stattools import adfuller, kpss
from config_loader import load_config
from logger import logger
import pandas as pd
import os

# Parameters
config = load_config()

DATA_FOLDER = config.time_cca_params.output_dir  # "data/time_resolved_cca"
OUTPUT_PATH = os.path.join(DATA_FOLDER, "stationarity_results.csv")

# List CCA timeseries files
cca_files = [f for f in os.listdir(DATA_FOLDER) if f.endswith("_cca_timeseries.csv")]

results = []

for fname in cca_files:
    df = pd.read_csv(os.path.join(DATA_FOLDER, fname))
    subject = fname.replace("_cca_timeseries.csv", "")

    for comp in ["cca_corr1", "cca_corr2"]:
        values = df[comp].dropna().values

        if len(values) < 10:
            continue  # too few values to test

        # ADF Test
        adf_stat, adf_pval, _, _, _, _ = adfuller(values)

        # KPSS Test
        try:
            kpss_stat, kpss_pval, _, _ = kpss(values, regression='c')
        except:
            kpss_stat, kpss_pval = None, None

        results.append({
            "subject": subject,
            "component": comp,
            "adf_pval": adf_pval,
            "adf_stationary": adf_pval < 0.05,
            "kpss_pval": kpss_pval,
            "kpss_stationary": kpss_pval is not None and kpss_pval > 0.05
        })

results_df = pd.DataFrame(results)

# Save results to CSV
results_df.to_csv(OUTPUT_PATH, index=False)
logger.info(f"Stationarity results saved to {OUTPUT_PATH}")


In [None]:
# time_resolved_cca_analysis.py
from scipy.stats import entropy, skew, kurtosis
from config_loader import load_config
from logger import logger
from glob import glob
import pandas as pd
import numpy as np
import os

# Parameters
config = load_config()

OUTPUT_FOLDER = config.time_cca_params.output_dir  # "data/time_resolved_cca"
RESULTS_FOLDER = config.time_cca_params.results_dir  # "data/time_resolved_cca_analysis"

# Load all *_cca_timeseries.csv files
all_files = glob(os.path.join(OUTPUT_FOLDER, "*_cca_timeseries.csv"))
aggregated_data = pd.concat([pd.read_csv(f) for f in all_files], ignore_index=True)

# Function 1: Stagewise mean and std of CCA1 and CCA2
stagewise_stats = aggregated_data.groupby("stage")[["cca_corr1", "cca_corr2"]].agg(["mean", "std", "count"])
stagewise_stats.columns = ['_'.join(col).strip() for col in stagewise_stats.columns.values]
stagewise_stats.reset_index(inplace=True)
stagewise_stats_path = os.path.join(RESULTS_FOLDER, "stagewise_summary.csv")
stagewise_stats.to_csv(stagewise_stats_path, index=False)
logger.info(f"Stagewise summary saved to {stagewise_stats_path}")

# Function 2: Compute temporal mean trajectories (binned over time)
aggregated_data["time_bin"] = pd.cut(aggregated_data["time_sec"], bins=np.arange(0, aggregated_data["time_sec"].max() + 600, 600))
trajectory = aggregated_data.groupby(["stage", "time_bin"])[["cca_corr1", "cca_corr2"]].mean().reset_index()
trajectory_path = os.path.join(RESULTS_FOLDER, "mean_cca_trajectory_by_stage.csv")
trajectory.to_csv(trajectory_path, index=False)
logger.info(f"Temporal mean trajectories saved to {trajectory_path}")

# Function 3: Compute subjectwise entropy of CCA1 and CCA2 per stage
def compute_entropy(x):
    hist, _ = np.histogram(x, bins=20, range=(0, 1), density=True)
    return entropy(hist + 1e-12)  # add small value to avoid log(0)

entropy_stats = (
    aggregated_data.groupby(["subject", "stage"])
    .agg({
        "cca_corr1": [compute_entropy, np.mean, np.std, skew, kurtosis],
        "cca_corr2": [compute_entropy, np.mean, np.std, skew, kurtosis]
    })
)

entropy_stats.columns = ['_'.join(col).strip() for col in entropy_stats.columns.values]
entropy_stats.reset_index(inplace=True)
entropy_stats_path = os.path.join(RESULTS_FOLDER, "entropy_by_subject_stage.csv")
entropy_stats.to_csv(entropy_stats_path, index=False)
logger.info(f"Entropy statistics saved to {entropy_stats_path}")

# Function 4: Save a few representative trajectories
sampled_subjects = aggregated_data["subject"].drop_duplicates().sample(3, random_state=42).tolist()
subset = aggregated_data[aggregated_data["subject"].isin(sampled_subjects)]
subset_path = os.path.join(RESULTS_FOLDER, "subset_trajectories.csv")
subset.to_csv(subset_path, index=False)
logger.info(f"Sample trajectories saved to {subset_path}")


In [None]:
# time_resolved_cca_plotting_groupped.py
from config_loader import load_config
import matplotlib.pyplot as plt
from logger import logger
import seaborn as sns
import pandas as pd
import os

# Parameters
config = load_config()

RESULTS_FOLDER = config.time_cca_params.results_dir  # "data/time_resolved_cca_analysis"
FIGURES_FOLDER = os.path.join(RESULTS_FOLDER, "figures")

# Load the files
stagewise_summary = pd.read_csv(os.path.join(RESULTS_FOLDER, "stagewise_summary.csv"))
mean_cca_trajectory_by_stage = pd.read_csv(os.path.join(RESULTS_FOLDER, "mean_cca_trajectory_by_stage.csv"))
entropy_by_subject_stage = pd.read_csv(os.path.join(RESULTS_FOLDER, "entropy_by_subject_stage.csv"))
subset_trajectories = pd.read_csv(os.path.join(RESULTS_FOLDER, "subset_trajectories.csv"))

# 1. Boxplot of cca_corr1 and cca_corr2 by stage
plt.figure(figsize=(10, 6))
sns.boxplot(x='stage', y='cca_corr1', data=subset_trajectories)
plt.title('CCA Corr1 Distribution by Sleep Stage (Time-resolved)')
plt.savefig(os.path.join(FIGURES_FOLDER, "boxplot_cca_corr1_by_stage.png"))
plt.close()
logger.info("Saved boxplot of cca_corr1 by stage.")

plt.figure(figsize=(10, 6))
sns.boxplot(x='stage', y='cca_corr2', data=subset_trajectories)
plt.title('CCA Corr2 Distribution by Sleep Stage (Time-resolved)')
plt.savefig(os.path.join(FIGURES_FOLDER, "boxplot_cca_corr2_by_stage.png"))
plt.close()
logger.info("Saved boxplot of cca_corr2 by stage.")

# 2. Lineplot of mean trajectories per stage over time
plt.figure(figsize=(12, 6))
for stage in mean_cca_trajectory_by_stage['stage'].unique():
    stage_data = mean_cca_trajectory_by_stage[mean_cca_trajectory_by_stage['stage'] == stage]
    plt.plot(stage_data.index, stage_data['cca_corr1'], label=f"{stage} - CCA1")
plt.title("Mean CCA Corr1 Trajectory Over Time by Stage")
plt.xlabel("Time Bin Index")
plt.ylabel("CCA Corr1")
plt.legend()
plt.savefig(os.path.join(FIGURES_FOLDER, "trajectory_cca_corr1_by_stage.png"))
plt.close()
logger.info("Saved mean CCA Corr1 trajectory plot by stage.")

plt.figure(figsize=(12, 6))
for stage in mean_cca_trajectory_by_stage['stage'].unique():
    stage_data = mean_cca_trajectory_by_stage[mean_cca_trajectory_by_stage['stage'] == stage]
    plt.plot(stage_data.index, stage_data['cca_corr2'], label=f"{stage} - CCA2")
plt.title("Mean CCA Corr2 Trajectory Over Time by Stage")
plt.xlabel("Time Bin Index")
plt.ylabel("CCA Corr2")
plt.legend()
plt.savefig(os.path.join(FIGURES_FOLDER, "trajectory_cca_corr2_by_stage.png"))
plt.close()
logger.info("Saved mean CCA Corr2 trajectory plot by stage.")

# 3. Entropy values per stage for CCA1 and CCA2
plt.figure(figsize=(10, 6))
sns.boxplot(x='stage', y='cca_corr1_compute_entropy', data=entropy_by_subject_stage)
plt.title('CCA Corr1 Entropy by Sleep Stage')
plt.savefig(os.path.join(FIGURES_FOLDER, "entropy_cca_corr1_by_stage.png"))
plt.close()
logger.info("Saved entropy plot for CCA Corr1 by stage.")

plt.figure(figsize=(10, 6))
sns.boxplot(x='stage', y='cca_corr2_compute_entropy', data=entropy_by_subject_stage)
plt.title('CCA Corr2 Entropy by Sleep Stage')
plt.savefig(os.path.join(FIGURES_FOLDER, "entropy_cca_corr2_by_stage.png"))
plt.close()
logger.info("Saved entropy plot for CCA Corr2 by stage.")


In [None]:
# generate_figures_report.py
from config_loader import load_config
import matplotlib.pyplot as plt
from logger import logger
import seaborn as sns
import pandas as pd
import os

# Parameters
config = load_config()

SUMMARY_FOLDER = config.static_cca_params.output_dir # "data/static_cca"
TIME_RESOLVED_RESULTS_FOLDER = config.time_cca_params.results_dir  # "data/time_resolved_cca_analysis"
REPORT_FIGURES_FOLDER = config.report.figures_folder  # "report/figs"

# Load data
summary_df = pd.read_csv(os.path.join(SUMMARY_FOLDER, "eeg_eog_cca_summary_stats.csv"))
subset_trajectories = pd.read_csv(os.path.join(TIME_RESOLVED_RESULTS_FOLDER, "subset_trajectories.csv"))
mean_cca_trajectory_by_stage = pd.read_csv(os.path.join(TIME_RESOLVED_RESULTS_FOLDER, "mean_cca_trajectory_by_stage.csv"))
entropy_by_subject_stage = pd.read_csv(os.path.join(TIME_RESOLVED_RESULTS_FOLDER, "entropy_by_subject_stage.csv"))

# Figure 1: Static CCA Boxplots
fig, axs = plt.subplots(1, 2, figsize=(12, 5))
sns.boxplot(x="stage", y="cca_corr1", data=summary_df, showmeans=True, ax=axs[0])
axs[0].set_title(r"Static CCA: $\rho_1$")
axs[0].set_ylabel("Correlation")
axs[0].grid(alpha=0.3)
axs[0].set_ylim(0,1)
fig.text(0.05, 0.95, 'Panel A', ha='left', va='center', rotation='horizontal', fontsize=12, fontweight='bold')

sns.boxplot(x="stage", y="cca_corr2", data=summary_df, showmeans=True, ax=axs[1])
axs[1].set_title(r"Static CCA: $\rho_2$")
#axs[1].set_ylabel("Correlation")
axs[1].grid(alpha=0.3)
axs[1].set_ylim(0,1)
axs[1].yaxis.set_visible(False)
fig.text(0.55, 0.95, 'Panel B', ha='left', va='center', rotation='horizontal', fontsize=12, fontweight='bold')

plt.tight_layout()
plt.savefig(os.path.join(REPORT_FIGURES_FOLDER, "figure1_static_cca_boxplots.png"))
plt.close()
logger.info("Saved static CCA boxplots.")

# Figure 2: Time-resolved CCA Boxplots
fig, axs = plt.subplots(1, 2, figsize=(12, 5))
sns.boxplot(x='stage', y='cca_corr1', data=subset_trajectories, ax=axs[0])
axs[0].set_title(r'Time-Resolved CCA: $\rho_1$')
axs[0].set_ylabel("Correlation")
axs[0].grid(alpha=0.3)
axs[0].set_ylim(0,1)
fig.text(0.05, 0.95, 'Panel A', ha='left', va='center', rotation='horizontal', fontsize=12, fontweight='bold')

sns.boxplot(x='stage', y='cca_corr2', data=subset_trajectories, ax=axs[1])
axs[1].set_title(r'Time-Resolved CCA: $\rho_2$')
#axs[1].set_ylabel("Correlation")
axs[1].grid(alpha=0.3)
axs[1].set_ylim(0,1)
axs[1].yaxis.set_visible(False)
fig.text(0.55, 0.95, 'Panel B', ha='left', va='center', rotation='horizontal', fontsize=12, fontweight='bold')

plt.tight_layout()
plt.savefig(os.path.join(REPORT_FIGURES_FOLDER, "figure2_time_resolved_boxplots.png"))
plt.close()
logger.info("Saved time-resolved CCA boxplots.")

# Figure 3: Mean Trajectories by Stage
fig, axs = plt.subplots(1, 2, figsize=(14, 5))
for stage in mean_cca_trajectory_by_stage['stage'].unique():
    data = mean_cca_trajectory_by_stage[mean_cca_trajectory_by_stage['stage'] == stage]
    axs[0].plot(data.index, data['cca_corr1'], label=stage)
    axs[1].plot(data.index, data['cca_corr2'], label=stage)

axs[0].set_title(r"Mean Trajectory: $\rho_1$")
axs[0].set_xlabel("Time Bin Index")
axs[0].set_ylabel("Correlation")
#axs[0].legend()
axs[0].grid(alpha=0.3)
axs[0].set_ylim(0,1)
fig.text(0.05, 0.9, 'Panel A', ha='left', va='center', rotation='horizontal', fontsize=12, fontweight='bold')

axs[1].set_title(r"Mean Trajectory: $\rho_2$")
axs[1].set_xlabel("Time Bin Index")
#axs[1].set_ylabel("Correlation")
#axs[1].legend()
axs[1].grid(alpha=0.3)
axs[1].set_ylim(0,1)
axs[1].yaxis.set_visible(False)
fig.text(0.55, 0.9, 'Panel B', ha='left', va='center', rotation='horizontal', fontsize=12, fontweight='bold')

handles, labels = axs[0].get_legend_handles_labels()

fig.legend(
    handles, labels,
    loc='lower center',
    ncol=len(labels),
    bbox_to_anchor=(0.5, 0.92),
    bbox_transform=fig.transFigure,
    frameon=False
)

plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.savefig(os.path.join(REPORT_FIGURES_FOLDER, "figure3_cca_trajectories.png"))
plt.close()
logger.info("Saved mean CCA trajectories by stage.")

# Figure 4: Entropy Boxplots
fig, axs = plt.subplots(1, 2, figsize=(12, 5))
sns.boxplot(x='stage', y='cca_corr1_compute_entropy', data=entropy_by_subject_stage, ax=axs[0])
axs[0].set_title(r"Entropy: $\rho_1$")
axs[0].set_ylabel("Entropy")
axs[0].grid(alpha=0.3)
axs[0].set_ylim(0,3)
fig.text(0.05, 0.95, 'Panel A', ha='left', va='center', rotation='horizontal', fontsize=12, fontweight='bold')

sns.boxplot(x='stage', y='cca_corr2_compute_entropy', data=entropy_by_subject_stage, ax=axs[1])
axs[1].set_title(r"Entropy: $\rho_2$")
#axs[1].set_ylabel("Entropy")
axs[1].grid(alpha=0.3)
axs[1].set_ylim(0,3)
axs[1].yaxis.set_visible(False)
fig.text(0.55, 0.95, 'Panel B', ha='left', va='center', rotation='horizontal', fontsize=12, fontweight='bold')

plt.tight_layout()
plt.savefig(os.path.join(REPORT_FIGURES_FOLDER, "figure4_entropy_boxplots.png"))
plt.close()
logger.info("Saved entropy boxplots for CCA correlations.")
