# Training and evaluating Multinomial classifiers trained EC, EO and random epochs

In [None]:
import os
import random
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import mne
from mne.time_frequency import psd_array_welch

from scipy.interpolate import make_interp_spline

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import StratifiedKFold, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    accuracy_score,
    classification_report,
    confusion_matrix,
    ConfusionMatrixDisplay
)

### EC epochs 

In [None]:
# ------------ LOAD DATA ------------
with open("top_epochs_per_subject.pkl", "rb") as f:
    top_epochs_per_subject = pickle.load(f)
    
top_epochs_per_subject = {str(k).strip(): v for k, v in top_epochs_per_subject.items()}

with open("test_subjects_EC_rf.pkl", "rb") as f:
    test_subjects_EC_rf = pickle.load(f)

test_subjects_EC_mr = test_subjects_EC_rf  # Now this works correctly


metadata = pd.read_csv("metadata_time_filtered.csv")
metadata["subject_id"] = metadata["subject_id"].astype(str).str.strip()

all_subjects_grid = [s for s in top_epochs_per_subject.keys() if s in metadata["subject_id"].values]
train_subjects_grid = [s for s in all_subjects_grid if s not in test_subjects_EC_mr]

def assign_age_group(age):
    if age < 21:
        return 0
    elif age < 71:
        return 1
    else:
        return 2

metadata["age_group"] = metadata["age"].apply(assign_age_group)
metadata = metadata[metadata["subject_id"].isin(all_subjects_grid)]

# ------------ FEATURE EXTRACTION ------------
def extract_psd_features(subject_id, epoch_indices, set_folder):
    path = f"{set_folder}/{subject_id}_epoched.set"
    epochs = mne.io.read_epochs_eeglab(path, verbose='ERROR')
    data = epochs.get_data()[epoch_indices]
    sfreq = epochs.info["sfreq"]
    psds, freqs = mne.time_frequency.psd_array_welch(
        data, sfreq=sfreq, fmin=1, fmax=45, n_fft=200, verbose=False
    )
    return psds.mean(axis=(0, 1))

set_folder = "G:/ChristianMusaeus/Preprocessed_setfiles"
X_EC_mr, y_EC_mr = [], []

for subj_id in train_subjects_grid:
    try:
        features = extract_psd_features(subj_id, top_epochs_per_subject[subj_id], set_folder)
        age_group = metadata.loc[metadata["subject_id"] == subj_id, "age_group"].values[0]
        X_EC_mr.append(features)
        y_EC_mr.append(age_group)
    except Exception as e:
        print(f" Error processing {subj_id}: {e}")

X_EC_mr = np.array(X_EC_mr)
y_EC_mr = np.array(y_EC_mr)

# ------------ CV + HYPERPARAMETER TUNING ------------
param_grid = {
    'C': [0.01, 0.1, 1, 10],
    'penalty': ['l2'],
    'solver': ['lbfgs'],
    'max_iter': [500]
}

skf_grid = StratifiedKFold(n_splits=5, shuffle=True, random_state=13)
accuracies_grid = []

for fold, (train_idx, val_idx) in enumerate(skf_grid.split(X_EC_mr, y_EC_mr), start=1):
    print(f"\n Fold {fold} running...")
    X_train, X_val = X_EC_mr[train_idx], X_EC_mr[val_idx]
    y_train, y_val = y_EC_mr[train_idx], y_EC_mr[val_idx]

    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_val_scaled = scaler.transform(X_val)

    base_model = LogisticRegression(multi_class='multinomial', random_state=13)
    grid_search = GridSearchCV(base_model, param_grid, cv=3, n_jobs=-1, scoring='neg_log_loss', verbose=1)
    grid_search.fit(X_train_scaled, y_train)

    preds = grid_search.best_estimator_.predict(X_val_scaled)
    acc = accuracy_score(y_val, preds)
    accuracies_grid.append(acc)

    print(f"Fold {fold} Accuracy: {acc:.3f}")
    print(f"Best params: {grid_search.best_params_}")

print(f"\n Mean CV Accuracy: {np.mean(accuracies_grid):.3f}")

# ------------ FINAL MODEL + TEST SET EVALUATION ------------
scaler_final = StandardScaler()
X_scaled_EC_mr = scaler_final.fit_transform(X_EC_mr)
final_model = LogisticRegression(**grid_search.best_params_, multi_class='multinomial', random_state=13)
final_model.fit(X_scaled_EC_mr, y_EC_mr)

