In [None]:
import pickle
import numpy as np
from scipy.signal import hilbert
from pathlib import Path
import mne

# Channel renaming
channel_mapping = {
    "EEG Fp1": "Fp1", "EEG Fp2": "Fp2", "EEG F3": "F3", "EEG F4": "F4",
    "EEG F7": "F7", "EEG F8": "F8", "EEG T3": "T3", "EEG T4": "T4",
    "EEG C3": "C3", "EEG C4": "C4", "EEG T5": "T5", "EEG T6": "T6",
    "EEG P3": "P3", "EEG P4": "P4", "EEG O1": "O1", "EEG O2": "O2",
    "EEG Fz": "Fz", "EEG Cz": "Cz", "EEG Pz": "Pz"
}
channels = list(channel_mapping.values())

# Frequency bands
bands = {
    'delta': (1, 4),
    'theta': (4, 8),
    'alpha': (8, 12),
    'beta': (12, 30),
    'gamma': (30, 45),
    'broadband': (1, 45)
}

# Envelope extractor per subject
def extract_band_envelopes(raw, bands=bands, channels=channels):
    raw = raw.copy()

    # Clean and restrict to EEG channels
    raw.rename_channels(lambda ch: ch.replace("EEG ", "") if ch.startswith("EEG ") else ch)
    raw.pick_channels(channels)
    raw.set_eeg_reference('average', projection=False)

    band_envelopes = {}
    for band, (low, high) in bands.items():
        try:
            # Bandpass
            raw_band = raw.copy().filter(low, high, fir_design='firwin', verbose=False)
            data = raw_band.get_data()

            # Hilbert envelope
            analytic = hilbert(data, axis=1)
            envelope = np.abs(analytic)

            # Average over channels
            mean_envelope = np.mean(envelope, axis=0)
            band_envelopes[band] = mean_envelope
        except Exception as e:
            print(f"❌ Failed band {band}: {e}")
    return band_envelopes

# Batch processor
def process_all_subjects(eeg_dict, label="rest"):
    band_data = {}
    for subj_id, raw in eeg_dict.items():
        try:
            band_data[subj_id] = extract_band_envelopes(raw)
            print(f"✅ {label.upper()} subject {subj_id} processed.")
        except Exception as e:
            print(f"❌ {label.upper()} subject {subj_id} failed: {e}")
    return band_data

# === Load EEG dicts ===
with open("/home/donaf-strange/LAB_WORK/eeg_arithmetic_project/data/all_rest_eeg_by_participant.pkl", "rb") as f:
    rest_eeg_dict = pickle.load(f)

with open("/home/donaf-strange/LAB_WORK/eeg_arithmetic_project/data/all_task_eeg_by_participant.pkl", "rb") as f:
    task_eeg_dict = pickle.load(f)

# === Run ===
band_env_rest = process_all_subjects(rest_eeg_dict, label="rest")
band_env_task = process_all_subjects(task_eeg_dict, label="task")

# === Save ===
Path("band_envelopes").mkdir(exist_ok=True)

with open("band_envelopes/band_envelopes_rest.pkl", "wb") as f:
    pickle.dump(band_env_rest, f)

with open("band_envelopes/band_envelopes_task.pkl", "wb") as f:
    pickle.dump(band_env_task, f)


In [None]:
def jansen_rit_simulate(A=3.25, B=22, C=135, mu=120, dt=0.0001, duration=1.0):
    """Simulates Jansen-Rit model and returns output EEG signal"""
    t = np.arange(0, duration, dt)
    n = len(t)

    # State variables
    y0 = np.zeros(n)
    y1 = np.zeros(n)
    y2 = np.zeros(n)
    y3 = np.zeros(n)
    y4 = np.zeros(n)
    y5 = np.zeros(n)

    # Parameters
    a = 100.0  # excitatory time constant
    b = 50.0   # inhibitory time constant
    e0 = 2.5
    v0 = 6.0
    r = 0.56

    def sigmoid(v):
        return 2 * e0 / (1 + np.exp(r * (v0 - v)))

    for i in range(2, n):
        p = mu + np.random.randn() * 2.0  # white noise input

        y0[i] = y0[i-1] + dt * y3[i-1]
        y3[i] = y3[i-1] + dt * (A * a * sigmoid(y1[i-1] - y2[i-1]) - 2 * a * y3[i-1] - a**2 * y0[i-1])
        y1[i] = y1[i-1] + dt * y4[i-1]
        y4[i] = y4[i-1] + dt * (A * a * (p + C * sigmoid(C * y0[i-1])) - 2 * a * y4[i-1] - a**2 * y1[i-1])
        y2[i] = y2[i-1] + dt * y5[i-1]
        y5[i] = y5[i-1] + dt * (B * b * C * sigmoid(C * y0[i-1]) - 2 * b * y5[i-1] - b**2 * y2[i-1])

    return y1 - y2  # pyramidal output


In [None]:
import pickle

# Load the pickle file
with open("band_envelopes/band_envelopes_rest.pkl", "rb") as f:
    band_envelopes_rest = pickle.load(f)

