In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import stats
import ipywidgets as widgets
from IPython.display import display

index_cols = ['subject', 'original_row', 'vertex_id', 'vertex_region']

def load_and_prepare_data(filepath):
    df = pd.read_parquet(filepath).set_index(index_cols).sort_index().reset_index()
    return df

# Load and prepare data
plac = load_and_prepare_data("aggregate_fixed_placebo_iterparams_final.parquet.gzip")
mem = load_and_prepare_data("aggregate_fixed_memantine_iterparams_final.parquet.gzip")

# Add group identifiers
plac['group'] = 'placebo'
mem['group'] = 'memantine'

# Combine and filter data to ensure perfect alignment
combined = pd.concat([
    plac.assign(source='plac'),
    mem.assign(source='mem')
])

# First filter: remove invalid sigma values
valid_mask = (combined['2'] > 0) & (combined['6'] > 0)
combined = combined[valid_mask]

# Pivot to align placebo and memantine data
pivoted = combined.pivot_table(
    index=['subject', 'vertex_region', 'original_row'],
    columns='source',
    values=['0','1','2','3','5','6','9']
).dropna()

# Define filter function with adjustable eccentricity
def calculate_filter_mask(pivoted, min_ecc=0.5, max_ecc=4.5):
    x_plac = pivoted[('0','plac')]
    y_plac = pivoted[('1','plac')]
    ecc_plac = np.sqrt(x_plac**2 + y_plac**2)
    r_plac = pivoted[('9','plac')]
    b1_plac = pivoted[('3','plac')]
    s1_plac = pivoted[('2','plac')]
    s2_plac = pivoted[('6','plac')]

    x_mem = pivoted[('0','mem')]
    y_mem = pivoted[('1','mem')]
    ecc_mem = np.sqrt(x_mem**2 + y_mem**2)
    r_mem = pivoted[('9','mem')]
    b1_mem = pivoted[('3','mem')]
    s1_mem = pivoted[('2','mem')]
    s2_mem = pivoted[('6','mem')]

    filter_mask = (
        (r_mem >= 0.3) &
        (r_plac >= 0.3) &
        (np.abs(r_mem - r_plac) <= 0.3) &
        (ecc_mem > min_ecc) & (ecc_mem < max_ecc) &
        (ecc_plac > min_ecc) & (ecc_plac < max_ecc) &
        (b1_mem > 0) & (b1_plac > 0) &
        (s1_mem < s2_mem) & (s1_plac < s2_plac)
    )
    return filter_mask

# Initialize filtered datasets with default eccentricity range
initial_filter_mask = calculate_filter_mask(pivoted)
valid_indices = pivoted[initial_filter_mask].index
plac_filtered = plac.set_index(['subject', 'vertex_region', 'original_row']).loc[valid_indices].reset_index()
mem_filtered = mem.set_index(['subject', 'vertex_region', 'original_row']).loc[valid_indices].reset_index()

# Define the Mexican Hat function with new normalization option
def DoG_profile(x, sigma1, sigma2, A1=1.0, A2=1.0, normalization='mean_percent', group=None, maxs_avg=None):
    sigma1 = max(sigma1, 0.000001)
    sigma2 = max(sigma2, 0.000001)
    
    g1 = A1 * np.exp(-x**2 / (2 * sigma1**2)) / (np.sqrt(2 * np.pi) * sigma1)
    g2 = A2 * np.exp(-x**2 / (2 * sigma2**2)) / (np.sqrt(2 * np.pi) * sigma2)
    ans = g1 - g2
    ans = np.nan_to_num(ans)
    
    if normalization == 'mean_percent':
        mean_val = np.mean(np.abs(ans[ans != 0])) if np.any(ans != 0) else 1
        return (ans / mean_val) * 100
    elif normalization == 'max':
        max_val = np.max(ans) if np.any(ans != 0) else 1
        if max_val <= 0:
            return ans  # Avoid division by zero or negative max
        return ans / max_val
    elif normalization == 'maxs_avg':
        if maxs_avg is None:
            raise ValueError("For 'maxs_avg' normalization, maxs_avg parameter must be provided")
        return ans / maxs_avg
    elif normalization == 'None':
        return ans
    else:
        return ans

