In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from ipywidgets import interact, widgets, fixed
from IPython.display import display
%matplotlib inline

# === Step 1: Load and Preprocess Data ===
rename_map = {
    '0': 'mu_x',
    '1': 'mu_y',
    '2': 'S1',
    '3': 'B1',
    '4': 'bold_baseline',
    '5': 'B2',
    '6': 'S2',
    '7': 'HRF1',
    '8': 'HRF2',
    '9': 'rsq'
}

index_cols = ['subject', 'original_row', 'vertex_id', 'vertex_region']

def load_and_prepare_data(filepath):
    df = pd.read_parquet(filepath).set_index(index_cols).sort_index().reset_index()
    df.rename(columns=rename_map, inplace=True)
    # Add eccentricity column
    df['eccentricity'] = np.sqrt(df['mu_x']**2 + df['mu_y']**2)
    df['surround_index'] = (df['B2']*df['S2']**2) / (df['B1']*df['S1']**2)
    return df

# Load datasets
df_mem_fixed = load_and_prepare_data("aggregate_fixed_memantine_iterparams_final.parquet.gzip")
df_plac_fixed = load_and_prepare_data("aggregate_fixed_placebo_iterparams_final.parquet.gzip")
# df_mem_free = load_and_prepare_data("aggregate_free_memantine_iterparams.parquet.gzip")
# df_plac_free = load_and_prepare_data("aggregate_free_placebo_iterparams.parquet.gzip")

# === Step 2: Helper - Remove Outliers ===
def remove_upper_outliers(x, y, lower_percentile=2.3, upper_percentile=97.7):
    x = np.array(x)
    y = np.array(y)
    x_lower = np.percentile(x, lower_percentile)
    x_upper = np.percentile(x, upper_percentile)
    y_lower = np.percentile(y, lower_percentile)
    y_upper = np.percentile(y, upper_percentile)
    mask = (x >= x_lower) & (x <= x_upper) & (y >= y_lower) & (y <= y_upper)
    return x[mask], y[mask]

# === Step 3: Helper - Apply Inclusion Criteria ===
def get_filtered_indices(my_df_mem, my_df_plac, min_ecc=0.5, max_ecc=4.5):
    x_mem = my_df_mem['mu_x']
    y_mem = my_df_mem['mu_y']
    ecc_mem = np.sqrt(x_mem**2 + y_mem**2)
    r_mem = my_df_mem['rsq']
    b1_mem = my_df_mem['B1']
    b2_mem = my_df_mem['B2']
    s1_mem = my_df_mem['S1']
    s2_mem = my_df_mem['S2']

    x_plac = my_df_plac['mu_x']
    y_plac = my_df_plac['mu_y']
    ecc_plac = np.sqrt(x_plac**2 + y_plac**2)
    r_plac = my_df_plac['rsq']
    b1_plac = my_df_plac['B1']
    b2_plac = my_df_plac['B2']
    s1_plac = my_df_plac['S1']
    s2_plac = my_df_plac['S2']

    mask = (
        (r_mem >= 0.3) &
        (r_plac >= 0.3) &
        (np.abs(r_mem - r_plac) <= 0.3) &
        (ecc_mem > min_ecc) & (ecc_mem < max_ecc) &
        (ecc_plac > min_ecc) & (ecc_plac < max_ecc) &
        (b1_mem > 0) & (b1_plac > 0) &
        (s1_mem < s2_mem) & (s1_plac < s2_plac)
    )

    return np.where(mask)[0]

# === Step 4: Plotting Logic ===
def filter_df(df_mem, df_plac, subjects, regions, var_mem, var_plac, apply_filter, min_ecc, max_ecc):
    mem_sub = df_mem
    plac_sub = df_plac

    # Filter by subjects
    if subjects:
        mem_sub = mem_sub[mem_sub["subject"].isin(subjects)]
        plac_sub = plac_sub[plac_sub["subject"].isin(subjects)]
    
    # Filter by regions
    if regions:
        mem_sub = mem_sub[mem_sub["vertex_region"].isin(regions)]
        plac_sub = plac_sub[plac_sub["vertex_region"].isin(regions)]

    if apply_filter:
        indices = get_filtered_indices(mem_sub, plac_sub, min_ecc, max_ecc)
        mem_sub = mem_sub.iloc[indices]
        plac_sub = plac_sub.iloc[indices]
        x, y = remove_upper_outliers(mem_sub[var_mem].values, plac_sub[var_plac].values)
    else:
        x = mem_sub[var_mem].values
        y = plac_sub[var_plac].values

    return x, y

