In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from ipywidgets import interact, widgets, fixed
from IPython.display import display
%matplotlib inline

# === Step 1: Load and Preprocess Data ===
rename_map = {
    '0': 'mu_x',
    '1': 'mu_y',
    '2': 'S1',
    '3': 'B1',
    '4': 'bold_baseline',
    '5': 'B2',
    '6': 'S2',
    '7': 'HRF1',
    '8': 'HRF2',
    '9': 'rsq'
}

index_cols = ['subject', 'original_row', 'vertex_id', 'vertex_region']

def load_and_prepare_data(filepath):
    df = pd.read_parquet(filepath).set_index(index_cols).sort_index().reset_index()
    df.rename(columns=rename_map, inplace=True)
    return df

# Load datasets
df_mem_fixed = load_and_prepare_data("aggregate_fixed_memantine_iterparams.parquet.gzip")
df_plac_fixed = load_and_prepare_data("aggregate_fixed_placebo_iterparams.parquet.gzip")
df_mem_free = load_and_prepare_data("aggregate_free_memantine_iterparams.parquet.gzip")
df_plac_free = load_and_prepare_data("aggregate_free_placebo_iterparams.parquet.gzip")

# === Step 2: Helper - Remove Outliers ===
def remove_upper_outliers(x, y, lower_percentile=2.3, upper_percentile=97.7):
    x = np.array(x)
    y = np.array(y)
    x_lower = np.percentile(x, lower_percentile)
    x_upper = np.percentile(x, upper_percentile)
    y_lower = np.percentile(y, lower_percentile)
    y_upper = np.percentile(y, upper_percentile)
    mask = (x >= x_lower) & (x <= x_upper) & (y >= y_lower) & (y <= y_upper)
    return x[mask], y[mask]

# === Step 3: Helper - Apply Inclusion Criteria ===
def get_filtered_indices(my_df_mem, my_df_plac):
    x_mem = my_df_mem['mu_x']
    y_mem = my_df_mem['mu_y']
    ecc_mem = np.sqrt(x_mem**2 + y_mem**2)
    r_mem = my_df_mem['rsq']
    b1_mem = my_df_mem['B1']
    b2_mem = my_df_mem['B2']
    s1_mem = my_df_mem['S1']
    s2_mem = my_df_mem['S2']

    x_plac = my_df_plac['mu_x']
    y_plac = my_df_plac['mu_y']
    ecc_plac = np.sqrt(x_plac**2 + y_plac**2)
    r_plac = my_df_plac['rsq']
    b1_plac = my_df_plac['B1']
    b2_plac = my_df_plac['B2']
    s1_plac = my_df_plac['S1']
    s2_plac = my_df_plac['S2']

    mask = (
        (r_mem >= 0.3) &
        (r_plac >= 0.3) &
        (np.abs(r_mem - r_plac) <= 0.3) &
        (ecc_mem > 0.5) & (ecc_mem < 4.5) &
        (ecc_plac > 0.5) & (ecc_plac < 4.5) &
        (b1_mem > 0) & (b1_plac > 0) &
        (s1_mem < s2_mem) & (s1_plac < s2_plac)
    )

    return np.where(mask)[0]

# === Step 4: Plotting Logic ===
def filter_df(df_mem, df_plac, subject, region, variable, apply_filter):
    col = variable
    mem_sub = df_mem
    plac_sub = df_plac

    if subject != "All":
        mem_sub = mem_sub[mem_sub["subject"] == subject]
        plac_sub = plac_sub[plac_sub["subject"] == subject]
    
    if region != "All":
        mem_sub = mem_sub[mem_sub["vertex_region"] == region]
        plac_sub = plac_sub[plac_sub["vertex_region"] == region]

    if apply_filter:
        indices = get_filtered_indices(mem_sub, plac_sub)
        mem_sub = mem_sub.iloc[indices]
        plac_sub = plac_sub.iloc[indices]
        x, y = remove_upper_outliers(mem_sub[col].values, plac_sub[col].values)
    else:
        x = mem_sub[col].values
        y = plac_sub[col].values

    return x, y

def plot_comparison(x, y, title, ax):
    ax.scatter(x, y, alpha=0.6)
    ax.set_xlabel("Memantine")
    ax.set_ylabel("Placebo")
    ax.set_title(title)
    ax.grid(True)
    ax.plot([min(x.min(), y.min()), max(x.max(), y.max())],
            [min(x.min(), y.min()), max(x.max(), y.max())], 'k--')

def plot_data(subject, variable, region, apply_filter, dataset_choice):
    variable_name = variable

    fig, axes = None, []

    if dataset_choice in ["Fixed", "Both"]:
        x1, y1 = filter_df(df_mem_fixed, df_plac_fixed, subject, region, variable_name, apply_filter)
        if dataset_choice == "Both":
            fig, ax1 = plt.subplots(1, 2, figsize=(12, 6))
            plot_comparison(x1, y1, f"Fixed | {variable_name}", ax1[0])
            axes = ax1
        else:
            fig, ax = plt.subplots(figsize=(6, 6))
            plot_comparison(x1, y1, f"Fixed | {variable_name}", ax)
    
    if dataset_choice in ["Free", "Both"]:
        x2, y2 = filter_df(df_mem_free, df_plac_free, subject, region, variable_name, apply_filter)
        if dataset_choice == "Both":
            plot_comparison(x2, y2, f"Free | {variable_name}", axes[1])
        else:
            fig, ax = plt.subplots(figsize=(6, 6))
            plot_comparison(x2, y2, f"Free | {variable_name}", ax)

    plt.tight_layout()
    plt.show()

# === Step 5: Create Widgets ===
subjects = ["All"] + sorted(df_mem_fixed["subject"].unique())
variables = list(rename_map.values())
regions = ["All"] + sorted(df_mem_fixed["vertex_region"].unique())
dataset_options = ["Fixed", "Free", "Both"]

subject_dropdown = widgets.Dropdown(options=subjects, description="Subject:")
variable_dropdown = widgets.Dropdown(options=variables, description="Variable:")
region_dropdown = widgets.Dropdown(options=regions, description="Region:")
filter_toggle = widgets.Checkbox(value=False, description="Apply Filter")
dataset_dropdown = widgets.Dropdown(options=dataset_options, description="Dataset:")

ui = widgets.VBox([subject_dropdown, variable_dropdown, region_dropdown, filter_toggle, dataset_dropdown])
out = widgets.interactive_output(
    plot_data,
    {
        "subject": subject_dropdown,
        "variable": variable_dropdown,
        "region": region_dropdown,
        "apply_filter": filter_toggle,
        "dataset_choice": dataset_dropdown,
    },
)

# Display everything
note = widgets.HTML("<b style='color: darkred;'>⚠️ If you select 'Both' from the dataset list, be careful when assessing the plots, since the scale of units might not be the same.</b>")

# Display everything together
display(ui, widgets.VBox([note, out]))
