In [None]:
# === Imports ===
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import ks_2samp, mannwhitneyu, anderson_ksamp, wilcoxon, ttest_rel
from ipywidgets import interact, widgets
from IPython.display import display
%matplotlib inline

# === Setup ===
rename_map = {
    '0': 'mu_x',
    '1': 'mu_y',
    '2': 'S1',
    '3': 'B1',
    '4': 'bold_baseline',
    '5': 'B2',
    '6': 'S2',
    '7': 'HRF1',
    '8': 'HRF2',
    '9': 'rsq'
}

index_cols = ['subject', 'original_row', 'vertex_id', 'vertex_region']

def load_and_prepare_data(filepath):
    df = pd.read_parquet(filepath).set_index(index_cols).sort_index().reset_index()
    df.rename(columns=rename_map, inplace=True)
    df['eccentricity'] = np.sqrt(df['mu_x']**2 + df['mu_y']**2)
    df['surround_index'] = (df['B2']*df['S2']**2) / (df['B1']*df['S1']**2)
    return df

# === Load Data ===
df_mem_fixed = load_and_prepare_data("aggregate_fixed_memantine_iterparams_final.parquet.gzip")
df_plac_fixed = load_and_prepare_data("aggregate_fixed_placebo_iterparams_final.parquet.gzip")
# df_mem_free = load_and_prepare_data("aggregate_free_memantine_iterparams.parquet.gzip")
# df_plac_free = load_and_prepare_data("aggregate_free_placebo_iterparams.parquet.gzip")

def remove_upper_outliers(x, y, lower_percentile=2.3, upper_percentile=97.7):
    x = np.array(x)
    y = np.array(y)
    x_lower = np.percentile(x, lower_percentile)
    x_upper = np.percentile(x, upper_percentile)
    y_lower = np.percentile(y, lower_percentile)
    y_upper = np.percentile(y, upper_percentile)
    mask = (x >= x_lower) & (x <= x_upper) & (y >= y_lower) & (y <= y_upper)
    return x[mask], y[mask]

# === Inclusion Criteria ===
def get_filtered_indices(df_mem, df_plac, min_ecc=0.5, max_ecc=4.5):
    try:
        x_mem = df_mem['mu_x']
        y_mem = df_mem['mu_y']
        ecc_mem = np.sqrt(x_mem**2 + y_mem**2)
        r_mem = df_mem['rsq']
        b1_mem = df_mem['B1']
        b2_mem = df_mem['B2']
        s1_mem = df_mem['S1']
        s2_mem = df_mem['S2']

        x_plac = df_plac['mu_x']
        y_plac = df_plac['mu_y']
        ecc_plac = np.sqrt(x_plac**2 + y_plac**2)
        r_plac = df_plac['rsq']
        b1_plac = df_plac['B1']
        b2_plac = df_plac['B2']
        s1_plac = df_plac['S1']
        s2_plac = df_plac['S2']

        mask = (
            (r_mem >= 0.3) &
            (r_plac >= 0.3) &
            (np.abs(r_mem - r_plac) <= 0.3) &
            (ecc_mem > min_ecc) & (ecc_mem < max_ecc) &
            (ecc_plac > min_ecc) & (ecc_plac < max_ecc) &
            (b1_mem > 0) & (b1_plac > 0) &
            (s1_mem < s2_mem) & (s1_plac < s2_plac)
        )
        return mask
    except Exception as e:
        print("Filter error:", e)
        return np.full(len(df_mem), False)