# Print top-level keys (subject IDs)
print("Subject IDs:", list(band_envelopes_rest.keys())[:5])  # show a few subjects

# Pick one subject
example_subject = list(band_envelopes_rest.keys())[0]

# Print bands available for that subject
print(f"\nAvailable bands for subject {example_subject}:")
print(band_envelopes_rest[example_subject].keys())

# Print size of each band signal
for band, signal in band_envelopes_rest[example_subject].items():
    print(f"  {band}: shape = {len(signal)}")


In [None]:
import numpy as np
import pickle
from scipy.optimize import differential_evolution
from pathlib import Path
import os

# === Jansen-Rit model ===
def jansen_rit_simulate(A=3.25, B=22, C=135, mu=120, dt=0.002, duration=1.0, noise_std=2.0):
    t = np.arange(0, duration, dt)
    n = len(t)

    y0, y1, y2 = np.zeros(n), np.zeros(n), np.zeros(n)
    y3, y4, y5 = np.zeros(n), np.zeros(n), np.zeros(n)

    a, b = 100.0, 50.0
    e0, v0, r = 2.5, 6.0, 0.56

    def sigmoid(v): return 2 * e0 / (1 + np.exp(r * (v0 - v)))

    for i in range(2, n):
        p = mu + np.random.normal(0, noise_std)

        y0[i] = y0[i-1] + dt * y3[i-1]
        y3[i] = y3[i-1] + dt * (A * a * sigmoid(y1[i-1] - y2[i-1]) - 2 * a * y3[i-1] - a**2 * y0[i-1])

        y1[i] = y1[i-1] + dt * y4[i-1]
        y4[i] = y4[i-1] + dt * (A * a * (p + C * sigmoid(C * y0[i-1])) - 2 * a * y4[i-1] - a**2 * y1[i-1])

        y2[i] = y2[i-1] + dt * y5[i-1]
        y5[i] = y5[i-1] + dt * (B * b * C * sigmoid(C * y0[i-1]) - 2 * b * y5[i-1] - b**2 * y2[i-1])

    return y1 - y2  # Output

# === Cost function ===
def cost_function(params, target_signal, dt, duration):
    A, B, C, mu = params
    sim = jansen_rit_simulate(A=A, B=B, C=C, mu=mu, dt=dt, duration=duration)

    sim_len = len(sim)
    target_resampled = np.interp(
        np.linspace(0, len(target_signal) - 1, sim_len),
        np.arange(len(target_signal)),
        target_signal
    )

    sim_z = (sim - np.mean(sim)) / np.std(sim)
    tgt_z = (target_resampled - np.mean(target_resampled)) / np.std(target_resampled)

    return np.mean((sim_z - tgt_z) ** 2)

# === Fitting loop for any dataset ===
def fit_all_subjects(band_envelopes_dict, label="rest"):
    dt = 0.002
    duration = 1.0
    bounds = [(2, 6), (10, 30), (100, 150), (80, 140)]  # A, B, C, mu

    fitted_params = {}
    print(f"\n🚀 Starting parameter fitting for {label.upper()} data...")

    for subj_id, bands in band_envelopes_dict.items():
        if "alpha" not in bands:
            print(f"⚠️ Skipping subject {subj_id} (no alpha band)")
            continue
        try:
            target_signal = bands["alpha"]
            print(f"🔍 Fitting subject {subj_id}...")

            result = differential_evolution(
                func=cost_function,
                bounds=bounds,
                args=(target_signal, dt, duration),
                strategy='best1bin',
                maxiter=50,
                popsize=10,
                tol=1e-5,
                polish=True,
                disp=False
            )

            fitted_params[subj_id] = {
                "A": result.x[0],
                "B": result.x[1],
                "C": result.x[2],
                "mu": result.x[3],
                "loss": result.fun
            }

            print(f"✅ {label.capitalize()} | Subject {subj_id} | Loss = {result.fun:.5f}")

        except Exception as e:
            print(f"❌ Failed for subject {subj_id}: {e}")

    # Save results
    Path("jansen_rit").mkdir(exist_ok=True)
    save_path = f"jansen_rit/fitted_params_{label}.pkl"
    with open(save_path, "wb") as f:
        pickle.dump(fitted_params, f)

    print(f"💾 Saved fitted parameters to {save_path}")
    print(f"📊 Total successful fits: {len(fitted_params)} / {len(band_envelopes_dict)}\n")
    return fitted_params

# === Load both datasets ===
with open("band_envelopes/band_envelopes_rest.pkl", "rb") as f:
    band_env_rest = pickle.load(f)

with open("band_envelopes/band_envelopes_task.pkl", "rb") as f:
    band_env_task = pickle.load(f)

# === Run fitting
fitted_rest = fit_all_subjects(band_env_rest, label="rest")
fitted_task = fit_all_subjects(band_env_task, label="task")




In [None]:


# === Function to load and save RIT features ===
def save_rit_features_for_ml(condition_label):
    param_file = f"jansen_rit/fitted_params_{condition_label}.pkl"
    output_file = f"all_features/rit_features_{condition_label}.pkl"

    try:
        with open(param_file, "rb") as f:
            param_dict = pickle.load(f)

        # Just save the dict — it's already subject-wise and model parameter-wise
        with open(output_file, "wb") as f:
            pickle.dump(param_dict, f)

        print(f"✅ Saved RIT features to {output_file}")
    except Exception as e:
        print(f"❌ Failed to save features for {condition_label}: {e}")

# === Save both
save_rit_features_for_ml("rest")
save_rit_features_for_ml("task")


In [None]:
import pickle
import pymc as pm
import arviz as az
import numpy as np
import matplotlib.pyplot as plt
import os

# === Group Labels ===
good_ids = ['1', '2', '3', '5', '7', '8', 
            '11', '12', '13', '15', '16', '17', 
            '18', '20', '23', '24', '25', '26', 
            '27', '28', '29', '31', '32', '33', 
            '34', '35']

bad_ids = ['0', '4', '6', '9', '10', 
           '14', '19', '21', '22', '30']

# === Bayesian comparison function ===
def run_bayesian_parameter_comparison(param_name, param_dict, good_ids, bad_ids, title_label="REST"):
    # Extract data
    good_vals = [param_dict[sid][param_name] for sid in good_ids if sid in param_dict and param_name in param_dict[sid]]
    bad_vals = [param_dict[sid][param_name] for sid in bad_ids if sid in param_dict and param_name in param_dict[sid]]

    if len(good_vals) < 3 or len(bad_vals) < 3:
        print(f"⚠️ Not enough data for parameter {param_name} in {title_label}. Skipping.")
        return None

    print(f"\n📊 Comparing parameter '{param_name}' in {title_label}")
    print(f"  Good mean: {np.mean(good_vals):.3f}, Bad mean: {np.mean(bad_vals):.3f}")

    with pm.Model():
        mu_good = pm.Normal("mu_good", mu=0, sigma=10)
        mu_bad = pm.Normal("mu_bad", mu=0, sigma=10)
        sigma = pm.HalfNormal("sigma", sigma=10)

        obs_good = pm.Normal("obs_good", mu=mu_good, sigma=sigma, observed=good_vals)
        obs_bad = pm.Normal("obs_bad", mu=mu_bad, sigma=sigma, observed=bad_vals)

        diff = pm.Deterministic("diff", mu_good - mu_bad)

        trace = pm.sample(2000, tune=1000, chains=4, target_accept=0.95, progressbar=True)

    # Plot
    #az.plot_posterior(trace, var_names=["diff"], ref_val=0)
    #plt.title(f"Bayesian Comparison: {param_name} ({title_label})")
    #plt.tight_layout()
    #plt.show()

    return az.summary(trace, var_names=["diff"])

# === Function to process a condition (rest/task)
def process_condition(condition_label):
    param_file = f"jansen_rit/fitted_params_{condition_label}.pkl"
    if not os.path.exists(param_file):
        print(f"❌ Missing file: {param_file}")
        return

    print(f"\n=== 🔍 Analyzing {condition_label.upper()} parameters ===")

    with open(param_file, "rb") as f:
        param_dict = pickle.load(f)

    summaries = {}
    for param in ["A", "B", "C", "mu"]:
        summary = run_bayesian_parameter_comparison(param, param_dict, good_ids, bad_ids, title_label=condition_label.upper())
        if summary is not None:
            summaries[param] = summary

    # Print all results
    print(f"\n📌 SUMMARY: {condition_label.upper()}")
    for param, table in summaries.items():
        print(f"\n🔹 {param}:\n{table}\n")

# === Run for both conditions ===
process_condition("rest")
process_condition("task")

In [None]:
import numpy as np
import pickle
from scipy.signal import welch
from scipy.stats import entropy
from pathlib import Path
import nolds
import os

# === Jansen–Rit Simulator with all parameters ===
def jansen_rit_simulate(A=3.25, B=22, C=135, mu=120, a=100.0, b=50.0, e0=2.5, v0=6.0, r=0.56,
                        dt=0.002, duration=1.0, noise_std=2.0, pulse=None):
    t = np.arange(0, duration, dt)
    n = len(t)

    y0, y1, y2 = np.zeros(n), np.zeros(n), np.zeros(n)
    y3, y4, y5 = np.zeros(n), np.zeros(n), np.zeros(n)

    def sigmoid(v):
        return 2 * e0 / (1 + np.exp(r * (v0 - v)))

    for i in range(2, n):
        input_mu = mu
        if pulse and pulse[0] < i * dt < pulse[1]:
            input_mu += pulse[2]
        p = input_mu + np.random.normal(0, noise_std)

        y0[i] = y0[i-1] + dt * y3[i-1]
        y3[i] = y3[i-1] + dt * (A * a * sigmoid(y1[i-1] - y2[i-1]) - 2 * a * y3[i-1] - a**2 * y0[i-1])

        y1[i] = y1[i-1] + dt * y4[i-1]
        y4[i] = y4[i-1] + dt * (A * a * (p + C * sigmoid(C * y0[i-1])) - 2 * a * y4[i-1] - a**2 * y1[i-1])

        y2[i] = y2[i-1] + dt * y5[i-1]
        y5[i] = y5[i-1] + dt * (B * b * C * sigmoid(C * y0[i-1]) - 2 * b * y5[i-1] - b**2 * y2[i-1])

    return y1 - y2  # Output from pyramidal neurons

