# Step 1: Calculating the longest consecutive run of EC epochs for each subject

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import mne
from mne.time_frequency import psd_array_welch
import numpy as np

In [None]:
# Load the ensemble prediction CSV
df = pd.read_csv("E:/ChristianMusaeus/label_predictions.csv")

# Store longest runs
longest_runs = {}

for subject_id in df["Test subject ID"].unique():
    subj_df = df[df["Test subject ID"] == subject_id].sort_values(by="Epoch number").reset_index(drop=True)
    labels = subj_df["Label"].astype(int).values
    epochs = subj_df["Epoch number"].values

    max_run = (None, None)
    current_start = None

    for i, label in enumerate(labels):
        if label == 1:
            if current_start is None:
                current_start = i
        else:
            if current_start is not None:
                run_length = i - current_start
                if max_run[0] is None or run_length > (max_run[1] - max_run[0]):
                    max_run = (current_start, i - 1)
                current_start = None
    if current_start is not None:
        run_length = len(labels) - current_start
        if max_run[0] is None or run_length > (max_run[1] - max_run[0]):
            max_run = (current_start, len(labels) - 1)

    if max_run[0] is not None:
        start_idx, end_idx = max_run
        run_length = end_idx - start_idx + 1
        longest_runs[subject_id] = run_length


### Plotting the longest consecutive run against the number of subjects

In [None]:
# Convert to Series
run_lengths = pd.Series(longest_runs)

# Count frequency of each run length
length_counts = run_lengths.value_counts().sort_index()

# Plot histogram
plt.figure(figsize=(10, 5))
plt.bar(length_counts.index, length_counts.values)
plt.xlabel("Length of Longest Consecutive Eyes-Closed Run (Epochs)")
plt.ylabel("Number of Test Subjects")
plt.title("Distribution of Longest Eyes-Closed Runs Across Subjects")
plt.grid(True, axis='y')
plt.tight_layout()
plt.show()


# Step 2: Extracting 60 epochs from each subject

### For each subject, the longest consecutive runs of eyes closed epochs are calculated. The epochs in these consecutive runs, are   added to a list of epochs for each subject. If there are no consecutive runs left, the EC epochs with the highest probability is added to the list, to make sure there are 60 eyes closed epochs for each subject. If the longest consecutive run of eyes closed epochs is below 6, the subject is ignored. The subjects are stored as keys in a dictionary and the list of the 60 chosen epochs is stored as the value. 

In [None]:
# Function to extract top 60 eyes-closed epochs from long runs
def get_top_60_consecutive_closed_epochs(subj_df):
    subj_df = subj_df.sort_values(by="Epoch number").reset_index(drop=True)
    labels = subj_df["Label"].astype(int).values
    probs = subj_df["Probability"].values
    epoch_nums = subj_df["Epoch number"].values

    runs = []
    start = None
    for i, label in enumerate(labels):
        if label == 1:
            if start is None:
                start = i
        else:
            if start is not None:
                runs.append((start, i - 1))
                start = None
    if start is not None:
        runs.append((start, len(labels) - 1))

    runs.sort(key=lambda r: r[1] - r[0] + 1, reverse=True)

    selected_epochs = []

    for run in runs:
        i, j = run
        run_indices = np.arange(i, j + 1)
        run_probs = probs[run_indices]
        sorted_indices = run_indices[np.argsort(run_probs)[::-1]]

        for idx in sorted_indices:
            selected_epochs.append(epoch_nums[idx])
            if len(selected_epochs) == 60:
                return np.array(selected_epochs)

    return np.array(selected_epochs)

# Dictionary to store valid subjects and their selected epochs
top_epochs_per_subject = {}

# Loop through each subject and apply filtering + extraction
for subject_id in df["Test subject ID"].unique():
    subj_df = df[df["Test subject ID"] == subject_id].sort_values(by="Epoch number").reset_index(drop=True)
    labels = subj_df["Label"].astype(int).values

    # Compute longest run of consecutive 1s
    max_run_length = 0
    current_length = 0
    for label in labels:
        if label == 1:
            current_length += 1
            max_run_length = max(max_run_length, current_length)
        else:
            current_length = 0

    # Skip subject if longest run < 6
    if max_run_length < 6:
        continue

    # Extract top 60 confident eyes-closed epochs
    top_epochs = get_top_60_consecutive_closed_epochs(subj_df)

    if len(top_epochs) < 60:
        print(f"⚠️ Only {len(top_epochs)} valid epochs for {subject_id}, skipping...")
        continue

    top_epochs_per_subject[subject_id] = top_epochs


