In [1]:
import os
import nibabel as nib
import sys
import pandas as pd 
import numpy as np
import subprocess
from glob import glob
from skimage.morphology import dilation
from skimage.morphology import ball
from scipy.ndimage import gaussian_filter
import re
import csv
from skimage.measure import label
import mne

## Functions

In [2]:
def read_nifti(img_path):
    test_img = nib.load(img_path)
    test_data = np.asarray(test_img.dataobj)
    test_aff = test_img.affine
    test_shape = test_img.shape

    return test_img, test_data, test_aff, test_shape

In [3]:
def determineFCSVCoordSystem(input_fcsv,overwrite_fcsv=False):
    # need to determine if file is in RAS or LPS
    # loop through header to find coordinate system
    coordFlag = re.compile('# CoordinateSystem')
    verFlag = re.compile('# Markups fiducial file version')
    headFlag = re.compile('# columns')
    coord_sys=None
    headFin=None
    ver_fin=None

    with open(input_fcsv, 'r') as myfile:
        firstNlines=myfile.readlines()[0:3]
    for row in firstNlines:
        row=re.sub("[\s\,]+[\,]","",row).replace("\n","")
        cleaned_dict={row.split('=')[0].strip():row.split('=')[1].strip()}
        if None in list(cleaned_dict):
            cleaned_dict['# columns'] = cleaned_dict.pop(None)
        if any(coordFlag.match(x) for x in list(cleaned_dict)):
            coord_sys = list(cleaned_dict.values())[0]
        if any(verFlag.match(x) for x in list(cleaned_dict)):
            verString = list(filter(verFlag.match,  list(cleaned_dict)))
            assert len(verString)==1
            ver_fin = verString[0].split('=')[-1].strip()
        if any(headFlag.match(x) for x in list(cleaned_dict)):
            headFin=list(cleaned_dict.values())[0].split(',')
    return coord_sys
def extract_coords(file_path):
    coord_sys=determineFCSVCoordSystem(file_path)
    df = pd.read_csv(file_path, skiprows = 3, header = None)
    coord_arr = df[[1,2,3,11]].to_numpy()
    if any(x in coord_sys for x in {'LPS','1'}):
        coord_arr = coord_arr[0:3] * np.array([-1,-1,1])
    return coord_arr

In [4]:
def transform_points_ct_space(coord_path, ct_t1_path):
    t1_coords = extract_coords(coord_path)
    ct_coords = np.zeros(t1_coords.shape)
    transform = np.loadtxt(ct_t1_path)
    transform = np.linalg.inv(transform)

    M = transform[:3,:3]
    abc = transform[:3,3]


    for i in range(len(ct_coords)):
        vec = t1_coords[i,:]
        tvec = M.dot(vec) + abc
        ct_coords[i,:] = tvec[:3]

    # ct_coords = np.round(ct_coords).astype(int) #DELETE

    return ct_coords

In [5]:
def transform_coords_vox_space(planned_coords, img_aff):
    inv_affine = np.linalg.inv(img_aff)
    M = inv_affine[:3,:3]
    abc = inv_affine[:3,3]

    transform_coords = np.zeros(planned_coords.shape)

    for i in range(len(transform_coords)):
        vec = planned_coords[i,:]
        tvec = M.dot(vec) + abc
        transform_coords[i,:] = tvec[:3]
    
    transform_coords = np.round(transform_coords).astype(int)

    return transform_coords

def check_coord_dims(pointa, pointb, img_size):
    for i in range(3):
        if pointa[i] > img_size[i] or pointb[i] > img_size[i]:
            print('review planned_fcsv - dim mismatch')
            return False
    return True

In [6]:
def create_line_mask(point1, point2, shape):
    # Create an empty mask with the specified shape
    mask = np.zeros(shape, dtype=bool)
    x1, y1, z1 = point1
    x2, y2, z2 = point2
    
    mask[x1, y1, z1] = True
    # Get the directions of each axis
    dx = abs(x2 - x1)
    dy = abs(y2 - y1)
    dz = abs(z2 - z1)
    if x2 > x1:
        xs = 1
    else:
        xs = -1
    if y2 > y1:
        ys = 1
    else:
        ys = -1
    if z2 > z1:
        zs = 1
    else:
        zs = -1

    # Driving axis is X-axis
    if dx >= dy and dx >= dz:
        p1 = 2 * dy - dx
        p2 = 2 * dz - dx
        while x1 != x2:
            x1 += xs
            if p1 >= 0:
                y1 += ys
                p1 -= 2 * dx
            if p2 >= 0:
                z1 += zs
                p2 -= 2 * dx
            p1 += 2 * dy
            p2 += 2 * dz
            mask[x1, y1, z1] = True
        mask[x1, y1, z1] = True

    # Driving axis is Y-axis"
    elif dy >= dx and dy >= dz:
        p1 = 2 * dx - dy
        p2 = 2 * dz - dy
        while y1 != y2:
            y1 += ys
            if p1 >= 0:
                x1 += xs
                p1 -= 2 * dy
            if p2 >= 0:
                z1 += zs
                p2 -= 2 * dy
            p1 += 2 * dx
            p2 += 2 * dz
            mask[x1, y1, z1] = True
        mask[x1, y1, z1] = True

    # Driving axis is Z-axis"
    else:
        p1 = 2 * dy - dz
        p2 = 2 * dx - dz
        while z1 != z2:
            z1 += zs
            if p1 >= 0:
                y1 += ys
                p1 -= 2 * dz
            if p2 >= 0:
                x1 += xs
                p2 -= 2 * dz
            p1 += 2 * dy
            p2 += 2 * dx
            mask[x1, y1, z1] = True
        mask[x1, y1, z1] = True
    return mask

