In [None]:
import os
import sys
import tempfile

import numpy as np
import yaml
import gradio as gr

import matplotlib
matplotlib.use("Agg")  # non-interactive backend for web apps
from matplotlib import pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.lines import Line2D

from numpy.fft import fft
from scipy import signal

# ---------------------------------------------------------
# Project imports (adjust path if needed)
# ---------------------------------------------------------
try:
    CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
except NameError:
    CURRENT_DIR = os.getcwd()

PROJECT_ROOT = os.path.abspath(os.path.join(CURRENT_DIR, "..", ".."))
if PROJECT_ROOT not in sys.path:
    sys.path.append(PROJECT_ROOT)

# Your model code must expose these
from stn_gpe import STN_GPe_loop, Analysis

# ---------------------------------------------------------
# Matplotlib style
# ---------------------------------------------------------
plt.rcParams.update({
    "font.family": "serif",
    "font.serif": ["Times New Roman"],
    "font.size": 12,
    "font.weight": "bold",
    "axes.titlesize": 12,
    "axes.titleweight": "bold",
    "axes.labelsize": 11,
    "axes.labelweight": "bold",
    "xtick.labelsize": 10,
    "ytick.labelsize": 10,
    "legend.fontsize": 10,
    "figure.facecolor": "white",
    "axes.grid": False
})


# ---------------------------------------------------------
# Helper: create a YAML file directly from UI params
# ---------------------------------------------------------
def create_yaml_from_inputs(params: dict) -> str:
    """
    Create a temporary YAML file from the provided params dict and
    return its path.
    """
    fd, tmp_path = tempfile.mkstemp(suffix=".yaml")
    with os.fdopen(fd, "w") as tmp_file:
        yaml.safe_dump(params, tmp_file)
    return tmp_path