def calculate_width_metrics(profile, x):
    """Calculate FWHM and FWMIN metrics for a profile"""
    peak_val = np.max(profile)
    half_max = peak_val / 2
    min_val = np.min(profile)
    
    # Find FWHM
    above_hm = profile >= half_max
    if np.any(above_hm):
        left_idx = np.argmax(above_hm)
        right_idx = len(profile) - left_idx - 1
        fwhm = np.abs(x[right_idx] - x[left_idx])
    else:
        fwhm = np.nan
    
    # Find FWMIN (full width at minimum)
    above_min = profile <= min_val
    if np.any(above_min):
        left_min_idx = np.argmax(above_min)
        right_min_idx = len(profile) - left_min_idx - 1
        fwmin = np.abs(x[right_min_idx] - x[left_min_idx])
    else:
        fwmin = np.nan
    
    return fwhm, fwmin

def calculate_profiles_voxel_first(group_data, x_range=(-10, 10, 100), 
                                 normalization='mean_percent', ci_method='se', maxs_avg=None):
    x = np.linspace(*x_range)
    profiles = []
    
    # Filter by selected subjects and regions first
    for (subject, region), sub_df in group_data.groupby(['subject', 'vertex_region']):
        voxel_profiles = []
        for _, row in sub_df.iterrows():
            try:
                # Normalize individual profiles to peak 1 if normalization is max
                norm = 'max' if normalization == 'max' else normalization
                profile = DoG_profile(
                    x, 
                    row['2'],  # sigma1
                    row['6'],  # sigma2
                    row['3'],  # A1
                    row['5'],  # A2
                    normalization=norm,
                    maxs_avg=maxs_avg
                )
                voxel_profiles.append(profile)
            except:
                continue
        
        if voxel_profiles:
            avg_profile = np.nanmean(voxel_profiles, axis=0)
            
            # Re-normalize averaged profile to peak 1 if normalization is max
            if normalization == 'max':
                max_val = np.max(avg_profile) if np.any(avg_profile != 0) else 1
                if max_val > 0:
                    avg_profile = avg_profile / max_val
            
            if ci_method == 'percentile':
                ci_low = np.percentile(voxel_profiles, 2.5, axis=0)
                ci_high = np.percentile(voxel_profiles, 97.5, axis=0)
            else:  # standard error
                stderr = stats.sem(voxel_profiles, axis=0, nan_policy='omit')
                ci_low = avg_profile - 1.96 * stderr
                ci_high = avg_profile + 1.96 * stderr
            
            peak_idx = np.argmax(avg_profile)
            fwhm, fwmin = calculate_width_metrics(avg_profile, x)
            
            profiles.append({
                'subject': subject,
                'region': region,
                'profile': avg_profile,
                'ci_low': ci_low,
                'ci_high': ci_high,
                'x': x,
                'n_voxels': len(voxel_profiles),
                'peak_values': (ci_low[peak_idx], avg_profile[peak_idx], ci_high[peak_idx]),
                'fwhm': fwhm,
                'fwmin': fwmin
            })
    
    return pd.DataFrame(profiles)

