In [None]:
# === Imports ===
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import ks_2samp, mannwhitneyu, anderson_ksamp, wilcoxon
from ipywidgets import interact, widgets
from IPython.display import display
%matplotlib inline

# === Setup ===
rename_map = {
    '0': 'mu_x',
    '1': 'mu_y',
    '2': 'S1',
    '3': 'B1',
    '4': 'bold_baseline',
    '5': 'B2',
    '6': 'S2',
    '7': 'HRF1',
    '8': 'HRF2',
    '9': 'rsq'
}

index_cols = ['subject', 'original_row', 'vertex_id', 'vertex_region']

def load_and_prepare_data(filepath):
    df = pd.read_parquet(filepath).set_index(index_cols).sort_index().reset_index()
    df.rename(columns=rename_map, inplace=True)
    return df

# === Load Data ===
df_mem_fixed = load_and_prepare_data("aggregate_fixed_memantine_iterparams.parquet.gzip")
df_plac_fixed = load_and_prepare_data("aggregate_fixed_placebo_iterparams.parquet.gzip")
df_mem_free = load_and_prepare_data("aggregate_free_memantine_iterparams.parquet.gzip")
df_plac_free = load_and_prepare_data("aggregate_free_placebo_iterparams.parquet.gzip")

def remove_upper_outliers(x, y, lower_percentile=2.3, upper_percentile=97.7):
    x = np.array(x)
    y = np.array(y)
    x_lower = np.percentile(x, lower_percentile)
    x_upper = np.percentile(x, upper_percentile)
    y_lower = np.percentile(y, lower_percentile)
    y_upper = np.percentile(y, upper_percentile)
    mask = (x >= x_lower) & (x <= x_upper) & (y >= y_lower) & (y <= y_upper)
    return x[mask], y[mask]

# === Inclusion Criteria ===
def get_filtered_indices(df_mem, df_plac):
    try:
        x_mem = df_mem['mu_x']
        y_mem = df_mem['mu_y']
        ecc_mem = np.sqrt(x_mem**2 + y_mem**2)
        r_mem = df_mem['rsq']
        b1_mem = df_mem['B1']
        b2_mem = df_mem['B2']
        s1_mem = df_mem['S1']
        s2_mem = df_mem['S2']

        x_plac = df_plac['mu_x']
        y_plac = df_plac['mu_y']
        ecc_plac = np.sqrt(x_plac**2 + y_plac**2)
        r_plac = df_plac['rsq']
        b1_plac = df_plac['B1']
        b2_plac = df_plac['B2']
        s1_plac = df_plac['S1']
        s2_plac = df_plac['S2']

        mask = (
            (r_mem >= 0.3) &
            (r_plac >= 0.3) &
            (np.abs(r_mem - r_plac) <= 0.3) &
            (ecc_mem > 0.5) & (ecc_mem < 4.5) &
            (ecc_plac > 0.5) & (ecc_plac < 4.5) &
            (b1_mem > 0) & (b1_plac > 0) &
            (s1_mem < s2_mem) & (s1_plac < s2_plac)
        )
        return mask
    except Exception as e:
        print("Filter error:", e)
        return np.full(len(df_mem), False)