X_test_EC_mr, y_test_EC_mr = [], []
for subj_id in test_subjects_EC_mr:
    try:
        features = extract_psd_features(subj_id, top_epochs_per_subject[subj_id], set_folder)
        age_group = metadata.loc[metadata["subject_id"] == subj_id, "age_group"].values[0]
        X_test_EC_mr.append(features)
        y_test_EC_mr.append(age_group)
    except Exception as e:
        print(f" Error processing test subject {subj_id}: {e}")

X_test_EC_mr = np.array(X_test_EC_mr)
y_test_EC_mr = np.array(y_test_EC_mr)
X_test_scaled_EC_mr = scaler_final.transform(X_test_EC_mr)

test_preds_EC_mr = final_model.predict(X_test_scaled_EC_mr)
test_acc_EC_mr = accuracy_score(y_test_EC_mr, test_preds_EC_mr)
print(f"\n Final Test Accuracy: {test_acc_EC_mr:.3f}")
print(classification_report(y_test_EC_mr, test_preds_EC_mr))

# ------------ SUBJECT-LEVEL ACCURACY EXPORT FOR ANOVA ------------
subject_level_scores_mlr_ec = [{
    "subject_id": subj_id,
    "model_type": "MLR",
    "data_type": "EC",
    "score": int(true == pred)
} for subj_id, true, pred in zip(test_subjects_EC_mr, y_test_EC_mr, test_preds_EC_mr)]

df_subject_scores_mlr_ec = pd.DataFrame(subject_level_scores_mlr_ec)
df_subject_scores_mlr_ec.sort_values(by="subject_id").to_csv("subject_scores_mlr_ec.csv", index=False)
print(" Saved subject-level scores to 'subject_scores_mlr_ec.csv'")

# ------------ SAVE FOR LATER ------------
with open('y_test_EC_mr.pkl', 'wb') as f:
    pickle.dump(y_test_EC_mr, f)
with open('test_preds_EC_mr.pkl', 'wb') as f:
    pickle.dump(test_preds_EC_mr, f)
with open('test_subjects_EC_mr.pkl', 'wb') as f:
    pickle.dump(test_subjects_EC_mr, f)

### Confusion matrix for EC

In [None]:
# Compute confusion matrix
cm_mlr_ec = confusion_matrix(y_test_EC_mr, test_preds_EC_mr)

# Plot
fig, ax = plt.subplots(figsize=(8, 8))
disp = ConfusionMatrixDisplay(confusion_matrix=cm_mlr_ec)
disp.plot(ax=ax, cmap="Blues", colorbar=True)

plt.title("Confusion Matrix for EC MLR Model")
plt.xlabel("Predicted Age Group")
plt.ylabel("True Age Group")
plt.tight_layout()
plt.show()

### EO epochs 

In [None]:
# ------------ LOAD DATA ------------
with open("top_60_EO_epochs_per_subject.pkl", "rb") as f:
    top_60_EO_epochs_per_subject = pickle.load(f)
    
top_60_EO_epochs_per_subject = {str(k).strip(): v for k, v in top_60_EO_epochs_per_subject.items()}

with open("test_subjects_EC_rf.pkl", "rb") as f:
    test_subjects_EO_mr = pickle.load(f)

metadata = pd.read_csv("metadata_time_filtered.csv")
metadata["subject_id"] = metadata["subject_id"].astype(str).str.strip()

all_subjects_EO_mr = [s for s in top_60_EO_epochs_per_subject if s in metadata["subject_id"].values]
train_subjects_EO_mr = [s for s in all_subjects_EO_mr if s not in test_subjects_EO_mr]

def assign_age_group(age):
    if age < 21:
        return 0
    elif age < 71:
        return 1
    else:
        return 2

metadata["age_group"] = metadata["age"].apply(assign_age_group)
metadata = metadata[metadata["subject_id"].isin(all_subjects_EO_mr)]

# ------------ FEATURE EXTRACTION ------------
def extract_psd_features(subject_id, epoch_indices, set_folder):
    path = f"{set_folder}/{subject_id}_epoched.set"
    epochs = mne.io.read_epochs_eeglab(path, verbose='ERROR')
    data = epochs.get_data()[epoch_indices]
    sfreq = epochs.info["sfreq"]
    psds, freqs = mne.time_frequency.psd_array_welch(
        data, sfreq=sfreq, fmin=1, fmax=45, n_fft=200, verbose=False
    )
    return psds.mean(axis=(0, 1))

set_folder = "G:/ChristianMusaeus/Preprocessed_setfiles"
X_EO_mr, y_EO_mr = [], []
for subj_id in train_subjects_EO_mr:
    try:
        features = extract_psd_features(subj_id, top_60_EO_epochs_per_subject[subj_id], set_folder)
        age_group = metadata.loc[metadata["subject_id"] == subj_id, "age_group"].values[0]
        X_EO_mr.append(features)
        y_EO_mr.append(age_group)
    except Exception as e:
        print(f"Error processing train subject {subj_id}: {e}")

