## Setup

### Initial Imports

In [None]:
import os
import pandas as pd
import numpy as np
import sys
from pathlib import Path

# Ensure the parent directory is in the system path for module imports
sys.path.append(str(Path.cwd().parent))

from dataclasses import dataclass
from typing import List, Optional

from pytest import param
from zmq import has

import scipy
from scipy import integrate, interpolate

import numpy as np
import matplotlib.pyplot as plt
import matplotx
from matplotlib.colors import to_rgb

### Plot Setup

In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import matplotlib.pyplot as plt
import matplotx

# plt.style.use('science')  # Use ggplot style for all plots
plt.rcParams['figure.figsize'] = (10, 6)  # Default figure size
plt.rcParams['figure.dpi'] = 300  # Default figure dpi
plt.rcParams['font.size'] = 12  # Default font size
plt.rcParams['lines.linewidth'] = 2  # Default line width
plt.rcParams['axes.labelsize'] = 14  # Default label size
plt.rcParams['axes.titlesize'] = 16  # Default title size
plt.rcParams['xtick.labelsize'] = 12  # Default x-tick label size
plt.rcParams['ytick.labelsize'] = 12  # Default y-tick label size
plt.rcParams['legend.fontsize'] = 12  # Default legend font size
plt.rcParams['figure.titlesize'] = 18  # Default figure title size

## Gathering subjects' data

### Data Structures

data structure to store subject data;

In [None]:
from dataclasses import dataclass
import pandas as pd
import numpy as np

@dataclass
class SubjectData:
    name: str = None
    pid: str = None
    nb: int = None
    session: str = None

    width_nas: float = None
    width_tem: float = None
    width_inf: float = None
    width_sup: float = None
    max_slope_nas: float = None
    max_slope_tem: float = None
    max_slope_inf: float = None
    max_slope_sup: float = None

    oct_bump_X: float = None
    oct_bump_Y: float = None
    oct_width_X: float = None
    oct_width_Y: float = None
    oct_max_slope: float = None
    oct_depth: float = None
    oct_flatness: float = None

    age: float = None
    axial_length: float = None
    spherical_equiv: float = None
    sex: int = None

    eccs: np.ndarray = None
    density_X: pd.Series = None
    density_Y: pd.Series = None
    density_fit_X: pd.Series = None
    density_fit_Y: pd.Series = None

    cvi_X: pd.Series = None
    cvi_Y: pd.Series = None
    gcl_ipl_X: pd.Series = None
    gcl_ipl_Y: pd.Series = None
    onl_X: pd.Series = None
    onl_Y: pd.Series = None
    inl_opl_X: pd.Series = None
    inl_opl_Y: pd.Series = None
    rnfl_X: pd.Series = None
    rnfl_Y: pd.Series = None
    chrd_X: pd.Series = None
    chrd_Y: pd.Series = None
    pr_rpe_X: pd.Series = None
    pr_rpe_Y: pd.Series = None
    os_X: pd.Series = None
    os_Y: pd.Series = None

    nb_cones: float = None
    nb_cones_fit: float = None

    width_gcl_X: float = None
    width_gcl_Y: float = None
    min_thick_gcl: float = None



@dataclass
class FoveaParams:
    """Class for storing fovea 3D fitted parameters."""
    # Patient information
    subject: str
    patient_id: str
    subject_folder: str
    trial_name: str
    age: Optional[int] = None

    # Fitted parameters
    A00: Optional[float] = None
    A10: Optional[float] = None
    A01: Optional[float] = None
    A20: Optional[float] = None
    A02: Optional[float] = None
    A11: Optional[float] = None
    foveal_depth: Optional[float] = None
    foveal_center_X: Optional[float] = None
    foveal_width_X: Optional[float] = None
    foveal_center_Y: Optional[float] = None
    foveal_width_Y: Optional[float] = None
    foveal_max_slope: Optional[float] = None
    foveal_flatness: Optional[float] = None
    foveal_volume: Optional[float] = None



In [None]:

# here to avoid having to rerun the pipeline for -
# all subjects everytime i want to test something on the model.
 
# Since the list of subjects is ordered by strings , it goes from 10 to 100 to 103 etc...
# which requires a bit of work to get the first 5 subjects

#It will later be used to extract the first 5 subjects from the list of subject_data

take_first_five = False
first_five_subjects = ["Subject10","Subject100","Subject101","Subject104","Subject105"]

### Function definitions

#### Foveal Data Extraction

In [None]:
def extract_fovea_data(base_path: str) -> List[FoveaParams]:
    """
    Extract fovea parameters from CSV files with known structure.
    
    Args:
        base_path: Path to the base directory containing subject folders
        (subjfolder/trialfolder/layer_new/fovea_3d_fitted_params.csv)
    
    Returns:
        List of FoveaParams objects, one for each found CSV file
    """
    fovea_data = []
    
    # Get only the subject directories (directories starting with "Subject")
    try:
        subject_dirs = [d for d in os.listdir(base_path) 
                      if os.path.isdir(os.path.join(base_path, d)) and d.startswith("Subject")]
    except FileNotFoundError:
        print(f"Base path not found: {base_path}")
        return []
    
    # For each subject directory, get the session directories
    for subject_dir in subject_dirs:
        subject_path = os.path.join(base_path, subject_dir)
        
        # Extract subject number
        import re
        subject_match = re.search(r'Subject(\d+)', subject_dir)
        if not subject_match:
            continue
            
        subject_num = subject_match.group(1)
        patient_id = f"{subject_num}"
        
        try:
            # Get session directories (directories starting with "Session")
            session_dirs = [d for d in os.listdir(subject_path) 
                          if os.path.isdir(os.path.join(subject_path, d)) and d.startswith("Session")]
        except FileNotFoundError:
            continue
        
        # For each session directory, check if the CSV file exists
        for session_dir in session_dirs:
            session_path = os.path.join(subject_path, session_dir)
            csv_path = os.path.join(session_path, "layer_new", "fovea_3d_fitted_params.csv")
            
            # Check if the CSV file exists
            if os.path.isfile(csv_path):
                # try:
                # Read CSV file
                df = pd.read_csv(csv_path, sep=';', header=None, names=['param', 'value'])
                
                # Create basic FoveaParams object with patient info
                fovea_obj = FoveaParams(
                    patient_id=str(patient_id),
                    subject=f"Subject{patient_id}",
                    subject_folder=subject_dir,
                    trial_name=session_dir
                )
                print(f"Processing file: {csv_path} for patient {patient_id}")
                
                # Fill in parameter values
                for _, row in df.iterrows():
                    param_name = row['param']
                    param_value = row['value']
                    print(f"Processing {param_name} with value {param_value} for patient {patient_id}")
                    
                    if param_value == "params":
                        print(f"Skipping parameter {param_name} for patient {patient_id} as it is 'params'")
                        continue
                    
                    # Check if this parameter exists in our class
                    if hasattr(fovea_obj, param_name) or hasattr(fovea_obj, f"foveal_{param_name}"):
                        print(f"Setting {param_name} for patient {patient_id}")
                        try:
                            # Convert to float and set attribute
                            print(f"Trying to set {param_name} for patient {patient_id}")
                            setattr(fovea_obj, param_name, float(param_value))
                            print(f"Successfully set {param_name} for patient {patient_id}")
                        except:
                            
                            print(f"Error setting {param_name} for patient {patient_id}: {param_value}")

                    if hasattr(fovea_obj, f"foveal_{param_name}"):
                        try:
                            print(f"Trying to set foveal_{param_name} for patient {patient_id}")
                            setattr(fovea_obj, f"foveal_{param_name}", float(param_value))

                        except:
                            print(f"Skipping parameter {param_name} for patient {patient_id} ")

                            pass
                    
                fovea_data.append(fovea_obj)
                # except Exception as e:
                    # print(f"Error processing file {csv_path}: {str(e)}")
    
    return fovea_data

def save_to_dataframe(fovea_data: List[FoveaParams], output_file: str = "fovea_parameters.csv") -> pd.DataFrame:
    """
    Convert list of FoveaParams objects to a pandas DataFrame and save to CSV.
    
    Args:
        fovea_data: List of FoveaParams objects
        output_file: Path to save the CSV file
        
    Returns:
        DataFrame containing all fovea data
    """
    # Convert to list of dictionaries
    data_dicts = [vars(f) for f in fovea_data]
    
    # Create DataFrame
    df = pd.DataFrame(data_dicts)
    
    # Save to CSV
    df.to_csv(output_file, index=False)
    
    return df



## Loading data

### Imports


In [None]:
from pathlib import Path
from typing import List, Tuple, Dict




from src.cell.analysis.constants import MM_PER_DEGREE
from src.cell.layer.helpers import gaussian_filter_nan
from src.configs.parser import Parser

### Loading

In [None]:
Parser.initialize()

subjects_sessions = [[int(n) for n in s.strip().split()] for s in open('../src/processed.txt').readlines()] 


try:
    sheet = pd.ExcelFile(r'V:\Studies\AOSLO\data\cohorts\AOSLO healthy\DATA_HC+DM.xlsx').parse('Healthy', header=0, nrows=45, index_col=0)
    sheet.index = sheet.index.map(lambda x: f'Subject{x}')
    age_dict = ((sheet['Date of visit'] - sheet['DDN']).dt.days / 365).to_dict()
    axial_dict = sheet['AL D (mm)'].where(sheet['Laterality'] == 'OD', sheet['AL G (mm)']).to_dict()
    spherical_dict = sheet['Equi Sph D'].where(sheet['Laterality'] == 'OD', sheet['Equi Sph G']).to_dict()
    sex_dict = sheet['Sexe'].map(lambda x: 1 if x == 'F' else 0).to_dict()
except:
    # if the excel file is not found, use a hardcoded dictionary
    age_dict = {}
base_path = Path(r'P:\AOSLO\_automation\_PROCESSED\Photoreceptors\Healthy\_Results')

# look-up table for subject and session numbers


# subject for which OCTs are tilted (white dot is not well aligned with PR+RPE peak)
# see explanation in `PRxRLT_expmanual.ipynb`
oct_to_exclude = {
    13, 18, 20, 25, 26, 30, 35, 42, 46, 66, 100, 105,
} 


subjects_data: List[SubjectData] = []
for subject_n, session_n in subjects_sessions:
    if subject_n in oct_to_exclude:
        continue

    sd = SubjectData()
    sd.name = f'Subject{subject_n}'
    sd.pid = f'AOHC_{subject_n}'
    sd.nb = subject_n
    sd.session = f'Session{session_n}'

    #
    path = base_path / sd.name / sd.session
    print(f'Loading {sd.name} {sd.session}...')

    # record subject's metadata from the excel sheet
    sd.age = age_dict[sd.name]
    sd.axial_length = axial_dict[sd.name]
    sd.spherical_equiv = spherical_dict[sd.name]
    sd.sex = sex_dict[sd.name]

    # record foveal shape parameters (populated by `src/save_layer_features.ipynb`)
    df_oct = pd.read_csv(path / Parser.get_layer_thickness_dir() / 'fovea_3d_fitted_params.csv', sep=';', index_col=0)
    sd.oct_bump_X = df_oct.loc['A20', 'params']
    sd.oct_bump_Y = df_oct.loc['A02', 'params']
    sd.oct_width_X = df_oct.loc['width_X', 'params'] * np.sqrt(2 * 2.8) / MM_PER_DEGREE # in °
    sd.oct_width_Y = df_oct.loc['width_Y', 'params'] * np.sqrt(2 * 2.8) / MM_PER_DEGREE # in °
    sd.oct_max_slope = df_oct.loc['max_slope', 'params']
    sd.oct_depth = df_oct.loc['depth', 'params'] # in mm
    sd.oct_flatness = df_oct.loc['flatness', 'params']
    # sd.oct_volume = df_oct.loc['volume', 'params']

    # record cone density and fitted parameters (populated by `src/cell/analysis/density_analysis_pipeline_manager.py`)
    df_density = pd.read_csv(path / Parser.get_density_analysis_dir() / 'densities.csv', sep=';', index_col=0)
    df_raw_density_x = pd.read_csv(path / Parser.get_density_analysis_dir() / 'densities_raw_x.csv', sep=';', index_col=0)
    df_raw_density_y = pd.read_csv(path / Parser.get_density_analysis_dir() / 'densities_raw_y.csv', sep=';', index_col=0)
    
    sd.width_nas = df_density['width_nasal'].iloc[0]
    sd.width_tem = df_density['width_temporal'].iloc[0]
    sd.width_inf = df_density['width_inferior'].iloc[0]
    sd.width_sup = df_density['width_superior'].iloc[0]
    sd.max_slope_nas = df_density['max_slope_nasal'].iloc[0]
    sd.max_slope_tem = df_density['max_slope_temporal'].iloc[0]
    sd.max_slope_inf = df_density['max_slope_inferior'].iloc[0]
    sd.max_slope_sup = df_density['max_slope_superior'].iloc[0]
    sd.density_X = df_density['dens_smthd_X']
    sd.density_Y = df_density['dens_smthd_Y']
    sd.density_fit_X = df_density['dens_fit_X']
    sd.density_fit_Y = df_density['dens_fit_Y']
    
    sd.eccs = df_density.index.to_numpy()

    # record layer thicknesses (populated by `src/save_layer_features.ipynb`)
    df_thick = pd.read_csv(path / Parser.get_density_analysis_dir() / 'results.csv', sep=',', index_col=0, skiprows=1).query('-10 <= index <= 10')
    sd.cvi_X = df_thick['CVI_X']
    sd.cvi_Y = df_thick['CVI_Y']
    sd.gcl_ipl_X = df_thick['GCL+IPL_X']
    sd.gcl_ipl_Y = df_thick['GCL+IPL_Y']
    sd.onl_X = df_thick['ONL_X']
    sd.onl_Y = df_thick['ONL_Y']
    sd.inl_opl_X = df_thick['INL+OPL_X']
    sd.inl_opl_Y = df_thick['INL+OPL_Y']
    sd.rnfl_X = df_thick['RNFL_X']
    sd.rnfl_Y = df_thick['RNFL_Y']
    sd.chrd_X = df_thick['Choroid_X']
    sd.chrd_Y = df_thick['Choroid_Y']
    sd.pr_rpe_X = df_thick['PhotoR+RPE_X']
    sd.pr_rpe_Y = df_thick['PhotoR+RPE_Y']
    sd.os_X = df_thick['OS_X']
    sd.os_Y = df_thick['OS_Y']

    subjects_data.append(sd)

