In [None]:
# static_cca.py
from sklearn.cross_decomposition import CCA
from datetime import datetime, timedelta
from collections import Counter
import pandas as pd
import numpy as np
import mne
import os

# Parameters
data_folder = "data"
eeg_channels = ['C3_M2', 'C4_M1', 'O1_M2', 'O2_M1']
eog_channels = ['LOC', 'ROC']
valid_stages = ['W', 'N1', 'N2', 'N3', 'R']
fmt = "%H:%M:%S"
#t0 = datetime.strptime("23:02:00", fmt)  # Reference start time

# Initialize results list
summary_results = []

# Iterate through files
file_pairs = [(f, f.replace(".edf", ".annot")) for f in os.listdir(data_folder) if f.endswith(".edf")]

for edf_file, annot_file in file_pairs:
    edf_path = os.path.join(data_folder, edf_file)
    #if not edf_path.endswith("apples-170368.edf"):
    #    continue
    annot_path = os.path.join(data_folder, annot_file)

    if not os.path.exists(annot_path):
        continue

    try:
        raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
        sfreq = int(raw.info['sfreq'])
        start_datetime = raw.info['meas_date'].replace(tzinfo=None)
        print(f"EDF {edf_path} Start:", start_datetime)

        #print(f"{edf_file}: total samples = {raw.n_times}, sfreq = {sfreq}")

        #channel_names = raw.ch_names
        #n_channels = len(channel_names)
        #print(f"Number of channels: {n_channels}")
        #print(f"Channels avail: {channel_names}")

        # Read annotations
        with open(annot_path, "r", encoding="utf-8", errors="ignore") as f:
            lines = f.readlines()[1:]
        annotations = []
        starts = []
        for line in lines:
            parts = line.strip().split("\t")
            if len(parts) == 6:
                stage, _, _, start, stop, _ = parts
                annotations.append((stage, start, stop))

        parsed_epochs = []
        for stage, start_str, stop_str in annotations:
            if stage in valid_stages:
                try:
                    start_clock = datetime.strptime(start_str, fmt).time()
                    stop_clock = datetime.strptime(stop_str, fmt).time()

                    start_time = datetime.combine(start_datetime.date(), start_clock)
                    stop_time = datetime.combine(start_datetime.date(), stop_clock)

                    # Handle midnight crossing
                    if stop_time < start_time:
                        stop_time += timedelta(days=1)
                    if start_time < start_datetime:
                        start_time += timedelta(days=1)
                        stop_time += timedelta(days=1)

                    start_sec = (start_time - start_datetime).total_seconds()
                    stop_sec = (stop_time - start_datetime).total_seconds()

                    parsed_epochs.append((stage, start_sec, stop_sec))
                except Exception as e:
                    print(f"Annotation parse error: {e}")


        print('epochs parsed: ', parsed_epochs[:5])
        stage_counts = Counter([stage for stage, _, _ in parsed_epochs])
        print(f'Stage count for subject {edf_file}: {stage_counts}')

        # Organize EEG and EOG segments
        eeg_data = {stage: [] for stage in valid_stages}
        eog_data = {stage: [] for stage in valid_stages}
        for s, start, stop in parsed_epochs:
            #if s == "R" and edf_path.endswith("apples-170368.edf"):
            #    print(f"Stage R interval: {start}–{stop} sec → samples {int(start * sfreq)}–{int(stop * sfreq)}")

            #start_sample = int(start * sfreq)
            #stop_sample = min(int(stop * sfreq), raw.n_times)
            start_sample = round(start * sfreq)
            stop_sample = min(round(stop * sfreq), raw.n_times)

            if start_sample < 0 or start_sample >= stop_sample:
                print(f"Invalid sample range for stage {s}: start={start_sample}, stop={stop_sample}, total={raw.n_times}")
                continue

            try:
                eeg = raw.get_data(picks=eeg_channels, start=start_sample, stop=stop_sample)
                eog = raw.get_data(picks=eog_channels, start=start_sample, stop=stop_sample)
                eeg_data[s].append(eeg)
                eog_data[s].append(eog)
            except Exception as e:
                print(f"Failed to extract data: stage={s}, start={start_sample}, stop={stop_sample}, error: {e}")
                continue

        # Define downsampling factor (e.g. 1 Hz)
        target_fs = 1
        factor = int(sfreq / target_fs)

        # Perform CCA
        for stage in valid_stages:
            if not eeg_data[stage] or not eog_data[stage]:
                print(f"Skipping stage {stage}: no data available.")
                continue
            eeg_agg = np.hstack(eeg_data[stage])
            eog_agg = np.hstack(eog_data[stage])

            try:
                # Transpose to (n_samples, n_features)
                X = eeg_agg.T
                Y = eog_agg.T

                # Ensure same number of samples
                min_len = min(len(X), len(Y))
                X = X[:min_len]
                Y = Y[:min_len]

                cca = CCA(n_components=2)
                X_c, Y_c = cca.fit_transform(X, Y)

                # ---- Save downsampled canonical projections ----
                if len(X_c) > factor:
                    X_c_ds = X_c[::factor, :]
                    Y_c_ds = Y_c[::factor, :]
                else:
                    X_c_ds = X_c
                    Y_c_ds = Y_c

                file_prefix = f"{edf_file.replace('.edf','')}_{stage}"

                # Save downsampled projections
                pd.DataFrame(X_c_ds, columns=["Xc_1", "Xc_2"]).to_csv(
                    os.path.join(data_folder, f"{file_prefix}_Xc_downsampled.csv"), index=False
                )
                pd.DataFrame(Y_c_ds, columns=["Yc_1", "Yc_2"]).to_csv(
                    os.path.join(data_folder, f"{file_prefix}_Yc_downsampled.csv"), index=False
                )

                # Compute canonical correlation coefficients
                corr_coeffs = [np.corrcoef(X_c[:, i], Y_c[:, i])[0, 1] for i in range(X_c.shape[1])]
                
                # ---- Compute summary statistics ----
                summary = {
                    "subject": edf_file,
                    "stage": stage,
                    "cca_corr1": corr_coeffs[0],
                    "cca_corr2": corr_coeffs[1] if len(corr_coeffs) > 1 else np.nan
                }

                for idx in range(X_c.shape[1]):
                    x_vals = X_c[:, idx]
                    summary[f"Xc{idx+1}_mean"] = np.mean(x_vals)
                    summary[f"Xc{idx+1}_std"] = np.std(x_vals)
                    summary[f"Xc{idx+1}_25p"] = np.percentile(x_vals, 25)
                    summary[f"Xc{idx+1}_median"] = np.median(x_vals)
                    summary[f"Xc{idx+1}_75p"] = np.percentile(x_vals, 75)

                for idx in range(Y_c.shape[1]):
                    y_vals = Y_c[:, idx]
                    summary[f"Yc{idx+1}_mean"] = np.mean(y_vals)
                    summary[f"Yc{idx+1}_std"] = np.std(y_vals)
                    summary[f"Yc{idx+1}_25p"] = np.percentile(y_vals, 25)
                    summary[f"Yc{idx+1}_median"] = np.median(y_vals)
                    summary[f"Yc{idx+1}_75p"] = np.percentile(y_vals, 75)

                summary_results.append(summary)
                print(f'Summary resulst for edf {edf_file} for stage {stage} written')

            except Exception as e:
                print(f"CCA failed for {edf_file}, stage {stage}: {e}")
    except Exception as e:
        print(f"Failed on {edf_file}: {e}")
    #    
        raw._data = None  # Detach memory-mapped data if present
        raw.annotations.delete(0, len(raw.annotations))  # Clear MNE annotations    
        del raw, eeg_data, eog_data
        import gc
        gc.collect()    

