In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import stats
import ipywidgets as widgets
from IPython.display import display

index_cols = ['subject', 'original_row', 'vertex_id', 'vertex_region']

def load_and_prepare_data(filepath):
    df = pd.read_parquet(filepath).set_index(index_cols).sort_index().reset_index()
    return df

# Load and prepare data
plac = load_and_prepare_data("aggregate_fixed_placebo_iterparams.parquet.gzip")
mem = load_and_prepare_data("aggregate_fixed_memantine_iterparams.parquet.gzip")

# Add group identifiers
plac['group'] = 'placebo'
mem['group'] = 'memantine'

# Combine and filter data to ensure perfect alignment
combined = pd.concat([
    plac.assign(source='plac'),
    mem.assign(source='mem')
])

# First filter: remove invalid sigma values
valid_mask = (combined['2'] > 0) & (combined['6'] > 0)
combined = combined[valid_mask]

# Pivot to align placebo and memantine data
pivoted = combined.pivot_table(
    index=['subject', 'vertex_region', 'original_row'],
    columns='source',
    values=['0','1','2','3','5','6','9']
).dropna()

# Define filter function with adjustable eccentricity
def calculate_filter_mask(pivoted, min_ecc=0.5, max_ecc=4.5):
    x_plac = pivoted[('0','plac')]
    y_plac = pivoted[('1','plac')]
    ecc_plac = np.sqrt(x_plac**2 + y_plac**2)
    r_plac = pivoted[('9','plac')]
    b1_plac = pivoted[('3','plac')]
    s1_plac = pivoted[('2','plac')]
    s2_plac = pivoted[('6','plac')]

    x_mem = pivoted[('0','mem')]
    y_mem = pivoted[('1','mem')]
    ecc_mem = np.sqrt(x_mem**2 + y_mem**2)
    r_mem = pivoted[('9','mem')]
    b1_mem = pivoted[('3','mem')]
    s1_mem = pivoted[('2','mem')]
    s2_mem = pivoted[('6','mem')]

    filter_mask = (
        (r_mem >= 0.3) &
        (r_plac >= 0.3) &
        (np.abs(r_mem - r_plac) <= 0.3) &
        (ecc_mem > min_ecc) & (ecc_mem < max_ecc) &
        (ecc_plac > min_ecc) & (ecc_plac < max_ecc) &
        (b1_mem > 0) & (b1_plac > 0) &
        (s1_mem < s2_mem) & (s1_plac < s2_plac)
    )
    return filter_mask

# Initialize filtered datasets with default eccentricity range
initial_filter_mask = calculate_filter_mask(pivoted)
valid_indices = pivoted[initial_filter_mask].index
plac_filtered = plac.set_index(['subject', 'vertex_region', 'original_row']).loc[valid_indices].reset_index()
mem_filtered = mem.set_index(['subject', 'vertex_region', 'original_row']).loc[valid_indices].reset_index()

# Define the Mexican Hat function
def DoG_profile(x, sigma1, sigma2, A1=1.0, A2=1.0, normalization='mean_percent'):
    sigma1 = max(sigma1, 0.000001)
    sigma2 = max(sigma2, 0.000001)
    
    with np.errstate(divide='ignore', invalid='ignore'):
        g1 = A1 * np.exp(-x**2 / (2 * sigma1**2)) / (np.sqrt(2 * np.pi) * sigma1)
        g2 = A2 * np.exp(-x**2 / (2 * sigma2**2)) / (np.sqrt(2 * np.pi) * sigma2)
        ans = g1 - g2
        ans = np.nan_to_num(ans)
        
        if normalization == 'mean_percent':
            mean_val = np.mean(np.abs(ans[ans != 0])) if np.any(ans != 0) else 1
            return (ans / mean_val) * 100
        elif normalization == 'max':
            max_val = np.max(np.abs(ans)) if np.any(ans != 0) else 1
            return ans / max_val
        else:
            return ans