#### Populating Additional fields based on the previously gathered data

In [None]:
def get_nb_cones(ecc: np.ndarray, dens_X: pd.Series, dens_Y: pd.Series, radius: float, smoothen: bool = True) -> float:
    
    '''
    Given the cone density profiles along the X and Y axes, compute the total number of cones within a disk of radius `radius` (in degree) centered at the fovea by linearly interpolating (radially) the density profiles and integrating over the disk.
    '''
    smthd_x = gaussian_filter_nan(dens_X, sigma=4) if smoothen else dens_X.to_numpy()
    smthd_y = gaussian_filter_nan(dens_Y, sigma=4) if smoothen else dens_Y.to_numpy()
   
    x_amax = np.nanargmax(smthd_x)
    p = np.polyfit(ecc[x_amax-2:x_amax+3], smthd_x[x_amax-2:x_amax+3], 2)
    x_amax = -p[1] / (2 * p[0])

    y_amax = np.nanargmax(smthd_y)
    p = np.polyfit(ecc[y_amax-2:y_amax+3], smthd_y[y_amax-2:y_amax+3], 2)
    y_amax = -p[1] / (2 * p[0])

    R = np.linspace(0.0001, radius, 500) # radius in degrees
    disk = np.r_[
        np.interp(x_amax + R, ecc, smthd_x),
        np.interp(x_amax - R, ecc, smthd_x),
        np.interp(y_amax + R, ecc, smthd_y),
        np.interp(y_amax - R, ecc, smthd_y)
    ]
    
    norm_coef = MM_PER_DEGREE**2 * 2 * np.pi
    # integrate cone density over disk to get total nb of cones
    return norm_coef * np.trapz(np.nanmean(disk, axis=0) * R, R)

RADIUS = 3.33 # degree
for sd in subjects_data:
    sd.nb_cones = get_nb_cones(sd.eccs, sd.density_X, sd.density_Y, radius = RADIUS)
    sd.nb_cones_fit = get_nb_cones(sd.eccs, sd.density_fit_X, sd.density_fit_Y, radius = RADIUS, smoothen=False)

In [None]:
from scipy.signal import find_peaks

def adjust_flat(gcl_data: np.ndarray, peak_left: int, peak_right: int) -> np.ndarray:
    slope = (gcl_data[peak_right] - gcl_data[peak_left]) / (peak_right - peak_left)
    transformed_gcl = gcl_data - slope * (np.arange(len(gcl_data)) - peak_left)
    return transformed_gcl

def get_gcl_width(gcl: pd.Series) -> Tuple[float, float]:
    '''
    Given the GCL+IPL thickness profile, compute the width of the pit as well as the minimum thickness of the layer. Here, the width of the pit is defined as the distance between the two points where the thickness is 20% of the depth of the pit. The depth of the pit is defined as the difference between the thickness surrounding the pit and the thickness at the pit's bottom.
    '''
    # name = gcl.name
    gcl_to_plot = gcl.copy()
    eccs = gcl[np.abs(gcl.index) <= 6].index.to_numpy()
    gcl = gcl.interpolate(method='polynomial', order=1)[eccs].to_numpy()
    # plt.plot(eccs, gcl, label=name)
    smooth_param = 3
    peak_left = peak_right = []
    while not (len(peak_left) >= 1 and len(peak_right) >= 1) and smooth_param < 10:
        smoothed_gcl = gaussian_filter_nan(gcl, smooth_param)
        peaks = find_peaks(smoothed_gcl)[0]
        peak_left  = [peak for peak in peaks if peak < len(smoothed_gcl) / 3]
        peak_right = [peak for peak in peaks if peak > 2 * len(smoothed_gcl) / 3]
        smooth_param += 1
    assert len(peak_left) >= 1 and len(peak_right) >= 1, f'No peaks found for {gcl.name}'
    peak_left = round(np.mean(peak_left))   
    peak_right = round(np.mean(peak_right))
    adjusted_gcl = adjust_flat(gcl, peak_left, peak_right)
    smoothed_aj_gcl = gaussian_filter_nan(adjusted_gcl, 2)

    y_min = np.nanmin(smoothed_aj_gcl[peak_left:peak_right])
    y_target = y_min + (smoothed_aj_gcl[peak_left] - y_min) / 5
    intercepts = np.where(np.diff(np.sign(smoothed_aj_gcl - y_target)))[0]
    leftmost = eccs[intercepts[0]]
    rightmost = eccs[intercepts[-1]+1]
    width_pit_gcl = rightmost - leftmost

    indicies = np.argpartition(gcl, 10)[:10]
    p = np.polyfit(eccs[indicies], gcl[indicies], 2)
    if p[0] == 0:
    #     # gcl_to_plot.plot()
    #     plt.plot(eccs, gcl, label='gcl')
        plt.plot(np.sort(eccs[indicies]), np.polyval(p, np.sort(eccs[indicies])), '--')
    min_thickness_gcl = np.polyval(p, -p[1] / (2 * p[0]))
    return width_pit_gcl, min_thickness_gcl

for sd in subjects_data:
    width_gcl_x, min_thick_x = get_gcl_width(sd.gcl_ipl_X)
    width_gcl_y, min_thick_y = get_gcl_width(sd.gcl_ipl_Y)
    sd.width_gcl_X = width_gcl_x
    sd.width_gcl_Y = width_gcl_y
    sd.min_thick_gcl = min(min_thick_x, min_thick_y)
    # print(f'{sd.name:>10}: {width_gcl_x:.2f}°, {depth_gcl_x:.4f}, {width_gcl_y:.2f}°, {depth_gcl_y:.4f}')
    # plt.xlim(-6, 6)
    # plt.legend()
    # plt.title(sd.name)
    # plt.show()

In [None]:
eccs = subjects_data[0].eccs
layer_names = ['rnfl', 'gcl_ipl', 'inl_opl', 'onl', 'pr_rpe', 'os', 'chrd']
names_r = {'rnfl': 'RNFL', 'gcl_ipl': 'GCL+IPL', 'inl_opl': 'INL+OPL', 'onl': 'ONL', 'pr_rpe': 'PhotoR+RPE', 'os': 'OS', 'chrd': 'Choroid', 'cones': 'Cone Density'}

## General Function Definition

In [None]:
from src.shared.helpers.direction import Direction


def preprocess_functional_feature(data: np.ndarray, standardization: str = 'inter') -> np.ndarray:
    '''
    Preprocess a functional feature (functional feature such as cone density or layer thickness, for which the feature is a function of eccentricity) by (Z-)standardizing it.
    Given `data` matrix should have shape (n_subjects, n_eccentricities).
    
    - For an intra-indivual analysis, use `standardization='intra'` to standardize within subjects (i.e. within each row). This removes inter-subject variability.
    - For an inter-individual analysis, use `standardization='inter'` to standardize across subjects, eccentricity-wise (i.e. within each column). Removes eccentricity-level variability, focuses on between-patient trends
    '''
    if standardization == 'inter':
        mean = np.mean(data, axis=0, keepdims=True)
        std = np.std(data, axis=0, keepdims=True)
        return (data - mean) / std
    if standardization == 'intra':
        mean = np.nanmean(data, axis=1, keepdims=True)
        std = np.nanstd(data, axis=1, keepdims=True)
        return (data - mean) / std
    return data

def preprocess_functional_data(direction: Direction, standardization: str = 'inter', toLog : bool = True) -> Dict[str, np.ndarray]:
    '''
    Preprocess functional data (e.g. cone density, layer thicknesses) for a given direction (X or Y) by (Z-)standardizing it.
    '''

    layer_fds = {
        layer: preprocess_functional_feature(
            np.array([getattr(s, f'{layer}_{direction.value}') for s in subjects_data]), standardization
        )
        for layer in layer_names
    }
    if toLog:
        cone_density_fd = preprocess_functional_feature(
            np.array([np.log(getattr(s, f'density_fit_{direction.value}')) for s in subjects_data]), standardization
        )
    else:
        cone_density_fd = preprocess_functional_feature(
            np.array([(getattr(s, f'density_fit_{direction.value}')) for s in subjects_data]), standardization
        )

    cone_density_nonfit = preprocess_functional_feature(
            np.array([(getattr(s, f'density_{direction.value}')) for s in subjects_data]), standardization
        )
    # return {'cones': cone_density_fd, "nonfit": cone_density_nonfit, **layer_fds}
    return {'cones': cone_density_fd, "nonfit": cone_density_nonfit, **layer_fds}

from scipy.stats import kendalltau, pearsonr, spearmanr
import seaborn as sns


def kendall_pval(x,y):
    return kendalltau(x,y)[1]

def pearsonr_pval(x,y):
    return pearsonr(x,y)[1]

def spearmanr_pval(x,y):
    return spearmanr(x,y, nan_policy = "omit")[1]


## Density Analysis

In [None]:
cone_density_fd_X = preprocess_functional_data(Direction.X, standardization='none', toLog = False)['cones']
cone_density_fd_Y = preprocess_functional_data(Direction.Y, standardization='none', toLog = False)['cones']

eccs_in_MM = eccs * MM_PER_DEGREE

In [None]:
# print(eccs_in_MM)

# Normalize by the max value for each of the 201 parameters (column-wise normalization)
normalized_data_x = cone_density_fd_X/ np.max(cone_density_fd_X, axis=1, keepdims=True)
normalized_data_y = cone_density_fd_Y/ np.max(cone_density_fd_Y, axis=1, keepdims=True)

# Compute the mean across the 33 patients
mean_data_x= np.mean(normalized_data_x, axis=0)
mean_data_y= np.mean(normalized_data_y, axis=0)

mean_data_x_nonnorm = np.mean(cone_density_fd_X, axis=0)
mean_data_y_nonnorm = np.mean(cone_density_fd_Y, axis=0)

# After your existing code for knee point detection
# plot_flipped_comparison(eccs_in_MM, cone_density_fd_X, cone_density_fd_Y)

# You might also want to see the same plot for normalized data
# plot_flipped_comparison(eccs_in_MM, normalized_data_x, normalized_data_y)

import numpy as np
import matplotlib.pyplot as plt
import matplotx
from matplotlib.colors import to_rgb

# Get the background color of the current style
with plt.style.context(matplotx.styles.pacoty):
    fig, ax = plt.subplots()
    bg_color = ax.get_facecolor()  # Get the background color
    plt.close(fig)  # Close the figure we just created

# Convert to RGB if it's not already
bg_rgb = to_rgb(bg_color) if isinstance(bg_color, str) else bg_color[:3]

# Define breakpoints (positive and negative)
breakpoints = [0.175, 0.25, 0.75, 1.25, 2.75]
# all_breakpoints = sorted([-x for x in breakpoints] + breakpoints)

# Create a gradient from white to the background color
num_steps = len(breakpoints)
colors = [
    (1, 1, 1),  # white
    tuple(0.9 + 0.1 * x for x in bg_rgb),  # 80% white + 20% bg
    tuple(0.7 + 0.3 * x for x in bg_rgb),  # 60% white + 40% bg
    tuple(0.5 + 0.5 * x for x in bg_rgb),  # 40% white + 60% bg
    tuple(0.3 + 0.7 * x for x in bg_rgb),  # 20% white + 80% bg
    tuple(x for x in bg_rgb)  # full background color
]

# Define retinal regions and their boundaries
retinal_regions = [
    {'name': 'Perifovea', 'start': 1.25, 'end': np.inf, 'color': colors[5]},
    {'name': 'Parafovea', 'start': 0.75, 'end': 1.25, 'color': colors[3]},
    {'name': 'Fovea', 'start': 0.175, 'end': 0.75, 'color': colors[2]},
    # {'name': 'Faz', 'start': 0.175, 'end': 0.25, 'color': colors[1]},
    {'name': 'Foveola', 'start': -0.175, 'end': 0.175, 'color': colors[0]},
    # {'name': 'Faz', 'start': -0.25, 'end': -0.175, 'color': colors[1]},
    {'name': 'Fovea', 'start': -0.75, 'end': -0.175, 'color': colors[2]},
    {'name': 'Parafovea', 'start': -1.25, 'end': -0.75, 'color': colors[3]},
    {'name': 'Perifovea', 'start': -np.inf, 'end': -1.25, 'color': colors[5]}
]