# Save results
results_df = pd.DataFrame(summary_results)
results_csv_path = os.path.join(data_folder, "eeg_eog_cca_summary_stats.csv")
results_df.to_csv(results_csv_path, index=False)

print('Results')
print(results_df)


In [None]:
# time_resolved_cca.py
import os
import numpy as np
import pandas as pd
from datetime import datetime, timedelta
from collections import Counter
from sklearn.cross_decomposition import CCA

import mne

# Parameters
data_folder = "data"
output_folder = os.path.join(data_folder, "time_resolved_cca")
os.makedirs(output_folder, exist_ok=True)
eeg_channels = ['C3_M2', 'C4_M1', 'O1_M2', 'O2_M1']
eog_channels = ['LOC', 'ROC']
valid_stages = ['W', 'N1', 'N2', 'N3', 'R']
fmt = "%H:%M:%S"

win_len = 30  # seconds
step_len = 15  # seconds

file_pairs = [(f, f.replace(".edf", ".annot")) for f in os.listdir(data_folder) if f.endswith(".edf")]

for edf_file, annot_file in file_pairs:
    edf_path = os.path.join(data_folder, edf_file)
    annot_path = os.path.join(data_folder, annot_file)

    if not os.path.exists(annot_path):
        continue

    try:
        raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
        sfreq = int(raw.info['sfreq'])
        start_datetime = raw.info['meas_date'].replace(tzinfo=None)
        print(f"EDF {edf_path} Start:", start_datetime)

        # Read annotations
        with open(annot_path, "r", encoding="utf-8", errors="ignore") as f:
            lines = f.readlines()[1:]
        annotations = []
        for line in lines:
            parts = line.strip().split("\t")
            if len(parts) == 6:
                stage, _, _, start, stop, _ = parts
                annotations.append((stage, start, stop))

        parsed_epochs = []
        for stage, start_str, stop_str in annotations:
            if stage in valid_stages:
                try:
                    start_clock = datetime.strptime(start_str, fmt).time()
                    stop_clock = datetime.strptime(stop_str, fmt).time()

                    start_time = datetime.combine(start_datetime.date(), start_clock)
                    stop_time = datetime.combine(start_datetime.date(), stop_clock)

                    if stop_time < start_time:
                        stop_time += timedelta(days=1)
                    if start_time < start_datetime:
                        start_time += timedelta(days=1)
                        stop_time += timedelta(days=1)

                    start_sec = (start_time - start_datetime).total_seconds()
                    stop_sec = (stop_time - start_datetime).total_seconds()

                    parsed_epochs.append((stage, start_sec, stop_sec))
                except Exception as e:
                    print(f"Annotation parse error: {e}")

        stage_to_results = {stage: [] for stage in valid_stages}
        for stage, start, stop in parsed_epochs:
            t = start
            while t + win_len <= stop:
                start_sample = round(t * sfreq)
                stop_sample = round((t + win_len) * sfreq)

                if start_sample < 0 or stop_sample > raw.n_times:
                    t += step_len
                    continue

                try:
                    eeg = raw.get_data(picks=eeg_channels, start=start_sample, stop=stop_sample)
                    eog = raw.get_data(picks=eog_channels, start=start_sample, stop=stop_sample)
                    X = eeg.T
                    Y = eog.T

                    min_len = min(len(X), len(Y))
                    X = X[:min_len]
                    Y = Y[:min_len]

                    cca = CCA(n_components=2)
                    X_c, Y_c = cca.fit_transform(X, Y)
                    corr1 = np.corrcoef(X_c[:, 0], Y_c[:, 0])[0, 1]
                    corr2 = np.corrcoef(X_c[:, 1], Y_c[:, 1])[0, 1]

                    stage_to_results[stage].append({
                        "time_sec": t,
                        "cca_corr1": corr1,
                        "cca_corr2": corr2,
                        "subject": edf_file.replace(".edf", ""),
                        "stage": stage
                    })

                except Exception as e:
                    print(f"CCA failed in {edf_file} stage {stage} at t={t}: {e}")
                t += step_len

            for stage, results in stage_to_results.items():
                if results:
                    df_out = pd.DataFrame(results)
                    file_prefix = f"{edf_file.replace('.edf','')}_{stage}_cca_timeseries.csv"
                    df_out.to_csv(os.path.join(output_folder, file_prefix), index=False)
                    print(f"Saved CCA timeseries for {edf_file} stage {stage}")

    except Exception as e:
        print(f"Failed on {edf_file}: {e}")

    finally:
        raw._data = None
        del raw
        import gc
        gc.collect()