X_EO_mr = np.array(X_EO_mr)
y_EO_mr = np.array(y_EO_mr)

# ------------ MODEL TRAINING ------------
param_grid = {
    'C': [0.01, 0.1, 1, 10],
    'penalty': ['l2'],
    'solver': ['lbfgs'],
    'max_iter': [500],
    'multi_class': ['multinomial']
}

skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=13)
accuracies = []

for fold, (train_idx, val_idx) in enumerate(skf.split(X_EO_mr, y_EO_mr), start=1):
    print(f"\n Fold {fold} running...")
    X_train, X_val = X_EO_mr[train_idx], X_EO_mr[val_idx]
    y_train, y_val = y_EO_mr[train_idx], y_EO_mr[val_idx]

    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_val_scaled = scaler.transform(X_val)

    base_model = LogisticRegression(random_state=13)
    grid = GridSearchCV(base_model, param_grid, cv=3, scoring='neg_log_loss', verbose=1, n_jobs=-1)
    grid.fit(X_train_scaled, y_train)

    preds = grid.best_estimator_.predict(X_val_scaled)
    acc = accuracy_score(y_val, preds)
    accuracies.append(acc)

    print(f"Fold {fold} Accuracy: {acc:.3f}")
    print(f"Best params: {grid.best_params_}")

print(f"\n Mean CV Accuracy: {np.mean(accuracies):.3f}")

# ------------ FINAL TRAINING + EVAL ------------
scaler_final = StandardScaler()
X_scaled = scaler_final.fit_transform(X_EO_mr)
final_model = LogisticRegression(**grid.best_params_, random_state=13)
final_model.fit(X_scaled, y_EO_mr)

X_test, y_test_eo_mlr = [], []
for subj_id in test_subjects_EO_mr:
    try:
        features = extract_psd_features(subj_id, top_60_EO_epochs_per_subject[subj_id], set_folder)
        age_group = metadata.loc[metadata["subject_id"] == subj_id, "age_group"].values[0]
        X_test.append(features)
        y_test_eo_mlr.append(age_group)
    except Exception as e:
        print(f"Error processing test subject {subj_id}: {e}")

X_test = np.array(X_test)
y_test_eo_mlr = np.array(y_test_eo_mlr)
X_test_scaled = scaler_final.transform(X_test)

test_preds_EO_mlr = final_model.predict(X_test_scaled)
acc = accuracy_score(y_test_eo_mlr, test_preds_EO_mlr)
print(f"\n Final Test Accuracy: {acc:.3f}")
print(classification_report(y_test_eo_mlr, test_preds_EO_mlr))


# ------------ SAVE SUBJECT-LEVEL SCORES ------------
subject_scores_mlr_eo = [{
    "subject_id": subj_id,
    "model_type": "MLR",
    "data_type": "EO",
    "score": int(true == pred)
} for subj_id, true, pred in zip(test_subjects_EO_mr, y_test_eo_mlr, test_preds_EO_mlr)]


df_subject_scores_mlr_eo = pd.DataFrame(subject_scores_mlr_eo)
df_subject_scores_mlr_eo.sort_values(by="subject_id").to_csv("subject_scores_mlr_eo.csv", index=False)
print(" Saved subject-level scores to 'subject_scores_mlr_eo.csv'")


# ------------ SAVE OUTPUTS ------------
with open("y_test_EO_mr.pkl", "wb") as f:
    pickle.dump(y_test_eo_mlr, f)
with open("test_preds_EO_mr.pkl", "wb") as f:
    pickle.dump(test_preds_EO_mlr, f)
with open("test_subjects_EO_mr.pkl", "wb") as f:
    pickle.dump(test_subjects_EO_mr, f)


### Confusion matrix EO

In [None]:
# Compute confusion matrix
cm_mlr_eo = confusion_matrix(y_test_eo_mlr, test_preds_EO_mlr)

# Plot
fig, ax = plt.subplots(figsize=(8, 8))
disp = ConfusionMatrixDisplay(confusion_matrix=cm_mlr_eo)
disp.plot(ax=ax, cmap="Blues", colorbar=True)

plt.title("Confusion Matrix for EO MLR Model")
plt.xlabel("Predicted Age Group")
plt.ylabel("True Age Group")
plt.tight_layout()
plt.show()

### Random epochs 

In [None]:
# ------------ LOAD DATA ------------
with open("random_epochs_per_subject.pkl", "rb") as f:
    random_epochs_per_subject = pickle.load(f)
    
random_epochs_per_subject = {str(k).strip(): v for k, v in random_epochs_per_subject.items()}

