In [None]:
import os
import sys
import warnings
import tempfile

import numpy as np
import torch
import torch.nn as nn  # if used inside your imported code
import gradio as gr

import matplotlib
matplotlib.use("Agg")  # for web apps
from matplotlib import pyplot as plt

# Optional: only if you really use seaborn somewhere
# import seaborn as sns

# Root project setup
try:
    CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
except NameError:
    CURRENT_DIR = os.getcwd()

PROJECT_ROOT = os.path.abspath(os.path.join(CURRENT_DIR, "..", ".."))
if PROJECT_ROOT not in sys.path:
    sys.path.append(PROJECT_ROOT)

print("Project root added to sys.path:", PROJECT_ROOT)

# Imports from your project
from envs import *           # should define IGTEnv
from basal_ganglia import *  # should define train(...)
from stn_gpe import load_yaml   # to load igt_params.yaml

# Filter warnings
warnings.filterwarnings("ignore")

# Paths
IGT_PARAMS_PATH = os.path.join(
    PROJECT_ROOT,
    "params",
    "decision_making_task_params",
    "igt_params.yaml",
)

STN_PARAM_DIR = os.path.join(
    PROJECT_ROOT,
    "params",
    "stn_gpe_params",
)

# ---------------------------------------------------------
# Core function used by Gradio
# ---------------------------------------------------------
def run_igt_task(
    use_stn_gpe,     # checkbox
    stn_mode,        # dropdown: 'normal', 'PD', 'std_DBS'
    d1_amp,
    d2_amp,
    var,
    gpi_mean,
    progress=gr.Progress(track_tqdm=True),
):
    """
    Run the IGT decision-making task, with either noise or STN–GPe-based drive.

    Parameters from UI:
    - use_stn_gpe: bool, whether to use STN-GPe LFP (via YAML) or noise
    - stn_mode: 'normal', 'PD', or 'std_DBS' (only used if use_stn_gpe=True)
    - d1_amp, d2_amp: D1 / D2 input amplitudes for train(...)
    - var: standard deviation for reward & loss (same for all decks)
    - gpi_mean: mean GPI parameter for train(...)
    """

    # Convert to proper numeric types
    d1_amp = float(d1_amp)
    d2_amp = float(d2_amp)
    var = float(var)
    gpi_mean = float(gpi_mean)

    # 1) Load IGT task parameters from yaml
    params = load_yaml(IGT_PARAMS_PATH)
    TRIALS = int(params["TRIALS"])
    EPOCHS = int(params["EPOCHS"])
    NUM_BINS = int(params["NUM_BINS"])
    LR = float(params["LR"])
    NUM_ARMS = int(params["NUM_ARMS"])
    SCALING_FACTOR = float(params["SCALING_FACTOR"])

    # We interpret "var" as the reward/loss std (before scaling)
    REW_STD = 0 #var
    LOSS_STD = 0 #var

    # 2) Build environment
    mean_reward = np.array([100, 100, 50, 50]) / SCALING_FACTOR
    std_reward = np.array([REW_STD] * 4) / SCALING_FACTOR
    mean_loss = np.array([-250, -1250, -50, -250]) / SCALING_FACTOR
    std_loss = np.array([LOSS_STD] * 4) / SCALING_FACTOR

    env = IGTEnv(
        mean_reward=mean_reward,
        std_reward=std_reward,
        mean_loss=mean_loss,
        std_loss=std_loss,
    )

    # 3) Decide STN_data file path (or None for noise case)
    if use_stn_gpe:
        mode_to_file = {
            "normal": "params_Normal.yaml",
            "PD": "params_PD.yaml",
            "std_DBS": "params_std_DBS.yaml",
        }
        fname = mode_to_file[stn_mode]
        STN_DATA_path = os.path.join(STN_PARAM_DIR, fname)
    else:
        STN_DATA_path = None

    progress(5, desc="Environment ready. Starting training...")

    # 4) Run training
    # NOTE: train() is assumed to be imported from basal_ganglia via wildcard.
    reward_monitor, arm_chosen_monitor, avg_counts, ip_monitor, dp_monitor, ep_monitor, _ = train(
        env,
        trails=TRIALS,
        epochs=EPOCHS,
        lr=LR,
        bins=NUM_BINS,
        STN_data=STN_DATA_path,  # None -> noise; path -> STN-GPe sim
        d1_amp=d1_amp,
        d2_amp=d2_amp,
        gpi_threshold=0.15,
        max_gpi_iters=50,
        del_lim=None,
        train_IP=False,
        del_med=None,
        printing=False,
        gpi_mean=gpi_mean,
        ep_0=var,
    )

    progress(70, desc="Training done. Computing metrics and building plots...")

    # 5) Analyze results
    A_picks = avg_counts[0]
    B_picks = avg_counts[1]
    C_picks = avg_counts[2]
    D_picks = avg_counts[3]

    Avg_A_picks = torch.mean(A_picks, dim=0)
    Avg_B_picks = torch.mean(B_picks, dim=0)
    Avg_C_picks = torch.mean(C_picks, dim=0)
    Avg_D_picks = torch.mean(D_picks, dim=0)

    # IGT score: (C + D - A - B) per bin, averaged over epochs
    IGT_score = torch.mean(
        torch.add(C_picks, D_picks) - torch.add(A_picks, B_picks),
        dim=0,
    )  # shape: [NUM_BINS]

    IGT_dev = torch.std(
        torch.add(C_picks, D_picks) - torch.add(A_picks, B_picks),
        dim=0,
    ) / torch.sqrt(torch.tensor(EPOCHS, dtype=torch.float32))

    # Convert to numpy for plotting/printing
    IGT_score_np = IGT_score.detach().cpu().numpy().reshape(-1)
    IGT_dev_np = IGT_dev.detach().cpu().numpy().reshape(-1)

    BINS = np.arange(NUM_BINS) + 1

    # 6) Build plot: avg picks + IGT score
    fig = plt.figure(figsize=(10, 3))
    ax1 = fig.add_subplot(1, 2, 1)
    ax2 = fig.add_subplot(1, 2, 2)

    # Left: avg picks
    ax1.plot(BINS, Avg_A_picks.detach().cpu().numpy(), label="A", color="red", marker="o")
    ax1.plot(BINS, Avg_B_picks.detach().cpu().numpy(), label="B", color="red", marker="s")
    ax1.plot(BINS, Avg_C_picks.detach().cpu().numpy(), label="C", color="green", marker="D")
    ax1.plot(BINS, Avg_D_picks.detach().cpu().numpy(), label="D", color="green", marker="^")
    ax1.set_title("Avg card picks per bin")
    ax1.set_xlabel("Bins")
    ax1.set_ylabel("Avg picks")
    ax1.legend(loc="upper right")

    # Right: IGT score per bin with error bars
    ax2.bar(BINS, IGT_score_np)
    ax2.errorbar(BINS, IGT_score_np, yerr=IGT_dev_np, fmt="o", color="black")
    ax2.set_title("IGT score")
    ax2.set_xlabel("Bins")
    ax2.set_ylabel("IGT Score")

    fig.tight_layout()

    # 7) Build metrics text
    overall_mean_score = float(IGT_score_np.mean())
    overall_se = float(np.sqrt((IGT_dev_np**2).mean()))

    metrics_text = (
        f"Use STN-GPe: {use_stn_gpe}, Mode: {stn_mode if use_stn_gpe else 'Noise only'}\n\n"
        f"IGT score per bin: {IGT_score_np}\n"
        f"IGT SE per bin:    {IGT_dev_np}\n\n"
        f"Overall mean IGT score: {overall_mean_score:.3f}\n"
        f"Approx. overall SE:     {overall_se:.3f}\n"
    )

    progress(100, desc="Done.")

    return fig, metrics_text