## Test code

In [7]:
img_path = './data/sub-P020_res-0p4mm_ct.nii.gz'
coord_path = './data/sub-P020_actual.fcsv'
ct_t1_trans = './data/sub-P020_desc-rigid_from-ct_to-T1w_type-ras_ses-post_xfm.txt'
final_path = 'test.nii.gz'

In [8]:
# First, read nifti file
img, img_data, img_aff, img_shape = read_nifti(img_path)

In [9]:
img_aff

array([[  -0.39936084,    0.        ,    0.        ,  125.04445648],
       [   0.        ,    0.39936084,    0.        , -137.95605469],
       [   0.        ,    0.        ,    0.39911309,  -10.36294365],
       [   0.        ,    0.        ,    0.        ,    1.        ]])

In [10]:
img_shape

(626, 626, 451)

In [11]:
# Now get coordinates in CT space
ct_coord_arr = transform_points_ct_space(coord_path, ct_t1_trans)
ct_coord_arr[0:3,:]

ValueError: shapes (3,3) and (4,) not aligned: 3 (dim 1) != 4 (dim 0)

In [None]:
# Transform to voxels
transformed_coords = transform_coords_vox_space(ct_coord_arr, img_aff)
transformed_coords[0:3,:]

In [None]:
# Create mask
final_mask = np.zeros(img_shape).astype(bool)
    
for i in range(0,len(transformed_coords),2):
    pointa = transformed_coords[i, :]
    pointb = transformed_coords[i+1, :]


    test_mask = create_line_mask(pointa, pointb, img_shape)

    footprint = ball(4)
    dilated = dilation(test_mask, footprint)
    result = gaussian_filter(dilated.astype(float), sigma=0.6)
    final_mask += result>0

clipped_img = nib.Nifti1Image(final_mask, img_aff, img.header)
nib.save(clipped_img, final_path)

In [None]:
img, img_data, img_aff, img_shape = read_nifti('./data/sub-P020_res-0p4mm_desc-contacts_mask.nii.gz')
img_aff, img_shape

In [None]:
img, img_data, img_aff, img_shape = read_nifti('./data/sub-P020_res-0p4mm_desc-z_norm_pred_ct.nii.gz')
img_aff, img_shape

## Post-process

### Functions

#### Old algorithm

In [7]:
import numpy as np
from scipy.spatial import distance
from sklearn.linear_model import LinearRegression

def trace_line(entry_point, target_point):
    # Vector from target to entry
    return (entry_point - target_point)/np.linalg.norm(target_point - entry_point)

def find_next_contact_orig(current_point,
                      direction,
                      contacts,
                      max_distance_target,
                      min_distance_prior,
                      avg_distance,
                      entry_point,
                      distance_from_entry=1.5):
    next_contact = None
    next_contact_id = None
    min_distance = float('inf')
    # print(current_point, avg_distance, direction)
    theory_point = current_point + avg_distance*direction
    # Projection of entry on direction vector
    entry_proj = np.dot(entry_point, direction) / np.linalg.norm(direction)
    
    for ind, contact_array in enumerate(contacts):
        contact = contact_array[0:3]
#         vector_to_contact = contact - current_point
#         angle = calculate_angle(vector_to_contact, direction)

#         # If the angle is lower than 30 degrees (vectors heading the same direction)
#         if dot_product > 0:
#             allowed_distance = radius + avg_distance
#         else:
#             allowed_distance = radius

        distance_target = np.linalg.norm(theory_point - contact)
        distance_prior = np.linalg.norm(current_point - contact)
        distance_entry = np.linalg.norm(entry_point - contact)
        contact_proj = np.dot(contact, direction) / np.linalg.norm(direction)
        print(ind, distance_target, distance_prior, distance_entry)
        print(0, max_distance_target*avg_distance, min_distance_prior*avg_distance, distance_from_entry, '\n')
        distance_metric = np.sqrt(distance_target**2+distance_prior**2) # balance between distance to target and prior
        # Check that it's the closest point to the target, 
        # that it's below a specific distance from the target
        # and above a certain distance from the current point,
        # that it is not too close to the entry
        # and that is not beyond the entry point
        if (distance_metric < min_distance and
            distance_target< max_distance_target*avg_distance and
            distance_prior >= min_distance_prior*avg_distance and
            (distance_entry >= distance_from_entry and contact_proj < entry_proj)):
            min_distance = distance_metric
            next_contact = contact
            next_contact_id = ind
    if next_contact_id is not None:
        print(next_contact_id, '\n')
        contacts = np.delete(contacts, next_contact_id, axis=0) 
    return next_contact, contacts

def adjust_direction_using_all_contacts(contacts, original_direction):
    contacts = np.array(contacts)[-3:,:] # only use last 3 coords
    X = contacts[:, 0].reshape(-1, 1)
    y1 = contacts[:, 1]
    y2 = contacts[:, 2]

    reg1 = LinearRegression().fit(X, y1)
    reg2 = LinearRegression().fit(X, y2)

    direction_vector = np.array([1, reg1.coef_[0], reg2.coef_[0]])
    direction_vector = direction_vector / np.linalg.norm(direction_vector)

    # Ensure the direction vector points toward the initial direction
    if np.dot(direction_vector, original_direction) < 0:
        direction_vector = -direction_vector
        
    return direction_vector