# === Feature extractors ===
def dominant_frequency(signal, sfreq=500):
    freqs, psd = welch(signal, fs=sfreq)
    return freqs[np.argmax(psd)]

def signal_entropy(signal, bins=50):
    hist, _ = np.histogram(signal, bins=bins, density=True)
    return entropy(hist + 1e-8)

def lyapunov_exponent(signal):
    return nolds.lyap_r(signal)

def erp_features(signal, dt=0.002):
    peak_amp = np.max(signal)
    latency = np.argmax(signal) * dt
    width = np.sum(signal > (0.5 * peak_amp)) * dt
    return {
        "erp_amp": peak_amp,
        "erp_latency": latency,
        "erp_width": width
    }

# === Combine all features per subject ===
def extract_extended_features(params, duration=1.0):
    sim_params = {k: params[k] for k in ['A', 'B', 'C', 'mu']}
    a, b = 100.0, 50.0
    e0, v0, r = 2.5, 6.0, 0.56

    sim = jansen_rit_simulate(**sim_params, a=a, b=b, e0=e0, v0=v0, r=r, duration=duration)

    return {
        **sim_params,
        "a": a, "b": b,
        "e0": e0, "v0": v0, "r": r,
        "dominant_freq": dominant_frequency(sim),
        "entropy": signal_entropy(sim),
        "lyapunov": lyapunov_exponent(sim),
        **erp_features(sim)
    }

# === Process a .pkl file and save extended features ===
def extend_param_file(in_path, out_path, duration=1.0):
    with open(in_path, "rb") as f:
        fitted_params = pickle.load(f)

    extended = {}
    for subj_id, params in fitted_params.items():
        try:
            feats = extract_extended_features(params, duration=duration)
            extended[subj_id] = feats
            print(f"✅ Extended subject {subj_id}")
        except Exception as e:
            print(f"❌ Failed subject {subj_id}: {e}")

    Path(os.path.dirname(out_path)).mkdir(exist_ok=True)
    with open(out_path, "wb") as f:
        pickle.dump(extended, f)
    print(f"\n💾 Saved extended features to: {out_path}")

# === Run for both REST and TASK ===
extend_param_file("jansen_rit/fitted_params_rest.pkl", "jansen_rit/fitted_params_rest_extended.pkl")
extend_param_file("jansen_rit/fitted_params_task.pkl", "jansen_rit/fitted_params_task_extended.pkl")


In [None]:
# combined

import numpy as np
import pickle
from scipy.signal import welch
from scipy.stats import entropy
from scipy.optimize import differential_evolution
from pathlib import Path
import nolds
import os

# === Jansen-Rit Simulator ===
def jansen_rit_simulate(A=3.25, B=22, C=135, mu=120, a=100.0, b=50.0, e0=2.5, v0=6.0, r=0.56,
                        dt=0.002, duration=1.0, noise_std=2.0):
    t = np.arange(0, duration, dt)
    n = len(t)

    y0, y1, y2 = np.zeros(n), np.zeros(n), np.zeros(n)
    y3, y4, y5 = np.zeros(n), np.zeros(n), np.zeros(n)

    def sigmoid(v):
        return 2 * e0 / (1 + np.exp(r * (v0 - v)))

    for i in range(2, n):
        p = mu + np.random.normal(0, noise_std)

        y0[i] = y0[i-1] + dt * y3[i-1]
        y3[i] = y3[i-1] + dt * (A * a * sigmoid(y1[i-1] - y2[i-1]) - 2 * a * y3[i-1] - a**2 * y0[i-1])

        y1[i] = y1[i-1] + dt * y4[i-1]
        y4[i] = y4[i-1] + dt * (A * a * (p + C * sigmoid(C * y0[i-1])) - 2 * a * y4[i-1] - a**2 * y1[i-1])

        y2[i] = y2[i-1] + dt * y5[i-1]
        y5[i] = y5[i-1] + dt * (B * b * C * sigmoid(C * y0[i-1]) - 2 * b * y5[i-1] - b**2 * y2[i-1])

    return y1 - y2  # Output signal

# === Feature Extractors ===
def dominant_frequency(signal, sfreq=500):
    freqs, psd = welch(signal, fs=sfreq)
    return freqs[np.argmax(psd)]

def signal_entropy(signal, bins=50):
    hist, _ = np.histogram(signal, bins=bins, density=True)
    return entropy(hist + 1e-8)

def lyapunov_exponent(signal):
    return nolds.lyap_r(signal)

def erp_features(signal, dt=0.002):
    peak_amp = np.max(signal)
    latency = np.argmax(signal) * dt
    width = np.sum(signal > (0.5 * peak_amp)) * dt
    return {
        "erp_amp": peak_amp,
        "erp_latency": latency,
        "erp_width": width
    }