def calculate_profiles_voxel_first(group_data, x_range=(-10, 10, 100), normalization='mean_percent'):
    x = np.linspace(*x_range)
    profiles = []
    
    for (subject, region), sub_df in group_data.groupby(['subject', 'vertex_region']):
        voxel_profiles = []
        for _, row in sub_df.iterrows():
            try:
                profile = DoG_profile(
                    x, 
                    row['2'],  # sigma1
                    row['6'],  # sigma2
                    row['3'],  # A1
                    row['5'],  # A2
                    normalization=normalization
                )
                voxel_profiles.append(profile)
            except:
                continue
        
        if voxel_profiles:
            avg_profile = np.nanmean(voxel_profiles, axis=0)
            ci_low = np.percentile(voxel_profiles, 2.5, axis=0)
            ci_high = np.percentile(voxel_profiles, 97.5, axis=0)
            
            peak_idx = np.argmax(avg_profile)
            
            profiles.append({
                'subject': subject,
                'region': region,
                'profile': avg_profile,
                'ci_low': ci_low,
                'ci_high': ci_high,
                'x': x,
                'n_voxels': len(voxel_profiles),
                'peak_values': (ci_low[peak_idx], avg_profile[peak_idx], ci_high[peak_idx])
            })
    
    return pd.DataFrame(profiles)

def calculate_profiles_param_first(group_data, x_range=(-10, 10, 100), normalization='mean_percent'):
    x = np.linspace(*x_range)
    profiles = []
    
    for (subject, region), sub_df in group_data.groupby(['subject', 'vertex_region']):
        if len(sub_df) == 0:
            continue
            
        mean_params = {
            'sigma1': np.mean(sub_df['2']),
            'sigma2': np.mean(sub_df['6']),
            'A1': np.mean(sub_df['3']),
            'A2': np.mean(sub_df['5'])
        }
        
        try:
            profile = DoG_profile(
                x,
                mean_params['sigma1'],
                mean_params['sigma2'],
                mean_params['A1'],
                mean_params['A2'],
                normalization=normalization
            )
            profiles.append({
                'subject': subject,
                'region': region,
                'profile': profile,
                'x': x,
                'n_voxels': len(sub_df)
            })
        except:
            continue
    
    return pd.DataFrame(profiles)

# Create widgets
subject_options = ['All'] + sorted(plac['subject'].unique().tolist())
region_options = ['All'] + sorted(plac['vertex_region'].unique().tolist())
normalization_options = ['mean_percent', 'max', 'none']
average_method_options = [
    ('Vertex profiles first, then average pRFs', 'voxel_first'),
    ('Average parameters first, then PRF', 'param_first')
]

# Eccentricity filter widgets
min_ecc_widget = widgets.FloatSlider(
    value=0.5,
    min=0.0,
    max=4.5,
    step=0.5,
    description='Min Eccentricity:',
    disabled=False,
    continuous_update=False
)

max_ecc_widget = widgets.FloatSlider(
    value=4.5,
    min=0.0,
    max=4.5,
    step=0.5,
    description='Max Eccentricity:',
    disabled=False,
    continuous_update=False
)

# Zoom control widgets
xmin_widget = widgets.FloatSlider(
    value=-7.5,
    min=-10,
    max=0,
    step=0.5,
    description='X-min:',
    disabled=False,
    continuous_update=False
)

xmax_widget = widgets.FloatSlider(
    value=7.5,
    min=0,
    max=10,
    step=0.5,
    description='X-max:',
    disabled=False,
    continuous_update=False
)

ymin_widget = widgets.FloatSlider(
    value=-30,
    min=-100,
    max=0,
    step=5,
    description='Y-min:',
    disabled=False,
    continuous_update=False
)

ymax_widget = widgets.FloatSlider(
    value=20,
    min=0,
    max=100,
    step=5,
    description='Y-max:',
    disabled=False,
    continuous_update=False
)