import pickle

with open("top_epochs_per_subject.pkl", "wb") as f:
    pickle.dump(top_epochs_per_subject, f)

print(f"saved top_epochs_per_subject to top_epochs_per_subject.pkl")


# Step 3: Calculate the mean absolute alpha power for each of the subjects for all channels 

In [None]:
alpha_by_subject = {}

for subject_id in list(top_epochs_per_subject.keys()):
    print(f"Processing alpha for {subject_id}...")

    set_path = f"G:/ChristianMusaeus/Preprocessed_setfiles/{subject_id}_epoched.set"

    try:
        epochs = mne.io.read_epochs_eeglab(set_path, verbose='ERROR')
        selected_data = epochs.get_data()[top_epochs_per_subject[subject_id]]  # (60, channels, timepoints)

        psds, freqs = psd_array_welch(
            selected_data,
            sfreq=epochs.info["sfreq"],
            fmin=8, fmax=13,
            n_fft=200,
            verbose=False
        )
        mean_power_per_channel = psds.mean(axis=-1)
        mean_power_over_epochs = mean_power_per_channel.mean(axis = 0)
        alpha_value = mean_power_over_epochs.mean() * 1e12

        alpha_by_subject[subject_id] = alpha_value

    except Exception as e:
        print(f"❌ Could not process {subject_id}: {e}")


# Calculating average relative alpha power for each subject for all channels

In [None]:
import mne
from mne.time_frequency import psd_array_welch
import numpy as np

relative_alpha_by_subject = {}

for subject_id in list(top_epochs_per_subject.keys()):
    print(f"Processing relative alpha for {subject_id}...")

    set_path = f"G:/ChristianMusaeus/Preprocessed_setfiles/{subject_id}_epoched.set"

    try:
        epochs = mne.io.read_epochs_eeglab(set_path, verbose='ERROR')
        selected_data = epochs.get_data()[top_epochs_per_subject[subject_id]]  # (epochs, channels, timepoints)

        # --- Alpha power (8–13 Hz) ---
        psds_alpha, _ = psd_array_welch(
            selected_data,
            sfreq=epochs.info["sfreq"],
            fmin=8, fmax=13,
            n_fft=200,
            verbose=False
        )
        alpha_power = psds_alpha.sum(axis=-1).mean()

        # --- Total power (1–40 Hz) ---
        psds_total, _ = psd_array_welch(
            selected_data,
            sfreq=epochs.info["sfreq"],
            fmin=1, fmax=40,
            n_fft=200,
            verbose=False
        )
        total_power = psds_total.sum(axis=-1).mean()

        # --- Relative alpha power ---
        relative_alpha = alpha_power / total_power
        relative_alpha_by_subject[subject_id] = relative_alpha

    except Exception as e:
        print(f"❌ Could not process {subject_id}: {e}")


# Step 4: Plotting mean absolute alpha power for all subjects and channels across age

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# Load subject metadata (must contain 'subject_id' and 'age')
metadata = pd.read_csv("E:/ChristianMusaeus/metadata_time_filtered.csv")

# Create DataFrame of alpha powers
alpha_df = pd.DataFrame({
    "subject_id": list(alpha_by_subject.keys()),
    "Alpha Power": list(alpha_by_subject.values())
})

# Merge with metadata to get ages
merged = pd.merge(alpha_df, metadata, on="subject_id")

# --- Filter out extreme values ---
merged = merged[merged["Alpha Power"] <= 40]

# Group by age and average
grouped = merged.groupby("age").agg(
    MeanAlpha=("Alpha Power", "mean"),
    N=("Alpha Power", "count")
).reset_index()

# Plot (dots only)
plt.figure(figsize=(16, 5))
plt.scatter(grouped["age"], grouped["MeanAlpha"], s=40)
plt.xlabel("Age")
plt.ylabel("Mean Absolute Alpha Power (µV²/Hz)")
plt.title("Mean Absolute Alpha Power vs. Age")
plt.grid(True)
plt.xticks(np.arange(10, 100, 10))
plt.tight_layout()
plt.show()