def extract_extended_features(params, duration=1.0):
    sim = jansen_rit_simulate(**params, duration=duration)
    return {
        **params,
        "dominant_freq": dominant_frequency(sim),
        "entropy": signal_entropy(sim),
        "lyapunov": lyapunov_exponent(sim),
        **erp_features(sim)
    }

# === Cost function for optimization ===
def cost_function(params, target_signal, dt, duration):
    A, B, C, mu = params
    sim = jansen_rit_simulate(A=A, B=B, C=C, mu=mu, dt=dt, duration=duration)

    sim_len = len(sim)
    target_resampled = np.interp(
        np.linspace(0, len(target_signal) - 1, sim_len),
        np.arange(len(target_signal)),
        target_signal
    )

    sim_z = (sim - np.mean(sim)) / np.std(sim)
    tgt_z = (target_resampled - np.mean(target_resampled)) / np.std(target_resampled)

    return np.mean((sim_z - tgt_z) ** 2)

# === Full Pipeline per Dataset (Rest/Task) ===
def full_jansen_pipeline(band_envelopes_dict, label="rest", duration=1.0):
    dt = 0.002
    bounds = [(2, 6), (10, 30), (100, 150), (80, 140)]  # A, B, C, mu

    fitted_params = {}
    extended_features = {}

    print(f"\n🚀 Starting full Jansen–Rit pipeline for {label.upper()}...")

    for subj_id, bands in band_envelopes_dict.items():
        if "alpha" not in bands:
            print(f"⚠️ Skipping subject {subj_id} (no alpha band)")
            continue

        try:
            target_signal = bands["alpha"]
            print(f"🔍 Fitting subject {subj_id}...")

            result = differential_evolution(
                func=cost_function,
                bounds=bounds,
                args=(target_signal, dt, duration),
                strategy='best1bin',
                maxiter=50,
                popsize=10,
                tol=1e-5,
                polish=True,
                disp=False
            )

            params = {
                "A": result.x[0],
                "B": result.x[1],
                "C": result.x[2],
                "mu": result.x[3],
                "a": 100.0,
                "b": 50.0,
                "e0": 2.5,
                "v0": 6.0,
                "r": 0.56
            }

            fitted_params[subj_id] = params
            extended_features[subj_id] = extract_extended_features(params, duration=duration)

            print(f"✅ {label.capitalize()} | Subject {subj_id} | Loss = {result.fun:.5f}")

        except Exception as e:
            print(f"❌ Failed for subject {subj_id}: {e}")

    # Save combined results
    Path("jansen_rit").mkdir(exist_ok=True)

    combined_data = {
        "fitted_params": fitted_params,
        "extended_features": extended_features
    }

    with open(f"jansen_rit/fitted_params_{label}.pkl", "wb") as f:
        pickle.dump(combined_data, f)

    print(f"\n💾 Saved all results to jansen_rit/fitted_params_{label}.pkl")
    print(f"📊 Total subjects processed: {len(fitted_params)} / {len(band_envelopes_dict)}")


    return fitted_params, extended_features

# === Load EEG Band Envelope Data ===
with open("band_envelopes/band_envelopes_rest.pkl", "rb") as f:
    band_env_rest = pickle.load(f)

with open("band_envelopes/band_envelopes_task.pkl", "rb") as f:
    band_env_task = pickle.load(f)

# === Run Combined Pipeline
fitted_rest, extended_rest = full_jansen_pipeline(band_env_rest, label="rest")
fitted_task, extended_task = full_jansen_pipeline(band_env_task, label="task")


In [None]:
#improved

import numpy as np
import pickle
from scipy.signal import welch
from scipy.stats import entropy
from scipy.optimize import differential_evolution
from scipy.integrate import solve_ivp
from pathlib import Path
import nolds
import os

# === Ornstein-Uhlenbeck noise generator ===
def ou_process(n, dt, tau=0.05, sigma=1.0):
    x = np.zeros(n)
    for i in range(1, n):
        x[i] = x[i-1] * np.exp(-dt / tau) + sigma * np.sqrt(1 - np.exp(-2 * dt / tau)) * np.random.randn()
    return x

# === Jansen–Rit ODE ===
def jansen_rit_ode(t, state, A, B, C, mu, a, b, e0, v0, r, noise_t, t_array):
    y0, y3, y1, y4, y2, y5 = state
    idx = min(int(t / (t_array[1] - t_array[0])), len(noise_t) - 1)
    p = mu + noise_t[idx]
    S = lambda v: 2 * e0 / (1 + np.exp(r * (v0 - v)))

    dy0 = y3
    dy3 = A * a * S(y1 - y2) - 2 * a * y3 - a ** 2 * y0
    dy1 = y4
    dy4 = A * a * (p + C * S(C * y0)) - 2 * a * y4 - a ** 2 * y1
    dy2 = y5
    dy5 = B * b * C * S(C * y0) - 2 * b * y5 - b ** 2 * y2
    return [dy0, dy3, dy1, dy4, dy2, dy5]