# Create regions with corresponding colors
regions_with_colors = [(r['start'], r['end'], r['color']) for r in retinal_regions]

def plot_with_regions(data_x, data_y):
    # with plt.style.context(matplotx.styles.pacoty):
    # First plot
    plt.figure(figsize=(12, 6))
    
    # Add gradient background
    for left, right, color in regions_with_colors:
        plt.axvspan(left, right, facecolor=color, alpha=1, zorder=0)
    
    # Add vertical lines and labels for retinal regions
    for region in retinal_regions:
        if not np.isinf(region['start']):
            plt.axvline(region['start'], color='lightgrey', linestyle='--', alpha=1, linewidth=0.5, zorder = 1)
        if not np.isinf(region['end']):
            plt.axvline(region['end'], color='lightgrey', linestyle='--', alpha=1, linewidth=0.5, zorder = 1)
        
        # Calculate position for label (middle of region)
        if np.isinf(region['start']):
            x_pos = region['end'] - 0.3
        elif np.isinf(region['end']):
            x_pos = region['start'] + 0.3
        else:
            x_pos = (region['start'] + region['end']) / 2
        # if region['name'] == 'Foveola':
        #     plt.text(x_pos, 1.05*plt.ylim()[1], region['name'], 
        #             rotation=0, verticalalignment='top', 
        #             horizontalalignment='center', fontsize=10,
        #             bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=0))
        # else:
        #     plt.text(x_pos, 1.05*plt.ylim()[1], region['name'], 
        #         rotation=90, verticalalignment='top', 
        #         horizontalalignment='center', fontsize=10,
        #         bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=0))
        
    # Plot data
    for i in range(33):
        plt.plot(eccs_in_MM, data_x[i, :], color='blue', alpha=0.2, zorder = 2)
        plt.plot(eccs_in_MM, data_y[i, :], color='blue', alpha=0.2, zorder = 2)

    mean_data_x = np.mean(data_x, axis=0)
    mean_data_y = np.mean(data_y, axis=0)


    # computes complete mean of all the data
    mean_data = np.mean(np.concatenate([data_x, data_y], axis=0), axis=0)


    plt.plot(eccs_in_MM, mean_data, color='red')
    # plt.plot(eccs_in_MM, mean_data_y, color='green')
    
    plt.xlabel("Eccentricity (mm)")
    plt.ylabel("Normalized Value")
    plt.title("Normalized Data Across Patients with Mean Line")
    plt.grid(False)
    plt.show()
    plt.close()


plot_with_regions(cone_density_fd_X, cone_density_fd_Y)
# Plotting the cone density data for each subject
plot_with_regions(normalized_data_x, normalized_data_y) 

In [None]:
#creates a scatter plot between the maximum of the cone density and the cone nb
# Function to create scatter plots comparing max/min cone density with total number of cones



def plot_max_vs_total_cones(subjects_data, title_prefix="Maximum"):
    """
    Creates a scatter plot comparing maximum (or minimum) cone density with total number of cones
    
    Parameters:
    subjects_data (list): List of SubjectData objects
    title_prefix (str): Prefix for the plot title, either "Maximum" or "Minimum"
    """
    
    # Extract data based on whether we're looking at maximum or minimum
    if title_prefix == "Maximum":
        density_values = [np.nanmax(sd.density_fit_X) for sd in subjects_data]
    else:  # Minimum
        density_values = [np.nanmin(sd.density_fit_X) for sd in subjects_data]
    
    nb_cones = [sd.nb_cones_fit for sd in subjects_data]
    
    # Create scatter plot
    with plt.style.context(matplotx.styles.pacoty):
        plt.figure(figsize=(10, 6))
        
        # Plot scatter points
        plt.scatter(density_values, nb_cones, s=80, alpha=0.7, edgecolor='black')
        
        # Add linear regression line
        slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(density_values, nb_cones)
        x_line = np.linspace(min(density_values), max(density_values), 100)
        y_line = slope * x_line + intercept
        plt.plot(x_line, y_line, 'r-', alpha=0.8, 
                 label=f'y = {slope:.2f}x + {intercept:.2f}\nR² = {r_value**2:.3f}, p = {p_value:.4f}')
        
        # Customize plot
        plt.xlabel(f"{title_prefix} Cone Density (cones/mm²)", fontsize=12)
        plt.ylabel("Total Number of Cones", fontsize=12)
        plt.title(f"{title_prefix} Cone Density vs Total Number of Cones", fontsize=14)
        plt.grid(False)
        plt.legend()
        plt.tight_layout()
        plt.show()

# Plot maximum cone density vs total cones
plot_max_vs_total_cones(subjects_data, "Maximum")

# Plot minimum cone density vs total cones
plot_max_vs_total_cones(subjects_data, "Minimum")


In [None]:
radii = [0.150, 0.300, 1.2]
indices = []
ratio = []

# mean_data_x = cone_density_fd_X[0]
# mean_data_y = cone_density_fd_Y[0]
for i, radius in enumerate (radii):
    index_pos = np.searchsorted(eccs_in_MM, radius)
    index_neg = np.searchsorted(eccs_in_MM, -radius)
    ratio_values =[mean_data_x[index_pos], mean_data_y[index_pos],
                    mean_data_x[index_neg], mean_data_y[index_neg]]
    print(ratio_values)
    ratio.append( np.mean(ratio_values))
    print (f"Ratio at {radii[i]} mm, for eccentricity of radius {eccs_in_MM[index_pos]} mm is {ratio[i]}")

### Comparison with Zhang's 2015 Integration Results

In [None]:



def integrate_cone_density_circle(eccs_in_MM: np.array,
                                  cone_density_fd_X: np.array,
                                  cone_density_fd_Y: np.array,
                                  radius: float = 1.0,
                                  num_r: int = 400,
                                  num_theta: int = 400,
                                  exclude_center: bool = False):
    """
    Integrates cone density over a circular region by interpolating between the horizontal (X)
    and vertical (Y) density measurements for each subject.
    
    Also computes:
      - The total integrated density over the circle (a single value per subject).
      - The maximum density along the horizontal and vertical meridians.
      - The cumulative integrated density as a function of radius (in mm).
    
    Parameters:
        eccs_in_MM (np.array): 1D array of eccentricities (radial positions in MM).
        cone_density_fd_X (np.array): 2D array of cone densities along the horizontal meridian.
                                      Shape: (n_subjects, len(eccs_in_MM)).
        cone_density_fd_Y (np.array): 2D array of cone densities along the vertical meridian.
                                      Shape: (n_subjects, len(eccs_in_MM)).
        radius (float): Maximum radius (in MM) for the circular integration.
        num_r (int): Number of radial grid points.
        num_theta (int): Number of angular grid points.
        exclude_center (bool): If True, excludes the central region (sets r_min to 0.3 MM).
        
    Returns:
        mean_int (float): Mean integrated density across subjects.
        std_int (float): Standard deviation of the integrated densities.
        min_int (float): Minimum integrated density.
        max_int (float): Maximum integrated density.
        cov_int (float): Coefficient of Variation (std/mean * 100).
        int_results (np.array): Array of integrated density values (one per subject).
        max_x_results (np.array): Array of maximum densities along the X meridian per subject.
        max_y_results (np.array): Array of maximum densities along the Y meridian per subject.
        r_grid (np.array): The radial grid used for integration.
        cumulative_integrations (np.array): 2D array (n_subjects x num_r) of cumulative integrated
                                             density as a function of radius.
    """
    # Set the lower bound for integration in r
    r_min = 0.3 if exclude_center else 0.0
    
    # Create the polar grid for integration:
    r = np.linspace(r_min, radius, num_r)
    theta = np.linspace(0, 2 * np.pi, num_theta)
    
    int_results = []
    max_x_results = []
    max_y_results = []
    cumulative_integrations_list = []
    n_subjects = cone_density_fd_X.shape[0]
    
    # Loop over each subject (each row in the data)
    for i in range(n_subjects):
        # Create interpolation functions for the horizontal and vertical densities
        f_x = interpolate.interp1d(eccs_in_MM, cone_density_fd_X[i, :],
                                   bounds_error=False, fill_value="extrapolate")
        f_y = interpolate.interp1d(eccs_in_MM, cone_density_fd_Y[i, :],
                                   bounds_error=False, fill_value="extrapolate")
        
        # Evaluate the interpolated densities on the radial grid
        density_x = f_x(r)  # shape: (num_r,)
        density_y = f_y(r)  # shape: (num_r,)
        
        # Compute maximum density along each meridian within the integration region
        max_x = np.max(density_x)
        max_y = np.max(density_y)
        max_x_results.append(max_x)
        max_y_results.append(max_y)
        
        # Compute the density field on the polar grid.
        # For each (r, theta) point, combine the two measurements as:
        # density(r, theta) = density_x(r)*cos²(theta) + density_y(r)*sin²(theta)
        density_field = (density_x[:, None] * np.cos(theta)**2 +
                         density_y[:, None] * np.sin(theta)**2)
        
        # Multiply by the Jacobian (r) to account for the area element in polar coordinates
        density_field_weighted = density_field * r[:, None]
        
        # Integrate first over theta (axis=1) then over r using Simpson’s rule.
        integral_theta = integrate.simpson(density_field_weighted, x=theta, axis=1)
        integrated_density = integrate.simpson(integral_theta, x=r)
        int_results.append(integrated_density)
        
        # Compute cumulative integration as a function of r using cumulative trapezoidal rule.
        # This gives the integrated cone density from r_min up to each r value.
        cumulative_integration = integrate.cumulative_trapezoid(integral_theta, r, initial=0)
        cumulative_integrations_list.append(cumulative_integration)
    
    # Convert lists to NumPy arrays for further statistics
    int_results = np.array(int_results)
    max_x_results = np.array(max_x_results)
    max_y_results = np.array(max_y_results)
    cumulative_integrations = np.array(cumulative_integrations_list)
    
    # Compute the integration metrics for the total integrated density
    mean_int = np.mean(int_results)
    std_int = np.std(int_results)
    min_int = np.min(int_results)
    max_int = np.max(int_results)
    cov_int = (std_int / mean_int * 100) if mean_int != 0 else np.nan
    
    return (mean_int, std_int, min_int, max_int, cov_int, int_results,
            max_x_results, max_y_results, r, cumulative_integrations)

In [None]:


(mean_int, std_int, min_int, max_int, cov_int, int_results, max_x_results, 
 max_y_results, r_grid, cumulative_integrations) = integrate_cone_density_circle(
    eccs_in_MM, cone_density_fd_X, cone_density_fd_Y, radius=1.0, exclude_center=False)

print(f"Mean integrated cone density: {mean_int}")
print(f"Standard deviation: {std_int}")
print(f"Minimum integrated density: {min_int}")
print(f"Maximum integrated density: {max_int}")
print(f"Coefficient of Variation (COV): {cov_int:.2f}%\n")

# Show results for each subject:
print("Subject\tIntegrated Density\tMax (X direction)\tMax (Y direction)")
for i in range(len(int_results)):
    print(f"{i+1}\t{int_results[i]:.3f}\t\t\t{max_x_results[i]:.3f}\t\t\t{max_y_results[i]:.3f}")

# Plot histogram of integrated densities for all subjects
plt.figure(figsize=(8, 6))
plt.hist(int_results, bins=100, edgecolor='black', alpha=0.7)
plt.xlabel("Integrated Cone Density")
plt.ylabel("Frequency")

plt.grid(False)
plt.title("Histogram of Integrated Cone Density Values")
plt.show()


# Compute mean and standard deviation across subjects for each radius value
mean_cum = np.mean(cumulative_integrations, axis=0)
std_cum = np.std(cumulative_integrations, axis=0)

# Define breakpoints for the gradient
breakpoints = [0.175, 0.25, 0.75, 1.25, 2.75]
all_breakpoints = sorted([-x for x in breakpoints] + breakpoints)

# with plt.style.context(matplotx.styles.pacoty):
fig, ax = plt.subplots()
bg_color = ax.get_facecolor()  # Get the background color
plt.close(fig)  # Close the figure we just created

# Convert to RGB if it's not already
bg_rgb = to_rgb(bg_color) if isinstance(bg_color, str) else bg_color[:3]

# Create a smooth gradient from white to the background color
colors = [
    (1, 1, 1),  # white
    tuple(0.9 + 0.1 * x for x in bg_rgb),  # 80% white + 20% bg
    tuple(0.7 + 0.3 * x for x in bg_rgb),  # 60% white + 40% bg
    tuple(0.5 + 0.5 * x for x in bg_rgb),  # 40% white + 60% bg
    tuple(0.3 + 0.7 * x for x in bg_rgb),  # 20% white + 80% bg
    tuple(x for x in bg_rgb)  # full background color
]

