In [None]:
# Stop warnings
import warnings
warnings.filterwarnings("ignore")

# Import 
import os 
import time
import numpy as np
import pandas as pd
from scipy import stats
from scipy.spatial import ConvexHull
import neuropythy as ny
import matplotlib.pyplot as plt
import scipy.sparse.csgraph as cs
from scipy.spatial import cKDTree
from IPython.display import clear_output

In [None]:
def apply_wang(sub, h):
    """Returns the Wang atlas as a property for a subjects hemisphere.

    `apply_wang(sub, 'lh')` applies the atlas to the subject's left hemisphere.
    `apply_wang(sub, 'rh')` applies the atlas to the subject's right hemisphere.
    """
    atl = apply_wang.atlas[h]
    hem = sub.hemis[h]
    fsa = ny.freesurfer_subject(apply_wang.fsaverage_path)
    fsa = fsa.hemis[h]
    return fsa.interpolate(hem, atl)
apply_wang.atlas = {
    h: ny.load(os.path.join(ny.library_path(), 'data', 'fsaverage', 'surf', f'{h}.wang15_mplbl.v1_0.mgz'))
    for h in ('lh', 'rh')}
apply_wang.fsaverage_path = '/home/stone-ext1/freesurfer/subjects/fsaverage'

In [None]:
def compute_geodesic_area(vert_of_interest_idx, adjacency_matrix, max_distance, surfarea):
    
    distances = cs.dijkstra(adjacency_matrix, indices=[vert_of_interest_idx], directed=False)[0]
    distance_mask = distances < max_distance
    vert_dist_idx = np.where(distance_mask)[0]
    surf_area_sum = np.sum(surfarea[vert_dist_idx])
    
    return surf_area_sum, vert_dist_idx

In [None]:
def compute_pRF_area(vert_of_interest_idx, vert_dist_idx, prf_x, prf_y):
    # vert_dist_idx from compute_geodesic_area
    neighbor_x = prf_x[vert_dist_idx]
    neighbor_y = prf_y[vert_dist_idx]
    points_array = np.column_stack((neighbor_x, neighbor_y))
    points = points_array[~np.isnan(points_array).any(axis=1)] # remove NaNs
    points = np.unique(points, axis=0) # remove duplicate points
    
    if not isinstance(points, np.ndarray) or points.ndim != 2 or points.shape[1] != 2:
        raise ValueError("Input 'points' must be a NumPy array with shape (N, 2).")

    if points.shape[0] < 3:
        return None
    try:
        hull = ConvexHull(points, qhull_options="QJ")
        area = hull.volume
        return area
    except Exception as e:
        print(f"ConvexHull failed: {e}")
        return None

In [None]:
# Settings 
max_distance = 5
r2_th = 0.1
hemis = ['rh', 'lh']
rois = ["V1v", "V1d", "V2v", "V2d", "V3v", "V3d", "hV4", "VO1", "VO2",
          "PHC1", "PHC2", "TO2", "TO1", "LO2", "LO1", "V3B", "V3A",
          "IPS0", "IPS1", "IPS2", "IPS3", "IPS4", "IPS5", "SPL1", "FEF"]
roi_code = {
    "V1v": 1,
    "V1d": 2,
    "V2v": 3,
    "V2d": 4,
    "V3v": 5,
    "V3d": 6,
    "hV4": 7,
    "VO1": 8,
    "VO2": 9,
    "PHC1": 10,
    "PHC2": 11,
    "TO2": 12,
    "TO1": 13,
    "LO2": 14,
    "LO1": 15,
    "V3B": 16,
    "V3A": 17,
    "IPS0": 18,
    "IPS1": 19,
    "IPS2": 20,
    "IPS3": 21,
    "IPS4": 22,
    "IPS5": 23,
    "SPL1": 24,
    "FEF": 25
}

In [None]:
# Get subject numbers
allsub = ny.data['hcp_lines'].subject_list

In [None]:
label_names = (None, "V1v", "V1d", "V2v", "V2d", "V3v", "V3d", "hV4", "VO1", "VO2",
          "PHC1", "PHC2", "TO2", "TO1", "LO2", "LO1", "V3B", "V3A",
          "IPS0", "IPS1", "IPS2", "IPS3", "IPS4", "IPS5", "SPL1", "FEF")