with open("test_subjects_EC_rf.pkl", "rb") as f:
    test_subjects_random_mr = pickle.load(f)

metadata = pd.read_csv("metadata_time_filtered.csv")
metadata["subject_id"] = metadata["subject_id"].astype(str).str.strip()

all_subjects_random = [s for s in random_epochs_per_subject if s in metadata["subject_id"].values]
train_subjects_random_mr = [s for s in all_subjects_random if s not in test_subjects_random_mr]

def assign_age_group(age):
    if age < 21:
        return 0
    elif age < 71:
        return 1
    else:
        return 2

metadata["age_group"] = metadata["age"].apply(assign_age_group)
metadata = metadata[metadata["subject_id"].isin(all_subjects_random)]

# ------------ FEATURE EXTRACTION ------------
def extract_psd_features(subject_id, epoch_indices, set_folder):
    path = f"{set_folder}/{subject_id}_epoched.set"
    epochs = mne.io.read_epochs_eeglab(path, verbose='ERROR')
    data = epochs.get_data()[epoch_indices]
    sfreq = epochs.info["sfreq"]
    psds, freqs = mne.time_frequency.psd_array_welch(
        data, sfreq=sfreq, fmin=1, fmax=45, n_fft=200, verbose=False
    )
    return psds.mean(axis=(0, 1))

set_folder = "G:/ChristianMusaeus/Preprocessed_setfiles"

X_random_mr, y_random_mr = [], []
for subj_id in train_subjects_random_mr:
    try:
        features = extract_psd_features(subj_id, random_epochs_per_subject[subj_id], set_folder)
        age_group = metadata.loc[metadata["subject_id"] == subj_id, "age_group"].values[0]
        X_random_mr.append(features)
        y_random_mr.append(age_group)
    except Exception as e:
        print(f"Error processing train subject {subj_id}: {e}")

X_random_mr = np.array(X_random_mr)
y_random_mr = np.array(y_random_mr)

# ------------ MODEL TRAINING ------------
param_grid = {
    'C': [0.01, 0.1, 1, 10],
    'penalty': ['l2'],
    'solver': ['lbfgs'],
    'multi_class': ['multinomial'],
    'max_iter': [500]
}

skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=13)
accuracies = []

for fold, (train_idx, val_idx) in enumerate(skf.split(X_random_mr, y_random_mr), start=1):
    print(f"\n Fold {fold} running...")
    X_train, X_val = X_random_mr[train_idx], X_random_mr[val_idx]
    y_train, y_val = y_random_mr[train_idx], y_random_mr[val_idx]

    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_val_scaled = scaler.transform(X_val)

    base_model = LogisticRegression(random_state=13)
    grid = GridSearchCV(base_model, param_grid, cv=3, scoring='neg_log_loss', verbose=1, n_jobs=-1)
    grid.fit(X_train_scaled, y_train)

    preds = grid.best_estimator_.predict(X_val_scaled)
    acc = accuracy_score(y_val, preds)
    accuracies.append(acc)

    print(f"Fold {fold} Accuracy: {acc:.3f}")
    print(f"Best params: {grid.best_params_}")

print(f"\n Mean CV Accuracy (Random MLR): {np.mean(accuracies):.3f}")

# ------------ FINAL TRAINING + TESTING ------------
scaler_final = StandardScaler()
X_scaled = scaler_final.fit_transform(X_random_mr)
final_model = LogisticRegression(**grid.best_params_, random_state=13)
final_model.fit(X_scaled, y_random_mr)

X_test, y_test_random_mlr = [], []
for subj_id in test_subjects_random_mr:
    try:
        features = extract_psd_features(subj_id, random_epochs_per_subject[subj_id], set_folder)
        age_group = metadata.loc[metadata["subject_id"] == subj_id, "age_group"].values[0]
        X_test.append(features)
        y_test_random_mlr.append(age_group)
    except Exception as e:
        print(f"Error processing test subject {subj_id}: {e}")

X_test = np.array(X_test)
y_test_random_mlr = np.array(y_test_random_mlr)
X_test_scaled = scaler_final.transform(X_test)

preds_random_mlr = final_model.predict(X_test_scaled)
acc = accuracy_score(y_test_random_mlr, preds_random_mlr)
print(f"\n Final Test Accuracy: {acc:.3f}")
print(classification_report(y_test_random_mlr, preds_random_mlr))

# ------------ SAVE SUBJECT-LEVEL SCORES ------------
subject_scores_mlr_random = [{
    "subject_id": subj_id,
    "model_type": "MLR",
    "data_type": "Random",
    "score": int(true == pred)
} for subj_id, true, pred in zip(test_subjects_random_mr, y_test_random_mlr, preds_random_mlr)]