# === Simulator ===
def jansen_rit_simulate(A=3.25, B=22, C=135, mu=120, a=100.0, b=50.0,
                        e0=2.5, v0=6.0, r=0.56, dt=0.002, duration=1.0, noise_std=2.0):
    t = np.arange(0, duration, dt)
    y0 = [0.0] * 6
    noise_t = ou_process(len(t), dt, sigma=noise_std)

    def ode_wrapper(ti, state):
        return jansen_rit_ode(ti, state, A, B, C, mu, a, b, e0, v0, r, noise_t, t)

    sol = solve_ivp(
        fun=ode_wrapper,
        t_span=(0, duration),
        y0=y0,
        t_eval=t,
        method='RK45'
    )
    signal = sol.y[2] - sol.y[4]
    return signal / (np.max(np.abs(signal)) + 1e-8)

# === Feature Extractors ===
def dominant_frequency(signal, sfreq=500):
    freqs, psd = welch(signal, fs=sfreq)
    return freqs[np.argmax(psd)]

def signal_entropy(signal, bins=50):
    hist, _ = np.histogram(signal, bins=bins, density=True)
    return entropy(hist + 1e-8)

def lyapunov_exponent(signal):
    return nolds.lyap_r(signal)

def erp_features(signal, dt=0.002):
    peak_amp = np.max(signal)
    latency = np.argmax(signal) * dt
    width = np.sum(signal > (0.5 * peak_amp)) * dt
    return {"erp_amp": peak_amp, "erp_latency": latency, "erp_width": width}

def extract_extended_features(params, duration=1.0):
    sim = jansen_rit_simulate(**params, duration=duration)
    return {
        **params,
        "dominant_freq": dominant_frequency(sim),
        "entropy": signal_entropy(sim),
        "lyapunov": lyapunov_exponent(sim),
        **erp_features(sim)
    }

# === Cost function ===
def cost_function(params_vec, target_signal, dt, duration):
    A, B, C, mu = params_vec
    params = {
        "A": A, "B": B, "C": C, "mu": mu,
        "a": 100.0, "b": 50.0,
        "e0": 2.5, "v0": 6.0, "r": 0.56,
        "dt": dt, "noise_std": 2.0
    }
    sim = jansen_rit_simulate(**params, duration=duration)
    sim_len = len(sim)
    target_resampled = np.interp(np.linspace(0, len(target_signal)-1, sim_len),
                                 np.arange(len(target_signal)), target_signal)
    sim_z = (sim - np.mean(sim)) / np.std(sim)
    tgt_z = (target_resampled - np.mean(target_resampled)) / np.std(target_resampled)
    return np.mean((sim_z - tgt_z) ** 2)

# === Full Pipeline ===
def full_jansen_pipeline(band_envelopes_dict, label="rest", duration=1.0):
    dt = 0.002
    bounds = [(2, 6), (10, 30), (100, 150), (80, 140)]  # A, B, C, mu
    fitted_params, extended_features = {}, {}

    print(f"\n🚀 Starting Jansen–Rit pipeline for {label.upper()}")

    for subj_id, bands in band_envelopes_dict.items():
        if "alpha" not in bands:
            print(f"⚠️ Skipping subject {subj_id} (no alpha band)")
            continue
        try:
            target_signal = bands["alpha"]
            print(f"🔍 Fitting subject {subj_id}...")

            result = differential_evolution(
                func=cost_function,
                bounds=bounds,
                args=(target_signal, dt, duration),
                strategy='best1bin',
                maxiter=50,
                popsize=10,
                tol=1e-5,
                polish=True,
                disp=False
            )

            params = {
                "A": result.x[0], "B": result.x[1], "C": result.x[2], "mu": result.x[3],
                "a": 100.0, "b": 50.0,
                "e0": 2.5, "v0": 6.0, "r": 0.56,
                "dt": dt, "noise_std": 2.0
            }

            fitted_params[subj_id] = params
            extended_features[subj_id] = extract_extended_features(params, duration=duration)

            print(f"✅ {label.capitalize()} | Subject {subj_id} | Loss = {result.fun:.5f}")

        except Exception as e:
            print(f"❌ Failed for subject {subj_id}: {e}")

    # Save combined results
    Path("jansen_rit").mkdir(exist_ok=True)

    combined_data = {
        "fitted_params": fitted_params,
        "extended_features": extended_features
    }

    with open(f"jansen_rit/fitted_params_{label}.pkl", "wb") as f:
        pickle.dump(combined_data, f)

    print(f"\n💾 Saved all results to jansen_rit/fitted_params_{label}.pkl")
    print(f"📊 Total subjects processed: {len(fitted_params)} / {len(band_envelopes_dict)}")
    return fitted_params, extended_features

# === Load Data and Run ===
with open("band_envelopes/band_envelopes_rest.pkl", "rb") as f:
    band_env_rest = pickle.load(f)
with open("band_envelopes/band_envelopes_task.pkl", "rb") as f:
    band_env_task = pickle.load(f)

fitted_rest, extended_rest = full_jansen_pipeline(band_env_rest, label="rest")
fitted_task, extended_task = full_jansen_pipeline(band_env_task, label="task")