In [None]:
# static_cca_analyze_canonical_projections.py
from scipy.stats import skew, kurtosis, f_oneway
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import glob
import os

# === PARAMETERS ===
data_folder = "data"
output_folder = "static_cca_analysis"

# File patterns
pattern_Xc = os.path.join(data_folder, "static_cca", "*_Xc_downsampled.csv")
pattern_Yc = os.path.join(data_folder, "static_cca", "*_Yc_downsampled.csv")

# Helper to load projection CSVs
def load_projection_files(pattern, prefix):
    """
    Loads all matching CSVs into long-form DataFrame.
    """
    rows = []
    for filepath in glob.glob(pattern):
        filename = os.path.basename(filepath)
        
        # Example filename: apples-170368_N1_Xc_downsampled.csv
        name_parts = filename.replace("_Xc_downsampled.csv", "") \
                             .replace("_Yc_downsampled.csv", "").split("_")
        
        if len(name_parts) != 2:
            print(f"Skipping unexpected filename: {filename}")
            continue
        
        subject, stage = name_parts

        df = pd.read_csv(filepath)

        # Convert wide → long
        df_long = df.melt(var_name="projection", value_name="value")
        df_long["subject"] = subject
        df_long["stage"] = stage
        df_long["type"] = prefix

        rows.append(df_long)

    if rows:
        return pd.concat(rows, ignore_index=True)
    else:
        return pd.DataFrame(columns=["subject", "stage", "projection", "value", "type"])