### Distribution of age and sex of all subjects 

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# Group by age and sex, and count the number of subjects
grouped = metadata.groupby(['age', 'sex']).size().unstack(fill_value=0).reset_index()


In [None]:
metadata_path = "metadata_time_filtered.csv"
metadata = pd.read_csv(metadata_path)
metadata["sex"] = metadata["sex"].astype(str).str.strip()

with open("top_epochs_per_subject.pkl", "rb") as f:
    top_epochs_per_subject = pickle.load(f)

top_ids = set(str(k).strip() for k in top_epochs_per_subject.keys())
filtered_metadata = metadata[metadata["subject_id"].isin(top_ids)]
grouped = metadata.groupby(['age', 'sex']).size().unstack(fill_value=0).reset_index()


bar_width = 0.4
ages = grouped["age"]
x = np.arange(len(ages))

plt.figure(figsize=(16, 5))
plt.bar(x - bar_width/2, grouped["Female"], width=bar_width, label='Female', color='skyblue', edgecolor='black')
plt.bar(x + bar_width/2, grouped["Male"], width=bar_width, label='Male', color='green', edgecolor='black')
plt.xlabel("Age")
plt.ylabel("Number of Subjects")
plt.title("Number of Subjects per Age by Sex")
plt.xticks(x, ages)
plt.legend()
plt.grid(axis='y')

xticks = np.arange(grouped["age"].min(), grouped["age"].max()+1, 5)
plt.xticks(ticks=xticks)

plt.tight_layout()
plt.show()



### Number of subjects per sex

In [None]:
metadata_path = "metadata_time_filtered.csv"
metadata = pd.read_csv(metadata_path)

metadata["subject_id"] = metadata["subject_id"].astype(str).str.strip()
metadata["sex"] = metadata["sex"].astype(str).str.strip()

with open("top_epochs_per_subject.pkl", "rb") as f:
    top_epochs_per_subject = pickle.load(f)

top_ids = set(str(k).strip() for k in top_epochs_per_subject.keys())
filtered_metadata = metadata[metadata["subject_id"].isin(top_ids)]

gender_counts = filtered_metadata["sex"].value_counts()

print(gender_counts)

# Mean relative alpha power across age 

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# Load subject metadata (must contain 'subject_id' and 'age')
metadata = pd.read_csv("E:/ChristianMusaeus/metadata_time_filtered.csv")

# Create DataFrame of alpha powers
relative_df = pd.DataFrame({
    "subject_id": list(relative_alpha_by_subject.keys()),
    "Relative Alpha Power": list(relative_alpha_by_subject.values())
})

# Merge with metadata to get ages
metadata["subject_id"] = metadata["subject_id"].astype(str)
relative_df["subject_id"] = relative_df["subject_id"].astype(str)
merged = pd.merge(relative_df, metadata, on="subject_id")

# --- Filter out extreme values ---
merged = merged[merged["Relative Alpha Power"] <= 40]

# Group by age and average
grouped = merged.groupby("age").agg(
    MeanRelAlpha=("Relative Alpha Power", "mean"),
    N=("Relative Alpha Power", "count")
).reset_index()

# Plot (dots only)
plt.figure(figsize=(16, 5))
plt.scatter(grouped["age"], grouped["MeanRelAlpha"], s=40)
plt.xlabel("Age")
plt.ylabel("Mean Relative Alpha Power (µV²/Hz)")
plt.title("Mean Relative Alpha Power vs. Age")
plt.grid(True)
plt.xticks(np.arange(10, 100, 10))
plt.tight_layout()
plt.show()


## Mean relative alpha power for all channels across age with CI

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# --- Load and merge ---
metadata = pd.read_csv("E:/ChristianMusaeus/metadata_time_filtered.csv")
alpha_df = pd.DataFrame({
    "subject_id": list(alpha_by_subject.keys()),
    "Alpha Power": list(alpha_by_subject.values())
})
merged = pd.merge(alpha_df, metadata, on="subject_id")
merged = merged[merged["Alpha Power"] <= 40]  # Filter out extreme values