def run_segmentation_qc_orig(entry_point, target_point, contacts, max_distance_target, min_distance_prior, initial_spacing=4):
    original_direction = trace_line(entry_point, target_point)
    direction = original_direction
    current_point = target_point
    avg_distance = 1 # mm
    
    # Initialize with the closest contact to the entry
    current_point, contacts = find_next_contact_orig(current_point, direction, contacts, 2, 0, avg_distance, entry_point)
    found_contacts = [current_point]
    distances_between_contacts = []
    next_contact = current_point
    # Find next contact
    next_contact, contacts = find_next_contact_orig(current_point, direction, contacts, max_distance_target, min_distance_prior, initial_spacing, entry_point)
    # Calculate
    while next_contact is not None:   
        found_contacts.append(next_contact)
        distances_between_contacts.append(np.linalg.norm(next_contact - current_point))
        if len(found_contacts) >= 3:
            # print('Cambio direccion \n')
            direction = adjust_direction_using_all_contacts(found_contacts, original_direction)
            # print(direction)
        current_point = next_contact

        avg_distance = np.mean(distances_between_contacts)
        
        # Find next contact
        if len(contacts) > 0:
            next_contact, contacts = find_next_contact_orig(current_point, direction, contacts, max_distance_target, min_distance_prior, avg_distance, entry_point)
        else:
            next_contact = None
    print(f"Adjusted Mean Distance: {avg_distance}")
    return found_contacts, avg_distance

#### New algorithm

In [31]:
import numpy as np
from scipy.spatial import distance
from sklearn.linear_model import LinearRegression

# Function to calculate the angle between two vectors in degrees
def calculate_angle(v1, v2):
    dot_product = np.dot(v1, v2)
    magnitude_v1 = np.linalg.norm(v1)
    magnitude_v2 = np.linalg.norm(v2)
    cos_theta = dot_product / (magnitude_v1 * magnitude_v2)
    angle_rad = np.arccos(cos_theta)
    angle_deg = np.degrees(angle_rad)
    return angle_deg

def trace_line(entry_point, target_point):
    # Vector from target to entry
    return (entry_point - target_point)/np.linalg.norm(target_point - entry_point)

def find_next_contact_new(current_point,
                      direction,
                      contacts,
                      max_angle_target,
                      max_distance_prior, # percentage
                      min_distance_prior, # percentage
                      avg_distance,
                      entry_point,
                      distance_from_entry=1.5,
                        ):
    next_contact = None
    next_contact_id = None
    min_distance = float('inf')
    # print(current_point, avg_distance, direction)
    theory_point = current_point + avg_distance*direction
    # Projection of entry on direction vector
    entry_proj = np.dot(entry_point, direction) / np.linalg.norm(direction)
    
    for ind, contact_array in enumerate(contacts):
        contact = contact_array[0:3]
        
        # Vector last detected contact to possible next contact
        v_contact = trace_line(contact, current_point)
        # Angle between the predicted direction and the v_contact
        angle_contact = calculate_angle(v_contact, direction)
        distance_target = np.linalg.norm(theory_point - contact)
        distance_prior = np.linalg.norm(current_point - contact)
        distance_entry = np.linalg.norm(entry_point - contact)
        contact_proj = np.dot(contact, direction) / np.linalg.norm(direction)
        # print(ind, angle_contact, distance_prior, distance_entry)
        # print(0, max_distance_prior*avg_distance, min_distance_prior*avg_distance, distance_from_entry, '\n')
        distance_metric = np.sqrt(distance_target**2+distance_prior**2) # balance between distance to target and prior
        # Check that it's the closest point to the target, 
        # that it's below a specific distance from the target
        # and above a certain distance from the current point,
        # that it is not too close to the entry
        # and that is not beyond the entry point
        if (distance_metric < min_distance and
            angle_contact <= max_angle_target and
            distance_prior >= min_distance_prior*avg_distance and
            distance_prior <= max_distance_prior*avg_distance):
            if distance_from_entry is None or (distance_entry >= distance_from_entry and contact_proj < entry_proj):
                min_distance = distance_metric
                next_contact = contact
                next_contact_id = ind
    if next_contact_id is not None:
        # print(contacts[next_contact_id,-1], '\n')
        # print(next_contact_id, '\n')
        contacts = np.delete(contacts, next_contact_id, axis=0) 
    return next_contact, contacts

def adjust_direction_using_all_contacts(contacts, original_direction):
    n_elements_reg = 3 # use last 3 elements
    contacts = np.array(contacts)[-n_elements_reg:,:] # only use last 3 coords
    # print(contacts)
    
    # Check dimension where all vals are not equal
    dim = None
    i = 0
    dims = list(range(n_elements_reg))
    while dim is None and i<n_elements_reg:
        if not np.all(contacts[:,i] == contacts[0,i]):
            dim = dims.pop(i)
        i += 1
    assert dim is not None
    
    X = contacts[:, dim].reshape(-1, 1)
    y1 = contacts[:, dims[0]]
    y2 = contacts[:, dims[1]]

    reg1 = LinearRegression().fit(X, y1)
    reg2 = LinearRegression().fit(X, y2)

    direction_vector = np.ones(n_elements_reg)
    direction_vector[dims[0]] = reg1.coef_[0]
    direction_vector[dims[1]] = reg2.coef_[0]
    # Normalize
    direction_vector = direction_vector / np.linalg.norm(direction_vector)

    # Ensure the direction vector points toward the initial direction
    if np.dot(direction_vector, original_direction) < 0:
        direction_vector = -direction_vector
        
    return direction_vector