In [None]:
group_sub = allsub[0:]
for n_sub, subject_id in enumerate(group_sub):
    start_time = time.time()
    print(f"{n_sub} / {len(group_sub)}")
    # Load an HCP subject
    sub = ny.hcp_subject(subject_id)
    sub = ny.data['hcp_lines'].subjects[subject_id]
    wang_labels_lh = apply_wang(sub, 'lh')
    wang_labels_rh = apply_wang(sub, 'rh')
    prf_CM_df = pd.DataFrame()
    for n_hemi, hemi in enumerate(hemis): 
        for roi in rois: 
            try:
                # ROI mask et paramètres
                if hemi == 'lh':
                    #label_roi = sub.lh.prop('visual_area')
                    label_roi = wang_labels_lh
                    prf_x_hemi = sub.lh.prop('prf_x')
                    prf_y_hemi = sub.lh.prop('prf_y')
                    prf_sigma_hemi = sub.lh.prop('prf_radius')
                    prf_r2_hemi = sub.lh.prop('prf_variance_explained')
                    prf_ecc_hemi = sub.lh.prop('prf_eccentricity')
                    white = sub.lh.surface('white') 
                    surfarea = sub.lh.prop('white_surface_area')
                elif hemi == 'rh':
                    #label_roi = sub.rh.prop('visual_area')
                    label_roi = wang_labels_rh
                    prf_x_hemi = sub.rh.prop('prf_x')
                    prf_y_hemi = sub.rh.prop('prf_y')
                    prf_sigma_hemi = sub.rh.prop('prf_radius')
                    prf_r2_hemi = sub.rh.prop('prf_variance_explained')
                    prf_ecc_hemi = sub.rh.prop('prf_eccentricity')
                    white = sub.rh.surface('white')
                    surfarea = sub.rh.prop('white_surface_area')

                # Create the ROI mask
                roi_vertices_hemi_mask = np.where(label_roi == roi_code[roi])[0]
                if len(roi_vertices_hemi_mask) == 0:
                    raise ValueError(f"No vertices found for ROI {roi} in {hemi}")
                
                # Filter vertices by R² threshold
                prf_r2_hemi_selected = prf_r2_hemi[roi_vertices_hemi_mask]
                filtered_vertices_mask = roi_vertices_hemi_mask[prf_r2_hemi_selected >= r2_th]

                # Skip this ROI if no vertices pass threshold
                if len(filtered_vertices_mask) == 0:
                    print(f"No vertices above R² threshold for ROI {roi} in {hemi}")
                    continue
                
                # Access the adjacency_matrix of the ROI
                submesh = white.submesh(filtered_vertices_mask)
                adjacency_matrix = submesh.adjacency_matrix

            except Exception as e:
                print(f"Skipping ROI {roi} in {hemi}: {e}")
                continue  # passe directement à la prochaine ROI

            prf_area_list = [] 
            geo_area_list = []
            prf_ecc_list = []
            prf_r2_list = []

            for vert_num, vert_idx in enumerate(submesh.labels):   
                surf_area_sum, vert_dist_idx = compute_geodesic_area(vert_num, adjacency_matrix, max_distance, surfarea[filtered_vertices_mask])
                geo_area_list.append(surf_area_sum)
                prf_area_list.append(compute_pRF_area(vert_num, vert_dist_idx, prf_x_hemi[filtered_vertices_mask], prf_y_hemi[filtered_vertices_mask]))
                prf_ecc_list.append(prf_ecc_hemi[vert_idx]) 
                prf_r2_list.append(prf_r2_hemi[vert_idx])

            prf_CM_df_hemi = pd.DataFrame({
                'prf_area': prf_area_list,
                'geo_area': geo_area_list,
                'prf_ecc': prf_ecc_list,
                'prf_r2': prf_r2_list,
                'hemi': hemi,
                'roi': roi,
                'subject': f'{subject_id}'
            })

            prf_CM_df = pd.concat([prf_CM_df, prf_CM_df_hemi], ignore_index=True)
            
    prf_CM_df['pRF_CM'] = prf_CM_df['geo_area'] / prf_CM_df['prf_area']
    # Export DF
    tsv_dir = '/home/naxos2-raid27/wong0876/pRF_project/tsv'
    os.makedirs(tsv_dir, exist_ok=True)
    
    tsv_fn = '{}/{}_new_CM_area.tsv'.format(tsv_dir, subject_id)

    prf_CM_df.to_csv(tsv_fn, sep="\t", na_rep='NaN', index=False)
    
    del sub
    del submesh
    
    end_time = time.time()
    print("Execution time: {:.2f} seconds".format(end_time - start_time))
    ny = ny.reload_neuropythy()
    clear_output(wait=True)