# --- Group by age and compute stats ---
grouped = merged.groupby("age").agg(
    MeanAlpha=("Alpha Power", "mean"),
    StdAlpha=("Alpha Power", "std"),
    N=("Alpha Power", "count")
).reset_index()
grouped["SEM"] = grouped["StdAlpha"] / np.sqrt(grouped["N"])
grouped["CI95"] = 1.96 * grouped["SEM"]
grouped["Lower"] = grouped["MeanAlpha"] - grouped["CI95"]
grouped["Upper"] = grouped["MeanAlpha"] + grouped["CI95"]

# --- Plot with shaded confidence bands ---
plt.figure(figsize=(16, 5))

# Scatter raw mean points
plt.scatter(grouped["age"], grouped["MeanAlpha"], color="blue", s=30, label="Mean Absolute Alpha Power")

# Line plot of the mean
plt.plot(grouped["age"], grouped["MeanAlpha"], color="red", label="Mean Trend")

# Shaded 95% CI area
plt.fill_between(grouped["age"], grouped["Lower"], grouped["Upper"],
                 color='skyblue', alpha=0.4, label="95% Confidence Interval")

plt.xlabel("Age")
plt.ylabel("Mean Absolute Alpha Power (µV²/Hz)")
plt.title("Mean Absolute Alpha Power vs. Age with 95% CI (Shaded)")
plt.grid(True)
plt.xticks(np.arange(10, 100, 10))
plt.legend()
plt.tight_layout()
plt.show()


## Mean relative alpha power for all channels across age with CI

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# --- Load and merge ---
metadata = pd.read_csv("E:/ChristianMusaeus/metadata_time_filtered.csv")
relative_df = pd.DataFrame({
    "subject_id": list(relative_alpha_by_subject.keys()),
    "Relative Alpha Power": list(relative_alpha_by_subject.values())
})
merged = pd.merge(relative_df, metadata, on="subject_id")
merged = merged[merged["Relative Alpha Power"] <= 20]  # Filter out extreme values

# --- Group by age and compute stats ---
grouped = merged.groupby("age").agg(
    MeanRelAlpha=("Relative Alpha Power", "mean"),
    StdAlpha=("Relative Alpha Power", "std"),
    N=("Relative Alpha Power", "count")
).reset_index()
grouped["SEM"] = grouped["StdAlpha"] / np.sqrt(grouped["N"])
grouped["CI95"] = 1.96 * grouped["SEM"]
grouped["Lower"] = grouped["MeanRelAlpha"] - grouped["CI95"]
grouped["Upper"] = grouped["MeanRelAlpha"] + grouped["CI95"]

# --- Plot with shaded confidence bands ---
plt.figure(figsize=(16, 5))

# Scatter raw mean points
plt.scatter(grouped["age"], grouped["MeanRelAlpha"], color="blue", s=30, label="Mean Relative Alpha Power")

# Line plot of the mean
plt.plot(grouped["age"], grouped["MeanRelAlpha"], color="red", label="Mean Trend")

# Shaded 95% CI area
plt.fill_between(grouped["age"], grouped["Lower"], grouped["Upper"],
                 color='skyblue', alpha=0.4, label="95% Confidence Interval")

plt.xlabel("Age")
plt.ylabel("Mean Relative Alpha Power (µV²/Hz)")
plt.title("Mean Relative Alpha Power vs. Age with 95% CI (Shaded)")
plt.grid(True)
plt.xticks(np.arange(10, 100, 10))
plt.legend()
plt.tight_layout()
plt.show()


## Plotting absolute alpha power for all channels with CI (Smoothed)

In [None]:
from scipy.interpolate import make_interp_spline
import numpy as np
import matplotlib.pyplot as plt

grouped_clean = grouped.replace([np.inf, -np.inf], np.nan).dropna(subset=["age", "MeanAlpha","Lower", "Upper"])

# Original values
x = grouped_clean["age"].values
y = grouped_clean["MeanAlpha"].values
y_lower = grouped_clean["Lower"].values
y_upper = grouped_clean["Upper"].values

# Smooth x-axis
x_smooth = np.linspace(x.min(), x.max(), 1000)

# Spline smoothing
spline_mean = make_interp_spline(x, y, k=5)
spline_lower = make_interp_spline(x, y_lower, k=5)
spline_upper = make_interp_spline(x, y_upper, k=5)

y_smooth = spline_mean(x_smooth)
y_lower_smooth = spline_lower(x_smooth)
y_upper_smooth = spline_upper(x_smooth)