# Load all downsampled files
df_Xc = load_projection_files(pattern_Xc, "Xc")
df_Yc = load_projection_files(pattern_Yc, "Yc")

all_data = pd.concat([df_Xc, df_Yc], ignore_index=True)
print(f"Loaded downsampled projections: {all_data.shape}")

if all_data.empty:
    print("No data loaded! Check file paths and patterns.")
    exit()

# Save combined file for convenience
all_data.to_csv(os.path.join(data_folder, "static_cca", "all_downsampled_projections.csv"), index=False)
print("Saved combined downsampled data.")

# --- Plot distributions per stage and projection
for proj_type in ["Xc", "Yc"]:
    for comp in ["1", "2"]:
        projection_name = f"{proj_type}_{comp}"

        df_plot = all_data[
            (all_data["type"] == proj_type) &
            (all_data["projection"] == projection_name)
        ]

        if df_plot.empty:
            print(f"No data for {projection_name}")
            continue

        # KDE Plot
        plt.figure(figsize=(10,5))
        sns.kdeplot(
            data=df_plot,
            x="value",
            hue="stage",
            common_norm=False,
            fill=True,
            alpha=0.3,
            linewidth=1.5
        )
        plt.title(f"Density of {projection_name} across stages")
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.savefig(os.path.join(data_folder, output_folder, f"{projection_name}_kde.png"))
        plt.close()
        print(f"Saved density plot for {projection_name}")

        # Boxplot
        plt.figure(figsize=(8,5))
        sns.boxplot(
            x="stage",
            y="value",
            data=df_plot,
            showmeans=True
        )
        plt.title(f"Boxplot of {projection_name} across stages")
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.savefig(os.path.join(data_folder, output_folder, f"{projection_name}_boxplot.png"))
        plt.close()
        print(f"Saved boxplot for {projection_name}")

# --- Compute summary statistics by stage and projection
summary_rows = []
for proj_type in ["Xc", "Yc"]:
    for comp in ["1", "2"]:
        projection_name = f"{proj_type}_{comp}"
        df_proj = all_data[
            (all_data["type"] == proj_type) &
            (all_data["projection"] == projection_name)
        ]
        if df_proj.empty:
            continue

        grouped = df_proj.groupby("stage")["value"]
        for stage, vals in grouped:
            summary_rows.append({
                "projection": projection_name,
                "stage": stage,
                "mean": np.mean(vals),
                "std": np.std(vals),
                "skewness": skew(vals),
                "kurtosis": kurtosis(vals),
                "count": len(vals)
            })