# ---------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------
with gr.Blocks() as demo:
    gr.Markdown("# Iowa Gambling Task simulation (IGT)")

    with gr.Row():
        with gr.Column():
            use_stn_gpe = gr.Checkbox(
                value=False,
                label="Use STN–GPe-based simulation (instead of noise)?",
            )
            stn_mode = gr.Dropdown(
                choices=["normal", "PD", "std_DBS"],
                value="normal",
                label="STN–GPe condition (if enabled)",
            )

        with gr.Column():
            d1_amp = gr.Number(value=0.5, label="d1_amp")
            d2_amp = gr.Number(value=0.01, label="d2_amp")
            var = gr.Number(value=0.0, label="var (reward/loss std)")
            gpi_mean = gr.Number(value=1.0, label="gpi_mean")

    run_button = gr.Button("Run IGT Task")

    plot_out = gr.Plot(label="IGT Results (Avg Picks & IGT Score)")
    metrics_out = gr.Textbox(
        label="IGT Metrics",
        lines=8,
        interactive=False,
    )

    run_button.click(
        fn=run_igt_task,
        inputs=[use_stn_gpe, stn_mode, d1_amp, d2_amp, var, gpi_mean],
        outputs=[plot_out, metrics_out],
    )

if __name__ == "__main__":
    demo.launch()


Project root added to sys.path: c:\Myfiles\Workshop_comp_modelling\FromSTNSpikesToChoices
* Running on local URL:  http://127.0.0.1:7860
* To create a public link, set `share=True` in `launch()`.


Running STN-GPe system


  0%|          | 0/50000 [00:00<?, ?steps/s]

tensor(0.0313, dtype=torch.float64)
tensor(1., dtype=torch.float64) tensor(0.0274, dtype=torch.float64)
tensor(0.9996) tensor(0.0250)


  0%|          | 0/100 [00:00<?, ?steps/s]

Running STN-GPe system


  0%|          | 0/50000 [00:00<?, ?steps/s]

tensor(0.0022, dtype=torch.float64)
tensor(1., dtype=torch.float64) tensor(0.0022, dtype=torch.float64)
tensor(1.0000) tensor(0.0020)


  0%|          | 0/100 [00:00<?, ?steps/s]