# --- Plot ---
plt.figure(figsize=(12, 5))
plt.scatter(x, y, color='blue', s=20, label="Mean Alpha Power")
plt.plot(x_smooth, y_smooth, color='blue', label="Smoothed Mean")
plt.fill_between(x_smooth, y_lower_smooth, y_upper_smooth,
                 color='skyblue', alpha=0.30, label="95% CI")

plt.xlabel("Age")
plt.ylabel("Mean Absolute Alpha Power (µV²/Hz)")
plt.title("Mean Absolute Alpha Power vs. Age with 95% Confidence Interval")
plt.grid(True)
plt.xticks(np.arange(10, 100, 10))
plt.legend()
plt.tight_layout()
plt.show()


## Plotting mean relative apha power for all channels with CI (Smoothed)

In [None]:
from scipy.interpolate import make_interp_spline
import numpy as np
import matplotlib.pyplot as plt

grouped_clean = grouped.replace([np.inf, -np.inf], np.nan).dropna(subset=["age", "MeanRelAlpha","Lower", "Upper"])

# Original values
x = grouped_clean["age"].values
y = grouped_clean["MeanRelAlpha"].values
y_lower = grouped_clean["Lower"].values
y_upper = grouped_clean["Upper"].values

# Smooth x-axis
x_smooth = np.linspace(x.min(), x.max(), 1000)

# Spline smoothing
spline_mean = make_interp_spline(x, y, k=5)
spline_lower = make_interp_spline(x, y_lower, k=5)
spline_upper = make_interp_spline(x, y_upper, k=5)

y_smooth = spline_mean(x_smooth)
y_lower_smooth = spline_lower(x_smooth)
y_upper_smooth = spline_upper(x_smooth)

# --- Plot ---
plt.figure(figsize=(12, 5))
plt.scatter(x, y, color='blue', s=20, label="Mean Relative Alpha Power")
plt.plot(x_smooth, y_smooth, color='blue', label="Smoothed Mean")
plt.fill_between(x_smooth, y_lower_smooth, y_upper_smooth,
                 color='skyblue', alpha=0.30, label="95% CI")

plt.xlabel("Age")
plt.ylabel("Mean Relative Alpha Power (µV²/Hz)")
plt.title("Mean Relative Alpha Power vs. Age with 95% Confidence Interval")
plt.grid(True)
plt.xticks(np.arange(10, 100, 10))
plt.legend()
plt.tight_layout()
plt.show()


# Plotting mean absolute alpha power across age by sex 

In [None]:
grouped = merged.groupby(["age", "sex"]).agg(
    MeanAlpha=("Alpha Power", "mean"),
    N=("Alpha Power", "count")
).reset_index()

female_data = grouped[grouped["sex"] == "Female"]
male_data = grouped[grouped["sex"] == "Male"]

# Plot (dots only)
plt.figure(figsize=(16, 5))
plt.scatter(female_data["age"], female_data["MeanAlpha"], marker='o', label="Female", color="blue")
plt.scatter(male_data["age"], male_data["MeanAlpha"], marker='o', label="Male", color="green")

plt.xlabel("Age")
plt.ylabel("Mean Absolute Alpha Power (µV²/Hz)")
plt.title("Mean Absolute Alpha Power vs. Age by sex")
plt.grid(True)
plt.xticks(np.arange(10, 100, 10))
plt.legend()
plt.tight_layout()
plt.show()


# Plotting mean relative alpha power across age by sex

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# --- Load metadata ---
metadata = pd.read_csv("E:/ChristianMusaeus/metadata_time_filtered.csv")

# --- Create DataFrame of relative alpha power ---
relative_df = pd.DataFrame({
    "subject_id": list(relative_alpha_by_subject.keys()),
    "Relative Alpha Power": list(relative_alpha_by_subject.values())
})

# --- Ensure consistent types ---
metadata["subject_id"] = metadata["subject_id"].astype(str)
relative_df["subject_id"] = relative_df["subject_id"].astype(str)

# --- Merge to get age and sex ---
merged = pd.merge(relative_df, metadata, on="subject_id")
# Rename if needed depending on your column order:


# --- Optional: Filter out invalid relative power (should be 0–1 normally) ---
merged = merged[merged["Relative Alpha Power"] <= 1]

# --- Group by age and sex ---
grouped = merged.groupby(["age", "sex"]).agg(
    MeanRelAlpha=("Relative Alpha Power", "mean"),
    N=("Relative Alpha Power", "count")
).reset_index()