summary_df = pd.DataFrame(summary_rows)
summary_csv = os.path.join(data_folder, output_folder, "canonical_projection_summary_by_stage.csv")
summary_df.to_csv(summary_csv, index=False)
print(f"Saved summary stats to {summary_csv}")

# --- Run ANOVA across stages
for proj_type in ["Xc", "Yc"]:
    for comp in ["1", "2"]:
        projection_name = f"{proj_type}_{comp}"
        df_proj = all_data[
            (all_data["type"] == proj_type) &
            (all_data["projection"] == projection_name)
        ]
        if df_proj.empty:
            continue

        groups = [g["value"].values for _, g in df_proj.groupby("stage")]
        if len(groups) > 1:
            fval, pval = f_oneway(*groups)
            print(f"ANOVA for {projection_name}: F = {fval:.3f}, p = {pval:.3e}")


In [None]:
# stataic_cca_analyze_summary_stats.py
from scipy.stats import f_oneway
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import os

# === PARAMETERS ===
data_folder = "data"
output_folder = "static_cca_analysis"

# --- Load summary CSV
summary_path = os.path.join(data_folder, "static_cca", "eeg_eog_cca_summary_stats.csv")
summary_df = pd.read_csv(summary_path)

print("Summary stats loaded:", summary_df.shape)

# --- Plot distributions of CCA correlations
for var in ["cca_corr1", "cca_corr2"]:
    plt.figure(figsize=(8,5))
    sns.boxplot(x="stage", y=var, data=summary_df, showmeans=True)
    plt.title(f"Boxplot of {var} across sleep stages")
    plt.ylabel(var)
    plt.grid(alpha=0.3)
    plt.tight_layout()
    plot_path = os.path.join(data_folder, output_folder, f"{var}_boxplot_by_stage.png")
    plt.savefig(plot_path)
    plt.close()
    print(f"Saved boxplot for {var}")

# --- Run ANOVA on CCA correlations
for var in ["cca_corr1", "cca_corr2"]:
    groups = []
    for stage in summary_df["stage"].unique():
        vals = summary_df.loc[summary_df["stage"]==stage, var].dropna().values
        if len(vals) > 1:
            groups.append(vals)
    if len(groups) > 1:
        stat, pval = f_oneway(*groups)
        print(f"ANOVA for {var}: F={stat:.3f}, p={pval:.3e}")
    else:
        print(f"Not enough groups for ANOVA on {var}")

# --- Aggregate mean ± std for CCA correlations
agg = summary_df.groupby("stage")[["cca_corr1", "cca_corr2"]].agg(["mean", "std", "count"])
agg.columns = ['_'.join(col) for col in agg.columns]
agg.reset_index(inplace=True)
agg.to_csv(os.path.join(data_folder, output_folder, "cca_correlation_summary.csv"), index=False)
print("Saved aggregated summary CSV for cca_corr1 and cca_corr2.")

# --- Optional: Analyze canonical projection means
projection_cols = [c for c in summary_df.columns if ("Xc" in c or "Yc" in c) and "_mean" in c]

for var in projection_cols:
    plt.figure(figsize=(8,5))
    sns.boxplot(x="stage", y=var, data=summary_df, showmeans=True)
    plt.title(f"Boxplot of {var} across sleep stages")
    plt.ylabel(var)
    plt.grid(alpha=0.3)
    plt.tight_layout()
    plot_path = os.path.join(data_folder, output_folder, f"{var}_boxplot_by_stage.png")
    plt.savefig(plot_path)
    plt.close()
    print(f"Saved boxplot for {var}")


In [None]:
# time_resolved_cca_analysis.py
from scipy.stats import entropy, skew, kurtosis
from glob import glob
import pandas as pd
import numpy as np
import os