df_subject_scores_mlr_random = pd.DataFrame(subject_scores_mlr_random)
df_subject_scores_mlr_random.sort_values(by="subject_id").to_csv("subject_scores_mlr_random.csv", index=False)
print(" Saved subject-level scores to 'subject_scores_mlr_random.csv'")

# ------------ SAVE OUTPUTS ------------
with open("y_test_random_mr.pkl", "wb") as f:
    pickle.dump(y_test_random_mlr, f)
with open("test_preds_random_mr.pkl", "wb") as f:
    pickle.dump(preds_random_mlr, f)
with open("test_subjects_random_mr.pkl", "wb") as f:
    pickle.dump(test_subjects_random_mr, f)


### Confusion matrix random epochs 

In [None]:
# Compute confusion matrix
cm_mlr_random = confusion_matrix(y_test_random_mlr, preds_random_mlr)

# Plot
fig, ax = plt.subplots(figsize=(8, 8))
disp = ConfusionMatrixDisplay(confusion_matrix=cm_mlr_random)
disp.plot(ax=ax, cmap="Blues", colorbar=True)

plt.title("Confusion Matrix for Random MLR Model")
plt.xlabel("Predicted Age Group")
plt.ylabel("True Age Group")
plt.tight_layout()
plt.show()

# Plots 

## Eyes open mean absolute power across age 

In [None]:
# ------------ LOAD DATA ------------
with open("top_60_EO_epochs_per_subject.pkl", "rb") as f:
    top_epochs_EO = pickle.load(f)

top_epochs_EO = {str(k).strip(): v for k, v in top_epochs_EO.items()}
metadata = pd.read_csv("metadata_time_filtered.csv")
metadata["subject_id"] = metadata["subject_id"].astype(str).str.strip()

# ------------ COMPUTE ALPHA POWER PER SUBJECT ------------
abs_alpha_power_by_subject_EO = {}
set_folder = "G:/ChristianMusaeus/Preprocessed_setfiles"  # Update path

for subject_id, epoch_indices in top_epochs_EO.items():
    try:
        path = f"{set_folder}/{subject_id}_epoched.set"
        epochs = mne.io.read_epochs_eeglab(path, verbose='ERROR')
        data = epochs.get_data()[epoch_indices]  # shape: (n_epochs, n_channels, n_times)
        sfreq = epochs.info["sfreq"]

        psds, freqs = psd_array_welch(data, sfreq=sfreq, fmin=8, fmax=13, n_fft=200, verbose=False)
        alpha_value = psds.mean() * 1e12  # Convert to μV²/Hz
        abs_alpha_power_by_subject_EO[subject_id] = alpha_value  # Store as float

    except Exception as e:
        print(f" Could not process {subject_id}: {e}")

# ------------ MERGE ALPHA POWER WITH METADATA ------------
df_alpha = pd.DataFrame({
    "subject_id": list(abs_alpha_power_by_subject_EO.keys()),
    "alpha_power": list(abs_alpha_power_by_subject_EO.values())
})
df_alpha["subject_id"] = df_alpha["subject_id"].astype(str)
metadata["subject_id"] = metadata["subject_id"].astype(str)
merged = pd.merge(df_alpha, metadata, on="subject_id")

# ------------ FILTER BAD/EXTREME VALUES ------------
merged = merged[merged["alpha_power"].apply(lambda x: isinstance(x, (int, float)))]
merged = merged[merged["alpha_power"] <= 20]

# ------------ GROUP BY AGE AND COMPUTE STATS ------------
grouped = merged.groupby("age").agg(
    mean_alpha=("alpha_power", "mean"),
    std=("alpha_power", "std"),
    N=("alpha_power", "count")
).reset_index()

grouped["sem"] = grouped["std"] / np.sqrt(grouped["N"])
grouped["ci_upper"] = grouped["mean_alpha"] + 1.96 * grouped["sem"]
grouped["ci_lower"] = grouped["mean_alpha"] - 1.96 * grouped["sem"]

# ------------ SMOOTHING FOR PLOT ------------
valid = (~grouped["ci_upper"].isna()) & (~grouped["ci_lower"].isna())
ages = grouped.loc[valid, "age"].values
mean_alpha = grouped.loc[valid, "mean_alpha"].values
ci_upper = grouped.loc[valid, "ci_upper"].values
ci_lower = grouped.loc[valid, "ci_lower"].values

ages_smooth = np.linspace(ages.min(), ages.max(), 500)
mean_spline = make_interp_spline(ages, mean_alpha, k=3)
upper_spline = make_interp_spline(ages, ci_upper, k=3)
lower_spline = make_interp_spline(ages, ci_lower, k=3)