# Define retinal regions (for background gradient)
retinal_regions = [

    {'name': 'Parafovea', 'start': 0.75, 'end': 1, 'color': colors[5]},
    {'name': 'Fovea', 'start': 0.175, 'end': 0.75, 'color': colors[4]},
    {'name': 'Faz', 'start': 0.175, 'end': 0.25, 'color': colors[3]},
    {'name': 'Foveola', 'start': 0, 'end': 0.175, 'color': colors[0]},
    
]

# Create regions with corresponding colors
regions_with_colors = [(r['start'], r['end'], r['color']) for r in retinal_regions]

# Create the plot with the gradient background
plt.figure(figsize=(8, 6))

# Add gradient background (use axvspan for each region)
for left, right, color in regions_with_colors:
    plt.axvspan(left, right, facecolor=color, alpha=1.0, zorder=0)

# Plot the mean and shaded area for standard deviation
plt.plot(r_grid, mean_cum, label="Mean Cumulative Cone Density", color='blue')
plt.fill_between(r_grid, mean_cum - std_cum, mean_cum + std_cum, alpha=0.1, label="±1 Std Dev", color='blue')

# Add vertical lines and labels for retinal regions
for region in retinal_regions:
    if not np.isinf(region['start']):
        plt.axvline(region['start'], color='black', linestyle='--', alpha=0.3, linewidth=0.5)
    if not np.isinf(region['end']):
        plt.axvline(region['end'], color='black', linestyle='--', alpha=0.3, linewidth=0.5)
    


# Final labels and title
plt.xlabel("Eccentricity (mm)")
plt.ylabel("Cumulative Integrated Cone Density")
plt.title("Cumulative Cone Density vs. Eccentricity")
plt.legend()
plt.grid(False)

# Show the plot
plt.show()


## Intra/Inter- Individual Analysis

### Functions for intra-/inter-individual analysis

In [None]:
from typing import Iterable
import warnings
from scipy.stats import spearmanr, pearsonr, binomtest
import statsmodels.formula.api as smf

def mixedlm(cd: np.ndarray, lt: np.ndarray, pids: np.ndarray, eccs: np.ndarray, standardization: str = 'inter') -> Tuple[float, float]:
    data = pd.DataFrame({'Subject': pids, 'Eccentricity': eccs, 'LayerThickness': lt, 'ConeDensity': cd})
    if standardization == 'intra':
        model = smf.mixedlm("LayerThickness ~ ConeDensity", data, groups="Subject", re_formula="~Eccentricity")
    else:
        model = smf.mixedlm("LayerThickness ~ ConeDensity", data, groups="Subject")
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        result = model.fit()
    return result.fe_params['ConeDensity'], result.pvalues['ConeDensity']

def plot_slice_correlation(layer_name: str, degree: float | Iterable | None, direction: Direction, eccs: np.ndarray, comp_to = 'cones', restrict_to_os = False, standardization: str = 'inter'):
    
    fd = preprocess_functional_data(direction, standardization)

    n_subject = fd['cones'].shape[0]

    if degree is None:
        degree = eccs
        ecc_str = 'across all eccs'
    elif isinstance(degree, Iterable):
        degree = np.round(np.array(degree), 1)
        ecc_str = f'on {np.min(degree)}° to {np.max(degree)}°'
    else:
        degree = np.round(float(degree), 1)
        ecc_str = f'at {degree}°'

    indices = np.searchsorted(eccs, degree)

    cd = fd[comp_to][:, indices].flatten()
    lt = fd[layer_name][:, indices].flatten()
    pids = np.repeat(np.arange(n_subject), len(indices))
    eccentricities = np.tile(eccs[indices], n_subject)
    if restrict_to_os:
        os_mask = ~np.isnan(fd['os'][:, indices].flatten())
        cd = cd[os_mask]
        lt = lt[os_mask]
        pids = pids[os_mask]
        eccentricities = eccentricities[os_mask]
    valid = ~np.isnan(cd) & ~np.isnan(lt)
    if not valid.any():
        print(f'No valid data for {layer_name} {ecc_str}.')
        return
    cd = cd[valid]
    lt = lt[valid]
    pids = pids[valid]
    eccentricities = eccentricities[valid]

    LL_UR = (lt > 0) != (cd > 0)
    UL_LR = (lt > 0) == (cd > 0)

    spearman_corr = spearmanr(cd, lt)
    pearson_corr = pearsonr(cd, lt)
    if (perform_mlm := len(indices) > 1):
        mixedlm_corr = mixedlm(cd, lt, pids, eccentricities)
    binom_corr = binomtest(
        LL_UR.sum() if pearson_corr.correlation < 0 else UL_LR.sum(),
        LL_UR.sum() + UL_LR.sum(), 
        p=0.5, alternative='greater'
    )
    with(plt.style.context(matplotx.styles.pacoty)):
        
        plot_limit = max(3, 0.1 + np.ceil(np.max(np.abs([lt, cd])) * 10) / 10)
        # plot_limit=3.2
        # colors = iter(plt.get_cmap('Accent', 33)(np.arange(33)).tolist())
        # for _cd, _lt in zip(cone_density_fd.data_matrix[:,:,0], layer_fds[layer_name].data_matrix[:, :, 0]):
        #     plt.scatter(_cd, _lt, 2, color=next(colors), alpha=0.6)
        plt.scatter(cd[LL_UR], lt[LL_UR], 5, color='blue', alpha=0.6, label=f'n = {LL_UR.sum()}')
        plt.scatter(cd[UL_LR], lt[UL_LR], 5, color='red', alpha=0.6, label=f'n = {UL_LR.sum()}')

        plt.axhline(0, color='black', linewidth=0.5)
        plt.axvline(0, color='black', linewidth=0.5)
        plt.fill_between([-plot_limit, 0], -plot_limit, 0, color='red', alpha=0.05)
        plt.fill_between([0, plot_limit], 0, plot_limit, color='red', alpha=0.05)
        plt.fill_between([-plot_limit, 0], 0, plot_limit, color='blue', alpha=0.05)
        plt.fill_between([0, plot_limit], -plot_limit, 0, color='blue', alpha=0.05)

        indices = np.argsort(cd)
        x_p = np.linspace(-plot_limit, plot_limit, 100)
        if perform_mlm:
            p = np.polyfit(cd[indices], lt[indices], 1)
            plt.plot(x_p, np.polyval(p, x_p), color='olive', label=f'fit: y = {p[0]:.4g}x')
        else:
            slope_mlm = mixedlm_corr[0]
            plt.plot(x_p, slope_mlm * x_p, color='olive', label=f'MLM: y = {slope_mlm:.4g}x')
        # plt.plot(x_p, np.sign(p[0]) * x_p, '--', color='olive', label=f'identity', alpha=0.6)

        plt.ylim(-plot_limit, plot_limit)
        plt.xlim(-plot_limit, plot_limit)
        plt.gca().set_aspect('equal', adjustable='box')
        comp_to_str = 'Cone density' if comp_to=='cones' else f'{names_r[comp_to]} thickness'
        plt.xlabel(f'{comp_to_str} (Z-Score)')
        plt.ylabel(f'{names_r[layer_name]} thickness, (Z-Score)')
        # plot legend on right side, out of the plot
        plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
        # below the legend, add box of text displaying the statistics results
        plt.text(1.02, 0.7, 
                f'Spearman: {spearman_corr[0]:.3f}, p={spearman_corr[1]:.2g}'
                f'\nPearson: {pearson_corr[0]:.3f}, p={pearson_corr[1]:.2g}'
                f'\nBinomial test: p={binom_corr.pvalue:.2g}'
                f'\nMixedLM: {mixedlm_corr[0]:.3f}, p={mixedlm_corr[1]:.2g}' if perform_mlm else '', 
                fontsize=12, ha='left', transform=plt.gca().transAxes)

        plt.title(f'{names_r[layer_name]} thickness vs {comp_to_str} across subjects & {ecc_str}, {direction.value}-axis\nStandardized {standardization}-individually')
        plt.show()

### Intra-Individual


Have a look at the violin plots as well: `P:\AOSLO\_automation\_PROCESSED\Photoreceptors\Healthy\_Results\all_stats_new\spearman_correlation_for_*.png`.
The following plots are just an other way to visualize the same thing since both are intra-individual analysis.

In [None]:
features = ['rnfl', 'pr_rpe', 'onl', 'gcl_ipl']#, 'cones']
for i in range(len(features)):
    # for j in range(i+1, len(features)):
    layer, comp_to = features[i], 'cones'#features[j]
    # for direction in Direction:
    plot_slice_correlation(layer, None, Direction.X, eccs, comp_to=comp_to, standardization='intra')
    
# plot_slice_correlation('onl', None, Direction.X, eccs, comp_to='cones', restrict_to_os=False, standardization='intra')

### Inter-individual

In [None]:
features = ['rnfl', 'pr_rpe', 'os', 'onl', 'cones']

step = 0.1
deg = np.arange(-1, 1+step, step)

for i in range(len(features)):
    for j in range(i+1, len(features)):
        layer, comp_to = features[i], features[j]
        plot_slice_correlation(layer, deg, Direction.X, eccs, comp_to=comp_to, standardization='inter')

## ECC-wise analysis

### General function definition

#### Computing

In [None]:
from typing import Callable

def compute_correlations_eccwise(direction: Direction, eccs: np.ndarray, correlation_fun: Callable[[np.ndarray, np.ndarray], Tuple[float, float]] = lambda cd, lt: spearmanr(cd, lt, nan_policy='omit')
) -> Dict[str, Tuple[np.ndarray]]:
    """
    Compute inter-individual Spearman correlations between standardized cone density and each standardized layer, for each eccentricity.
    In a nutshell, for each eccentricity, we correlate the deviations of the cone density and the layer thickness from their respective means,
    to assess inter-individual relationships between cone density and retinal layer thicknesses.
    """
    fd = preprocess_functional_data(direction, standardization='inter')
    results = {}
    for layer, layer_fd in fd.items():
        if layer == 'cones':
            continue
        pointwise_corr = np.zeros(len(eccs))
        pointwise_pvalues = np.zeros(len(eccs))
        
        # Iterate over each eccentricity to compute Spearman correlation and p-values
        for i in range(len(eccs)):
            cone_density_values = fd['cones'][:, i]
            layer_values = layer_fd[:, i]
            corr, pv = correlation_fun(cone_density_values, layer_values)
            # corr, pv = spearmanr(cone_density_values, layer_values, nan_policy='omit')
            pointwise_corr[i] = corr
            pointwise_pvalues[i] = pv
        results[layer] = (pointwise_corr, pointwise_pvalues)
    return results




#### Plotting

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, Tuple, List
from matplotlib.colors import to_rgb