# --- Split by sex ---
female_data = grouped[grouped["sex"] == "Female"]
male_data = grouped[grouped["sex"] == "Male"]

# --- Plot ---
plt.figure(figsize=(16, 5))
plt.scatter(female_data["age"], female_data["MeanRelAlpha"], label="Female", color="blue", s=40)
plt.scatter(male_data["age"], male_data["MeanRelAlpha"], label="Male", color="green", s=40)

plt.xlabel("Age")
plt.ylabel("Mean Relative Alpha Power (unitless)")
plt.title("Mean Relative Alpha Power vs. Age by Sex - All Channels")
plt.xticks(np.arange(10, 100, 10))
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()


## Mean absolute alpha power vs. age by sex with smoothed CI

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import make_interp_spline

# --- Load and merge ---
metadata = pd.read_csv("E:/ChristianMusaeus/metadata_time_filtered.csv")
alpha_df = pd.DataFrame({
    "subject_id": list(alpha_by_subject.keys()),
    "Alpha Power": list(alpha_by_subject.values())
})
metadata["subject_id"] = metadata["subject_id"].astype(str)
alpha_df["subject_id"] = alpha_df["subject_id"].astype(str)
merged = pd.merge(alpha_df, metadata, on="subject_id")

# --- Filter ---
merged = merged[merged["Alpha Power"] <= 20]

# --- Group by age and sex ---
grouped = merged.groupby(["age", "sex"]).agg(
    MeanAlpha=("Alpha Power", "mean"),
    StdAlpha=("Alpha Power", "std"),
    N=("Alpha Power", "count")
).reset_index()
grouped["SEM"] = grouped["StdAlpha"] / np.sqrt(grouped["N"])
grouped["CI95"] = 1.96 * grouped["SEM"]
grouped["Lower"] = grouped["MeanAlpha"] - grouped["CI95"]
grouped["Upper"] = grouped["MeanAlpha"] + grouped["CI95"]

# --- Smoothing function for one sex ---
def plot_smooth_ci(data, label_text, line_color):
    data_clean = data.replace([np.inf, -np.inf], np.nan).dropna(subset=["age", "MeanAlpha", "Lower", "Upper"])
    x = data_clean["age"].values
    y = data_clean["MeanAlpha"].values
    y_lower = data_clean["Lower"].values
    y_upper = data_clean["Upper"].values

    if len(x) < 4:
        return

    x_smooth = np.linspace(x.min(), x.max(), 500)
    spline_mean = make_interp_spline(x, y, k=3)
    spline_lower = make_interp_spline(x, y_lower, k=3)
    spline_upper = make_interp_spline(x, y_upper, k=3)

    y_smooth = spline_mean(x_smooth)
    y_lower_smooth = spline_lower(x_smooth)
    y_upper_smooth = spline_upper(x_smooth)


    
    plt.plot(x_smooth, y_smooth, color=line_color, label=label_text)
    plt.fill_between(x_smooth, y_lower_smooth, y_upper_smooth, color=line_color, alpha=0.15)
    plt.scatter(x, y, color=line_color, s=20, marker='o')

# --- Plot both sexes ---
plt.figure(figsize=(14, 5))

female_data = grouped[grouped["sex"] == "Female"]
male_data = grouped[grouped["sex"] == "Male"]

plot_smooth_ci(female_data, label_text="Female", line_color="blue")
plot_smooth_ci(male_data, label_text="Male", line_color="green")

plt.xlabel("Age")
plt.ylabel("Mean Absolute Alpha Power (µV²/Hz)")
plt.title("Mean Absolute Alpha Power vs. Age by Sex with 95% Confidence Bands")
plt.xticks(np.arange(10, 100, 10))
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()


## Relative alpha power vs. age by sex with smoothed CI

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import make_interp_spline

# --- Load and merge ---
metadata = pd.read_csv("E:/ChristianMusaeus/metadata_time_filtered.csv")
relative_df_df = pd.DataFrame({
    "subject_id": list(alpha_by_subject.keys()),
    "Relative Alpha Power": list(alpha_by_subject.values())
})
metadata["subject_id"] = metadata["subject_id"].astype(str)
relative_df["subject_id"] = relative_df["subject_id"].astype(str)
merged = pd.merge(relative_df, metadata, on="subject_id")