mean_smooth = mean_spline(ages_smooth)
upper_smooth = upper_spline(ages_smooth)
lower_smooth = lower_spline(ages_smooth)

# ------------ PLOT ------------
plt.figure(figsize=(14, 8))
plt.plot(ages_smooth, mean_smooth, label="Smoothed Mean Alpha", color="blue")
plt.fill_between(ages_smooth, lower_smooth, upper_smooth, color="skyblue", alpha=0.4, label="95% CI")
plt.scatter(ages, mean_alpha, s=10, color="blue", label="Mean per Age")

plt.xlabel("Age")
plt.ylabel("Absolute Alpha Power (μV²/Hz)")
plt.title("Absolute Alpha Power across Age for EO epochs")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.yticks(np.arange(0, 8.5, 1))
plt.show()

## Random epochs mean absolute power across age 

In [None]:
# ------------ LOAD DATA ------------
with open("random_epochs_per_subject.pkl", "rb") as f:
    top_epochs_random = pickle.load(f)

top_epochs_random = {str(k).strip(): v for k, v in top_epochs_random.items()}
metadata = pd.read_csv("metadata_time_filtered.csv")
metadata["subject_id"] = metadata["subject_id"].astype(str).str.strip()

# ------------ COMPUTE ALPHA POWER PER SUBJECT ------------
abs_alpha_power_by_subject_random = {}
set_folder = "G:/ChristianMusaeus/Preprocessed_setfiles"  # Update path

for subject_id, epoch_indices in top_epochs_random.items():
    try:
        path = f"{set_folder}/{subject_id}_epoched.set"
        epochs = mne.io.read_epochs_eeglab(path, verbose='ERROR')
        data = epochs.get_data()[epoch_indices]  # shape: (n_epochs, n_channels, n_times)
        sfreq = epochs.info["sfreq"]

        psds, freqs = psd_array_welch(data, sfreq=sfreq, fmin=8, fmax=13, n_fft=200, verbose=False)
        alpha_value = psds.mean() * 1e12  # Convert to μV²/Hz
        abs_alpha_power_by_subject_random[subject_id] = alpha_value  # Store as float 

    except Exception as e:
        print(f" Could not process {subject_id}: {e}")

# ------------ MERGE ALPHA POWER WITH METADATA ------------
df_alpha = pd.DataFrame({
    "subject_id": list(abs_alpha_power_by_subject_random.keys()),
    "alpha_power": list(abs_alpha_power_by_subject_random.values())
})
df_alpha["subject_id"] = df_alpha["subject_id"].astype(str)
metadata["subject_id"] = metadata["subject_id"].astype(str)
merged = pd.merge(df_alpha, metadata, on="subject_id")

# ------------ FILTER BAD/EXTREME VALUES ------------
merged = merged[merged["alpha_power"].apply(lambda x: isinstance(x, (int, float)))]
merged = merged[merged["alpha_power"] <= 20]

# ------------ GROUP BY AGE AND COMPUTE STATS ------------
grouped = merged.groupby("age").agg(
    mean_alpha=("alpha_power", "mean"),
    std=("alpha_power", "std"),
    N=("alpha_power", "count")
).reset_index()

grouped["sem"] = grouped["std"] / np.sqrt(grouped["N"])
grouped["ci_upper"] = grouped["mean_alpha"] + 1.96 * grouped["sem"]
grouped["ci_lower"] = grouped["mean_alpha"] - 1.96 * grouped["sem"]

# ------------ SMOOTHING FOR PLOT ------------
valid = (~grouped["ci_upper"].isna()) & (~grouped["ci_lower"].isna())
ages = grouped.loc[valid, "age"].values
mean_alpha = grouped.loc[valid, "mean_alpha"].values
ci_upper = grouped.loc[valid, "ci_upper"].values
ci_lower = grouped.loc[valid, "ci_lower"].values

ages_smooth = np.linspace(ages.min(), ages.max(), 500)
mean_spline = make_interp_spline(ages, mean_alpha, k=3)
upper_spline = make_interp_spline(ages, ci_upper, k=3)
lower_spline = make_interp_spline(ages, ci_lower, k=3)

mean_smooth = mean_spline(ages_smooth)
upper_smooth = upper_spline(ages_smooth)
lower_smooth = lower_spline(ages_smooth)

# ------------ PLOT ------------
plt.figure(figsize=(14, 8))
plt.plot(ages_smooth, mean_smooth, label="Smoothed Mean Alpha", color="blue")
plt.fill_between(ages_smooth, lower_smooth, upper_smooth, color="skyblue", alpha=0.4, label="95% CI")
plt.scatter(ages, mean_alpha, s=10, color="blue", label="Mean per Age")