def plot_correlations_eccwise(results: Dict[str, Tuple[np.ndarray]], eccs: np.ndarray, direction: str, layers_to_plot: List[str] | None = None, abs_: bool = False, pv_threshold: float | None = None, corr_name: str = 'Spearman'):
    """
    Plot the pointwise correlations of cone density with each layer, for each eccentricity.
    """
    with plt.style.context(matplotx.styles.pacoty):
        
        f = lambda x: np.abs(x) if abs_ else x
        colors = plt.get_cmap('Accent_r', len(results))  # Using the Accent_r colormap
        plt.figure(figsize=(10, 6), dpi=300)
        
        # Create a gradient background based on the previous approach
        breakpoints = [0.175, 0.25, 0.75, 1.25, 2.75]
        all_breakpoints = sorted([-x for x in breakpoints] + breakpoints)

        # Create a smooth gradient from white to the background color
        with plt.style.context(matplotx.styles.pacoty):
            fig, ax = plt.subplots()
            bg_color = ax.get_facecolor()  # Get the background color
            plt.close(fig)  # Close the figure we just created
        bg_rgb = to_rgb(bg_color) if isinstance(bg_color, str) else bg_color[:3]
        
        # Create a gradient from white to the background color
        gradient_colors = [
            (1, 1, 1),  # white
            tuple(0.8 + 0.2 * x for x in bg_rgb),  # 80% white + 20% bg
            tuple(0.6 + 0.4 * x for x in bg_rgb),  # 60% white + 40% bg
            tuple(0.4 + 0.6 * x for x in bg_rgb),  # 40% white + 60% bg
            tuple(0.2 + 0.8 * x for x in bg_rgb),  # 20% white + 80% bg
            bg_rgb  # full background color
        ]
        
        # Define retinal regions (for background gradient)
        retinal_regions = [
            {'name': 'Perifovea', 'start': 1.25, 'end': np.inf, 'color': gradient_colors[5]},
            {'name': 'Parafovea', 'start': 0.75, 'end': 1.25, 'color': gradient_colors[3]},
            {'name': 'Fovea', 'start': 0.25, 'end': 0.75, 'color': gradient_colors[2]},
            {'name': 'Faz', 'start': 0.175, 'end': 0.25, 'color': gradient_colors[1]},
            {'name': 'Foveola', 'start': -0.175, 'end': 0.175, 'color': gradient_colors[0]},
            {'name': 'Faz', 'start': -0.25, 'end': -0.175, 'color': gradient_colors[1]},
            {'name': 'Fovea', 'start': -0.75, 'end': -0.25, 'color': gradient_colors[2]},
            {'name': 'Parafovea', 'start': -1.25, 'end': -0.75, 'color': gradient_colors[3]},
            {'name': 'Perifovea', 'start': -np.inf, 'end': -1.25, 'color': gradient_colors[5]}
        ]
        
        regions_with_colors = [(r['start']/MM_PER_DEGREE, r['end']/MM_PER_DEGREE, r['color']) for r in retinal_regions]

        # Plot the gradient background first
        for left, right, color in regions_with_colors:
            plt.axvspan(left, right, facecolor=color, alpha=1.0, zorder=0)
        
        def __plot_smooth_alpha(xs, ys, alphas, color, linewidth_fun):
            assert len(xs) == len(ys) == len(alphas) == 2
            n_steps = 2 + int(np.abs(np.diff(alphas))[0] / 0.01)
            x = np.linspace(xs[0], xs[1], n_steps)
            y = np.interp(x, xs, ys)
            alpha = np.linspace(alphas[0], alphas[1], n_steps)
            
            for j in range(n_steps - 1):
                plt.plot(
                    x[j:j+2], 
                    y[j:j+2],
                    color=color,
                    alpha=alpha[j],
                    linewidth=linewidth_fun(alpha[j]),
                    label=None
                )

        min_alpha = 0.15
        plots = []
        for i, layer in enumerate(layers_to_plot or results.keys()):
            correlations, pvalues = results[layer]
            alphas = np.minimum(
                1, 
                np.where(np.isnan(pvalues), 0, min_alpha + (1 - min_alpha) * (1 - pvalues) ** 4)
            )  # Higher alpha for smaller p-values, lower alpha for larger p-values
            if pv_threshold is not None:
                alphas = np.where(pvalues <= pv_threshold, alphas, 0)

            for j in range(len(eccs) - 1):
                __plot_smooth_alpha(
                    eccs[j:j+2], 
                    f(correlations[j:j+2]),
                    color=colors(i),
                    alphas=alphas[j:j+2],
                    linewidth_fun=lambda a: 0.3 + a * 0.7,
                )
            
            label = names_r[layer] if layer != 'nonfit' else "nonfit"
            plots.append(plt.scatter(
                eccs, f(correlations), 
                label=label,
                color=colors(i),
                alpha=alphas,
                edgecolors='none',
                s=20 + alphas * 30
            ))

        legend = plt.legend(loc='best')

        for lh in legend.legend_handles:
            lh.set_alpha(np.ones_like(alphas))  # Set alpha of legend markers to 1

        lim = 0.8 #np.ceil(max([abs(y) for y in plt.ylim()]) * 10) / 10
        plt.ylim(0 if abs_ else -lim, lim)
        plt.xlim(-10, 10)

        plt.title(f'Pointwise {corr_name} Correlation of Cone Density with Layers, {direction}-Axis')
        plt.xlabel('Eccentricity [°]')
        ylabel = f'{corr_name} correlation coefficient'
        plt.ylabel(f'|{ylabel}|' if abs_ else ylabel)

        plt.grid(False)
        plt.show()



#### Covariate Removal

In [None]:
import numpy as np
from typing import Dict, Callable, Tuple, Optional
from scipy import stats  # used for potential further diagnostics (if needed)
import matplotlib.pyplot as plt


def density_to_spacing(cd: np.ndarray) -> np.ndarray:
    """
    Convert cone density [cells/mm²] to cone spacing [arcmin].
    """
    return np.sqrt(2 / cd / np.sqrt(3)) / MM_PER_DEGREE * 60 # in arcmin

def get_data_from_range(layer_name, left, right, only_valid: bool = True, flatten:bool = True):
    range_eccs = np.argwhere((left <= eccs) & (eccs < right)).flatten()

    cd = np.array([getattr(s, f'density_fit_X') for s in subjects_data])[:, range_eccs]
    os = np.array([getattr(s, f'{layer_name}_X') for s in subjects_data])[:, range_eccs]

    if flatten:
        cd = cd.flatten()
        os = os.flatten()
    # print(os)
    if only_valid:
        valid = ~np.isnan(cd) & ~np.isnan(os)
        cd = cd[valid]
        os = os[valid]
    return cd, os

def preprocess_functional_feature(data: np.ndarray, 
                                  covariates: Optional[np.ndarray] = None, 
                                  standardization: str = 'inter', 
                                  perform_regression: bool = False,
                                  verbose: bool = False) -> np.ndarray:
    """
    Preprocess a functional feature (e.g. cone density or layer thickness as a function of eccentricity)
    by optionally removing the effects of one or more covariates via linear regression and then (Z-)standardizing it.
    `data` should have shape (n_subjects, n_features).

    Parameters:
      data: Array with dimensions (n_subjects, n_features)
      covariates: Either a 1D array of shape (n_subjects,) or a 2D array of shape (n_subjects, n_covariates)
      standardization: 'inter' to standardize across subjects (per column) or 'intra' to standardize within subjects (per row)
      perform_regression: If True and covariates is not None, the function will remove covariate effects via linear regression.
      verbose: If True, prints diagnostic output (e.g. regression coefficients, R², and diagnostic plots)

    Returns:
      The preprocessed (residualized and standardized) data.
    """
    if perform_regression and covariates is not None:
        # Ensure covariates is 2D (n_subjects, n_covariates)
        if covariates.ndim == 1:
            covariates = covariates.reshape(-1, 1)

        # Create an array for the residuals
        residuals = np.empty_like(data)
        n_cov = covariates.shape[1]

        # For each feature column (e.g. each eccentricity), perform a multiple linear regression with an intercept.
        for i in range(data.shape[1]):
            y = data[:, i]
            # Only consider valid entries (finite y and finite covariates)
            valid_mask = np.isfinite(y) & np.all(np.isfinite(covariates), axis=1)
            if np.sum(valid_mask) < n_cov + 1:
                residuals[:, i] = np.nan
                if verbose:
                    print(f"Feature {i}: Not enough valid data points for regression.")
            else:
                # Build design matrix with intercept
                X = covariates[valid_mask, :]
                X = np.column_stack((np.ones(X.shape[0]), X))
                y_valid = y[valid_mask]
                # Solve the least squares regression: y_valid = X * beta
                beta, residuals_sum, rank, s = np.linalg.lstsq(X, y_valid, rcond=None)
                # Compute predicted values and residuals
                y_hat = X.dot(beta)
                res = y_valid - y_hat
                # Prepare a full array of residuals (with NaN where data is invalid)
                full_res = np.full(y.shape, np.nan)
                full_res[valid_mask] = res
                residuals[:, i] = full_res

                if verbose:
                    # Calculate R-squared for this regression (if possible)
                    ss_res = np.sum(res**2)
                    ss_tot = np.sum((y_valid - np.mean(y_valid))**2)
                    r_squared = 1 - ss_res / ss_tot if ss_tot != 0 else np.nan
                    print(f"Feature {i} regression coefficients: {beta}")
                    print(f"Feature {i} R-squared: {r_squared}")

                    # Diagnostic plot: Predicted vs. Actual
                    plt.figure()
                    plt.scatter(y_valid, y_hat, c='blue', label='Fitted values')
                    plt.plot(y_valid, y_valid, 'r--', label='Ideal')
                    plt.xlabel('Actual values')
                    plt.ylabel('Predicted values')
                    plt.title(f'Feature {i}: Predicted vs. Actual')
                    plt.legend()
                    plt.show()

                    # Diagnostic plot: Residuals vs. Predicted
                    plt.figure()
                    plt.scatter(y_hat, res, c='green', label='Residuals')
                    plt.axhline(0, color='red', linestyle='--')
                    plt.xlabel('Predicted values')
                    plt.ylabel('Residuals')
                    plt.title(f'Feature {i}: Residuals vs. Predicted')
                    plt.legend()
                    plt.show()
        data = residuals

        # Additional diagnostic: Plot the correlation matrix of the residualized features (across eccentricities)
        if verbose:
            try:
                corr_matrix = np.corrcoef(data, rowvar=False)
                plt.figure()
                plt.imshow(corr_matrix, interpolation='nearest', cmap='viridis')
                plt.title("Correlation Matrix among Residualized Features")
                plt.colorbar()
                plt.show()
            except Exception as e:
                print("Could not compute correlation matrix for diagnostics:", e)

    # Standardize the data using only valid (non-NaN) values
    if standardization == 'inter':
        mean = np.nanmean(data, axis=0, keepdims=True)
        std = np.nanstd(data, axis=0, keepdims=True)
        data = (data - mean) / std
    elif standardization == 'intra':
        mean = np.nanmean(data, axis=1, keepdims=True)
        std = np.nanstd(data, axis=1, keepdims=True)
        data = (data - mean) / std

    return data

def preprocess_functional_data(direction, 
                               standardization: str = 'inter', 
                               toLog: bool = True, 
                               covariates: Optional[np.ndarray] = None, 
                               perform_regression: bool = False,
                               verbose: bool = False) -> Dict[str, np.ndarray]:
    """
    Preprocess functional data (e.g. cone density, layer thicknesses) for a given direction (X or Y)
    by (Z-)standardizing it and optionally removing covariate effects.

    It processes each layer as well as cone density (optionally logging cone density values).

    Parameters:
      direction: An object with a .value attribute indicating the direction (e.g., 'X' or 'Y')
      standardization: 'inter' or 'intra' standardization mode
      toLog: If True, apply logarithm to the cone density values before processing
      covariates: A 1D or 2D array of covariate values to regress out
      perform_regression: If True, perform regression to remove the effects of covariates
      verbose: If True, prints regression diagnostics and produces diagnostic plots
          
    Returns:
      A dictionary with keys for the cone density ('cones') and each layer.
    """
    # Assuming subjects_data and layer_names are defined in your context
    layer_fds = {
        layer: preprocess_functional_feature(
            data=np.array([getattr(s, f'{layer}_{direction.value}') for s in subjects_data]),
            covariates=covariates,
            standardization=standardization,
            perform_regression=perform_regression,
            verbose=verbose
        )
        for layer in layer_names
    }
    
    if toLog:
        cone_density_fd = preprocess_functional_feature(
            data=np.array([np.log(getattr(s, f'density_fit_{direction.value}')) for s in subjects_data]),
            covariates=covariates,
            standardization=standardization,
            perform_regression=perform_regression,
            verbose=verbose
        )
    else:
        cone_density_fd = preprocess_functional_feature(
            data=np.array([getattr(s, f'density_fit_{direction.value}') for s in subjects_data]),
            covariates=covariates,
            standardization=standardization,
            perform_regression=perform_regression,
            verbose=verbose
        )

    cone_density_nonfit = preprocess_functional_feature(
        data=np.array([getattr(s, f'density_{direction.value}') for s in subjects_data]),
        covariates=covariates,
        standardization=standardization,
        perform_regression=perform_regression,
        verbose=verbose
    )
    
    # Return a dictionary with the cone density and the layers.
    return {'cones': cone_density_fd, **layer_fds}

def compute_correlations_eccwise(direction, 
                                 eccs: np.ndarray, 
                                 correlation_fun: Callable[[np.ndarray, np.ndarray], Tuple[float, float]] = lambda cd, lt: stats.spearmanr(cd, lt, nan_policy='omit'),
                                 covariates: Optional[np.ndarray] = None, 
                                 perform_regression: bool = False,
                                 verbose: bool = False) -> Dict[str, Tuple[np.ndarray, np.ndarray]]:
    """
    Compute inter-individual Spearman correlations between standardized cone density and each standardized layer,
    for each eccentricity. The functional data are first preprocessed by optionally removing covariate effects.

    Parameters:
      direction: An object with a .value attribute (e.g., 'X' or 'Y')
      eccs: Array of eccentricity values
      correlation_fun: A function that computes correlation and p-value between two 1D arrays
      covariates: A 1D or 2D array of covariate values to regress out
      perform_regression: If True, perform regression to remove covariate effects before computing correlations
      verbose: If True, prints diagnostic information during preprocessing
          
    Returns:
      A dictionary mapping each layer (other than 'cones') to a tuple of arrays: (correlation values, p-values)
    """
    
    fd = preprocess_functional_data(
        direction=direction, 
        standardization='inter', 
        covariates=covariates, 
        perform_regression=perform_regression,
        verbose=verbose
    )
    
    results = {}
    for layer, layer_fd in fd.items():
        if layer == 'cones':
            continue
        pointwise_corr = np.zeros(len(eccs))
        pointwise_pvalues = np.zeros(len(eccs))
        
        # For each eccentricity compute the Spearman correlation and p-value.
        for i in range(len(eccs)):
            cone_density_values = fd['cones'][:, i]
            layer_values = layer_fd[:, i]
            corr, pv = correlation_fun(cone_density_values, layer_values)
            pointwise_corr[i] = corr
            pointwise_pvalues[i] = pv
        results[layer] = (pointwise_corr, pointwise_pvalues)
    return results


### Retinal Scheme Heatmap