In [None]:
print("Subjects in fitted_params:", list(data['fitted_params'].keys()))
print("Subjects in extended_features:", list(data['extended_features'].keys()))
# ids are stored in string

In [None]:
import pickle

# Load the saved file
with open("jansen_rit/fitted_params_rest.pkl", "rb") as f:
    data = pickle.load(f)

# Access subject '1' (string key)
subject_id = '1'

# View fitted parameters
print(f"📌 Fitted Parameters for subject {subject_id}:")
print(data['fitted_params'][subject_id])

# View extended features
print(f"\n📊 Extended Features for subject {subject_id}:")
for k, v in data['extended_features'][subject_id].items():
    print(f"{k}: {v}")


In [None]:
import pickle
import pymc as pm
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import os

# === Define subject groups ===
good_ids = ['1', '2', '3', '5', '7', '8',
            '11', '12', '13', '15', '16', '17',
            '18', '20', '23', '24', '25', '26',
            '27', '28', '29', '31', '32', '33',
            '34', '35']

bad_ids = ['0', '4', '6', '9', '10',
           '14', '19', '21', '22', '30']

# === Bayesian group comparison ===
def bayesian_compare_feature(param_name, data_dict, good_ids, bad_ids, label="REST"):
    good_vals = [data_dict[sid][param_name] for sid in good_ids if sid in data_dict and param_name in data_dict[sid]]
    bad_vals = [data_dict[sid][param_name] for sid in bad_ids if sid in data_dict and param_name in data_dict[sid]]

    if len(good_vals) < 3 or len(bad_vals) < 3:
        print(f"⚠️ Not enough data for {param_name} in {label}")
        return None

    print(f"\n📊 {label}: Comparing '{param_name}'")
    print(f"  Good mean: {np.mean(good_vals):.3f}, Bad mean: {np.mean(bad_vals):.3f}")

    with pm.Model():
        mu_good = pm.Normal("mu_good", mu=0, sigma=10)
        mu_bad = pm.Normal("mu_bad", mu=0, sigma=10)
        sigma = pm.HalfNormal("sigma", sigma=10)

        pm.Normal("obs_good", mu=mu_good, sigma=sigma, observed=good_vals)
        pm.Normal("obs_bad", mu=mu_bad, sigma=sigma, observed=bad_vals)

        diff = pm.Deterministic("diff", mu_good - mu_bad)

        trace = pm.sample(2000, tune=1000, chains=4, target_accept=0.95, progressbar=True)

    return az.summary(trace, var_names=["diff"], hdi_prob=0.95)

# === Function to process one .pkl file ===
def process_feature_file(pkl_path, label="REST"):
    if not os.path.exists(pkl_path):
        print(f"❌ Missing file: {pkl_path}")
        return

    with open(pkl_path, "rb") as f:
        raw = pickle.load(f)

    # Extract correct layer
    if isinstance(raw, dict) and 'extended_features' in raw:
        feature_dict = raw['extended_features']
    else:
        feature_dict = raw  # fallback (in case only flat dict is saved)

    # Ensure all keys are strings
    feature_dict = {str(k): v for k, v in feature_dict.items()}

    # Detect scalar keys to compare
    feature_keys = [k for k, v in next(iter(feature_dict.values())).items()
                    if isinstance(v, (int, float, np.number)) and not isinstance(v, bool)]

    results = {}

    print(f"\n=== 🔍 Running Bayesian comparisons for {label.upper()} ===")
    for param in feature_keys:
        summary = bayesian_compare_feature(param, feature_dict, good_ids, bad_ids, label=label)
        if summary is not None:
            results[param] = summary

    # === Summary Table ===
    print(f"\n📌 SUMMARY TABLE for {label.upper()}:")
    print(f"{'Parameter':<20} {'Diff Mean':>10}   {'HDI (low - high)':>25}")
    print("-" * 60)
    for param, df in results.items():
        mean = df.loc['diff', 'mean']

        hdi_cols = [col for col in df.columns if 'hdi' in col]
        if len(hdi_cols) >= 2:
            hdi_low = df.loc['diff', hdi_cols[0]]
            hdi_high = df.loc['diff', hdi_cols[1]]
            print(f"{param:<20} {mean:>10.3f}   [{hdi_low:.3f}, {hdi_high:.3f}]")
        else:
            print(f"{param:<20} {mean:>10.3f}   [HDI N/A]")

    return results

# === Run for REST and TASK ===
results_rest = process_feature_file("jansen_rit/fitted_params_rest.pkl", label="REST")
results_task = process_feature_file("jansen_rit/fitted_params_task.pkl", label="TASK")


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd

# === Input means and HDIs for C and mu ===
# Format: {Condition: {Parameter: {Group: val, ...}}}
summary = {
    'Rest': {
        'C': {'Good': 135 + 24.901/2, 'Bad': 135 - 24.901/2, 'HDI_low': 3.402, 'HDI_high': 47.227},
        'mu': {'Good': 120 + 23.861/2, 'Bad': 120 - 23.861/2, 'HDI_low': 0.859, 'HDI_high': 47.100}
    },
    'Task': {
        'C': {'Good': 135 + 21.335/2, 'Bad': 135 - 21.335/2, 'HDI_low': 0.171, 'HDI_high': 44.246},
        'mu': {'Good': 120 + 27.078/2, 'Bad': 120 - 27.078/2, 'HDI_low': 4.246, 'HDI_high': 49.699}
    }
}

# === Convert to long format DataFrame ===
rows = []
for cond in summary:
    for param in summary[cond]:
        g = summary[cond][param]['Good']
        b = summary[cond][param]['Bad']
        hdi_range = summary[cond][param]['HDI_high'] - summary[cond][param]['HDI_low']
        rows.append([cond, param, 'Good', g, hdi_range])
        rows.append([cond, param, 'Bad', b, hdi_range])
df = pd.DataFrame(rows, columns=['Condition', 'Parameter', 'Group', 'Value', 'HDI'])

# === Plot ===
sns.set(style="whitegrid")
fig, axes = plt.subplots(1, 2, figsize=(12, 6), sharey=True)

for i, param in enumerate(['C', 'mu']):
    ax = axes[i]
    plot_df = df[df['Parameter'] == param]
    sns.barplot(data=plot_df, x='Condition', y='Value', hue='Group', ax=ax,
                palette=['#66c2a5', '#fc8d62'], ci=None)
    
    # Add error bars (HDI ranges)
    for idx, row in plot_df.iterrows():
        ax.errorbar(x=['Rest', 'Task'].index(row['Condition']) + (-0.2 if row['Group'] == 'Good' else 0.2),
                    y=row['Value'], yerr=row['HDI'] / 2,
                    fmt='none', ecolor='black', capsize=5)

    ax.set_title(f"Parameter: {param}", fontsize=14)
    ax.set_ylabel("Value")
    ax.set_xlabel("Condition")
    ax.set_ylim(bottom=100 if param == 'C' else 100)  # Adjust as needed

axes[0].legend(title='Group')
#plt.suptitle("🔬 Group Differences in Jansen–Rit Parameters", fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()


In [None]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pymc as pm
import arviz as az

# === Load fitted parameters ===
with open("jansen_rit/fitted_params_rest_extended.pkl", "rb") as f:
    rest_params = pickle.load(f)

with open("jansen_rit/fitted_params_task_extended.pkl", "rb") as f:
    task_params = pickle.load(f)

# === Group definitions ===
good_ids = ['1', '2', '3', '5', '7', '8', '11', '12', '13', '15', '16', '17',
            '18', '20', '23', '24', '25', '26', '27', '28', '29', '31', '32', '33', '34', '35']

bad_ids = ['0', '4', '6', '9', '10', '14', '19', '21', '22', '30']

# === Function to extract group-wise values ===
def extract_group_data(param_dict, param_name):
    good_vals = [param_dict[sid][param_name] for sid in good_ids if sid in param_dict and param_name in param_dict[sid]]
    bad_vals = [param_dict[sid][param_name] for sid in param_dict if sid in bad_ids and param_name in param_dict[sid]]
    return good_vals, bad_vals

# === Plot function ===
def plot_violin_and_hdi(param_dict, param_list, condition="REST"):
    for param in param_list:
        good_vals, bad_vals = extract_group_data(param_dict, param)
        if len(good_vals) < 3 or len(bad_vals) < 3:
            print(f"Skipping {param} (insufficient data)")
            continue

        # Plot violin
        plt.figure(figsize=(6, 4))
        sns.violinplot(data=[good_vals, bad_vals], palette="pastel", inner="box")
        plt.xticks([0, 1], ['Good', 'Bad'])
        plt.title(f"{condition} – {param}")
        plt.ylabel(param)
        plt.grid(True)
        plt.tight_layout()
        plt.show()

        # Plot posterior difference (HDI)
        with pm.Model():
            mu_good = pm.Normal("mu_good", mu=0, sigma=10)
            mu_bad = pm.Normal("mu_bad", mu=0, sigma=10)
            sigma = pm.HalfNormal("sigma", sigma=10)

            pm.Normal("obs_good", mu=mu_good, sigma=sigma, observed=good_vals)
            pm.Normal("obs_bad", mu=mu_bad, sigma=sigma, observed=bad_vals)

            diff = pm.Deterministic("diff", mu_good - mu_bad)
            trace = pm.sample(2000, tune=1000, chains=4, target_accept=0.95, progressbar=False)

        az.plot_posterior(trace, var_names=["diff"], ref_val=0)
        plt.title(f"{condition} – {param} (HDI of Group Diff)")
        plt.tight_layout()
        plt.show()

# === Parameters to plot ===
core_params = ["A", "B", "C", "mu"]
dyn_features = ["dominant_freq", "entropy", "lyapunov", "erp_amp", "erp_latency", "erp_width"]

# === Run plots ===
plot_violin_and_hdi(rest_params, core_params + dyn_features, condition="REST")
plot_violin_and_hdi(task_params, core_params + dyn_features, condition="TASK")