def plot_comparison(x, y, title, ax):
    ax.scatter(x, y, alpha=0.6)
    ax.set_xlabel("Memantine")
    ax.set_ylabel("Placebo")
    ax.set_title(title)
    ax.grid(True)
    ax.plot([min(x.min(), y.min()), max(x.max(), y.max())],
            [min(x.min(), y.min()), max(x.max(), y.max())], 'k--')

def plot_data(subjects, regions, var_mem, var_plac, apply_filter, min_ecc, max_ecc):
    fig, axes = None, []

    # if dataset_choice in ["Fixed", "Both"]:
    x1, y1 = filter_df(df_mem_fixed, df_plac_fixed, subjects, regions, var_mem, var_plac, apply_filter, min_ecc, max_ecc)
        # if dataset_choice == "Both":
        #     fig, ax1 = plt.subplots(1, 2, figsize=(12, 6))
        #     plot_comparison(x1, y1, f"Fixed | Mem: {var_mem} vs Plac: {var_plac}", ax1[0])
        #     axes = ax1
        # else:
    fig, ax = plt.subplots(figsize=(6, 6))
    plot_comparison(x1, y1, f"Fixed | Mem: {var_mem} vs Plac: {var_plac}", ax)
    
    # if dataset_choice in ["Free", "Both"]:
    #     x2, y2 = filter_df(df_mem_free, df_plac_free, subjects, regions, var_mem, var_plac, apply_filter, min_ecc, max_ecc)
    #     if dataset_choice == "Both":
    #         plot_comparison(x2, y2, f"Free | Mem: {var_mem} vs Plac: {var_plac}", axes[1])
    #     else:
    #         fig, ax = plt.subplots(figsize=(6, 6))
    #         plot_comparison(x2, y2, f"Free | Mem: {var_mem} vs Plac: {var_plac}", ax)

    plt.tight_layout()
    plt.show()

# === Step 5: Create Widgets ===
all_subjects = sorted(df_mem_fixed["subject"].unique())
all_regions = sorted(df_mem_fixed["vertex_region"].unique())
variables = list(rename_map.values()) + ['eccentricity'] + ['surround_index']  # Added eccentricity
# dataset_options = ["Fixed", "Free", "Both"]

# Subject and region widgets (SelectMultiple)
subject_selector = widgets.SelectMultiple(
    options=all_subjects,
    value=[all_subjects[0]],  # Default to first subject
    description="Subjects:",
    disabled=False
)

region_selector = widgets.SelectMultiple(
    options=all_regions,
    value=[all_regions[0]],  # Default to first region
    description="Regions:",
    disabled=False
)

# Variable selection (Dropdown)
var_mem_dropdown = widgets.Dropdown(options=variables, description="Mem Var:")
var_plac_dropdown = widgets.Dropdown(options=variables, description="Plac Var:")

# Other controls
filter_toggle = widgets.Checkbox(value=False, description="Apply Filter")
# dataset_dropdown = widgets.Dropdown(options=dataset_options, description="Dataset:")

# Eccentricity sliders
min_ecc_slider = widgets.FloatSlider(
    value=0.5,
    min=0.1,
    max=10.0,
    step=0.1,
    description='Min Ecc:',
    disabled=not filter_toggle.value
)

max_ecc_slider = widgets.FloatSlider(
    value=4.5,
    min=0.1,
    max=10.0,
    step=0.1,
    description='Max Ecc:',
    disabled=not filter_toggle.value
)

# Enable/disable sliders based on filter checkbox
def update_sliders(change):
    min_ecc_slider.disabled = not change.new
    max_ecc_slider.disabled = not change.new

filter_toggle.observe(update_sliders, names='value')

ui = widgets.VBox([
    widgets.HTML("<b>Compare Memantine vs Placebo Variables</b>"),
    widgets.HBox([var_mem_dropdown, var_plac_dropdown]),
    subject_selector,
    region_selector,
    filter_toggle,
    widgets.HBox([min_ecc_slider, max_ecc_slider])#,
    # dataset_dropdown
])

out = widgets.interactive_output(
    plot_data,
    {
        "subjects": subject_selector,
        "regions": region_selector,
        "var_mem": var_mem_dropdown,
        "var_plac": var_plac_dropdown,
        "apply_filter": filter_toggle,
        # "dataset_choice": dataset_dropdown,
        "min_ecc": min_ecc_slider,
        "max_ecc": max_ecc_slider
    },
)

# Display everything
note = widgets.HTML("<b style='color: darkred;'>⚠️ If you select 'Both' from the dataset list, be careful when assessing the plots, since the scale of units might not be the same.</b>")

# Display everything together
display(ui, widgets.VBox([note, out]))