# Define path to CCA timeseries files
data_folder = "data/time_resolved_cca"
output_folder = "data/time_resolved_cca_analysis"
os.makedirs(output_folder, exist_ok=True)

# Load all *_cca_timeseries.csv files
all_files = glob(os.path.join(data_folder, "*_cca_timeseries.csv"))
aggregated_data = pd.concat([pd.read_csv(f) for f in all_files], ignore_index=True)

# Function 1: Stagewise mean and std of CCA1 and CCA2
stagewise_stats = aggregated_data.groupby("stage")[["cca_corr1", "cca_corr2"]].agg(["mean", "std", "count"])
stagewise_stats.columns = ['_'.join(col).strip() for col in stagewise_stats.columns.values]
stagewise_stats.reset_index(inplace=True)
stagewise_stats.to_csv(os.path.join(output_folder, "stagewise_summary.csv"), index=False)

# Function 2: Compute temporal mean trajectories (binned over time)
aggregated_data["time_bin"] = pd.cut(aggregated_data["time_sec"], bins=np.arange(0, aggregated_data["time_sec"].max() + 600, 600))
trajectory = aggregated_data.groupby(["stage", "time_bin"])[["cca_corr1", "cca_corr2"]].mean().reset_index()
trajectory.to_csv(os.path.join(output_folder, "mean_cca_trajectory_by_stage.csv"), index=False)

# Function 3: Compute subjectwise entropy of CCA1 and CCA2 per stage
def compute_entropy(x):
    hist, _ = np.histogram(x, bins=20, range=(0, 1), density=True)
    return entropy(hist + 1e-12)  # add small value to avoid log(0)

entropy_stats = (
    aggregated_data.groupby(["subject", "stage"])
    .agg({
        "cca_corr1": [compute_entropy, np.mean, np.std, skew, kurtosis],
        "cca_corr2": [compute_entropy, np.mean, np.std, skew, kurtosis]
    })
)
entropy_stats.columns = ['_'.join(col).strip() for col in entropy_stats.columns.values]
entropy_stats.reset_index(inplace=True)
entropy_stats.to_csv(os.path.join(output_folder, "entropy_by_subject_stage.csv"), index=False)

# Function 4: Save a few representative trajectories (handpicked or automatic later)
sampled_subjects = aggregated_data["subject"].drop_duplicates().sample(3, random_state=42).tolist()
subset = aggregated_data[aggregated_data["subject"].isin(sampled_subjects)]
subset.to_csv(os.path.join(output_folder, "subset_trajectories.csv"), index=False)

In [None]:
# time_resolved_cca_plotting_groupped.py
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import os

# Load the files
stagewise_summary = pd.read_csv("data/time_resolved_cca_analysis/stagewise_summary.csv")
mean_cca_trajectory_by_stage = pd.read_csv("data/time_resolved_cca_analysis/mean_cca_trajectory_by_stage.csv")
entropy_by_subject_stage = pd.read_csv("data/time_resolved_cca_analysis/entropy_by_subject_stage.csv")
subset_trajectories = pd.read_csv("data/time_resolved_cca_analysis/subset_trajectories.csv")

# Prepare to generate figures and save them
figures_folder = "data/time_resolved_cca_analysis/figures"
os.makedirs(figures_folder, exist_ok=True)

# 1. Boxplot of cca_corr1 and cca_corr2 by stage
plt.figure(figsize=(10, 6))
sns.boxplot(x='stage', y='cca_corr1', data=subset_trajectories)
plt.title('CCA Corr1 Distribution by Sleep Stage (Time-resolved)')
plt.savefig(os.path.join(figures_folder, "boxplot_cca_corr1_by_stage.png"))
plt.close()

plt.figure(figsize=(10, 6))
sns.boxplot(x='stage', y='cca_corr2', data=subset_trajectories)
plt.title('CCA Corr2 Distribution by Sleep Stage (Time-resolved)')
plt.savefig(os.path.join(figures_folder, "boxplot_cca_corr2_by_stage.png"))
plt.close()