def calculate_profiles_voxel_first_subject_average(group_data, x_range=(-10, 10, 100), 
                                                 normalization='mean_percent', ci_method='se', maxs_avg=None):
    x = np.linspace(*x_range)
    subject_profiles = []
    
    # First calculate per-subject averages
    for subject, subject_df in group_data.groupby('subject'):
        region_profiles = []
        
        for region, region_df in subject_df.groupby('vertex_region'):
            voxel_profiles = []
            for _, row in region_df.iterrows():
                try:
                    # Normalize individual profiles to peak 1 if normalization is max
                    norm = 'max' if normalization == 'max' else normalization
                    profile = DoG_profile(
                        x, 
                        row['2'],  # sigma1
                        row['6'],  # sigma2
                        row['3'],  # A1
                        row['5'],  # A2
                        normalization=norm,
                        maxs_avg=maxs_avg
                    )
                    voxel_profiles.append(profile)
                except:
                    continue
            
            if voxel_profiles:
                avg_profile = np.nanmean(voxel_profiles, axis=0)
                region_profiles.append(avg_profile)
        
        if region_profiles:
            # Average across regions for this subject
            subject_avg_profile = np.nanmean(region_profiles, axis=0)
            subject_profiles.append(subject_avg_profile)
    
    if not subject_profiles:
        return pd.DataFrame()
    
    # Now average across subjects
    final_profile = np.nanmean(subject_profiles, axis=0)
    
    # Re-normalize averaged profile to peak 1 if normalization is max
    if normalization == 'max':
        max_val = np.max(final_profile) if np.any(final_profile != 0) else 1
        if max_val > 0:
            final_profile = final_profile / max_val
    
    # Handle CI calculation differently for single vs multiple subjects
    if len(subject_profiles) > 1:  # Multiple subjects
        if ci_method == 'percentile':
            ci_low = np.percentile(subject_profiles, 2.5, axis=0)
            ci_high = np.percentile(subject_profiles, 97.5, axis=0)
        else:  # standard error
            stderr = stats.sem(subject_profiles, axis=0, nan_policy='omit')
            ci_low = final_profile - 1.96 * stderr
            ci_high = final_profile + 1.96 * stderr
    else:  # Single subject - use voxel-level variability
        voxel_profiles = []
        for _, row in group_data.iterrows():
            try:
                profile = DoG_profile(
                    x,
                    row['2'],  # sigma1
                    row['6'],  # sigma2
                    row['3'],  # A1
                    row['5'],  # A2
                    normalization=normalization,
                    maxs_avg=maxs_avg
                )
                voxel_profiles.append(profile)
            except:
                continue
        
        if ci_method == 'percentile':
            ci_low = np.percentile(voxel_profiles, 2.5, axis=0)
            ci_high = np.percentile(voxel_profiles, 97.5, axis=0)
        else:
            stderr = stats.sem(voxel_profiles, axis=0, nan_policy='omit')
            ci_low = final_profile - 1.96 * stderr
            ci_high = final_profile + 1.96 * stderr
    
    peak_idx = np.argmax(final_profile)
    fwhm, fwmin = calculate_width_metrics(final_profile, x)
    
    return pd.DataFrame([{
        'profile': final_profile,
        'ci_low': ci_low,
        'ci_high': ci_high,
        'x': x,
        'n_voxels': len(group_data),
        'peak_values': (ci_low[peak_idx], final_profile[peak_idx], ci_high[peak_idx]),
        'fwhm': fwhm,
        'fwmin': fwmin
    }])