# === Plot + Stats ===
def plot_distributions_with_stats(
    var_mem, var_plac, subjects, regions, apply_filter,
    V1_df, V2_df, min_ecc, max_ecc
):
    fig = plt.figure(figsize=(18, 6))
    gs = fig.add_gridspec(2, 2, width_ratios=[1, 1], height_ratios=[1, 1])
    ax1 = fig.add_subplot(gs[:, 0])  # Distribution plot
    ax2 = fig.add_subplot(gs[0, 1])  # Scatter plot 1
    ax3 = fig.add_subplot(gs[1, 1])  # Scatter plot 2
    
    datasets = []
    labels = []

    def extract_data():
        V1_data = df_mem_fixed if V1_df == "memantine data" else df_plac_fixed
        V2_data = df_mem_fixed if V2_df == "memantine data" else df_plac_fixed
    
        # Filter by selected subjects
        if "All" not in subjects:
            V1_data = V1_data[V1_data["subject"].isin(subjects)]
            V2_data = V2_data[V2_data["subject"].isin(subjects)]
        
        # Filter by selected regions
        if regions:  # regions will never contain "All" since we removed it
            V1_data = V1_data[V1_data["vertex_region"].isin(regions)]
            V2_data = V2_data[V2_data["vertex_region"].isin(regions)]
    
        V1_data = V1_data.reset_index(drop=True)
        V2_data = V2_data.reset_index(drop=True)
        df_len = min(len(V1_data), len(V2_data))
        V1_data = V1_data.iloc[:df_len].copy()
        V2_data = V2_data.iloc[:df_len].copy()
    
        if apply_filter:
            mask = get_filtered_indices(V1_data, V2_data, min_ecc, max_ecc)
            V1_data = V1_data[mask]
            V2_data = V2_data[mask]
            x = V1_data[var_mem].values
            y = V2_data[var_plac].values
            x, y = remove_upper_outliers(x, y)
            return pd.Series(x), pd.Series(y)
    
        return V1_data[var_mem].dropna(), V2_data[var_plac].dropna()

    mem_data, plac_data = extract_data()
    if len(mem_data) > 0 and len(plac_data) > 0:
        datasets.append((mem_data, plac_data))
        labels.append((
            f"{var_mem} ({V1_df}, Fixed, μ={mem_data.mean():.4f})",
            f"{var_plac} ({V2_df}, Fixed, μ={plac_data.mean():.4f})"
        ))

    # Plot distributions
    for (mem, plac), (mem_label, plac_label) in zip(datasets, labels):
        sns.kdeplot(mem, ax=ax1, label=mem_label, fill=True, alpha=0.5)
        sns.kdeplot(plac, ax=ax1, label=plac_label, fill=True, alpha=0.5)
    
    ax1.set_title("Distribution Comparison")
    ax1.set_xlabel("Value")
    ax1.legend()
    ax1.grid(True)
    
    # Plot scatter plots if we have data
    if datasets:
        x = datasets[0][0].values
        y = datasets[0][1].values
        
        # First scatter plot (Variable 1)
        ax2.scatter(range(len(x)), x, alpha=0.5, label=f'{var_mem} ({V1_df})')
        ax2.set_title(f"{var_mem} Values")
        ax2.set_xlabel("Index")
        ax2.set_ylabel("Value")
        ax2.legend()
        ax2.grid(True)
        
        # Second scatter plot (Variable 2)
        ax3.scatter(range(len(y)), y, alpha=0.5, color='orange', label=f'{var_plac} ({V2_df})')
        ax3.set_title(f"{var_plac} Values")
        ax3.set_xlabel("Index")
        ax3.set_ylabel("Value")
        ax3.legend()
        ax3.grid(True)

    plt.tight_layout()
    plt.show()

    # === Stats on first pair only ===
    if datasets:
        x = datasets[0][0].values
        y = datasets[0][1].values

        print("=== Statistical Tests Between First MEM and PLAC Distributions ===")
        print(f"Paired t-test (Compared mean difference using parametric dists):\n p = {ttest_rel(x, y).pvalue:.4f}\n")
        print(f"Kolmogorov–Smirnov (captures whole shape difference with non parametric dists):\n p = {ks_2samp(x, y).pvalue:.4f}\n")
        print(f"Mann–Whitney U (captures difference with medians in non parametric dists):\n p = {mannwhitneyu(x, y, alternative='two-sided').pvalue:.4f}\n")

        try:
            ad_result = anderson_ksamp([x, y])
            print(f"Anderson–Darling (similar to KS but more sensitive to tails):\n p = {ad_result.pvalue:.4f}\n")
        except Exception as e:
            print(f"Anderson–Darling:   Error - {e}")

        try:
            mask = ~np.isnan(x) & ~np.isnan(y)
            x_clean = x[mask]
            y_clean = y[mask]
            if len(x_clean) > 0 and np.any(x_clean != y_clean):
                w_stat, w_p = wilcoxon(x_clean, y_clean, zero_method="wilcox", alternative="two-sided")
                print(f"Wilcoxon signed-rank (non parametric version for paired t-test):\n p = {w_p:.4f}\n")
            else:
                print("Wilcoxon signed-rank: Not applicable (identical or zero-difference pairs)")
        except Exception as e:
            print(f"Wilcoxon signed-rank: Error - {e}")
    else:
        print("Not enough data to compute statistics.")

# === Widgets ===
all_subjects = sorted(df_mem_fixed["subject"].unique())
all_regions = sorted(df_mem_fixed["vertex_region"].unique())
variables = list(rename_map.values()) + ['eccentricity'] + ['surround_index']
condition_options = ["memantine data", "placebo data"]

var_selector_mem = widgets.Dropdown(options=variables, description="Var 1:")
var_selector_plac = widgets.Dropdown(options=variables, description="Var 2:")
V1_dataset_selector = widgets.Dropdown(options=condition_options, description="Var 1 From:")
V2_dataset_selector = widgets.Dropdown(options=condition_options, description="Var 2 From:")
subject_selector = widgets.SelectMultiple(
    options=all_subjects,
    value=[all_subjects[0]],  # Default to first subject
    description="Subjects:",
    disabled=False
)
region_selector = widgets.SelectMultiple(
    options=all_regions,
    value=[all_regions[0]],  # Default to first region
    description="Regions:",
    disabled=False
)
filter_checkbox = widgets.Checkbox(value=False, description="Apply Inclusion Criteria")
min_ecc_slider = widgets.FloatSlider(
    value=0.5,
    min=0.1,
    max=10.0,
    step=0.1,
    description='Min Eccentricity:',
    disabled=not filter_checkbox.value
)
max_ecc_slider = widgets.FloatSlider(
    value=4.5,
    min=0.1,
    max=10.0,
    step=0.1,
    description='Max Eccentricity:',
    disabled=not filter_checkbox.value
)

# Enable/disable sliders based on filter checkbox
def update_sliders(change):
    min_ecc_slider.disabled = not change.new
    max_ecc_slider.disabled = not change.new

filter_checkbox.observe(update_sliders, names='value')

ui = widgets.VBox([
    widgets.HTML("<b>Compare Selected Variable Distributions (Flexible Condition Assignment)</b>"),
    widgets.HBox([var_selector_mem, V1_dataset_selector]),
    widgets.HBox([var_selector_plac, V2_dataset_selector]),
    subject_selector,
    region_selector,
    filter_checkbox,
    widgets.HBox([min_ecc_slider, max_ecc_slider])
])

out = widgets.interactive_output(
    plot_distributions_with_stats,
    {
        "var_mem": var_selector_mem,
        "var_plac": var_selector_plac,
        "subjects": subject_selector,
        "regions": region_selector,
        "apply_filter": filter_checkbox,
        "V1_df": V1_dataset_selector,
        "V2_df": V2_dataset_selector,
        "min_ecc": min_ecc_slider,
        "max_ecc": max_ecc_slider
    }
)

display(ui, out)