# 2. Lineplot of mean trajectories per stage over time
plt.figure(figsize=(12, 6))
for stage in mean_cca_trajectory_by_stage['stage'].unique():
    stage_data = mean_cca_trajectory_by_stage[mean_cca_trajectory_by_stage['stage'] == stage]
    plt.plot(stage_data.index, stage_data['cca_corr1'], label=f"{stage} - CCA1")
plt.title("Mean CCA Corr1 Trajectory Over Time by Stage")
plt.xlabel("Time Bin Index")
plt.ylabel("CCA Corr1")
plt.legend()
plt.savefig(os.path.join(figures_folder, "trajectory_cca_corr1_by_stage.png"))
plt.close()

plt.figure(figsize=(12, 6))
for stage in mean_cca_trajectory_by_stage['stage'].unique():
    stage_data = mean_cca_trajectory_by_stage[mean_cca_trajectory_by_stage['stage'] == stage]
    plt.plot(stage_data.index, stage_data['cca_corr2'], label=f"{stage} - CCA2")
plt.title("Mean CCA Corr2 Trajectory Over Time by Stage")
plt.xlabel("Time Bin Index")
plt.ylabel("CCA Corr2")
plt.legend()
plt.savefig(os.path.join(figures_folder, "trajectory_cca_corr2_by_stage.png"))
plt.close()

# 3. Entropy values per stage for CCA1 and CCA2
plt.figure(figsize=(10, 6))
sns.boxplot(x='stage', y='cca_corr1_compute_entropy', data=entropy_by_subject_stage)
plt.title('CCA Corr1 Entropy by Sleep Stage')
plt.savefig(os.path.join(figures_folder, "entropy_cca_corr1_by_stage.png"))
plt.close()

plt.figure(figsize=(10, 6))
sns.boxplot(x='stage', y='cca_corr2_compute_entropy', data=entropy_by_subject_stage)
plt.title('CCA Corr2 Entropy by Sleep Stage')
plt.savefig(os.path.join(figures_folder, "entropy_cca_corr2_by_stage.png"))
plt.close()


In [None]:
# generate_figures_report.py
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# === Parameters ===
data_folder = "data/time_resolved_cca_analysis"
summary_folder = "data"
output_folder = "report/figs"
os.makedirs(output_folder, exist_ok=True)

# === Load Data ===
summary_df = pd.read_csv(os.path.join(summary_folder, "static_cca", "eeg_eog_cca_summary_stats.csv"))
subset_trajectories = pd.read_csv(os.path.join(data_folder, "subset_trajectories.csv"))
mean_cca_trajectory_by_stage = pd.read_csv(os.path.join(data_folder, "mean_cca_trajectory_by_stage.csv"))
entropy_by_subject_stage = pd.read_csv(os.path.join(data_folder, "entropy_by_subject_stage.csv"))

# === Figure 1: Static CCA Boxplots ===
fig, axs = plt.subplots(1, 2, figsize=(12, 5))
sns.boxplot(x="stage", y="cca_corr1", data=summary_df, showmeans=True, ax=axs[0])
axs[0].set_title("Static CCA: cca_corr1")
axs[0].set_ylabel("Correlation")
axs[0].grid(alpha=0.3)
axs[0].set_ylim(0,1)
fig.text(0.05, 0.95, 'Panel A', ha='left', va='center', rotation='horizontal', fontsize=12)

sns.boxplot(x="stage", y="cca_corr2", data=summary_df, showmeans=True, ax=axs[1])
axs[1].set_title("Static CCA: cca_corr2")
#axs[1].set_ylabel("Correlation")
axs[1].grid(alpha=0.3)
axs[1].set_ylim(0,1)
axs[1].yaxis.set_visible(False)
fig.text(0.55, 0.95, 'Panel B', ha='left', va='center', rotation='horizontal', fontsize=12)

plt.tight_layout()
plt.savefig(os.path.join(output_folder, "figure1_static_cca_boxplots.png"))
plt.close()

