In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from ipywidgets import interact, widgets, fixed
from IPython.display import display
%matplotlib inline

# === Step 1: Load Data ===
rename_map = {
    '0': 'mu_x',
    '1': 'mu_y',
    '2': 'S1',
    '3': 'B1',
    '4': 'bold_baseline',
    '5': 'B2',
    '6': 'S2',
    '7': 'HRF1',
    '8': 'HRF2',
    '9': 'rsq'
}


index_cols = ['subject', 'original_row', 'vertex_id', 'vertex_region']


# Load data once
df_mem = pd.read_parquet("aggregate_fixed_memantine_iterparams.parquet.gzip").set_index(index_cols).sort_index().reset_index()
df_plac = pd.read_parquet("aggregate_fixed_placebo_iterparams.parquet.gzip").set_index(index_cols).sort_index().reset_index()

df_mem.rename(columns=rename_map, inplace=True)
df_plac.rename(columns=rename_map, inplace=True)

# df_fixed_params_memantine = df_mem.copy()
# df_fixed_params_placebo = df_plac.copy()

# === Step 2: Helper - Remove Outliers ===
def remove_upper_outliers(x, y, lower_percentile=2.3, upper_percentile=97.7):
    x = np.array(x)
    y = np.array(y)
    x_lower = np.percentile(x, lower_percentile)
    x_upper = np.percentile(x, upper_percentile)
    y_lower = np.percentile(y, lower_percentile)
    y_upper = np.percentile(y, upper_percentile)
    mask = (x >= x_lower) & (x <= x_upper) & (y >= y_lower) & (y <= y_upper)
    return x[mask], y[mask]

# === Step 3: Helper - Apply Inclusion Criteria ===
def get_filtered_indices(my_df_mem, my_df_plac):
    x_mem = my_df_mem[rename_map['0']]  # 'mu_x'
    y_mem = my_df_mem[rename_map['1']]  # 'mu_y'
    ecc_mem = np.sqrt(x_mem**2 + y_mem**2)
    r_mem = my_df_mem[rename_map['9']]  # 'rsq'
    b1_mem = my_df_mem[rename_map['3']] # 'hrf_deriv'
    b2_mem = my_df_mem[rename_map['5']] # 'hrf_disp'
    s1_mem = my_df_mem[rename_map['2']] # 'srf_size'
    s2_mem = my_df_mem[rename_map['6']] # 'prf_size'

    # Placebo
    x_plac = my_df_plac[rename_map['0']]
    y_plac = my_df_plac[rename_map['1']]
    ecc_plac = np.sqrt(x_plac**2 + y_plac**2)
    r_plac = my_df_plac[rename_map['9']]
    b1_plac = my_df_plac[rename_map['3']]
    b2_plac = my_df_plac[rename_map['5']]
    s1_plac = my_df_plac[rename_map['2']]
    s2_plac = my_df_plac[rename_map['6']]

    mask = (
        (r_mem >= 0.3) &
        (r_plac >= 0.3) &
        (np.abs(r_mem - r_plac) <= 0.3) &
        (ecc_mem > 0.5) & (ecc_mem < 4.5) &
        (ecc_plac > 0.5) & (ecc_plac < 4.5) &
        (b1_mem > 0) & (b1_plac > 0) &
        (s1_mem < s2_mem) & (s1_plac < s2_plac)
    )

    return np.where(mask)[0]

# === Step 4: Plot Function ===
def plot_data(subject, variable, region, apply_filter):
    col_idx = str(variable)

    mem_sub = df_mem
    plac_sub = df_plac

    if subject != "All":
    # Filter by subject
        mem_sub = df_mem[df_mem["subject"] == subject]
        plac_sub = df_plac[df_plac["subject"] == subject]

    # Filter by brain region
    if region != "All":
        mem_sub = mem_sub[mem_sub["vertex_region"] == region]
        plac_sub = plac_sub[plac_sub["vertex_region"] == region]

    x = mem_sub[col_idx].values
    y = plac_sub[col_idx].values

    if apply_filter:
        indices = get_filtered_indices(mem_sub, plac_sub)
        x = mem_sub.iloc[indices][col_idx].values
        y = plac_sub.iloc[indices][col_idx].values
        x, y = remove_upper_outliers(x, y)

    # Apply renaming to both dataframes

    # Plotting
    plt.figure(figsize=(6, 6))
    plt.scatter(x, y, alpha=0.6)
    plt.xlabel("Memantine")
    plt.ylabel("Placebo")
    plt.title(f"Variable {col_idx} | Subject {subject} | Region {region}")
    plt.grid(True)
    plt.plot([x.min(), x.max()], [x.min(), x.max()], 'k--')  # y=x line
    plt.show()

# === Step 5: Create Widgets ===
subjects = ["All"] + sorted(df_mem["subject"].unique())
variables = list(rename_map.values())
regions = ["All"] + sorted(df_mem["vertex_region"].unique())

subject_dropdown = widgets.Dropdown(options=subjects, description="Subject:")
variable_dropdown = widgets.Dropdown(options=variables, description="Variable:")
region_dropdown = widgets.Dropdown(options=regions, description="Region:")
filter_toggle = widgets.Checkbox(value=False, description="Apply Inclusion Criteria")

ui = widgets.VBox([subject_dropdown, variable_dropdown, region_dropdown, filter_toggle])
out = widgets.interactive_output(
    plot_data,
    {
        "subject": subject_dropdown,
        "variable": variable_dropdown,
        "region": region_dropdown,
        "apply_filter": filter_toggle,
    },
)

# Display everything
display(ui, out)