plt.xlabel("Age")
plt.ylabel("Absolute Alpha Power (μV²/Hz)")
plt.title(" Absolute Alpha Power across Age for Randomly Chosen Epochs")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.yticks(np.arange(0, 8.5, 1))
plt.show()


## Eyes open - relative mean alpha 

In [None]:
import mne
import numpy as np
import pandas as pd
import pickle
import matplotlib.pyplot as plt
from mne.time_frequency import psd_array_welch
from scipy.interpolate import make_interp_spline

# ------------ LOAD DATA ------------
with open("top_60_EO_epochs_per_subject.pkl", "rb") as f:
    top_epochs = pickle.load(f)

top_epochs = {str(k).strip(): v for k, v in top_epochs.items()}
metadata = pd.read_csv("metadata_time_filtered.csv")
metadata["subject_id"] = metadata["subject_id"].astype(str).str.strip()

# ------------ COMPUTE RELATIVE ALPHA POWER PER SUBJECT ------------
alpha_power_by_subject = {}
set_folder = "G:/ChristianMusaeus/Preprocessed_setfiles"  # <-- Update to your directory!

for subject_id, epoch_indices in top_epochs.items():
    try:
        path = f"{set_folder}/{subject_id}_epoched.set"
        epochs = mne.io.read_epochs_eeglab(path, verbose='ERROR')
        data = epochs.get_data()[epoch_indices]
        sfreq = epochs.info["sfreq"]

        # Total PSD (1–45 Hz)
        psds_full, freqs = psd_array_welch(data, sfreq=sfreq, fmin=1, fmax=45, n_fft=200, verbose=False)
        total_power = psds_full.mean()

        # Alpha PSD (8–13 Hz)
        psds_alpha, _ = psd_array_welch(data, sfreq=sfreq, fmin=8, fmax=13, n_fft=200, verbose=False)
        alpha_power = psds_alpha.mean()

        rel_alpha = alpha_power / total_power if total_power > 0 else np.nan
        alpha_power_by_subject[subject_id] = rel_alpha

    except Exception as e:
        print(f" Could not process {subject_id}: {e}")

# ------------ MERGE WITH METADATA ------------
df_alpha = pd.DataFrame({
    "subject_id": list(alpha_power_by_subject.keys()),
    "alpha_power": list(alpha_power_by_subject.values())
})
df_alpha["subject_id"] = df_alpha["subject_id"].astype(str)
metadata["subject_id"] = metadata["subject_id"].astype(str)
merged = pd.merge(df_alpha, metadata, on="subject_id")

# ------------ FILTER & GROUP ------------
merged = merged[merged["alpha_power"].apply(lambda x: isinstance(x, (int, float, np.floating)))]
merged = merged[(merged["alpha_power"] >= 0) & (merged["alpha_power"] <= 1)]

grouped = merged.groupby("age").agg(
    mean_alpha=("alpha_power", "mean"),
    std=("alpha_power", "std"),
    N=("alpha_power", "count")
).reset_index()

grouped["sem"] = grouped["std"] / np.sqrt(grouped["N"])
grouped["ci_upper"] = grouped["mean_alpha"] + 1.96 * grouped["sem"]
grouped["ci_lower"] = grouped["mean_alpha"] - 1.96 * grouped["sem"]

# ------------ SMOOTHING FOR PLOT ------------
valid = (~grouped["ci_upper"].isna()) & (~grouped["ci_lower"].isna())
ages = grouped.loc[valid, "age"].values
mean_alpha = grouped.loc[valid, "mean_alpha"].values
ci_upper = grouped.loc[valid, "ci_upper"].values
ci_lower = grouped.loc[valid, "ci_lower"].values

ages_smooth = np.linspace(ages.min(), ages.max(), 500)
mean_spline = make_interp_spline(ages, mean_alpha, k=3)
upper_spline = make_interp_spline(ages, ci_upper, k=3)
lower_spline = make_interp_spline(ages, ci_lower, k=3)

mean_smooth = mean_spline(ages_smooth)
upper_smooth = upper_spline(ages_smooth)
lower_smooth = lower_spline(ages_smooth)

# ------------ PLOT ------------
plt.figure(figsize=(14, 6))
plt.plot(ages_smooth, mean_smooth, label="Smoothed Mean Relative Alpha", color="blue")
plt.fill_between(ages_smooth, lower_smooth, upper_smooth, color="lightblue", alpha=0.4, label="95% CI")
plt.scatter(ages, mean_alpha, s=10, color="blue", label="Mean per Age")