# === Plot + Stats ===
def plot_distributions_with_stats(
    var_mem, var_plac, subject, region, apply_filter, dataset_choice,
    mem_condition, plac_condition
):
    fig, ax = plt.subplots(figsize=(8, 6))
    datasets = []
    labels = []

    def extract_data(dataset_type):
        if dataset_type == "Fixed":
            mem_df = df_mem_fixed if mem_condition == "memantine" else df_plac_fixed
            plac_df = df_mem_fixed if plac_condition == "memantine" else df_plac_fixed
        elif dataset_type == "Free":
            mem_df = df_mem_free if mem_condition == "memantine" else df_plac_free
            plac_df = df_mem_free if plac_condition == "memantine" else df_plac_free
        else:
            return [], []

        if subject != "All":
            mem_df = mem_df[mem_df["subject"] == subject]
            plac_df = plac_df[plac_df["subject"] == subject]
        if region != "All":
            mem_df = mem_df[mem_df["vertex_region"] == region]
            plac_df = plac_df[plac_df["vertex_region"] == region]

        mem_df = mem_df.reset_index(drop=True)
        plac_df = plac_df.reset_index(drop=True)
        df_len = min(len(mem_df), len(plac_df))
        mem_df = mem_df.iloc[:df_len].copy()
        plac_df = plac_df.iloc[:df_len].copy()

        if apply_filter:
            mask = get_filtered_indices(mem_df, plac_df)
            mem_df = mem_df[mask]
            plac_df = plac_df[mask]
            x = mem_df[var_mem].values
            y = plac_df[var_plac].values
            x, y = remove_upper_outliers(x, y)
            return pd.Series(x), pd.Series(y)

        return mem_df[var_mem].dropna(), plac_df[var_plac].dropna()

    if dataset_choice in ["Fixed", "Both"]:
        mem_data, plac_data = extract_data("Fixed")
        if len(mem_data) > 0 and len(plac_data) > 0:
            datasets.append((mem_data, plac_data))
            labels.append((
                f"{var_mem} ({mem_condition}, Fixed, μ={mem_data.mean():.4f})",
                f"{var_plac} ({plac_condition}, Fixed, μ={plac_data.mean():.4f})"
            ))

    if dataset_choice in ["Free", "Both"]:
        mem_data, plac_data = extract_data("Free")
        if len(mem_data) > 0 and len(plac_data) > 0:
            datasets.append((mem_data, plac_data))
            labels.append((
                f"{var_mem} ({mem_condition}, Free, μ={mem_data.mean():.4f})",
                f"{var_plac} ({plac_condition}, Free, μ={plac_data.mean():.4f})"
            ))

    for (mem, plac), (mem_label, plac_label) in zip(datasets, labels):
        sns.kdeplot(mem, ax=ax, label=mem_label, fill=True, alpha=0.5)
        sns.kdeplot(plac, ax=ax, label=plac_label, fill=True, alpha=0.5)

    ax.set_title(f"Distribution Comparison")
    ax.set_xlabel("Value")
    ax.legend()
    ax.grid(True)
    plt.tight_layout()
    plt.show()

    # === Stats on first pair only ===
    if datasets:
        x = datasets[0][0].values
        y = datasets[0][1].values

        print("=== Statistical Tests Between First MEM and PLAC Distributions ===")
        print(f"Kolmogorov–Smirnov (captures whole shape difference with non parametric dists):\n p = {ks_2samp(x, y).pvalue:.4f}\n")
        print(f"Mann–Whitney U (captures difference with medians in non parametric dists):\n p = {mannwhitneyu(x, y, alternative='two-sided').pvalue:.4f}\n")

        try:
            ad_result = anderson_ksamp([x, y])
            print(f"Anderson–Darling (similar to KS but more sensitive to tails):\n p = {ad_result.pvalue:.4f}\n")
        except Exception as e:
            print(f"Anderson–Darling:   Error - {e}")

        try:
            mask = ~np.isnan(x) & ~np.isnan(y)
            x_clean = x[mask]
            y_clean = y[mask]
            if len(x_clean) > 0 and np.any(x_clean != y_clean):
                w_stat, w_p = wilcoxon(x_clean, y_clean, zero_method="wilcox", alternative="two-sided")
                print(f"Wilcoxon signed-rank (non parametric version for paired t-test):\n p = {w_p:.4f}\n")
            else:
                print("Wilcoxon signed-rank: Not applicable (identical or zero-difference pairs)")
        except Exception as e:
            print(f"Wilcoxon signed-rank: Error - {e}")
    else:
        print("Not enough data to compute statistics.")

# === Widgets ===
subjects = ["All"] + sorted(df_mem_fixed["subject"].unique())
regions = ["All"] + sorted(df_mem_fixed["vertex_region"].unique())
variables = list(rename_map.values())
dataset_options = ["Fixed", "Free", "Both"]
condition_options = ["memantine data", "placebo data"]

var_selector_mem = widgets.Dropdown(options=variables, description="Var 1:")
var_selector_plac = widgets.Dropdown(options=variables, description="Var 2:")
mem_condition_selector = widgets.Dropdown(options=condition_options, description="Var 1 From:")
plac_condition_selector = widgets.Dropdown(options=condition_options, description="Var 2 From:")
subject_selector = widgets.Dropdown(options=subjects, description="Subject:")
region_selector = widgets.Dropdown(options=regions, description="Region:")
filter_checkbox = widgets.Checkbox(value=False, description="Apply Inclusion Criteria")
dataset_selector = widgets.Dropdown(options=dataset_options, description="Dataset:")

ui = widgets.VBox([
    widgets.HTML("<b>Compare Selected Variable Distributions (Flexible Condition Assignment)</b>"),
    widgets.HBox([var_selector_mem, mem_condition_selector]),
    widgets.HBox([var_selector_plac, plac_condition_selector]),
    subject_selector,
    region_selector,
    filter_checkbox,
    dataset_selector
])

out = widgets.interactive_output(
    plot_distributions_with_stats,
    {
        "var_mem": var_selector_mem,
        "var_plac": var_selector_plac,
        "subject": subject_selector,
        "region": region_selector,
        "apply_filter": filter_checkbox,
        "dataset_choice": dataset_selector,
        "mem_condition": mem_condition_selector,
        "plac_condition": plac_condition_selector
    }
)

display(ui, out)