def calculate_profiles_param_first(group_data, x_range=(-10, 10, 100), normalization='mean_percent', maxs_avg=None):
    x = np.linspace(*x_range)
    profiles = []
    
    # Group by subject and region first
    for (subject, region), sub_df in group_data.groupby(['subject', 'vertex_region']):
        if len(sub_df) == 0:
            continue
        
        mean_params = {
            'sigma1': np.mean(sub_df['2']),
            'sigma2': np.mean(sub_df['6']),
            'A1': np.mean(sub_df['3']),
            'A2': np.mean(sub_df['5'])
        }
        
        try:
            # Normalize individual profiles to peak 1 if normalization is max
            norm = 'max' if normalization == 'max' else normalization
            profile = DoG_profile(
                x,
                mean_params['sigma1'],
                mean_params['sigma2'],
                mean_params['A1'],
                mean_params['A2'],
                normalization=norm,
                maxs_avg=maxs_avg
            )
            fwhm, fwmin = calculate_width_metrics(profile, x)
            
            profiles.append({
                'subject': subject,
                'region': region,
                'profile': profile,
                'x': x,
                'n_voxels': len(sub_df),
                'fwhm': fwhm,
                'fwmin': fwmin
            })
        except:
            continue
    
    # If multiple subjects/regions selected, average their profiles
    if len(profiles) > 1:
        all_profiles = [p['profile'] for p in profiles]
        avg_profile = np.mean(all_profiles, axis=0)
        
        # Re-normalize averaged profile to peak 1 if normalization is max
        if normalization == 'max':
            max_val = np.max(avg_profile) if np.any(avg_profile != 0) else 1
            if max_val > 0:
                avg_profile = avg_profile / max_val
        
        fwhm, fwmin = calculate_width_metrics(avg_profile, x)
        
        return pd.DataFrame([{
            'profile': avg_profile,
            'x': x,
            'n_voxels': sum(p['n_voxels'] for p in profiles),
            'fwhm': fwhm,
            'fwmin': fwmin
        }])
    elif len(profiles) == 1:
        return pd.DataFrame(profiles)
    else:
        return pd.DataFrame()

# Create widgets
subject_options = sorted(plac['subject'].unique().tolist())
region_options = sorted(plac['vertex_region'].unique().tolist())
normalization_options = ['mean_percent', 'max', 'maxs_avg', 'none']  # Added 'maxs_avg'
average_method_options = [
    ('Vertex pRFs, then avg pRFs', 'voxel_first'),
    ('Avg parameters, then pRF', 'param_first'),
    ('Vertex pRFs, then avg subject, then avg subjects', 'voxel_first_subject_avg')
]
ci_method_options = [
    ('Standard Error', 'se'),
    ('Percentile (95% CI)', 'percentile')
]

# Eccentricity filter widgets
min_ecc_widget = widgets.FloatSlider(
    value=0.5,
    min=0.0,
    max=4.5,
    step=0.5,
    description='Min Eccentricity:',
    disabled=False,
    continuous_update=False
)

max_ecc_widget = widgets.FloatSlider(
    value=4.5,
    min=0.0,
    max=4.5,
    step=0.5,
    description='Max Eccentricity:',
    disabled=False,
    continuous_update=False
)

# Zoom control widgets
xmin_widget = widgets.FloatSlider(
    value=-7.5,
    min=-10,
    max=0,
    step=0.5,
    description='X-min:',
    disabled=False,
    continuous_update=False
)

xmax_widget = widgets.FloatSlider(
    value=7.5,
    min=0,
    max=10,
    step=0.5,
    description='X-max:',
    disabled=False,
    continuous_update=False
)

ymin_widget = widgets.BoundedFloatText(
    value=-0.02,
    min=-0.3,
    max=0,
    step=0.01,
    description='Y-min:',
    disabled=False,
    continuous_update=False
)

ymax_widget = widgets.BoundedFloatText(
    value=0.01,
    min=0,
    max=0.3,
    step=0.01,
    description='Y-max:',
    disabled=False,
    continuous_update=False
)

# Other widgets
subject_widget = widgets.SelectMultiple(
    options=subject_options,
    value=[subject_options[0]],
    description='Subject:',
    disabled=False
)

region_widget = widgets.SelectMultiple(
    options=region_options,
    value=[region_options[0]],
    description='Region:',
    disabled=False
)

normalization_widget = widgets.RadioButtons(
    options=normalization_options,
    value='maxs_avg',
    description='Normalization:',
    disabled=False
)

show_placebo_widget = widgets.Checkbox(
    value=True,
    description='Show Placebo',
    disabled=False
)

show_memantine_widget = widgets.Checkbox(
    value=True,
    description='Show Memantine',
    disabled=False
)