# ---------------------------------------------------------
# Core simulation + analysis function (used by Gradio)
# ---------------------------------------------------------
def run_stn_gpe_sim(
    stn_gpe_units,
    time_steps,
    dt,
    lat_sparse,
    inter_sparse,
    I_strd2_gpe,
    lat_strength_stn,
    lat_strength_gpe,
    wsg_strength,
    wgs_strength,
    I_gpe_ext,
    I_stn_ext,
    binsize,
    stn_gpe_noise,
    DBS,
    DBS_func,
    DBS_freq,
    DBS_duty,
    DBS_A1,
    DBS_A2,
    pulseinterval,
    center,
    spread_amplitude,
    sigma,
    progress=gr.Progress(track_tqdm=True),
):
    """
    Run STN–GPe simulation with parameters from the UI and return:
    - Matplotlib figure (voltages, LFPs, rasters, spectrograms)
    - Text with metrics (entropy, synchrony, rate, frequency)
    """

    # Convert UI values to appropriate types
    stn_gpe_units = int(stn_gpe_units)
    time_steps = int(time_steps)
    dt = float(dt)
    lat_sparse = float(lat_sparse)
    inter_sparse = float(inter_sparse)
    I_strd2_gpe = float(I_strd2_gpe)
    lat_strength_stn = float(lat_strength_stn)
    lat_strength_gpe = float(lat_strength_gpe)
    wsg_strength = float(wsg_strength)
    wgs_strength = float(wgs_strength)
    I_gpe_ext = float(I_gpe_ext)
    I_stn_ext = float(I_stn_ext)
    binsize = int(binsize)
    stn_gpe_noise = float(stn_gpe_noise)
    DBS = bool(DBS)
    DBS_func = str(DBS_func)
    DBS_freq = float(DBS_freq)
    DBS_duty = float(DBS_duty)
    DBS_A1 = float(DBS_A1)
    DBS_A2 = float(DBS_A2)
    pulseinterval = int(pulseinterval)
    center = int(center)
    spread_amplitude = float(spread_amplitude)
    sigma = float(sigma)

    # 1) Create YAML config entirely from inputs
    config = {
        "stn_gpe_units": stn_gpe_units,
        "time": time_steps,
        "dt": dt,
        "lat_sparse": lat_sparse,
        "inter_sparse": inter_sparse,
        "I_strd2_gpe": I_strd2_gpe,
        "lat_strength_stn": lat_strength_stn,
        "lat_strength_gpe": lat_strength_gpe,
        "wsg_strength": wsg_strength,
        "wgs_strength": wgs_strength,
        "I_gpe_ext": I_gpe_ext,
        "I_stn_ext": I_stn_ext,
        "binsize": binsize,
        "stn_gpe_noise": stn_gpe_noise,
        "DBS": DBS,
        "DBS_func": DBS_func,
        "DBS_freq": DBS_freq,
        "DBS_duty": DBS_duty,
        "DBS_A1": DBS_A1,
        "DBS_A2": DBS_A2,
        "pulseinterval": pulseinterval,
        # DBS spread
        "center": center,
        "spread_amplitude": spread_amplitude,
        "sigma": sigma,
    }

    yaml_path = create_yaml_from_inputs(config)

    # 2) Run simulation
    progress(0, desc="Starting STN–GPe simulation...")
    results = STN_GPe_loop(yaml_path)  # STN_GPe_loop will read this YAML internally
    progress(60, desc="Simulation complete, running analysis...")

    # 3) Extract results
    V_stn_time_all = np.array(results["v_stn"])
    V_gpe_time_all = np.array(results["v_gpe"])
    spike_monitor_stn = results["spike_stn"]
    spike_monitor_gpe = results["spike_gpe"]
    lfp_stn = np.array(results["lfp_stn"])
    lfp_gpe = np.array(results["lfp_gpe"])

    iter_steps = time_steps
    h = dt  # ms

    # Sampling rate (Hz) from dt in ms
    sr = int(round(1000.0 / h)) if h > 0 else 10000

    # Time window for analysis: last 1 s or whatever is available
    steps_for_one_second = int(round(1000.0 / h))
    window_steps = min(steps_for_one_second, iter_steps)
    t_high = iter_steps
    t_low = max(0, t_high - window_steps)

    t_chunk = np.linspace(
        0,
        (t_high - t_low) * h / 1000.0,
        t_high - t_low,
    )

    # ************************Analysis**************************************
    analysis_STN = Analysis(spike_monitor_stn[t_low:t_high])
    analysis_GPe = Analysis(spike_monitor_gpe[t_low:t_high])

    # ************************* LFP smoothing ********************************
    pad = min(steps_for_one_second, t_low)  # ensure non-negative index
    lfp_smooth_stn = signal.savgol_filter(
        lfp_stn[t_low - pad:t_high],
        window_length=11,
        polyorder=5,
        deriv=0,
        delta=1.0,
        axis=-1,
        mode="interp",
        cval=0.0,
    )
    lfp_smooth_gpe = signal.savgol_filter(
        lfp_gpe[t_low - pad:t_high],
        window_length=11,
        polyorder=5,
        deriv=0,
        delta=1.0,
        axis=-1,
        mode="interp",
        cval=0.0,
    )

    # ************************* FFT ***************************************
    fft_output_stn = fft(lfp_stn)
    fft_output_gpe = fft(lfp_smooth_gpe)
    N = len(fft_output_stn)
    n = np.arange(N)
    T = N / sr
    freq = n / T  # not plotted, but kept if needed

    # ***********************Spectrogram***********************************
    fs = sr
    window_size = sr  # ~1 s window
    overlap = int(0.95 * window_size)
    nperseg = window_size
    noverlap = overlap

    # STN Spectrogram
    f_stn, t_spec_stn, Sxx_stn = signal.spectrogram(
        lfp_smooth_stn,
        fs=fs,
        nperseg=nperseg,
        noverlap=noverlap,
        window="hamming",
    )

    # GPe Spectrogram
    f_gpe, t_spec_gpe, Sxx_gpe = signal.spectrogram(
        lfp_smooth_gpe,
        fs=fs,
        nperseg=nperseg,
        noverlap=noverlap,
        window="hamming",
    )

    # **************************Spectral entropy***************************
    avg_entropy_stn = analysis_STN.spectral_entropy(
        signal=lfp_stn[t_low:t_high],
        fs=fs,
        nperseg=nperseg,
        fmax=35,
        normalize=True,
    )
    avg_entropy_gpe = analysis_GPe.spectral_entropy(
        signal=lfp_gpe[t_low:t_high],
        fs=fs,
        nperseg=nperseg,
        fmax=35,
        normalize=True,
    )

    # ***************************Synchrony********************************
    R_sync_stn, Ravg_stn = analysis_STN.synchrony()
    R_sync_gpe, Ravg_gpe = analysis_GPe.synchrony()

    # **************************Rate Output******************************
    rate_data = analysis_STN.spike_rate(binsize=binsize)
    mean_std = rate_data["mean_std"]

    #**************************Frequency*******************************
    frequency_avg_STN, frequency_max_STN, frequency_min_STN = analysis_STN.frequency(dt=h)
    frequency_avg_GPe, frequency_max_GPe, frequency_min_GPe = analysis_GPe.frequency(dt=h)

    progress(80, desc="Generating plots...")

    # ---------------------------------------------------------
    # Plotting
    # ---------------------------------------------------------
    fig, axs = plt.subplots(
        4, 2, figsize=(8, 6), facecolor="white",
        gridspec_kw={"height_ratios": [1, 1, 1, 1]},
    )
    fig.patch.set_visible(True)
    fig.patch.set_facecolor("white")
    fig.patch.set_edgecolor("white")
    fig.patch.set_linewidth(0)
    fig.subplots_adjust(
        wspace=0.3,
        left=0.06,
        right=0.98,
        top=0.95,
        bottom=0.08,
        hspace=0.4,
    )

    colors = {"STN": "#054b7c", "GPe": "#9f0d03ce"}

    # Choose safe neuron indices in case units < 16
    stn_i = min(15, V_stn_time_all.shape[1] - 1)
    stn_j = min(15, V_stn_time_all.shape[2] - 1)
    gpe_i = min(7, V_gpe_time_all.shape[1] - 1)
    gpe_j = min(7, V_gpe_time_all.shape[2] - 1)

    # STN voltage
    axs[0, 0].plot(t_chunk, V_stn_time_all[t_low:t_high, stn_i, stn_j], color=colors["STN"])
    axs[0, 0].set_xlabel("Time (s)")
    axs[0, 0].set_ylabel("V (mV)")

    # GPe voltage
    axs[0, 1].plot(t_chunk, V_gpe_time_all[t_low:t_high, gpe_i, gpe_j], color=colors["GPe"])
    axs[0, 1].set_xlabel("Time (s)")

    # STN LFP
    axs[1, 0].plot(t_chunk, lfp_stn[t_low:t_high], color=colors["STN"])
    axs[1, 0].set_xlabel("Time (s)")
    axs[1, 0].set_ylabel("V (mV)")

    # GPe LFP
    axs[1, 1].plot(t_chunk, lfp_gpe[t_low:t_high], color=colors["GPe"])
    axs[1, 1].set_xlabel("Time (s)")

    # Raster STN
    spike_array_stn = np.array(spike_monitor_stn[t_low:t_high])
    num_neurons_stn = spike_array_stn.shape[1]
    time_steps_stn = spike_array_stn.shape[0]
    t_raster = np.linspace(0, time_steps_stn * h / 1000.0, time_steps_stn)
    for n_idx in range(num_neurons_stn):
        axs[2, 0].scatter(
            t_raster,
            (n_idx + 1) * spike_array_stn[:, n_idx],
            color=colors["STN"],
            s=0.5,
        )
    axs[2, 0].set_ylim(0.5, num_neurons_stn + 0.5)
    axs[2, 0].set_xlabel("t (s)")
    axs[2, 0].set_ylabel("Neuron")

    # Raster GPe
    spike_array_gpe = np.array(spike_monitor_gpe[t_low:t_high])
    num_neurons_gpe = spike_array_gpe.shape[1]
    time_steps_gpe = spike_array_gpe.shape[0]
    t_raster_gpe = np.linspace(0, time_steps_gpe * h / 1000.0, time_steps_gpe)
    for n_idx in range(num_neurons_gpe):
        axs[2, 1].scatter(
            t_raster_gpe,
            (n_idx + 1) * spike_array_gpe[:, n_idx],
            color=colors["GPe"],
            s=0.5,
        )
    axs[2, 1].set_ylim(0.5, num_neurons_gpe + 0.5)
    axs[2, 1].set_xlabel("t (s)")
    axs[2, 1].set_ylabel("Neuron")

    # Custom colormaps
    colors_blue = [
        (1, 1, 1),
        (0.98, 0.98, 1),
        (0.95, 0.95, 1),
        (0.9, 0.9, 0.95),
        (0.89, 0.89, 0.94),
        (0.85, 0.85, 0.9),
        (0.8, 0.8, 0.85),
        (0.05, 0.3, 0.7),
        (0, 0.1, 0.5),
    ]
    colors_red = [
        (1, 1, 1),
        (1, 0.98, 0.98),
        (1, 0.95, 0.95),
        (1, 0.92, 0.92),
        (1, 0.9, 0.9),
        (1, 0.85, 0.85),
        (1, 0.81, 0.81),
        (0.7, 0.05, 0.05),
        (0.5, 0, 0),
    ]
    cmap_stn = LinearSegmentedColormap.from_list("custom_blues", colors_blue)
    cmap_gpe = LinearSegmentedColormap.from_list("custom_reds", colors_red)

    # Spectrograms
    axs[3, 0].pcolormesh(
        t_spec_stn,
        f_stn,
        10 * np.log10(Sxx_stn + 1e-12),
        cmap=cmap_stn,
        shading="gouraud",
    )
    axs[3, 0].set_ylabel("Frequency (Hz)")
    axs[3, 0].set_xlabel("Time (s)")
    axs[3, 0].set_ylim(0, 40)

    axs[3, 1].pcolormesh(
        t_spec_gpe,
        f_gpe,
        10 * np.log10(Sxx_gpe + 1e-12),
        cmap=cmap_gpe,
        shading="gouraud",
    )
    axs[3, 1].set_ylabel("Frequency (Hz)")
    axs[3, 1].set_xlabel("Time (s)")
    axs[3, 1].set_ylim(0, 40)

    # Style axes
    for ax in axs.flat:
        ax.set_facecolor("white")
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
        ax.spines["left"].set_visible(True)
        ax.spines["bottom"].set_visible(True)
        for spine in ["left", "bottom"]:
            ax.spines[spine].set_linewidth(0.8)
            ax.spines[spine].set_color("black")
        ax.xaxis.set_ticks_position("bottom")
        ax.yaxis.set_ticks_position("left")

    legend_elements = [
        Line2D([0], [0], color=colors["STN"], label="STN", linewidth=3),
        Line2D([0], [0], color=colors["GPe"], label="GPe", linewidth=3),
    ]

    fig.legend(
        handles=legend_elements,
        loc="center",
        bbox_to_anchor=(0.5, 0.02),
        ncol=2,
        frameon=False,
        handlelength=1.5,
    )

    plt.tight_layout()

    # ---------------------------------------------------------
    # Metrics text
    # ---------------------------------------------------------
    metrics_table = (
        f"|{'Metric':<15}|{'STN':^10}|{'GPe':^10}|\n"
        f"|{'-'*15}|{'-'*10}|{'-'*10}|\n"
        f"|{'Entropy':<15}|{avg_entropy_stn:^10.3f}|{avg_entropy_gpe:^10.3f}|\n"
        f"|{'Synchrony':<15}|{Ravg_stn:^10.3f}|{Ravg_gpe:^10.3f}|\n"
        f"|{'Standard deviation':<15}|{mean_std:^10.2f}|{'':^10}|\n"
        # f"|{'Freq avg (Hz)':<15}|{frequency_avg_STN:^10.2f}|{frequency_avg_GPe:^10.2f}|\n"
    )

    progress(100, desc="Done.")

    return fig, metrics_table