# Other widgets
subject_widget = widgets.Dropdown(
    options=subject_options,
    value=1,
    description='Subject:',
    disabled=False
)

region_widget = widgets.Dropdown(
    options=region_options,
    value='V1',
    description='Region:',
    disabled=False
)

normalization_widget = widgets.RadioButtons(
    options=normalization_options,
    value='mean_percent',
    description='Normalization:',
    disabled=False
)

show_placebo_widget = widgets.Checkbox(
    value=True,
    description='Show Placebo',
    disabled=False
)

show_memantine_widget = widgets.Checkbox(
    value=True,
    description='Show Memantine',
    disabled=False
)

average_method_widget = widgets.RadioButtons(
    options=average_method_options,
    value='voxel_first',
    description='Averaging Method:',
    disabled=False
)

filter_widget = widgets.Checkbox(
    value=True,
    description='Apply Inclusion Filter',
    disabled=False
)

verbose_widget = widgets.Checkbox(
    value=True,  # Default to showing verbose output
    description='Verbose pRF information',
    disabled=False
)

def update_plot(selected_subject, selected_region, normalization, show_placebo, show_memantine, 
               average_method, apply_filter, min_ecc, max_ecc,
               xmin, xmax, ymin, ymax, verbose):
    global plac_filtered, mem_filtered
    
    # Recalculate filter if eccentricity range changed or filter toggled
    if apply_filter:
        current_filter_mask = calculate_filter_mask(pivoted, min_ecc, max_ecc)
        valid_indices = pivoted[current_filter_mask].index
        plac_filtered = plac.set_index(['subject', 'vertex_region', 'original_row']).loc[valid_indices].reset_index()
        mem_filtered = mem.set_index(['subject', 'vertex_region', 'original_row']).loc[valid_indices].reset_index()
    
    # Use filtered or unfiltered data
    if apply_filter:
        plac_data = plac_filtered.copy()
        mem_data = mem_filtered.copy()
    else:
        plac_data = plac.copy()
        mem_data = mem.copy()
    
    # Create figure with two subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 6))
    
    groups_to_show = []
    if show_placebo:
        groups_to_show.append(('placebo', 'blue', plac_data))
    if show_memantine:
        groups_to_show.append(('memantine', 'red', mem_data))
    
    if not groups_to_show:
        print("Please select at least one group to show")
        return
    
    for group_name, group_color, group_data in groups_to_show:
        # Apply subject and region filters
        if selected_subject != 'All':
            group_data = group_data[group_data['subject'] == selected_subject]
        if selected_region != 'All':
            group_data = group_data[group_data['vertex_region'] == selected_region]
        
        if group_data.empty:
            print(f"No data available for {group_name} with current filters")
            continue
            
        if average_method == 'voxel_first':
            profiles_df = calculate_profiles_voxel_first(group_data, normalization=normalization)
        else:
            profiles_df = calculate_profiles_param_first(group_data, normalization=normalization)
        
        if profiles_df.empty:
            print(f"No valid profiles calculated for {group_name} with current filters")
            continue
            
        label = 'Placebo' if group_name == 'placebo' else 'Memantine'

        # --- Wrap print statements in `if verbose:` ---
        if verbose:  # <-- NEW: Only show if verbose is True
            print(f"\n--- {label} Profile Information ---")
            print(f"Number of Vertices: {profiles_df['n_voxels'].sum()}")
        
        x = profiles_df['x'].iloc[0]
        profile = profiles_df['profile'].iloc[0]
        
        # Main plot (original)
        if average_method == 'voxel_first':
            ci_low = profiles_df['ci_low'].iloc[0]
            ci_high = profiles_df['ci_high'].iloc[0]
            
            peak_idx = np.argmax(profile)
            
            if verbose:  # <-- NEW: Only show peak stats if verbose
                print(f"At peak response (x={x[peak_idx]:.1f}):")
                print(f"  Lower CI: {ci_low[peak_idx]:.4f}")
                print(f"  Mean:     {profile[peak_idx]:.4f}")
                print(f"  Upper CI: {ci_high[peak_idx]:.4f}")
                print(f"  CI Range: {ci_high[peak_idx] - ci_low[peak_idx]:.4f}")
            
            # Plotting logic (unchanged)
            ax1.plot(x, profile, color=group_color, label=label)
            ax1.fill_between(x, ci_low, ci_high, color=group_color, alpha=0.2, label=f'{label} 95% CI')
            
            # Zoomed plot (negative values)
            ax2.plot(x, profile, color=group_color, label=label)
            ax2.fill_between(x, ci_low, ci_high, color=group_color, alpha=0.2)
        else:
            if verbose:  # <-- NEW: Only show this if verbose
                print("Parameter-average method - no CIs calculated")
            
            # Plotting logic (unchanged)
            ax1.plot(x, profile, color=group_color, label=label)
            ax2.plot(x, profile, color=group_color, label=label)
            
    # Configure main plot
    ax1.set_xlabel('Distance from pRF Center (Degrees)')
    ax1.set_ylabel('BOLD Response (% of mean)' if normalization == 'mean_percent' else 
                  'BOLD Response (normalized)' if normalization == 'max' else 'BOLD Response')
    title_parts = []
    if selected_subject != 'All':
        title_parts.append(f"Subject: {selected_subject}")
    if selected_region != 'All':
        title_parts.append(f"Region: {selected_region}")
    title_parts.append(f"Method: {'Vertex-then-average' if average_method == 'voxel_first' else 'Parameter-average'}")
    if apply_filter:
        title_parts.append(f"Eccentricity: {min_ecc}-{max_ecc} deg")
    fig.suptitle('PRF Profiles: ' + (' vs '.join([g[0] for g in groups_to_show]) + (' (' + ', '.join(title_parts) + ')' if title_parts else '')))
    ax1.legend()
    ax1.grid(True)
    
    # Configure zoomed plot (negative values)
    ax2.set_xlabel('Distance from pRF Center (Degrees)')
    ax2.set_ylabel('BOLD Response (% of Mean)')
    ax2.set_title('Surround Response Detail')
    ax2.grid(True)
    
    # Set zoomed plot limits based on widget values
    ax2.set_xlim(xmin, xmax)
    ax2.set_ylim(ymin, ymax)
    
    # Add zero lines for reference
    ax2.axhline(0, color='black', linestyle='--', linewidth=0.5)
    ax2.axvline(0, color='black', linestyle='--', linewidth=0.5)
    
    plt.tight_layout()
    plt.show()