average_method_widget = widgets.RadioButtons(
    options=average_method_options,
    value='voxel_first_subject_avg',
    description='Averaging Method:',
    disabled=False
)

ci_method_widget = widgets.RadioButtons(
    options=ci_method_options,
    value='se',  # Default to standard error
    description='CI Method:',
    disabled=False
)

filter_widget = widgets.Checkbox(
    value=True,
    description='Apply Inclusion Filter',
    disabled=False
)

verbose_widget = widgets.Checkbox(
    value=True,
    description='Verbose pRF information',
    disabled=False
)

show_fwhm_widget = widgets.Checkbox(
    value=True,
    description='Show FWHM lines',
    disabled=False
)

show_fwmin_widget = widgets.Checkbox(
    value=True,
    description='Show FWMIN lines',
    disabled=False
)

def update_plot(selected_subject, selected_region, normalization, show_placebo, show_memantine, 
                average_method, apply_filter, min_ecc, max_ecc,
                xmin, xmax, ymin, ymax, verbose, ci_method, show_fwhm, show_fwmin):
    global plac_filtered, mem_filtered
    
    # Recalculate filter if eccentricity range changed or filter toggled
    if apply_filter:
        current_filter_mask = calculate_filter_mask(pivoted, min_ecc, max_ecc)
        valid_indices = pivoted[current_filter_mask].index
        plac_filtered = plac.set_index(['subject', 'vertex_region', 'original_row']).loc[valid_indices].reset_index()
        mem_filtered = mem.set_index(['subject', 'vertex_region', 'original_row']).loc[valid_indices].reset_index()
    
    # Use filtered or unfiltered data
    if apply_filter:
        plac_data = plac_filtered.copy()
        mem_data = mem_filtered.copy()
    else:
        plac_data = plac.copy()
        mem_data = mem.copy()
    
    # Apply subject and region filters BEFORE calculating maxs_avg
    if selected_subject:
        plac_data = plac_data[plac_data['subject'].isin(selected_subject)]
        mem_data = mem_data[mem_data['subject'].isin(selected_subject)]
    if selected_region:
        plac_data = plac_data[plac_data['vertex_region'].isin(selected_region)]
        mem_data = mem_data[mem_data['vertex_region'].isin(selected_region)]
    
    groups_to_show = []
    if show_placebo:
        groups_to_show.append(('placebo', 'blue', plac_data))
    if show_memantine:
        groups_to_show.append(('memantine', 'red', mem_data))
    
    if not groups_to_show:
        print("Please select at least one group to show")
        return
    
    # Calculate maxs_avg if needed for normalization, now using filtered & selected data
    maxs_avg = None
    if normalization == 'maxs_avg' and len(groups_to_show) == 2:
        max_values = []
        for group_name, group_color, group_data in groups_to_show:
            if average_method == 'voxel_first':
                profiles_df = calculate_profiles_voxel_first(
                    group_data, 
                    normalization='None',
                    ci_method=ci_method
                )
            elif average_method == 'voxel_first_subject_avg':
                profiles_df = calculate_profiles_voxel_first_subject_average(
                    group_data,
                    normalization='None',
                    ci_method=ci_method
                )
            else:
                profiles_df = calculate_profiles_param_first(group_data, normalization='None')
            
            if not profiles_df.empty:
                max_values.append(np.max(profiles_df['profile'].iloc[0]))
        
        if len(max_values) == 2:
            maxs_avg = np.mean(max_values)

    # Prepare verbose output
    verbose_output = {"placebo": [], "memantine": []}
    param_values = {"placebo": {}, "memantine": {}}  # To store parameter values
    
    # Create figure and axes here
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 6))

    for group_name, group_color, group_data in groups_to_show:
        # Apply subject and region filters
        if selected_subject:
            group_data = group_data[group_data['subject'].isin(selected_subject)]
        if selected_region:
            group_data = group_data[group_data['vertex_region'].isin(selected_region)]
        
        if group_data.empty:
            print(f"No data available for {group_name} with current filters")
            continue
            
        # Fix normalization for max to recalc profiles with normalization='max'
        norm_to_use = normalization
        if normalization == 'max':
            norm_to_use = 'max'
        elif normalization == 'maxs_avg':
            norm_to_use = normalization
        else:
            norm_to_use = normalization
        
        if average_method == 'voxel_first':
            profiles_df = calculate_profiles_voxel_first(
                group_data, 
                normalization=norm_to_use,
                ci_method=ci_method,
                maxs_avg=maxs_avg
            )
        elif average_method == 'voxel_first_subject_avg':
            profiles_df = calculate_profiles_voxel_first_subject_average(
                group_data,
                normalization=norm_to_use,
                ci_method=ci_method,
                maxs_avg=maxs_avg
            )
        else:
            profiles_df = calculate_profiles_param_first(group_data, normalization=norm_to_use, maxs_avg=maxs_avg)
        
        if profiles_df.empty:
            print(f"No valid profiles calculated for {group_name} with current filters")
            continue
        
        # All plotting must be done here, inside update_plot, after ax1, ax2 defined
        label = 'Placebo' if group_name == 'placebo' else 'Memantine'

        # Prepare data for plotting
        x = profiles_df['x'].iloc[0]
        profile = profiles_df['profile'].iloc[0]
        fwhm = profiles_df['fwhm'].iloc[0]
        fwmin = profiles_df['fwmin'].iloc[0]
        
        # Calculate additional parameters
        profile_max = np.max(profile)
        profile_min = np.min(profile)
        max_min_ratio = profile_max / abs(profile_min) if profile_min != 0 else float('inf')
        
        # Get parameter averages from the original data
        beta1 = group_data['3'].mean()
        beta2 = group_data['5'].mean()
        sigma1 = group_data['2'].mean()
        sigma2 = group_data['6'].mean()
        
        # Store parameter values
        param_values[group_name] = {
            'max': profile_max,
            'min': profile_min,
            'max_min_ratio': max_min_ratio,
            'beta1': beta1,
            'beta2': beta2,
            'sigma1': sigma1,
            'sigma2': sigma2,
            'fwhm': fwhm,
            'fwmin': fwmin
        }
        
        if verbose:
            verbose_output[group_name].append(f"--- {label} Profile Information ---")
            verbose_output[group_name].append(f"Number of Vertices Plotted: {profiles_df['n_voxels'].sum()}")
            if average_method == 'voxel_first_subject_avg':
                verbose_output[group_name].append(f"Number of Subjects: {len(selected_subject) if selected_subject else 'All'}")
            verbose_output[group_name].append(f"FWHM: {fwhm:.2f} degrees")
            if not np.isnan(fwmin):
                verbose_output[group_name].append(f"FWMIN: {fwmin:.2f} degrees")
            else:
                verbose_output[group_name].append("FWMIN: Not applicable (no negative component)")
            
            if average_method in ['voxel_first', 'voxel_first_subject_avg']:
                peak_idx = np.argmax(profile)
                peak_x = x[peak_idx]  # Define peak_x here to fix NameError
                verbose_output[group_name].append(f"At peak response (x={peak_x:.1f}):")
                verbose_output[group_name].append(f"  Lower CI: {profiles_df['ci_low'].iloc[0][peak_idx]:.4f}")
                verbose_output[group_name].append(f"  Mean:     {profile[peak_idx]:.4f}")
                verbose_output[group_name].append(f"  Upper CI: {profiles_df['ci_high'].iloc[0][peak_idx]:.4f}")
                verbose_output[group_name].append(f"  CI Range: {profiles_df['ci_high'].iloc[0][peak_idx] - profiles_df['ci_low'].iloc[0][peak_idx]:.4f}")

        # Plotting code using ax1, ax2
        ax1.plot(x, profile, color=group_color, label=label)
        if average_method in ['voxel_first', 'voxel_first_subject_avg']:
            ci_low = profiles_df['ci_low'].iloc[0]
            ci_high = profiles_df['ci_high'].iloc[0]
            ax1.fill_between(x, ci_low, ci_high, color=group_color, alpha=0.2, label=f'{label} 95% CI')
            ax2.plot(x, profile, color=group_color)
            ax2.fill_between(x, ci_low, ci_high, color=group_color, alpha=0.2)
        else:
            ax2.plot(x, profile, color=group_color)

        # Add FWHM and FWMIN visualization if requested
        if show_fwhm and not np.isnan(fwhm):
            # Find the x positions where profile crosses half max
            above_hm = profile >= (np.max(profile) / 2)
            left_idx = np.argmax(above_hm)
            right_idx = len(profile) - left_idx - 1
            left_x = x[left_idx]
            right_x = x[right_idx]
            
            # Draw FWHM lines
            ax1.plot([left_x, right_x], [(np.max(profile) / 2), (np.max(profile) / 2)], color=group_color, linestyle=':', alpha=0.5)
            ax1.text((left_x + right_x)/2, (np.max(profile) / 2), f'FWHM={fwhm:.1f}°', 
                    ha='center', va='bottom', color=group_color)
            
            ax2.plot([left_x, right_x], [(np.max(profile) / 2), (np.max(profile) / 2)], color=group_color, linestyle=':', alpha=0.5)
            
        if not np.isnan(fwmin):
            min_val = np.min(profile)
            
            # Find the x positions where profile equals minimum
            above_min = profile <= min_val
            left_min_idx = np.argmax(above_min)
            right_min_idx = len(profile) - left_min_idx - 1
            left_min_x = x[left_min_idx]
            right_min_x = x[right_min_idx]
            
            # Draw FWMIN lines
            ax1.plot([left_min_x, right_min_x], [min_val, min_val], color=group_color, linestyle='--', alpha=0.5)
            ax1.text((left_min_x + right_min_x)/2, min_val, f'FWMIN={fwmin:.1f}°', 
                    ha='center', va='top', color=group_color)
            
            ax2.plot([left_min_x, right_min_x], [min_val, min_val], color=group_color, linestyle='--', alpha=0.5)
            ax2.text((left_min_x + right_min_x)/2, min_val, f'FWMIN={fwmin:.1f}°', ha='center', va='top', color=group_color)
    
    # Print verbose output side by side with additional parameters
    if verbose:
        placebo_lines = verbose_output.get('placebo', [])
        memantine_lines = verbose_output.get('memantine', [])
        
        # Add parameter information to the output
        for group in ['placebo', 'memantine']:
            if group in param_values and param_values[group]:
                params = param_values[group]
                verbose_output[group].append(f"  Profile Max: {params['max']:.4f}")
                verbose_output[group].append(f"  Profile Min: {params['min']:.4f}")
                verbose_output[group].append(f"  Max/Min Ratio: {params['max_min_ratio']:.4f}")
                verbose_output[group].append(f"  Beta1 (Center): {params['beta1']:.4f}")
                verbose_output[group].append(f"  Beta2 (Surround): {params['beta2']:.4f}")
                verbose_output[group].append(f"  Sigma1 (Center): {params['sigma1']:.4f}")
                verbose_output[group].append(f"  Sigma2 (Surround): {params['sigma2']:.4f}")
        
        # Update lines after adding parameters
        placebo_lines = verbose_output.get('placebo', [])
        memantine_lines = verbose_output.get('memantine', [])
        
        # Determine maximum number of lines for padding
        max_lines = max(len(placebo_lines), len(memantine_lines))
        placebo_lines += [''] * (max_lines - len(placebo_lines))
        memantine_lines += [''] * (max_lines - len(memantine_lines))
        
        # Print side by side
        print("\n" + "="*80)
        print("COMPARISON OF PROFILE INFORMATION".center(80))
        print("="*80)
        print(f"{'PLACEBO':<40}{'MEMANTINE':<40}")
        print("-"*80)
        for p_line, m_line in zip(placebo_lines, memantine_lines):
            print(f"{p_line:<40}{m_line:<40}")
        
    
    # Configure main plot
    ax1.set_xlabel('Distance from pRF Center (Degrees)')
    ax1.set_ylabel('BOLD Response (% of mean)' if normalization == 'mean_percent' else 
                  'BOLD Response (normalized)' if normalization == 'max' else 
                  'BOLD Response (normalized by avg max)' if normalization == 'maxs_avg' else 'BOLD Response')
    title_parts = []
    if selected_subject:
        title_parts.append(f"Subjects: {', '.join(map(str, selected_subject))}")
    if selected_region:
        title_parts.append(f"Regions: {', '.join(selected_region)}")
    title_parts.append(f"Method: {'Vertex-then-average' if average_method == 'voxel_first' else 'Vertex-then-subject-average' if average_method == 'voxel_first_subject_avg' else 'Parameter-average'}")
    if apply_filter:
        title_parts.append(f"Eccentricity: {min_ecc}-{max_ecc} deg")
    fig.suptitle('PRF Profiles: ' + (' vs '.join([g[0] for g in groups_to_show]) + (' (' + ', '.join(title_parts) + ')' if title_parts else '')))
    ax1.legend()
    ax1.grid(True)
    
    # Configure zoomed plot (negative values)
    ax2.set_xlabel('Distance from pRF Center (Degrees)')
    ax2.set_ylabel('BOLD Response')
    ax2.set_title('Surround Response Detail')
    ax2.grid(True)
    
    # Set zoomed plot limits based on widget values
    ax2.set_xlim(xmin, xmax)
    ax2.set_ylim(ymin, ymax)  # Ensure this line is present and not overridden later

    # Remove or comment out any dynamic y-limit calculations for ax2 below this line
    # (e.g., code that recalculates ymin/ymax and calls ax2.set_ylim again)
    
    # Add zero lines for reference
    ax2.axhline(0, color='black', linestyle='--', linewidth=0.5)
    ax2.axvline(0, color='black', linestyle='--', linewidth=0.5)
    
    plt.tight_layout()
    plt.show()