# ---------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------
with gr.Blocks() as demo:
    gr.Markdown("## STN–GPe Network Simulation")

    with gr.Row():
        with gr.Column():
            stn_gpe_units = gr.Slider(1, 64, value=16, step=1, label="stn_gpe_units")
            time_steps = gr.Number(value=50000, label="time (steps)")
            dt = gr.Number(value=0.1, label="dt (ms)")
            lat_sparse = gr.Number(value=0.1, label="lat_sparse")
            inter_sparse = gr.Number(value=0.1, label="inter_sparse")
            I_strd2_gpe = gr.Number(value=5, label="I_strd2_gpe")
            lat_strength_stn = gr.Number(value=0.02, label="lat_strength_stn")
            lat_strength_gpe = gr.Number(value=0.1, label="lat_strength_gpe")
            wsg_strength = gr.Number(value=0.1, label="wsg_strength")
            wgs_strength = gr.Number(value=0.1, label="wgs_strength")
            I_gpe_ext = gr.Number(value=6, label="I_gpe_ext")
            I_stn_ext = gr.Number(value=12, label="I_stn_ext")
            binsize = gr.Number(value=100, label="binsize")
            stn_gpe_noise = gr.Number(value=3, label="stn_gpe_noise")

        with gr.Column():
            DBS = gr.Checkbox(value=False, label="DBS")
            DBS_func = gr.Dropdown(
                choices=[
                    "monophasicDBS",
                    "biphasicDBS",
                    "biphasicDBS_uninoise",
                    "biphasicDBS_normalnoise",
                ],
                value="biphasicDBS",
                label="DBS_func",
            )
            DBS_freq = gr.Number(value=130, label="DBS_freq (Hz)")
            DBS_duty = gr.Number(value=0.052, label="DBS_duty")
            DBS_A1 = gr.Number(value=250, label="DBS_A1")
            DBS_A2 = gr.Number(value=-250, label="DBS_A2")
            pulseinterval = gr.Number(value=10, label="pulseinterval")
            center = gr.Number(value=7, label="center (DBS spread center)")
            spread_amplitude = gr.Number(value=1, label="spread_amplitude")
            sigma = gr.Number(value=3, label="sigma")

    run_button = gr.Button("Run Simulation")

    plot_out = gr.Plot(label="Simulation Results")
    metrics_out = gr.Textbox(
        label="Metrics (Entropy, Synchrony, Standard deviation)",
        lines=8,
        interactive=False,
    )

    run_button.click(
        fn=run_stn_gpe_sim,
        inputs=[
            stn_gpe_units,
            time_steps,
            dt,
            lat_sparse,
            inter_sparse,
            I_strd2_gpe,
            lat_strength_stn,
            lat_strength_gpe,
            wsg_strength,
            wgs_strength,
            I_gpe_ext,
            I_stn_ext,
            binsize,
            stn_gpe_noise,
            DBS,
            DBS_func,
            DBS_freq,
            DBS_duty,
            DBS_A1,
            DBS_A2,
            pulseinterval,
            center,
            spread_amplitude,
            sigma,
        ],
        outputs=[plot_out, metrics_out],
    )

if __name__ == "__main__":
    demo.launch()


* Running on local URL:  http://127.0.0.1:7861
* To create a public link, set `share=True` in `launch()`.


  0%|          | 0/50000 [00:00<?, ?steps/s]

  0%|          | 0/50000 [00:00<?, ?steps/s]

  0%|          | 0/50000 [00:00<?, ?steps/s]

  0%|          | 0/50000 [00:00<?, ?steps/s]

  0%|          | 0/50000 [00:00<?, ?steps/s]

  0%|          | 0/50000 [00:00<?, ?steps/s]

  0%|          | 0/50000 [00:00<?, ?steps/s]