# --- Filter ---
merged = merged[merged["Relative Alpha Power"] <= 20]

# --- Group by age and sex ---
grouped = merged.groupby(["age", "sex"]).agg(
    MeanAlpha=("Relative Alpha Power", "mean"),
    StdAlpha=("Relative Alpha Power", "std"),
    N=("Relative Alpha Power", "count")
).reset_index()
grouped["SEM"] = grouped["StdAlpha"] / np.sqrt(grouped["N"])
grouped["CI95"] = 1.96 * grouped["SEM"]
grouped["Lower"] = grouped["MeanAlpha"] - grouped["CI95"]
grouped["Upper"] = grouped["MeanAlpha"] + grouped["CI95"]

# --- Smoothing function for one sex ---
def plot_smooth_ci(data, label_text, line_color):
    data_clean = data.replace([np.inf, -np.inf], np.nan).dropna(subset=["age", "MeanAlpha", "Lower", "Upper"])
    x = data_clean["age"].values
    y = data_clean["MeanAlpha"].values
    y_lower = data_clean["Lower"].values
    y_upper = data_clean["Upper"].values

    if len(x) < 4:
        return

    x_smooth = np.linspace(x.min(), x.max(), 500)
    spline_mean = make_interp_spline(x, y, k=3)
    spline_lower = make_interp_spline(x, y_lower, k=3)
    spline_upper = make_interp_spline(x, y_upper, k=3)

    y_smooth = spline_mean(x_smooth)
    y_lower_smooth = spline_lower(x_smooth)
    y_upper_smooth = spline_upper(x_smooth)


    
    plt.plot(x_smooth, y_smooth, color=line_color, label=label_text)
    plt.fill_between(x_smooth, y_lower_smooth, y_upper_smooth, color=line_color, alpha=0.15)
    plt.scatter(x, y, color=line_color, s=20, marker='o')

# --- Plot both sexes ---
plt.figure(figsize=(14, 5))

female_data = grouped[grouped["sex"] == "Female"]
male_data = grouped[grouped["sex"] == "Male"]

plot_smooth_ci(female_data, label_text="Female", line_color="blue")
plot_smooth_ci(male_data, label_text="Male", line_color="green")

plt.xlabel("Age")
plt.ylabel("Mean Relative  Alpha Power (µV²/Hz)")
plt.title("Mean Relative Alpha Power vs. Age by Sex with 95% Confidence Bands")
plt.xticks(np.arange(10, 100, 10))
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()


# Step 5: Calculating alpha power in just the occipital channels O1 and O2 

In [None]:
import mne
from mne.time_frequency import psd_array_welch
import numpy as np

alpha_by_subject_occipital = {}

for subject_id in list(top_epochs_per_subject.keys()):
    print(f"Processing alpha for {subject_id}...")

    set_path = f"G:/ChristianMusaeus/Preprocessed_setfiles/{subject_id}_epoched.set"

    try:
        epochs = mne.io.read_epochs_eeglab(set_path, verbose='ERROR')
        selected_data = epochs.get_data()[top_epochs_per_subject[subject_id]]  # (60, channels, timepoints)

        psds, freqs = psd_array_welch(
            selected_data,
            sfreq=epochs.info["sfreq"],
            fmin=8, fmax=13,
            n_fft=128,
            verbose=False
        )
        mean_power_per_channel = psds.mean(axis=-1)
        mean_power_over_epochs = mean_power_per_channel.mean(axis = 0)
        alpha_value = mean_power_over_epochs[8:10].mean() * 1e12

        alpha_by_subject_occipital[subject_id] = alpha_value

    except Exception as e:
        print(f"❌ Could not process {subject_id}: {e}")


## Plotting alpha power for only O1 and O2 across age

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# Load subject metadata (must contain 'subject_id' and 'age')
metadata = pd.read_csv("E:/ChristianMusaeus/metadata_time_filtered.csv")

# Create DataFrame of alpha powers
alpha_df = pd.DataFrame({
    "subject_id": list(alpha_by_subject_occipital.keys()),
    "Alpha Power": list(alpha_by_subject_occipital.values())
})

# Merge with metadata to get ages
merged = merged[merged["Alpha Power"] <= 20]