# Create a VBox to organize the widgets neatly
controls = widgets.VBox([
    widgets.HBox([subject_widget, region_widget]),
    widgets.HBox([min_ecc_widget, max_ecc_widget]),
    normalization_widget,
    widgets.HBox([show_placebo_widget, show_memantine_widget]),
    average_method_widget,
    widgets.HBox([filter_widget, verbose_widget]),
    widgets.HTML("<h3>Surround Zooming Controls</h3>"),  # Add this title
    widgets.HBox([xmin_widget, xmax_widget]),  # X-axis zoom controls
    widgets.HBox([ymin_widget, ymax_widget])   # Y-axis zoom controls
])

output = widgets.interactive_output(
    update_plot,
    {
        'selected_subject': subject_widget,
        'selected_region': region_widget,
        'normalization': normalization_widget,
        'show_placebo': show_placebo_widget,
        'show_memantine': show_memantine_widget,
        'average_method': average_method_widget,
        'apply_filter': filter_widget,
        'min_ecc': min_ecc_widget,
        'max_ecc': max_ecc_widget,
        'xmin': xmin_widget,
        'xmax': xmax_widget,
        'ymin': ymin_widget,
        'ymax': ymax_widget,
        'verbose': verbose_widget  # Add this new connection
    }
)

display(controls, output)