def run_segmentation_qc_new(entry_point,
                             target_point,
                             contacts, 
                             max_angle_target, 
                             max_distance_prior, 
                             min_distance_prior, 
                             initial_spacing=4,
                             max_n_contacts = -1 # set to -1 to apply distance threshold to entry
                           ):
    original_direction = trace_line(entry_point, target_point)
    direction = original_direction
    current_point = target_point
    avg_distance = 2 # mm
    
    # Initialize with the closest contact to the entry
    current_point, contacts = find_next_contact_new(current_point,
                                                    direction, 
                                                    contacts, 
                                                    90, 
                                                    2, 
                                                    0, 
                                                    avg_distance, 
                                                    entry_point)
    found_contacts = [current_point]
    distances_between_contacts = []
    next_contact = current_point
    # Find next contact
    next_contact, contacts = find_next_contact_new(current_point, 
                                                   direction, 
                                                   contacts, 
                                                   max_angle_target, 
                                                   max_distance_prior, 
                                                   min_distance_prior, 
                                                   initial_spacing, 
                                                   entry_point)
    # Calculate
    while next_contact is not None:   
        found_contacts.append(next_contact)
        distances_between_contacts.append(np.linalg.norm(next_contact - current_point))
        if len(found_contacts) >= 3:
            # print('Cambio direccion \n')
            # print(direction)
            direction = adjust_direction_using_all_contacts(found_contacts, original_direction)
            # print(direction)
        current_point = next_contact

        avg_distance = np.mean(distances_between_contacts)
        
        # Find next contact
        if len(contacts) > 0:
            next_contact, contacts = find_next_contact_new(current_point,
                                                           direction, 
                                                           contacts, 
                                                           max_angle_target, 
                                                           max_distance_prior, 
                                                           min_distance_prior, 
                                                           avg_distance, 
                                                           entry_point,
                                                          distance_from_entry=1.5 if len(found_contacts)>max_n_contacts else None)
        else:
            next_contact = None
    print(f"Adjusted Mean Distance: {avg_distance}")
    return found_contacts, avg_distance
    

#### Others

In [47]:
def df_to_fcsv(input_df, output_fcsv):
	with open(output_fcsv, 'w') as fid:
		fid.write("# Markups fiducial file version = 4.11\n")
		fid.write("# CoordinateSystem = 0\n")
		fid.write("# columns = id,x,y,z,ow,ox,oy,oz,vis,sel,lock,label,desc,associatedNodeID\n")
	
	out_df={'node_id':[],'x':[],'y':[],'z':[],'ow':[],'ox':[],'oy':[],'oz':[],
		'vis':[],'sel':[],'lock':[],'label':[],'description':[],'associatedNodeID':[]
	}
	
	for idx,ifid in input_df.iterrows():
		out_df['node_id'].append(idx+1)
		out_df['x'].append(ifid.iloc[0])
		out_df['y'].append(ifid.iloc[1])
		out_df['z'].append(ifid.iloc[2])
		out_df['ow'].append(0)
		out_df['ox'].append(0)
		out_df['oy'].append(0)
		out_df['oz'].append(0)
		out_df['vis'].append(1)
		out_df['sel'].append(1)
		out_df['lock'].append(1)
		out_df['label'].append(str(ifid.iloc[3]))
		out_df['description'].append('')
		out_df['associatedNodeID'].append('')

	out_df=pd.DataFrame(out_df)
	out_df.to_csv(output_fcsv, sep=',', index=False, lineterminator="", mode='a', header=False, float_format = '%.3f')

def filter_points(points_df, mask_data, mask_affine):
    filtered_points = []
    for _, row in points_df.iterrows():
        point = row[[1, 2, 3]].to_numpy()
        if inside_elec_mask(point, mask_data, mask_affine):
            filtered_points.append(row)
    return pd.DataFrame(filtered_points)

### First stage: get contacts position

In [8]:
pnms_dir = f'./data/prob_nms'
gt_dir = './data'

In [9]:
subjects = []

for root, dirs, files in os.walk(pnms_dir):
    for file in files:
        subject = file.split('_')[0]
        if subject not in subjects:
            subjects.append(subject)

print(subjects)

['sub-D128', 'sub-D112', 'sub-D129']


In [10]:
sub = subjects[1]
orig_pnms = pd.read_csv(f'{pnms_dir}/{sub}_prob_nms.fcsv', skiprows=3, header = None)
ct_t1_trans = np.loadtxt(f'{gt_dir}/{sub}/{sub}_desc-rigid_from-ct_to-T1w_type-ras_ses-post_xfm.txt')
ct = nib.load(f'./data/{sub}/{sub}_res-0p4mm_ct.nii.gz')
entry_target_path = f'./data/{sub}/{sub}_actual.fcsv'

### Second stage: apply masking filtering