# Group by age and average
grouped = merged.groupby("age").agg(
    MeanAlpha=("Alpha Power", "mean"),
    N=("Alpha Power", "count")
).reset_index()

# Plot (dots only)
plt.figure(figsize=(12, 5))
plt.scatter(grouped["age"], grouped["MeanAlpha"], s=40)
plt.xlabel("Age")
plt.ylabel("Mean Absolute Alpha Power (µV²/Hz)")
plt.title("Mean Absolute Alpha Power vs. Age (Only Occipital Channels)")
plt.grid(True)
plt.xticks(np.arange(10, 100, 10))
plt.tight_layout()
plt.show()


### Plotting alpha power for only O1 and O2 across age with CI

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# --- Load and merge ---
metadata = pd.read_csv("E:/ChristianMusaeus/metadata_time_filtered.csv")
alpha_df = pd.DataFrame({
    "subject_id": list(alpha_by_subject_occipital.keys()),
    "Alpha Power": list(alpha_by_subject_occipital.values())
})
merged = pd.merge(alpha_df, metadata, on="subject_id")
merged = merged[merged["Alpha Power"] <= 20]  # Filter out extreme values

# --- Group by age and compute stats ---
grouped = merged.groupby("age").agg(
    MeanAlpha=("Alpha Power", "mean"),
    StdAlpha=("Alpha Power", "std"),
    N=("Alpha Power", "count")
).reset_index()
grouped["SEM"] = grouped["StdAlpha"] / np.sqrt(grouped["N"])
grouped["CI95"] = 1.96 * grouped["SEM"]
grouped["Lower"] = grouped["MeanAlpha"] - grouped["CI95"]
grouped["Upper"] = grouped["MeanAlpha"] + grouped["CI95"]

# --- Plot with shaded confidence bands ---
plt.figure(figsize=(16, 5))

# Scatter raw mean points
plt.scatter(grouped["age"], grouped["MeanAlpha"], color="blue", s=30, label="Mean Alpha Power")

# Line plot of the mean
plt.plot(grouped["age"], grouped["MeanAlpha"], color="red", label="Mean Trend")

# Shaded 95% CI area
plt.fill_between(grouped["age"], grouped["Lower"], grouped["Upper"],
                 color='skyblue', alpha=0.4, label="95% Confidence Interval")

plt.xlabel("Age")
plt.ylabel("Mean Alpha Power (µV²/Hz)")
plt.title("Mean Alpha Power vs. Age Only Occipital channels with 95% CI (Shaded)")
plt.grid(True)
plt.xticks(np.arange(10, 100, 10))
plt.legend()
plt.tight_layout()
plt.show()


### Plotting alpha power for only O1 and O2 with CI (Smoothed)

In [None]:
from scipy.interpolate import make_interp_spline
import numpy as np
import matplotlib.pyplot as plt

grouped_clean = grouped.replace([np.inf, -np.inf], np.nan).dropna(subset=["age", "MeanAlpha","Lower", "Upper"])

# Original values
x = grouped_clean["age"].values
y = grouped_clean["MeanAlpha"].values
y_lower = grouped_clean["Lower"].values
y_upper = grouped_clean["Upper"].values

# Smooth x-axis
x_smooth = np.linspace(x.min(), x.max(), 1000)

# Spline smoothing
spline_mean = make_interp_spline(x, y, k=5)
spline_lower = make_interp_spline(x, y_lower, k=5)
spline_upper = make_interp_spline(x, y_upper, k=5)

y_smooth = spline_mean(x_smooth)
y_lower_smooth = spline_lower(x_smooth)
y_upper_smooth = spline_upper(x_smooth)

# --- Plot ---
plt.figure(figsize=(12, 5))
plt.scatter(x, y, color='blue', s=30, label="Mean Alpha Power")
plt.plot(x_smooth, y_smooth, color='blue', label="Smoothed Mean")
plt.fill_between(x_smooth, y_lower_smooth, y_upper_smooth,
                 color='skyblue', alpha=0.4, label="95% CI")

plt.xlabel("Age")
plt.ylabel("Mean Absolute Alpha Power (µV²/Hz)")
plt.title("Mean Absolute Alpha Power vs. Age (Only Occipital channels) with 95% Confidence Interval")
plt.grid(True)
plt.xticks(np.arange(10, 100, 10))
plt.legend()
plt.tight_layout()
plt.show()