plt.xlabel("Age")
plt.ylabel("Relative Alpha Power")
plt.title("Relative Alpha Power across Age (Eyes Open Epochs)")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.yticks(np.arange(0, 1.05, 0.1))
plt.show()


## Random epochs - relative mean alpha 

In [None]:
import mne
import numpy as np
import pandas as pd
import pickle
import matplotlib.pyplot as plt
from mne.time_frequency import psd_array_welch
from scipy.interpolate import make_interp_spline

# ------------ LOAD DATA ------------
with open("random_epochs_per_subject.pkl", "rb") as f:
    random_epochs = pickle.load(f)

random_epochs = {str(k).strip(): v for k, v in random_epochs.items()}
metadata = pd.read_csv("metadata_time_filtered.csv")
metadata["subject_id"] = metadata["subject_id"].astype(str).str.strip()

# ------------ COMPUTE RELATIVE ALPHA POWER PER SUBJECT ------------
alpha_power_by_subject = {}
set_folder = "G:/ChristianMusaeus/Preprocessed_setfiles"  # <-- Update to your directory!

for subject_id, epoch_indices in random_epochs.items():
    try:
        path = f"{set_folder}/{subject_id}_epoched.set"
        epochs = mne.io.read_epochs_eeglab(path, verbose='ERROR')
        data = epochs.get_data()[epoch_indices]
        sfreq = epochs.info["sfreq"]

        # Total PSD (1–45 Hz)
        psds_full, freqs = psd_array_welch(data, sfreq=sfreq, fmin=1, fmax=45, n_fft=200, verbose=False)
        total_power = psds_full.mean()

        # Alpha PSD (8–13 Hz)
        psds_alpha, _ = psd_array_welch(data, sfreq=sfreq, fmin=8, fmax=13, n_fft=200, verbose=False)
        alpha_power = psds_alpha.mean()

        rel_alpha = alpha_power / total_power if total_power > 0 else np.nan
        alpha_power_by_subject[subject_id] = rel_alpha

    except Exception as e:
        print(f" Could not process {subject_id}: {e}")


# ------------ MERGE WITH METADATA ------------
df_alpha = pd.DataFrame({
    "subject_id": list(alpha_power_by_subject.keys()),
    "alpha_power": list(alpha_power_by_subject.values())
})
df_alpha["subject_id"] = df_alpha["subject_id"].astype(str)
metadata["subject_id"] = metadata["subject_id"].astype(str)
merged = pd.merge(df_alpha, metadata, on="subject_id")

# ------------ FILTER & GROUP ------------
merged = merged[merged["alpha_power"].apply(lambda x: isinstance(x, (int, float, np.floating)))]
merged = merged[(merged["alpha_power"] >= 0) & (merged["alpha_power"] <= 1)]

grouped = merged.groupby("age").agg(
    mean_alpha=("alpha_power", "mean"),
    std=("alpha_power", "std"),
    N=("alpha_power", "count")
).reset_index()

grouped["sem"] = grouped["std"] / np.sqrt(grouped["N"])
grouped["ci_upper"] = grouped["mean_alpha"] + 1.96 * grouped["sem"]
grouped["ci_lower"] = grouped["mean_alpha"] - 1.96 * grouped["sem"]

# ------------ SMOOTHING FOR PLOT ------------
valid = (~grouped["ci_upper"].isna()) & (~grouped["ci_lower"].isna())
ages = grouped.loc[valid, "age"].values
mean_alpha = grouped.loc[valid, "mean_alpha"].values
ci_upper = grouped.loc[valid, "ci_upper"].values
ci_lower = grouped.loc[valid, "ci_lower"].values

ages_smooth = np.linspace(ages.min(), ages.max(), 500)
mean_spline = make_interp_spline(ages, mean_alpha, k=3)
upper_spline = make_interp_spline(ages, ci_upper, k=3)
lower_spline = make_interp_spline(ages, ci_lower, k=3)

mean_smooth = mean_spline(ages_smooth)
upper_smooth = upper_spline(ages_smooth)
lower_smooth = lower_spline(ages_smooth)

# ------------ PLOT ------------
plt.figure(figsize=(14, 6))
plt.plot(ages_smooth, mean_smooth, label="Smoothed Mean Relative Alpha", color="blue")
plt.fill_between(ages_smooth, lower_smooth, upper_smooth, color="lightblue", alpha=0.4, label="95% CI")
plt.scatter(ages, mean_alpha, s=20, color="blue", label="Mean per Age")

plt.xlabel("Age")
plt.ylabel("Relative Alpha Power")
plt.title("Relative Alpha Power across Age (Randomly Chosen Epochs)")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.yticks(np.arange(0, 1.05, 0.1))
plt.show()