In [12]:
# Get entry-target coords
entry_target = extract_coords(entry_target_path)
ind = 14-4
entry_target_interest = entry_target[ind:ind+2,:]
label = entry_target_interest[0,-1]
print(entry_target_interest)
# Transform to CT space
entry_target_interest = mne.transforms.apply_trans(np.linalg.inv(ct_t1_trans), entry_target_interest[:,:-1].astype(float))
entry_target_interest, label

[[-32.849 33.566 -12.016 'LPIn']
 [-13.9734 -17.8134 55.5648 'LPIn']]


(array([[-21.54201376,  25.53473941,  40.50927297],
        [ -8.25821108, -16.6060593 , 115.41583266]]),
 'LPIn')

In [13]:
inv_affine = np.linalg.inv(ct.affine)
data = np.asarray(ct.dataobj)
inv_affine.shape, data.shape

((4, 4), (626, 626, 401))

In [14]:
# Transform to voxels
transformed_coords = np.round(mne.transforms.apply_trans(inv_affine, entry_target_interest)).astype(int)
transformed_coords

array([[389, 372, 143],
       [356, 266, 330]])

In [15]:
# Create mask
img_shape = data.shape
final_mask = np.zeros(img_shape).astype(bool)
    
for i in range(0,len(transformed_coords),2):
    pointa = transformed_coords[i, :]
    pointb = transformed_coords[i+1, :]


    test_mask = create_line_mask(pointa, pointb, img_shape)

    footprint = ball(4)
    dilated = dilation(test_mask, footprint)
    result = gaussian_filter(dilated.astype(float), sigma=0.6)
    final_mask += result>0

In [16]:
# Mask based on CT intensity
mask_intensity = (data>2500).astype(bool)

# Merge the masks
merged_mask = final_mask*mask_intensity

# Dilate a bit
footprint = ball(2)
merged_mask = dilation(merged_mask, footprint)

np.unique(merged_mask), merged_mask.shape

(array([False,  True]), (626, 626, 401))

In [18]:
coords_interest = orig_pnms.iloc[:,1:4].to_numpy()
coords_interest.shape

(109, 3)

In [20]:
# Get contacts positions in voxel space
transformed_coords = np.round(mne.transforms.apply_trans(inv_affine, coords_interest)).astype(int)
transformed_coords[:3,:]

array([[343, 434, 140],
       [371, 270, 146],
       [399, 432, 166]])

In [22]:
# Mask the contacts inside the mask
filtered_coords = coords_interest[merged_mask[transformed_coords[:,0], transformed_coords[:,1], transformed_coords[:,2]]]
filtered_coords.shape, merged_mask[transformed_coords[:,0], transformed_coords[:,1], transformed_coords[:,2]]
# electrode_mask = nib.load(f'./data/{sub}/{sub}_res-0p4mm_desc-electrode_mask.nii.gz')
# electrode_mask = electrode_mask.get_fdata()

# filtered_coords = coords_interest[electrode_mask[transformed_coords[:,0], transformed_coords[:,1], transformed_coords[:,2]] > 0]
# filtered_coords.shape

((10, 3),
 array([False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False,  True, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False,  True, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False,  True,  True, False,  True, False,  True, False,
        False, False, False, False, False, False, False, False, False,
        False, False,  True,  True,  True,  True, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False]))

In [26]:
dict_df = {
    'x': filtered_coords[:,0],
    'y': filtered_coords[:,1],
    'z': filtered_coords[:,2],
    'label_order': [f'{label}-{idx+1:02d}' for idx in range(filtered_coords.shape[0])],

}
labeled_contacts = pd.DataFrame(dict_df)
print(labeled_contacts)

        x       y       z label_order
0 -20.708  22.811  45.021     LPIn-01
1 -21.507  24.408  42.627     LPIn-02
2 -17.513  14.025  60.582     LPIn-03
3 -16.715  11.229  65.370     LPIn-04
4 -18.312  15.622  58.188     LPIn-05
5 -19.909  21.213  47.415     LPIn-06
6 -17.513  13.226  62.976     LPIn-07
7 -19.909  20.414  50.208     LPIn-08
8 -18.312  17.220  54.996     LPIn-09
9 -19.111  18.817  52.602     LPIn-10


In [66]:
df_to_fcsv(labeled_contacts, f'./data/{sub}/test_fiducials.fcsv')

### Third stage: filter with line algorithm

In [27]:
# Example usage
target_point = entry_target_interest[0,:]  # Replace with actual entry point coordinates
entry_point = entry_target_interest[1,:]  # Replace with actual target point coordinates
contacts = filtered_coords  # Replace with your file path
min_distance_prior = 0.5 # at least half the average distance between contacts
max_angle_target = 35
max_distance_prior = 2 # not more than twice the average distance between contacts

# found_contacts, avg_distance = run_segmentation_qc(entry_point, target_point, contacts, max_distance_target, min_distance_prior)
found_contacts, avg_distance = run_segmentation_qc_new(entry_point,
                                                        target_point,
                                                        contacts, 
                                                        max_angle_target, 
                                                        max_distance_prior, 
                                                        min_distance_prior, 
                                                        initial_spacing=4)
print("Found Contacts:", found_contacts)
print("Average Distance:", avg_distance)


1 

0 

3 

4 

5 

4 

2 

0 

1 

0 