In [None]:
def create_circular_heatmap_multi_direction(results_dict, layer_name, 
                                           directions=None, 
                                           angular_sectors=None,
                                           show_significance=True):
    """
    Create a circular heatmap for a specific layer showing mean slopes
    in different regions for multiple directions.
    
    Parameters:
    -----------
    results_dict : dict
        Dictionary with keys as direction names and values as layer_results
        e.g., {"Superior": layer_results_S, "Nasal": layer_results_N, ...}
    layer_name : str
        Name of the layer to plot
    directions : list of str, optional
        Order of directions to plot. If None, uses keys from results_dict
    angular_sectors : list of dict, optional
        Custom angular sector definitions
    """
    fig, ax = plt.subplots(figsize=(10, 10))
    
    # Define default angular sectors if not provided
    if angular_sectors is None:
        angular_sectors = [
            {"name": "Superior", "start": 45, "end": 135},
            {"name": "Nasal", "start": 315, "end": 45},
            {"name": "Inferior", "start": 225, "end": 315},
            {"name": "Temporal", "start": 135, "end": 225}
        ]
    
    # Use provided direction order or get from results_dict
    if directions is None:
        directions = list(results_dict.keys())
    
    # Define the radial boundaries
    radial_bounds = {
        "central": (0, 0.5),
        "inner": (0.5, 1.5),
        "outer": (1.5, 3.0)
    }
    
    # For visualization, normalize these to 0-1 range
    max_radius = 3.0
    norm_bounds = {
        "central": (0, 0.5/max_radius),
        "inner": (0.5/max_radius, 1.5/max_radius),
        "outer": (1.5/max_radius, 3.0/max_radius)
    }
    
    # Collect all slope values for normalization
    all_slopes = []
    for direction, layer_results in results_dict.items():
        if layer_name in layer_results:
            for slope_dict in layer_results[layer_name]:
                if slope_dict["mean"] != -1:
                    all_slopes.append(slope_dict["mean"])
    
    # Set color scale bounds
    if all_slopes:
        vmin = -0.3 if min(all_slopes) > -0.5 else min(all_slopes) 
        vmax = 0.3 if max(all_slopes) < 0.5 else max(all_slopes)
    else:
        vmin, vmax = -1, 1
    
    # Create colormap
    cmap = plt.cm.plasma # Red-Blue reversed (red for positive, blue for negative)
    
    # Draw the heatmap
    for sector in angular_sectors:
        # Find the corresponding direction data
        direction_name = sector["name"]
        if direction_name not in results_dict:
            print(f"Warning: Direction '{direction_name}' not found in results_dict")
            continue
        
        layer_results = results_dict[direction_name]
        if layer_name not in layer_results:
            print(f"Warning: Layer '{layer_name}' not found in {direction_name} results")
            continue
        
        slopes_data = layer_results[layer_name]
        
        # Draw wedges for each radial region
        for region_name, (r_inner, r_outer) in norm_bounds.items():
            # Find the corresponding slope value
            slope_value = -1  # default
            for slope_dict in slopes_data:
                if slope_dict["name"] == region_name:
                    slope_value = slope_dict["mean"]
                    break
            
            theta1 = sector["start"]
            theta2 = sector["end"]
            
            # Handle wraparound for sectors that cross 0 degrees
            if theta2 < theta1:
                theta2 += 360
            
            # Create wedge
            if slope_value != -1:
                color = cmap((slope_value - vmin) / (vmax - vmin))
            else:
                color = 'lightgray'
            
            wedge = patches.Wedge(
                center=(0, 0),
                r=r_outer,
                theta1=theta1,
                theta2=theta2,
                width=r_outer - r_inner,
                facecolor=color,
                edgecolor='white',
                linewidth=2
            )
            ax.add_patch(wedge)
            
            # Add text annotation for slope value and significance
            # Calculate center position of wedge
            angle_mid = np.radians((theta1 + theta2) / 2)
            r_mid = (r_inner + r_outer) / 2
            x_text = r_mid * np.cos(angle_mid)
            y_text = r_mid * np.sin(angle_mid)
            
            if slope_value != -1:
                # Check if we have slopes list to calculate significance
                p_value = None
                if show_significance:
                    for slope_dict in slopes_data:
                        if slope_dict["name"] == region_name and "slopes" in slope_dict and len(slope_dict["slopes"]) > 1:
                            # Calculate t-test p-value for slopes different from zero
                            from scipy.stats import ttest_1samp
                            slopes_array = slope_dict["slopes"]
                            t_stat, p_value = ttest_1samp(slopes_array, 0)
                            break
                
                # Format text with significance indicator
                text = f'{slope_value:.2f}'
                if p_value is not None:
                    if p_value < 0.001:
                        text += '\n ***'
                    elif p_value < 0.01:
                        text += '\n **'
                    elif p_value < 0.05:
                        text += '\n *'
                    else:
                        text += ' \n ns'  # not significant
                
                ax.text(x_text, y_text, text, 
                       ha='center', va='center', 
                       fontsize=8, fontweight='bold',
                       color='white' if abs(slope_value - np.mean([vmin, vmax])) > (vmax - vmin) * 0.3 else 'black')
    
    # Add circle boundaries
    for region_name, (r_inner, r_outer) in norm_bounds.items():
        circle = plt.Circle((0, 0), r_outer, fill=False, color='white', linewidth=2)
        ax.add_patch(circle)
    
    # Add radial lines to separate directions
    for sector in angular_sectors:
        angle = np.radians(sector["start"])
        x_end = 1.1 * np.cos(angle)
        y_end = 1.1 * np.sin(angle)
        ax.plot([0, x_end], [0, y_end], 'white', linewidth=2)
    
    # Set axis properties
    ax.set_xlim(-1.3, 1.3)
    ax.set_ylim(-1.35, 1.3)  # Extended to accommodate legend
    ax.set_aspect('equal')
    ax.axis('off')
    
    # Add title
    ax.set_title(f'Mean Slopes Heatmap - {layer_name}', fontsize=16, pad=20)
    
    # Add colorbar
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=vmin, vmax=vmax))
    sm.set_array([])
    cbar = plt.colorbar(sm, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label('Mean Slope', rotation=270, labelpad=20)
    
    # Add labels for directions
    label_radius = 1.15
    for sector in angular_sectors:
        angle = np.radians((sector["start"] + sector["end"]) / 2)
        if sector["end"] < sector["start"]:  # Handle wraparound
            angle = np.radians((sector["start"] + sector["end"] + 360) / 2)
        x = label_radius * np.cos(angle)
        y = label_radius * np.sin(angle)
        ax.text(x, y, sector["name"], ha='center', va='center', 
                fontsize=12, fontweight='bold')
    
    # Add labels for radial regions
    angle_for_labels = np.radians(0)  # Put labels on the right side
    for region_name, (r_inner, r_outer) in norm_bounds.items():
        r_mid = (r_inner + r_outer) / 2
        x = r_mid * np.cos(angle_for_labels)
        y = r_mid * np.sin(angle_for_labels)

    
    # Add significance legend
    legend_text = "* p<0.05, ** p<0.01, *** p<0.001, ns = not significant"
    ax.text(0, -1.25, legend_text, ha='center', va='center', 
            fontsize=9, style='italic', color='gray')
    
    plt.tight_layout()
    return fig


def create_comparison_grid(results_dict, layers_to_plot, 
                          directions=None, angular_sectors=None):
    """
    Create a grid of circular heatmaps for multiple layers.
    
    Parameters:
    -----------
    results_dict : dict
        Dictionary with direction names as keys and layer_results as values
    layers_to_plot : list
        List of layer names to plot
    """
    n_layers = len(layers_to_plot)
    n_cols = min(3, n_layers)  # Maximum 3 columns
    n_rows = (n_layers + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(6*n_cols, 6*n_rows))
    if n_layers == 1:
        axes = [axes]
    else:
        axes = axes.flatten() if n_rows > 1 else axes
    
    for idx, layer_name in enumerate(layers_to_plot):
        ax = axes[idx]
        plt.sca(ax)
        
        # Create individual heatmap
        create_circular_heatmap_single_axis(results_dict, layer_name, ax,
                                          directions, angular_sectors)
    
    # Hide unused subplots
    for idx in range(n_layers, len(axes)):
        axes[idx].axis('off')
    
    plt.tight_layout()
    return fig


def create_circular_heatmap_single_axis(results_dict, layer_name, ax,
                                       directions=None, angular_sectors=None):
    """
    Helper function to create a circular heatmap on a specific axis.
    """
    # Define default angular sectors if not provided
    if angular_sectors is None:
        angular_sectors = [
            {"name": "Superior", "start": 45, "end": 135},
            {"name": "Nasal", "start": 135, "end": 225},
            {"name": "Inferior", "start": 225, "end": 315},
            {"name": "Temporal", "start": 315, "end": 45}
        ]
    
    # Define the radial boundaries
    norm_bounds = {
        "central": (0, 0.5/3.0),
        "inner": (0.5/3.0, 1.5/3.0),
        "outer": (1.5/3.0, 3.0/3.0)
    }
    
    # Collect all slope values for normalization
    all_slopes = []
    for direction, layer_results in results_dict.items():
        if layer_name in layer_results:
            for slope_dict in layer_results[layer_name]:
                if slope_dict["mean"] != -1:
                    all_slopes.append(slope_dict["mean"])
    
    # Set color scale bounds
    if all_slopes:
        vmin = -0.3 if min(all_slopes) > -0.5 else min(all_slopes) 
        vmax = 0.3 if max(all_slopes) < 0.5 else max(all_slopes)
    else:
        vmin, vmax = -1, 1
    
    # Create colormap
    cmap = plt.cm.RdBu_r
    
    # Draw the heatmap
    for sector in angular_sectors:
        direction_name = sector["name"]
        if direction_name not in results_dict:
            continue
        
        layer_results = results_dict[direction_name]
        if layer_name not in layer_results:
            continue
        
        slopes_data = layer_results[layer_name]
        
        # Draw wedges for each radial region
        for region_name, (r_inner, r_outer) in norm_bounds.items():
            slope_value = -1
            for slope_dict in slopes_data:
                if slope_dict["name"] == region_name:
                    slope_value = slope_dict["mean"]
                    break
            
            theta1 = sector["start"]
            theta2 = sector["end"]
            if theta2 < theta1:
                theta2 += 360
            
            if slope_value != -1:
                color = cmap((slope_value - vmin) / (vmax - vmin))
            else:
                color = 'lightgray'
            
            wedge = patches.Wedge(
                center=(0, 0),
                r=r_outer,
                theta1=theta1,
                theta2=theta2,
                width=r_outer - r_inner,
                facecolor=color,
                edgecolor='white',
                linewidth=1.5
            )
            ax.add_patch(wedge)
    
    # Add circle boundaries
    for region_name, (r_inner, r_outer) in norm_bounds.items():
        circle = plt.Circle((0, 0), r_outer, fill=False, color='white', linewidth=1.5)
        ax.add_patch(circle)
    
    # Set axis properties
    ax.set_xlim(-1.2, 1.2)
    ax.set_ylim(-1.2, 1.2)
    ax.set_aspect('equal')
    ax.axis('off')
    ax.set_title(layer_name, fontsize=12, pad=10)

### Analysis

In [None]:
# Axial Length covariate
axial_lengths = np.zeros(len(subjects_data))
for i, s in enumerate(subjects_data):
    axial_lengths[i] = s.axial_length

covariates = np.column_stack((axial_lengths))

resultsX = compute_correlations_eccwise(Direction.X, eccs, covariates=axial_lengths, perform_regression=True, verbose=False)
resultsY = compute_correlations_eccwise(Direction.Y, eccs, covariates=axial_lengths, perform_regression=True, verbose=False)

In [None]:
plot_correlations_eccwise(resultsY, eccs, Direction.Y, abs_=False, pv_threshold = 0.05)


In [None]:
plot_correlations_eccwise(resultsY, eccs, Direction.Y, abs_=False, pv_threshold =1)


In [None]:
plot_correlations_eccwise(resultsX, eccs, Direction.X, abs_=False, pv_threshold=0.05)


In [None]:
plot_correlations_eccwise(resultsX, eccs, Direction.X, abs_=False, pv_threshold=1)

### Circular Heatmap

In [None]:

from scipy.stats import linregress
from scipy.stats import zscore

results = resultsY

threshold = 1
layers_to_plot = results.keys()



#Range of eccentricity to plot
range_start = -10.0
range_end = 0




layer_results_superior = {}


# correlations_to_plot = "negative"
with(plt.style.context(matplotx.styles.pacoty)):

    for layer_name, (correlations, p_values) in results.items():
        
        circles = [

        {"name": "central", "start": -0.5, "end": 0.0},
        {"name": "inner", "start": -1.5, "end": -0.5},
        {"name": "outer", "start": -3.0, "end": -1.5},
        ]

        slopes = [
        {"name": "central", "slopes" : [], "mean": -1},

        {"name": "inner", "slopes" : [],  "mean": -1},

        {"name": "outer", "slopes" : [],  "mean": -1}
            



        ]


        if layer_name not in layers_to_plot:
            continue

        for ecc, corr, p_val in zip(eccs, correlations, p_values):
            if ecc < range_start or ecc > range_end:
                continue

            if p_val < threshold:
                # Retrieve the actual data for this layer and eccentricity
                x_data, y_data = get_data_from_range(layer_name, (ecc - 0.05), (ecc + 0.05))
                
                # Z-score the data for regression
                x_data_z = zscore(x_data)
                y_data_z = zscore(y_data)
                
                # Fit a line via linregress on the z-scored data
                slopeprint, intercept, r_value, p_value_fit, std_err = linregress(x_data_z, y_data_z)

                for circle in circles:
                    if ecc >= circle["start"] and ecc < circle["end"]:
                        for slope_dict in slopes:
                            if slope_dict["name"] == circle["name"]:
                                slope_dict["slopes"].append(slopeprint)
                                break
                        # print(f"DEBUG: {layer_name} at ecc {ecc} in {circle['name']} circle with slope {slopeprint:.3f}")
                        break

                slope, intercept, r_value, p_value_fit, std_err = linregress(x_data, y_data)

                # r2scores[layer_name].append(r_value**2)
                
                # Compute points on the line for plotting using the z-scored data
                x_line = np.linspace(x_data.min(), x_data.max(), 100)
                y_line = slope * x_line + intercept

                # print(f"processing {layer_name} at ecc {ecc} with corr {corr:.3f}, p-value {p_val:.3f}, slope {slopeprint:.3f}")

            

        for slope in slopes:
            if len(slope["slopes"]) > 0:
                # print(f"DEBUG: {layer_name} mean slope for {slope['name']} circle:", np.mean(slope["slopes"]))
                slope["mean"] = np.mean(slope["slopes"])
            else:
                # print(f"DEBUG: {layer_name} no slopes for {slope['name']} circle, setting mean to -1")
                slope["mean"] = -1

        # print(f"DEBUG: {layer_name} slopes:", [s["mean"] for s in slopes])

        layer_results_superior[layer_name] = slopes
        # print(f"DEBUG: {layer_name} layer results:", layer_results_superior)




In [None]:
from scipy.stats import linregress
from scipy.stats import zscore

results = resultsY

threshold = 1
layers_to_plot = results.keys()



#Range of eccentricity to plot
range_start = 0.0
range_end = 10.0




layer_results_inferior  = {}


# correlations_to_plot = "negative"
with(plt.style.context(matplotx.styles.pacoty)):

    for layer_name, (correlations, p_values) in results.items():
        
        circles = [

        {"name": "central", "start": 0.0, "end": 0.5},
        {"name": "inner", "start": 0.5, "end": 1.5},
        {"name": "outer", "start": 1.5, "end": 3.0},
        ]

        slopes = [
        {"name": "central", "slopes" : [], "mean": -1},

        {"name": "inner", "slopes" : [],  "mean": -1},

        {"name": "outer", "slopes" : [],  "mean": -1}
            



        ]


        if layer_name not in layers_to_plot:
            continue

        for ecc, corr, p_val in zip(eccs, correlations, p_values):
            if ecc < range_start or ecc > range_end:
                continue

            if p_val < threshold:
                # Retrieve the actual data for this layer and eccentricity
                x_data, y_data = get_data_from_range(layer_name, (ecc - 0.05), (ecc + 0.05))
                
                # Z-score the data for regression
                x_data_z = zscore(x_data)
                y_data_z = zscore(y_data)
                
                # Fit a line via linregress on the z-scored data
                slopeprint, intercept, r_value, p_value_fit, std_err = linregress(x_data_z, y_data_z)

                for circle in circles:
                    if ecc*MM_PER_DEGREE>= circle["start"] and ecc*MM_PER_DEGREE < circle["end"]:
                        for slope_dict in slopes:
                            if slope_dict["name"] == circle["name"]:
                                slope_dict["slopes"].append(slopeprint)
                                break
                        # print(f"DEBUG: {layer_name} at ecc {ecc} in {circle['name']} circle with slope {slopeprint:.3f}")
                        break

                slope, intercept, r_value, p_value_fit, std_err = linregress(x_data, y_data)

                # r2scores[layer_name].append(r_value**2)
                
                # Compute points on the line for plotting using the z-scored data
                x_line = np.linspace(x_data.min(), x_data.max(), 100)
                y_line = slope * x_line + intercept

                # print(f"processing {layer_name} at ecc {ecc} with corr {corr:.3f}, p-value {p_val:.3f}, slope {slopeprint:.3f}")

            


        for slope in slopes:
            if len(slope["slopes"]) > 0:
                # print(f"DEBUG: {layer_name} mean slope for {slope['name']} circle:", np.mean(slope["slopes"]))
                slope["mean"] = np.mean(slope["slopes"])
            else:
                # print(f"DEBUG: {layer_name} no slopes for {slope['name']} circle, setting mean to -1")
                slope["mean"] = -1

        # print(f"DEBUG: {layer_name} slopes:", [s["mean"] for s in slopes])

        layer_results_inferior[layer_name] = slopes
        # print(f"DEBUG: {layer_name} layer results:", layer_results_inferior)




In [None]:
from scipy.stats import linregress
from scipy.stats import zscore

results = resultsX

threshold = 1
layers_to_plot = results.keys()



#Range of eccentricity to plot
range_start = -10.0
range_end = 0




layer_results_temporal = {}


# correlations_to_plot = "negative"
with(plt.style.context(matplotx.styles.pacoty)):

    for layer_name, (correlations, p_values) in results.items():
        
        circles = [

        {"name": "central", "start": -0.5, "end": 0.0},
        {"name": "inner", "start": -1.5, "end": -0.5},
        {"name": "outer", "start": -3.0, "end": -1.5},
        ]

        slopes = [
        {"name": "central", "slopes" : [], "mean": -1},

        {"name": "inner", "slopes" : [],  "mean": -1},

        {"name": "outer", "slopes" : [],  "mean": -1}
            



        ]


        if layer_name not in layers_to_plot:
            continue

        for ecc, corr, p_val in zip(eccs, correlations, p_values):
            if ecc < range_start or ecc > range_end:
                continue

            if p_val < threshold:
                # Retrieve the actual data for this layer and eccentricity
                x_data, y_data = get_data_from_range(layer_name, (ecc - 0.05), (ecc + 0.05))
                
                # Z-score the data for regression
                x_data_z = zscore(x_data)
                y_data_z = zscore(y_data)
                
                # Fit a line via linregress on the z-scored data
                slopeprint, intercept, r_value, p_value_fit, std_err = linregress(x_data_z, y_data_z)

                for circle in circles:
                    if ecc*MM_PER_DEGREE >= circle["start"] and ecc*MM_PER_DEGREE < circle["end"]:
                        for slope_dict in slopes:
                            if slope_dict["name"] == circle["name"]:
                                slope_dict["slopes"].append(slopeprint)
                                break
                        # print(f"DEBUG: {layer_name} at ecc {ecc} in {circle['name']} circle with slope {slopeprint:.3f}")
                        break

                slope, intercept, r_value, p_value_fit, std_err = linregress(x_data, y_data)

                
                # Compute points on the line for plotting using the z-scored data
                x_line = np.linspace(x_data.min(), x_data.max(), 100)
                y_line = slope * x_line + intercept



        for slope in slopes:
            if len(slope["slopes"]) > 0:
                # print(f"DEBUG: {layer_name} mean slope for {slope['name']} circle:", np.mean(slope["slopes"]))
                slope["mean"] = np.mean(slope["slopes"])
            else:
                # print(f"DEBUG: {layer_name} no slopes for {slope['name']} circle, setting mean to -1")
                slope["mean"] = -1

        # print(f"DEBUG: {layer_name} slopes:", [s["mean"] for s in slopes])

        layer_results_temporal[layer_name] = slopes
        # print(f"DEBUG: {layer_name} layer results:", layer_results_temporal)




In [None]:
from scipy.stats import linregress
from scipy.stats import zscore

results = resultsY

threshold = 1
layers_to_plot = results.keys()

range_start = 0
range_end = 10




layer_results_nasal = {}



with(plt.style.context(matplotx.styles.pacoty)):

    for layer_name, (correlations, p_values) in results.items():
        
        circles = [

        {"name": "central", "start": 0.0, "end": 0.5},
        {"name": "inner", "start": 0.5, "end": 1.5},
        {"name": "outer", "start": 1.5, "end": 3.0},
        ]

        slopes = [
        {"name": "central", "slopes" : [], "mean": -1},

        {"name": "inner", "slopes" : [],  "mean": -1},

        {"name": "outer", "slopes" : [],  "mean": -1}
            



        ]


        if layer_name not in layers_to_plot:
            continue

        for ecc, corr, p_val in zip(eccs, correlations, p_values):
            if ecc < range_start or ecc > range_end:
                continue

            if p_val < threshold:
                # Retrieve the actual data for this layer and eccentricity
                x_data, y_data = get_data_from_range(layer_name, (ecc - 0.05), (ecc + 0.05))
                
                # Z-score the data for regression
                x_data_z = zscore(x_data)
                y_data_z = zscore(y_data)
                
                # Fit a line via linregress on the z-scored data
                slopeprint, intercept, r_value, p_value_fit, std_err = linregress(x_data_z, y_data_z)

                for circle in circles:
                    if ecc*MM_PER_DEGREE >= circle["start"] and ecc*MM_PER_DEGREE < circle["end"]:
                        for slope_dict in slopes:
                            if slope_dict["name"] == circle["name"]:
                                slope_dict["slopes"].append(slopeprint)
                                break
                        # print(f"DEBUG: {layer_name} at ecc {ecc} in {circle['name']} circle with slope {slopeprint:.3f}")
                        break

                slope, intercept, r_value, p_value_fit, std_err = linregress(x_data, y_data)

                # r2scores[layer_name].append(r_value**2)
                
                # Compute points on the line for plotting using the z-scored data
                x_line = np.linspace(x_data.min(), x_data.max(), 100)
                y_line = slope * x_line + intercept

                # print(f"processing {layer_name} at ecc {ecc} with corr {corr:.3f}, p-value {p_val:.3f}, slope {slopeprint:.3f}")

                

        for slope in slopes:
            if len(slope["slopes"]) > 0:
                # print(f"DEBUG: {layer_name} mean slope for {slope['name']} circle:", np.mean(slope["slopes"]))
                slope["mean"] = np.mean(slope["slopes"])
            else:
                # print(f"DEBUG: {layer_name} no slopes for {slope['name']} circle, setting mean to -1")
                slope["mean"] = -1

        # print(f"DEBUG: {layer_name} slopes:", [s["mean"] for s in slopes])

        layer_results_nasal[layer_name] = slopes
        # print(f"DEBUG: {layer_name} layer results:", layer_results_nasal)




### Extra: Heatmap Visualization

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.patches import Rectangle, Polygon
from matplotlib.collections import PolyCollection
from typing import Dict, Tuple, List, Optional, Union
import matplotx
from scipy.ndimage import gaussian_filter1d
from scipy.signal import savgol_filter

def plot_cross_heatmap(results_x: Dict[str, Tuple[np.ndarray]], 
                       results_y: Dict[str, Tuple[np.ndarray]], 
                       eccs: np.ndarray,
                       layers_to_plot: List[str] | None = None,
                       corr_name: str = 'Spearman',
                       pv_threshold: float = 0.05,
                       cmap: str = 'RdBu_r',
                       vmin: float = -1.0,
                       vmax: float = 1.0,
                       figsize: Tuple[float, float] = (10, 10),
                       bar_width: float = 1.5,
                       save_path: str | None = None,
                       dpi: int = 300,
                       show_regions: bool = True,
                       apply_style: bool = True,
                       region_cmap: str = 'gray',
                       region_alpha: float = 0.3,
                       taper_style: str = 'linear',
                       smoothing: Union[bool, str] = False,
                       smoothing_sigma: float = 1.0,
                       smoothing_window: int = 5):
    """
    Plot correlation results as cross-shaped heatmaps with tapered center, one image for each layer.
    
    Parameters:
    -----------
    results_x : Dict with x-axis correlation results
    results_y : Dict with y-axis correlation results
    eccs : Array of eccentricities
    layers_to_plot : List of layers to plot (if None, plot all)
    corr_name : Name of correlation method
    pv_threshold : P-value threshold for significance
    cmap : Colormap name for correlation data (e.g., 'RdBu_r', 'coolwarm', 'seismic')
    vmin, vmax : Color scale limits (defaults to -1.0, 1.0)
    figsize : Figure size for each plot
    bar_width : Width of the cross bars in eccentricity units
    save_path : Base path to save figures (optional)
    dpi : Resolution for saved figures
    show_regions : Whether to show retinal region boundaries
    apply_style : Whether to apply matplotx style
    region_cmap : Colormap for background regions
    region_alpha : Alpha transparency for region backgrounds
    taper_style : Style of tapering ('linear', 'quadratic', 'exponential')
    smoothing : False, True, 'gaussian', 'savgol', or 'moving_average'
    smoothing_sigma : Standard deviation for Gaussian smoothing
    smoothing_window : Window size for Savitzky-Golay or moving average smoothing
    """
    
    # Determine which layers to plot
    layers = layers_to_plot or list(results_x.keys())
    
    # Define retinal regions in degrees (assuming MM_PER_DEGREE = 0.3)
    MM_PER_DEGREE = 0.3
    
    # Define regions with their order (for gradient coloring)
    retinal_regions = [
        {'name': 'Foveola', 'start': -0.175/MM_PER_DEGREE, 'end': 0.175/MM_PER_DEGREE, 'order': 0},
        {'name': 'FAZ', 'start': -0.25/MM_PER_DEGREE, 'end': -0.175/MM_PER_DEGREE, 'order': 1},
        {'name': 'FAZ', 'start': 0.175/MM_PER_DEGREE, 'end': 0.25/MM_PER_DEGREE, 'order': 1},
        {'name': 'Fovea', 'start': -0.75/MM_PER_DEGREE, 'end': -0.25/MM_PER_DEGREE, 'order': 2},
        {'name': 'Fovea', 'start': 0.25/MM_PER_DEGREE, 'end': 0.75/MM_PER_DEGREE, 'order': 2},
        {'name': 'Parafovea', 'start': -1.25/MM_PER_DEGREE, 'end': -0.75/MM_PER_DEGREE, 'order': 3},
        {'name': 'Parafovea', 'start': 0.75/MM_PER_DEGREE, 'end': 1.25/MM_PER_DEGREE, 'order': 3},
        {'name': 'Perifovea', 'start': -10, 'end': -1.25/MM_PER_DEGREE, 'order': 4},
        {'name': 'Perifovea', 'start': 1.25/MM_PER_DEGREE, 'end': 10, 'order': 4},
    ]
    
    # Create a figure for each layer
    for layer in layers:
        if layer not in results_x or layer not in results_y:
            print(f"Skipping layer {layer} - not found in results")
            continue
        
        # Apply style context if requested
        if apply_style:
            with plt.style.context(matplotx.styles.pacoty):
                fig, ax = plt.subplots(figsize=figsize, dpi=100)
        else:
            fig, ax = plt.subplots(figsize=figsize, dpi=100)
        
        # Extract correlation values and p-values
        corr_x, pval_x = results_x[layer]
        corr_y, pval_y = results_y[layer]
        
        # Apply smoothing if requested
        if smoothing:
            # Determine smoothing method
            if smoothing is True or smoothing == 'gaussian':
                # Gaussian smoothing
                corr_x = gaussian_filter1d(corr_x, sigma=smoothing_sigma)
                corr_y = gaussian_filter1d(corr_y, sigma=smoothing_sigma)
                
            elif smoothing == 'savgol':
                # Savitzky-Golay filter (preserves features better)
                # Ensure window is odd
                window = smoothing_window if smoothing_window % 2 == 1 else smoothing_window + 1
                # Ensure window is not larger than data
                window = min(window, len(eccs) - 1)
                if window >= 3:  # Minimum window size for savgol
                    polyorder = min(3, window - 1)  # Polynomial order
                    corr_x = savgol_filter(corr_x, window, polyorder)
                    corr_y = savgol_filter(corr_y, window, polyorder)
                
            elif smoothing == 'moving_average':
                # Simple moving average
                kernel = np.ones(smoothing_window) / smoothing_window
                # Pad the data to handle edges
                pad_width = smoothing_window // 2
                corr_x_padded = np.pad(corr_x, pad_width, mode='edge')
                corr_y_padded = np.pad(corr_y, pad_width, mode='edge')
                # Convolve and trim
                corr_x = np.convolve(corr_x_padded, kernel, mode='valid')
                corr_y = np.convolve(corr_y_padded, kernel, mode='valid')
            
            # Ensure correlations stay within [-1, 1] after smoothing
            corr_x = np.clip(corr_x, -1, 1)
            corr_y = np.clip(corr_y, -1, 1)
        
        # Create masked arrays based on p-value threshold
        mask_x = pval_x > pv_threshold if pv_threshold is not None else np.zeros_like(pval_x, dtype=bool)
        mask_y = pval_y > pv_threshold if pv_threshold is not None else np.zeros_like(pval_y, dtype=bool)
        
        # Set black background
        ax.set_facecolor('black')
        fig.patch.set_facecolor('black')
        
        # Get colormap and normalization
        norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
        cmap_obj = plt.cm.get_cmap(cmap)
        
        # Create collections for horizontal and vertical bars
        horizontal_patches = []
        horizontal_colors = []
        vertical_patches = []
        vertical_colors = []
        
        # Define taper function
        def get_taper_width(distance_from_center, max_width, taper_style='linear'):
            """Calculate bar width based on distance from center"""
            if distance_from_center >= max_width/2:
                return max_width
            
            # Normalize distance to 0-1 range
            t = distance_from_center / (max_width/2)
            
            if taper_style == 'linear':
                return max_width * t
            elif taper_style == 'quadratic':
                return max_width * (t ** 2)
            elif taper_style == 'exponential':
                return max_width * (1 - np.exp(-3 * t))
            else:
                return max_width * t
        
        # Create patches for each eccentricity position
        for i, ecc in enumerate(eccs):
            # Calculate the width at this eccentricity
            distance_from_center = abs(ecc)
            current_width = get_taper_width(distance_from_center, bar_width, taper_style)
            
            # Calculate exact boundaries to ensure no gaps
            if i == 0:
                left_bound = eccs[0] - (eccs[1] - eccs[0])/2
            else:
                left_bound = (eccs[i-1] + eccs[i])/2
                
            if i == len(eccs) - 1:
                right_bound = eccs[-1] + (eccs[-1] - eccs[-2])/2
            else:
                right_bound = (eccs[i] + eccs[i+1])/2
            
            ecc_width = right_bound - left_bound
            
            # Create horizontal bar segment
            if not mask_x[i]:
                if distance_from_center >= bar_width/2:
                    # Full width horizontal bar outside the taper region
                    h_patch = Rectangle((left_bound, -bar_width/2), 
                                       ecc_width, bar_width)
                else:
                    # Tapered horizontal bar
                    h_patch = Rectangle((left_bound, -current_width/2), 
                                       ecc_width, current_width)
                
                horizontal_patches.append(h_patch)
                horizontal_colors.append(cmap_obj(norm(corr_x[i])))
            
            # Create vertical bar segment
            if not mask_y[i]:
                if distance_from_center >= bar_width/2:
                    # Full width vertical bar outside the taper region
                    v_patch = Rectangle((-bar_width/2, left_bound), 
                                       bar_width, ecc_width)
                else:
                    # Tapered vertical bar
                    v_patch = Rectangle((-current_width/2, left_bound), 
                                       current_width, ecc_width)
                
                vertical_patches.append(v_patch)
                vertical_colors.append(cmap_obj(norm(corr_y[i])))
        
        # Create patch collections
        h_collection = plt.matplotlib.collections.PatchCollection(
            horizontal_patches, facecolors=horizontal_colors, 
            edgecolors='none', zorder=2)
        v_collection = plt.matplotlib.collections.PatchCollection(
            vertical_patches, facecolors=vertical_colors, 
            edgecolors='none', zorder=2)
        
        # Add collections to axes
        ax.add_collection(h_collection)
        ax.add_collection(v_collection)
        
        # Add significance dimming overlay
        significance_threshold = 0.05
        sig_mask_x = pval_x <= significance_threshold
        sig_mask_y = pval_y <= significance_threshold
        
        # Create dimming overlays for non-significant regions
        for i, ecc in enumerate(eccs):
            distance_from_center = abs(ecc)
            current_width = get_taper_width(distance_from_center, bar_width, taper_style)
            
            # Calculate exact boundaries to match data patches
            if i == 0:
                left_bound = eccs[0] - (eccs[1] - eccs[0])/2
            else:
                left_bound = (eccs[i-1] + eccs[i])/2
                
            if i == len(eccs) - 1:
                right_bound = eccs[-1] + (eccs[-1] - eccs[-2])/2
            else:
                right_bound = (eccs[i] + eccs[i+1])/2
            
            ecc_width = right_bound - left_bound
            
            # Horizontal dimming
            if not sig_mask_x[i]:
                if distance_from_center >= bar_width/2:
                    dim_h = Rectangle((left_bound, -bar_width/2), 
                                     ecc_width, bar_width, 
                                     facecolor='black', alpha=0.4, zorder=3)
                else:
                    dim_h = Rectangle((left_bound, -current_width/2), 
                                     ecc_width, current_width,
                                     facecolor='black', alpha=0.4, zorder=3)
                ax.add_patch(dim_h)
            
            # Vertical dimming
            if not sig_mask_y[i]:
                if distance_from_center >= bar_width/2:
                    dim_v = Rectangle((-bar_width/2, left_bound), 
                                     bar_width, ecc_width,
                                     facecolor='black', alpha=0.4, zorder=3)
                else:
                    dim_v = Rectangle((-current_width/2, left_bound), 
                                     current_width, ecc_width,
                                     facecolor='black', alpha=0.4, zorder=3)
                ax.add_patch(dim_v)
        
        # Add circular region boundaries if requested
        if show_regions:
            # Define unique radii for region boundaries
            region_radii = [
                0.175/MM_PER_DEGREE,   # Foveola boundary
                0.25/MM_PER_DEGREE,    # FAZ boundary
                0.75/MM_PER_DEGREE,    # Fovea boundary
                1.25/MM_PER_DEGREE,    # Parafovea boundary
            ]
            
            # Draw white circles at each radius
            for radius in region_radii:
                circle = plt.Circle((0, 0), radius, 
                                   fill=False, 
                                   edgecolor='white', 
                                   linewidth=1.0, 
                                   alpha=0.7,
                                   linestyle='--',
                                   zorder=1)
                ax.add_patch(circle)
            
            # Add region labels
            label_positions = [
                (0, 0.1/MM_PER_DEGREE, 'Foveola', 0.7),
                (0, 0.21/MM_PER_DEGREE, 'FAZ', 0.7),
                (0, 0.5/MM_PER_DEGREE, 'Fovea', 0.7),
                (0, 1.0/MM_PER_DEGREE, 'Parafovea', 0.7),
                (0, 3.0, 'Perifovea', 0.7),
            ]
            
            for x, y, label, alpha in label_positions:
                ax.text(x, y, label, 
                       color='white', 
                       fontsize=8, 
                       ha='center', 
                       va='bottom',
                       alpha=alpha,
                       zorder=1)
        
        # Set axis limits
        max_ecc = eccs.max()
        ax.set_xlim(-max_ecc, max_ecc)
        ax.set_ylim(-max_ecc, max_ecc)
        
        # Add labels
        ax.set_xlabel('Horizontal Meridian (X-axis) Eccentricity [°]', fontsize=12, color='white')
        ax.set_ylabel('Vertical Meridian (Y-axis) Eccentricity [°]', fontsize=12, color='white')
        
        # Format layer name for title
        layer_name = layer if layer != 'nonfit' else 'Non-fit'
        
        # Add smoothing info to title if applicable
        title_parts = [f'{layer_name} - {corr_name} Correlation with Cone Density']
        if smoothing:
            if smoothing is True or smoothing == 'gaussian':
                title_parts.append(f'(Gaussian smoothing, σ={smoothing_sigma})')
            elif smoothing == 'savgol':
                title_parts.append(f'(Savitzky-Golay smoothing, w={smoothing_window})')
            elif smoothing == 'moving_average':
                title_parts.append(f'(Moving average, w={smoothing_window})')
        
        ax.set_title(' '.join(title_parts), fontsize=14, pad=20, color='white')
        
        # Style axes
        ax.tick_params(colors='white', which='both')
        for spine in ax.spines.values():
            spine.set_color('white')
        
        # Add grid
        ax.grid(True, alpha=0.2, linestyle=':', linewidth=0.5)
        
        # Add colorbar
        # Create a mappable for the colorbar
        sm = plt.cm.ScalarMappable(cmap=cmap_obj, norm=norm)
        sm.set_array([])
        cbar = plt.colorbar(sm, ax=ax, orientation='vertical', pad=0.02, fraction=0.046)
        cbar.set_label(f'{corr_name} Correlation Coefficient', fontsize=11, color='white')
        cbar.ax.tick_params(labelsize=10, colors='white')
        
        # Make the plot square
        ax.set_aspect('equal')
        
        # Adjust layout
        plt.tight_layout()
        
        # Save if path provided
        if save_path:
            smoothing_suffix = ""
            if smoothing:
                if smoothing is True or smoothing == 'gaussian':
                    smoothing_suffix = f"_gaussian_s{smoothing_sigma}"
                elif smoothing == 'savgol':
                    smoothing_suffix = f"_savgol_w{smoothing_window}"
                elif smoothing == 'moving_average':
                    smoothing_suffix = f"_movavg_w{smoothing_window}"
            
            filename = f"{save_path}_{layer}_cross_heatmap_tapered{smoothing_suffix}.png"
            fig.savefig(filename, dpi=dpi, bbox_inches='tight', facecolor='black')
            print(f"Saved: {filename}")
        
        plt.show()
        plt.close()


# Utility function to plot a single layer
def plot_single_layer_cross(results_x: Dict[str, Tuple[np.ndarray]], 
                           results_y: Dict[str, Tuple[np.ndarray]], 
                           eccs: np.ndarray,
                           layer: str,
                           **kwargs):
    """
    Convenience function to plot a single layer's cross heatmap.
    """
    plot_cross_heatmap(results_x, results_y, eccs, layers_to_plot=[layer], **kwargs)

In [None]:
plot_cross_heatmap(resultsX, resultsY, eccs, pv_threshold=1, smoothing='gaussian', smoothing_sigma=2.0)