# === Figure 2: Time-resolved CCA Boxplots ===
fig, axs = plt.subplots(1, 2, figsize=(12, 5))
sns.boxplot(x='stage', y='cca_corr1', data=subset_trajectories, ax=axs[0])
axs[0].set_title('Time-Resolved CCA: cca_corr1')
axs[0].set_ylabel("Correlation")
axs[0].grid(alpha=0.3)
axs[0].set_ylim(0,1)
fig.text(0.05, 0.95, 'Panel A', ha='left', va='center', rotation='horizontal', fontsize=12)

sns.boxplot(x='stage', y='cca_corr2', data=subset_trajectories, ax=axs[1])
axs[1].set_title('Time-Resolved CCA: cca_corr2')
#axs[1].set_ylabel("Correlation")
axs[1].grid(alpha=0.3)
axs[1].set_ylim(0,1)
axs[1].yaxis.set_visible(False)
fig.text(0.55, 0.95, 'Panel B', ha='left', va='center', rotation='horizontal', fontsize=12)

plt.tight_layout()
plt.savefig(os.path.join(output_folder, "figure2_time_resolved_boxplots.png"))
plt.close()

# === Figure 3: Mean Trajectories by Stage ===
fig, axs = plt.subplots(1, 2, figsize=(14, 5))
for stage in mean_cca_trajectory_by_stage['stage'].unique():
    data = mean_cca_trajectory_by_stage[mean_cca_trajectory_by_stage['stage'] == stage]
    axs[0].plot(data.index, data['cca_corr1'], label=stage)
    axs[1].plot(data.index, data['cca_corr2'], label=stage)

axs[0].set_title("Mean Trajectory: cca_corr1")
axs[0].set_xlabel("Time Bin Index")
axs[0].set_ylabel("Correlation")
#axs[0].legend()
axs[0].grid(alpha=0.3)
axs[0].set_ylim(0,1)
fig.text(0.05, 0.9, 'Panel A', ha='left', va='center', rotation='horizontal', fontsize=12)

axs[1].set_title("Mean Trajectory: cca_corr2")
axs[1].set_xlabel("Time Bin Index")
#axs[1].set_ylabel("Correlation")
#axs[1].legend()
axs[1].grid(alpha=0.3)
axs[1].set_ylim(0,1)
axs[1].yaxis.set_visible(False)
fig.text(0.55, 0.9, 'Panel B', ha='left', va='center', rotation='horizontal', fontsize=12)

handles, labels = axs[0].get_legend_handles_labels()

fig.legend(
    handles, labels,
    loc='lower center',
    ncol=len(labels),
    bbox_to_anchor=(0.5, 0.92),
    bbox_transform=fig.transFigure,
    frameon=False
)

# Leave extra space at bottom for legend
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.savefig(os.path.join(output_folder, "figure3_cca_trajectories.png"))
plt.close()

# === Figure 4: Entropy Boxplots ===
fig, axs = plt.subplots(1, 2, figsize=(12, 5))
sns.boxplot(x='stage', y='cca_corr1_compute_entropy', data=entropy_by_subject_stage, ax=axs[0])
axs[0].set_title("Entropy: cca_corr1")
axs[0].set_ylabel("Entropy")
axs[0].grid(alpha=0.3)
axs[0].set_ylim(0,3)
fig.text(0.05, 0.95, 'Panel A', ha='left', va='center', rotation='horizontal', fontsize=12)

sns.boxplot(x='stage', y='cca_corr2_compute_entropy', data=entropy_by_subject_stage, ax=axs[1])
axs[1].set_title("Entropy: cca_corr2")
#axs[1].set_ylabel("Entropy")
axs[1].grid(alpha=0.3)
axs[1].set_ylim(0,3)
axs[1].yaxis.set_visible(False)
fig.text(0.55, 0.95, 'Panel B', ha='left', va='center', rotation='horizontal', fontsize=12)

plt.tight_layout()
plt.savefig(os.path.join(output_folder, "figure4_entropy_boxplots.png"))
plt.close()