Adjusted Mean Distance: 3.016676015126724
Found Contacts: [array([-21.507,  24.408,  42.627]), array([-20.708,  22.811,  45.021]), array([-19.909,  21.213,  47.415]), array([-19.909,  20.414,  50.208]), array([-19.111,  18.817,  52.602]), array([-18.312,  17.22 ,  54.996]), array([-18.312,  15.622,  58.188]), array([-17.513,  14.025,  60.582]), array([-17.513,  13.226,  62.976]), array([-16.715,  11.229,  65.37 ])]
Average Distance: 3.016676015126724


In [28]:
len(found_contacts), len(filtered_coords)

(10, 10)

In [74]:
found_contacts = np.array(found_contacts)

In [92]:
found_contacts.shape

(10, 3)

In [75]:
dict_df = {
    'x': found_contacts[:,0],
    'y': found_contacts[:,1],
    'z': found_contacts[:,2],
    'label_order': [f'{label}-{idx+1:02d}' for idx in range(found_contacts.shape[0])],

}
labeled_contacts = pd.DataFrame(dict_df)
df_to_fcsv(labeled_contacts, f'./data/{sub}/test_fiducials.fcsv')

### Use in all subjects

In [29]:
pnms_dir = f'./data/prob_nms'
gt_dir = './data'

# Find subjects
subjects = []

for root, dirs, files in os.walk(pnms_dir):
    for file in files:
        subject = file.split('_')[0]
        if subject not in subjects:
            subjects.append(subject)

print(subjects)

['sub-D128', 'sub-D112', 'sub-D129']