# Create a VBox to organize the widgets neatly
controls = widgets.VBox([
    widgets.HBox([subject_widget, region_widget]),
    widgets.HBox([min_ecc_widget, max_ecc_widget]),
    normalization_widget,
    widgets.HBox([show_placebo_widget, show_memantine_widget]),
    average_method_widget,
    ci_method_widget,
    widgets.HBox([filter_widget, verbose_widget]),
    widgets.HBox([show_fwhm_widget, show_fwmin_widget]),
    widgets.HTML("<h3>Surround Zooming Controls</h3>"),
    widgets.HBox([xmin_widget, xmax_widget]),
    widgets.HBox([ymin_widget, ymax_widget])
])

output = widgets.interactive_output(
    update_plot,
    {
        'selected_subject': subject_widget,
        'selected_region': region_widget,
        'normalization': normalization_widget,
        'show_placebo': show_placebo_widget,
        'show_memantine': show_memantine_widget,
        'average_method': average_method_widget,
        'apply_filter': filter_widget,
        'min_ecc': min_ecc_widget,
        'max_ecc': max_ecc_widget,
        'xmin': xmin_widget,
        'xmax': xmax_widget,
        'ymin': ymin_widget,
        'ymax': ymax_widget,
        'verbose': verbose_widget,
        'ci_method': ci_method_widget,
        'show_fwhm': show_fwhm_widget,
        'show_fwmin': show_fwmin_widget
    }
)

display(controls, output)