In [32]:
for sub in subjects[1:2]: #.sort()
    print(sub)
    orig_pnms = pd.read_csv(f'{pnms_dir}/{sub}_prob_nms.fcsv', skiprows=3, header = None)
    ct_t1_trans = np.loadtxt(f'{gt_dir}/{sub}/{sub}_desc-rigid_from-ct_to-T1w_type-ras_ses-post_xfm.txt')
    ct = nib.load(f'./data/{sub}/{sub}_res-0p4mm_ct.nii.gz')
    entry_target_path = f'./data/{sub}/{sub}_actual.fcsv'
    
    # Get affine to transform to voxels
    inv_affine = np.linalg.inv(ct.affine)
    data = np.asarray(ct.dataobj)

    # Get entry-target coords
    entry_target = extract_coords(entry_target_path)
    # Transform to CT space
    entry_target_coords = mne.transforms.apply_trans(np.linalg.inv(ct_t1_trans), entry_target[:,:-1].astype(float))
    # Transform to voxels
    entry_target_vox = np.round(mne.transforms.apply_trans(inv_affine, entry_target_coords)).astype(int)
    
    # Coordinates outputted by the network
    coords_interest = orig_pnms.iloc[:,1:4].to_numpy()
    # Get contacts positions in voxel space
    transformed_coords = np.round(mne.transforms.apply_trans(inv_affine, coords_interest)).astype(int)
    
    # Empty final df
    df_pos = pd.DataFrame(columns=['x', 'y', 'z', 'label_order'])
    
    """
    First stage: Masking filtering
    """
    electrode_mask = nib.load(f'./data/{sub}/{sub}_res-0p4mm_desc-electrode_mask.nii.gz')
    electrode_mask = electrode_mask.get_fdata()

    filtered_coords = coords_interest[electrode_mask[transformed_coords[:,0], transformed_coords[:,1], transformed_coords[:,2]] > 0]

    """
    Second stage: Line filtering
    """
    # Run a first time to figure out the amount of contacts
    # Compute for each pair of entry-target
    n_contacts = []
    contacts_coords = []
    for i in range(0,entry_target.shape[0],2):  
        # Label for the contact
        label = entry_target[i,-1]
        print(label, flush=True)        
        found_contacts, _ = run_segmentation_qc_new(entry_point = entry_target_coords[i+1, :],
                                                               target_point = entry_target_coords[i, :],
                                                               contacts = filtered_coords, 
                                                               max_angle_target = 35, 
                                                               max_distance_prior = 2, 
                                                               min_distance_prior = 0.5, 
                                                               initial_spacing=4)
        n_contacts.append(len(found_contacts))
        contacts_coords.append(found_contacts)
    # Find the max_n_contacts
    print(n_contacts)
    elements, count = np.unique(n_contacts, return_counts=True)
    max_n_contacts = np.max(elements[count>1])
    print('Number of contacts for electrode: ', max_n_contacts)
    
    # Now repeat to find the actual points
    # Compute for each pair of entry-target
    for i in range(0,entry_target.shape[0],2):  
        # Label for the contact
        label = entry_target[i,-1]
        print(label, flush=True)        
        # Only re-execute if len(found_contacts)<max_n_contacts
        found_contacts = contacts_coords[i//2]
        if len(found_contacts) < max_n_contacts:
            found_contacts, avg_distance = run_segmentation_qc_new(entry_point = entry_target_coords[i+1, :],
                                                               target_point = entry_target_coords[i, :],
                                                               contacts = filtered_coords, 
                                                               max_angle_target = 35, 
                                                               max_distance_prior = 2, 
                                                               min_distance_prior = 0.5, 
                                                               initial_spacing=4,
                                                               max_n_contacts=max_n_contacts)
        """
        Third stage: Labelling
        """
        found_contacts = np.array(found_contacts)
        t1_coords = mne.transforms.apply_trans(ct_t1_trans, found_contacts)
        df_tmp = pd.DataFrame(t1_coords, columns=['x','y','z'])
        df_tmp['label_order'] = [f'{label}-{idx+1:02d}' for idx in range(found_contacts.shape[0])]
        df_pos = pd.concat([df_pos, df_tmp])
        print('Number of contacts found: ',  found_contacts.shape[0], '\n')
    


    # df_to_fcsv(df_pos, final_fname_t1)

sub-D112
LAm
Adjusted Mean Distance: 4.012370745962754
LAHc
Adjusted Mean Distance: 4.889503697160149
LPHc
Adjusted Mean Distance: 4.925406591290107
RAIn
Adjusted Mean Distance: 3.0548136896637557
LAIn
Adjusted Mean Distance: 2.9836174317785957
LPIn
Adjusted Mean Distance: 3.016676015126724
LOFr
Adjusted Mean Distance: 5.002209551373188
LInfTe
Adjusted Mean Distance: 4.063290593123217
LPTe
Adjusted Mean Distance: 4.96340957275441
LPTeOc
Adjusted Mean Distance: 4.022785432892366
[10, 8, 10, 10, 10, 10, 9, 10, 9, 10]
Number of contacts for electrode:  10
LAm
Number of contacts found:  10 

LAHc
Adjusted Mean Distance: 4.887619487260263
Number of contacts found:  9 

LPHc
Number of contacts found:  10 

RAIn
Number of contacts found:  10 

LAIn
Number of contacts found:  10 

LPIn
Number of contacts found:  10 

LOFr
Adjusted Mean Distance: 5.002209551373188
Number of contacts found:  9 

LInfTe
Number of contacts found:  10 

LPTe
Adjusted Mean Distance: 4.96340957275441
Number of contac

In [49]:
np.unique(merged_mask)

array([False,  True])

In [104]:
final_fname_t1 = f'./data/{sub}/test_fiducials_D129.fcsv'
df_to_fcsv(df_pos, final_fname_t1)

In [40]:
df_pos[48:59]

Unnamed: 0,x,y,z,label_order
9,-23.279147,42.097845,12.778976,LAIn-10
0,-32.655666,32.17115,-10.073695,LPIn-01
1,-31.676507,30.288086,-7.97242,LPIn-02
2,-30.697338,28.40403,-5.871277,LPIn-03
3,-29.516781,25.361386,-1.114054,LPIn-04
4,-28.537622,23.478321,0.987222,LPIn-05
5,-28.299812,21.474331,3.931696,LPIn-06
6,-27.320653,19.591266,6.032972,LPIn-07
7,-27.146087,18.484227,8.294304,LPIn-08
8,-26.164128,16.204695,10.342604,LPIn-09


### Example Test cases

In [11]:
entry_target_path = './data/sub-P020_actual.fcsv'
entry_target = extract_coords(entry_target_path)
ind =8-4
entry_target_interest = entry_target[ind:ind+2,:]
# RAIn: 18-4
# LAIn: 28-4
# RPIn: 20-4
# RPHc: 8-4
entry_target_interest 

array([[14.091, -7.65, 44.963, 'RPHc'],
       [62.119, -14.953, 45.604, 'RPHc']], dtype=object)

In [21]:
coords = './data/sub-P020_space-T1w_desc-unet_pnms.fcsv'
coords = extract_coords(coords)
ind = 123-4
coords_interest = coords[ind:ind+17,:]
# RAIn: 69
# LAIn: 0
# RPIn: 140-4
# RPHc: 123-4
coords_interest

array([[15.675, -7.837, 44.617, 'RPHc1'],
       [20.441, -8.295, 45.072, 'RPHc2'],
       [25.245, -9.685, 44.893, 'RPHc3'],
       [27.7, -7.364, 44.644, 'RPHc4'],
       [30.453, -8.579, 45.087, 'RPHc5'],
       [30.397, -10.541, 45.454, 'RPHc6'],
       [34.437, -10.744, 44.991, 'RPHc7'],
       [35.29, -8.793, 44.687, 'RPHc8'],
       [39.988, -11.605, 45.584, 'RPHc9'],
       [40.84, -9.654, 45.28, 'RPHc10'],
       [43.534, -14.082, 45.921, 'RPHc11'],
       [44.825, -11.819, 45.185, 'RPHc12'],
       [49.579, -12.669, 45.714, 'RPHc13'],
       [50.432, -10.718, 45.41, 'RPHc14'],
       [55.19, -13.678, 45.524, 'RPHc15'],
       [55.614, -11.647, 45.579, 'RPHc16'],
       [59.171, -13.732, 45.842, 'RPHc17']], dtype=object)

#### First apply masking


In [13]:
import mne

In [14]:
# ct = nib.load('./data/sub-D128_space-T1w_desc-rigid_ses-post_ct.nii.gz')
ct = nib.load('./data/sub-P020_desc-rigid_space-T1w_ct.nii.gz')
# ct = nib.load('./data/sub-P087_space-T1w_desc-rigid_ses-post_ct.nii.gz')
inv_affine = np.linalg.inv(ct.affine)
data = np.asarray(ct.dataobj)
inv_affine.shape, data.shape

((4, 4), (512, 512, 140))

In [23]:
# Transform to voxels
transformed_coords = np.round(mne.transforms.apply_trans(inv_affine, entry_target_interest[:,:-1].astype(float))).astype(int)
transformed_coords

array([[221, 293,  58],
       [119, 306,  58]])

In [24]:
# Create mask
img_shape = data.shape
final_mask = np.zeros(img_shape).astype(bool)
    
for i in range(0,len(transformed_coords),2):
    pointa = transformed_coords[i, :]
    pointb = transformed_coords[i+1, :]


    test_mask = create_line_mask(pointa, pointb, img_shape)

    footprint = ball(4)
    dilated = dilation(test_mask, footprint)
    result = gaussian_filter(dilated.astype(float), sigma=0.6)
    final_mask += result>0

In [25]:
# Mask based on CT intensity
mask_intensity = (data>2500).astype(bool)

# Merge the masks
merged_mask = final_mask*mask_intensity

# Dilate a bit
footprint = ball(2)
merged_mask = dilation(merged_mask, footprint)

np.unique(merged_mask), merged_mask.shape

(array([False,  True]), (512, 512, 140))

In [98]:
clipped_img = nib.Nifti1Image(merged_mask, ct.affine, ct.header)
nib.save(clipped_img, 'test_mask_sub-P020.nii.gz')

In [26]:
# Get contacts positions in voxel space
transformed_coords = np.round(mne.transforms.apply_trans(inv_affine, coords_interest[:,:-1].astype(float))).astype(int)
transformed_coords[:3,:]

array([[218, 293,  58],
       [208, 294,  58],
       [198, 297,  58]])

In [27]:
# Mask the contacts inside the mask
filtered_coords = coords_interest[merged_mask[transformed_coords[:,0], transformed_coords[:,1], transformed_coords[:,2]]]
filtered_coords.shape, merged_mask[transformed_coords[:,0], transformed_coords[:,1], transformed_coords[:,2]]

((13, 4),
 array([ True,  True,  True, False,  True,  True,  True, False,  True,
        False,  True,  True,  True, False,  True,  True,  True]))

#### Now line algorithm

In [28]:
entry_target_interest

array([[14.091, -7.65, 44.963, 'RPHc'],
       [62.119, -14.953, 45.604, 'RPHc']], dtype=object)

In [29]:
filtered_coords

array([[15.675, -7.837, 44.617, 'RPHc1'],
       [20.441, -8.295, 45.072, 'RPHc2'],
       [25.245, -9.685, 44.893, 'RPHc3'],
       [30.453, -8.579, 45.087, 'RPHc5'],
       [30.397, -10.541, 45.454, 'RPHc6'],
       [34.437, -10.744, 44.991, 'RPHc7'],
       [39.988, -11.605, 45.584, 'RPHc9'],
       [43.534, -14.082, 45.921, 'RPHc11'],
       [44.825, -11.819, 45.185, 'RPHc12'],
       [49.579, -12.669, 45.714, 'RPHc13'],
       [55.19, -13.678, 45.524, 'RPHc15'],
       [55.614, -11.647, 45.579, 'RPHc16'],
       [59.171, -13.732, 45.842, 'RPHc17']], dtype=object)

In [46]:
# Example usage
target_point = entry_target_interest[0,:-1]  # Replace with actual entry point coordinates
entry_point = entry_target_interest[1,:-1]  # Replace with actual target point coordinates
contacts = filtered_coords  # Replace with your file path
max_distance_target = 0.75  # Percentange from average 
min_distance_prior = 0.5 # at least half the average distance between contacts
max_angle_target = 20
max_distance_prior = 2 # not more than twice the average distance between contacts

# found_contacts, avg_distance = run_segmentation_qc(entry_point, target_point, contacts, max_distance_target, min_distance_prior)
found_contacts, avg_distance = run_segmentation_qc_new(entry_point,
                                                        target_point,
                                                        contacts, 
                                                        max_angle_target, 
                                                        max_distance_prior, 
                                                        min_distance_prior, 
                                                        initial_spacing=4)
print("Found Contacts:", found_contacts)
print("Average Distance:", avg_distance)


RPHc1 

RPHc2 

RPHc3 

RPHc6 

RPHc7 

RPHc9 

RPHc12 

RPHc13 

RPHc15 

RPHc17 

Adjusted Mean Distance: 4.911242544044423
Found Contacts: [array([15.675, -7.837, 44.617], dtype=object), array([20.441, -8.295, 45.072], dtype=object), array([25.245, -9.685, 44.893], dtype=object), array([30.397, -10.541, 45.454], dtype=object), array([34.437, -10.744, 44.991], dtype=object), array([39.988, -11.605, 45.584], dtype=object), array([44.825, -11.819, 45.185], dtype=object), array([49.579, -12.669, 45.714], dtype=object), array([55.19, -13.678, 45.524], dtype=object), array([59.171, -13.732, 45.842], dtype=object)]
Average Distance: 4.911242544044423


In [41]:
len(found_contacts), len(filtered_coords)

(10, 13)

In [42]:
found_contacts

[array([15.675, -7.837, 44.617], dtype=object),
 array([20.441, -8.295, 45.072], dtype=object),
 array([25.245, -9.685, 44.893], dtype=object),
 array([30.397, -10.541, 45.454], dtype=object),
 array([34.437, -10.744, 44.991], dtype=object),
 array([39.988, -11.605, 45.584], dtype=object),
 array([44.825, -11.819, 45.185], dtype=object),
 array([49.579, -12.669, 45.714], dtype=object),
 array([55.19, -13.678, 45.524], dtype=object),
 array([59.171, -13.732, 45.842], dtype=object)]