In [1]:
import os
import nibabel as nib
import sys
import pandas as pd 
import numpy as np
import subprocess
from glob import glob
from skimage.morphology import dilation
from skimage.morphology import ball
from scipy.ndimage import gaussian_filter
import re
import csv

## Functions

In [2]:
def read_nifti(img_path):
    test_img = nib.load(img_path)
    test_data = np.asarray(test_img.dataobj)
    test_aff = test_img.affine
    test_shape = test_img.shape

    return test_img, test_data, test_aff, test_shape

In [3]:
def determineFCSVCoordSystem(input_fcsv,overwrite_fcsv=False):
    # need to determine if file is in RAS or LPS
    # loop through header to find coordinate system
    coordFlag = re.compile('# CoordinateSystem')
    verFlag = re.compile('# Markups fiducial file version')
    headFlag = re.compile('# columns')
    coord_sys=None
    headFin=None
    ver_fin=None

    with open(input_fcsv, 'r') as myfile:
        firstNlines=myfile.readlines()[0:3]
    for row in firstNlines:
        row=re.sub("[\s\,]+[\,]","",row).replace("\n","")
        cleaned_dict={row.split('=')[0].strip():row.split('=')[1].strip()}
        if None in list(cleaned_dict):
            cleaned_dict['# columns'] = cleaned_dict.pop(None)
        if any(coordFlag.match(x) for x in list(cleaned_dict)):
            coord_sys = list(cleaned_dict.values())[0]
        if any(verFlag.match(x) for x in list(cleaned_dict)):
            verString = list(filter(verFlag.match,  list(cleaned_dict)))
            assert len(verString)==1
            ver_fin = verString[0].split('=')[-1].strip()
        if any(headFlag.match(x) for x in list(cleaned_dict)):
            headFin=list(cleaned_dict.values())[0].split(',')
    return coord_sys
def extract_coords(file_path):
    coord_sys=determineFCSVCoordSystem(file_path)
    df = pd.read_csv(file_path, skiprows = 3, header = None)
    coord_arr = df[[1,2,3,11]].to_numpy()
    if any(x in coord_sys for x in {'LPS','1'}):
        coord_arr = coord_arr[0:3] * np.array([-1,-1,1])
    return coord_arr

In [4]:
def transform_points_ct_space(coord_path, ct_t1_path):
    t1_coords = extract_coords(coord_path)
    ct_coords = np.zeros(t1_coords.shape)
    transform = np.loadtxt(ct_t1_path)
    transform = np.linalg.inv(transform)

    M = transform[:3,:3]
    abc = transform[:3,3]


    for i in range(len(ct_coords)):
        vec = t1_coords[i,:]
        tvec = M.dot(vec) + abc
        ct_coords[i,:] = tvec[:3]

    # ct_coords = np.round(ct_coords).astype(int) #DELETE

    return ct_coords

In [5]:
def transform_coords_vox_space(planned_coords, img_aff):
    inv_affine = np.linalg.inv(img_aff)
    M = inv_affine[:3,:3]
    abc = inv_affine[:3,3]

    transform_coords = np.zeros(planned_coords.shape)

    for i in range(len(transform_coords)):
        vec = planned_coords[i,:]
        tvec = M.dot(vec) + abc
        transform_coords[i,:] = tvec[:3]
    
    transform_coords = np.round(transform_coords).astype(int)

    return transform_coords

def check_coord_dims(pointa, pointb, img_size):
    for i in range(3):
        if pointa[i] > img_size[i] or pointb[i] > img_size[i]:
            print('review planned_fcsv - dim mismatch')
            return False
    return True

In [13]:
def create_line_mask(point1, point2, shape):
    # Create an empty mask with the specified shape
    mask = np.zeros(shape, dtype=bool)
    x1, y1, z1 = point1
    x2, y2, z2 = point2
    
    mask[x1, y1, z1] = True
    # Get the directions of each axis
    dx = abs(x2 - x1)
    dy = abs(y2 - y1)
    dz = abs(z2 - z1)
    if x2 > x1:
        xs = 1
    else:
        xs = -1
    if y2 > y1:
        ys = 1
    else:
        ys = -1
    if z2 > z1:
        zs = 1
    else:
        zs = -1

    # Driving axis is X-axis
    if dx >= dy and dx >= dz:
        p1 = 2 * dy - dx
        p2 = 2 * dz - dx
        while x1 != x2:
            x1 += xs
            if p1 >= 0:
                y1 += ys
                p1 -= 2 * dx
            if p2 >= 0:
                z1 += zs
                p2 -= 2 * dx
            p1 += 2 * dy
            p2 += 2 * dz
            mask[x1, y1, z1] = True
        mask[x1, y1, z1] = True

    # Driving axis is Y-axis"
    elif dy >= dx and dy >= dz:
        p1 = 2 * dx - dy
        p2 = 2 * dz - dy
        while y1 != y2:
            y1 += ys
            if p1 >= 0:
                x1 += xs
                p1 -= 2 * dy
            if p2 >= 0:
                z1 += zs
                p2 -= 2 * dy
            p1 += 2 * dx
            p2 += 2 * dz
            mask[x1, y1, z1] = True
        mask[x1, y1, z1] = True

    # Driving axis is Z-axis"
    else:
        p1 = 2 * dy - dz
        p2 = 2 * dx - dz
        while z1 != z2:
            z1 += zs
            if p1 >= 0:
                y1 += ys
                p1 -= 2 * dz
            if p2 >= 0:
                x1 += xs
                p2 -= 2 * dz
            p1 += 2 * dy
            p2 += 2 * dx
            mask[x1, y1, z1] = True
        mask[x1, y1, z1] = True
    return mask

## Test code

In [57]:
img_path = './data/sub-P020_res-0p4mm_ct.nii.gz'
coord_path = './data/sub-P020_actual.fcsv'
ct_t1_trans = './data/sub-P020_desc-rigid_from-ct_to-T1w_type-ras_ses-post_xfm.txt'
final_path = 'test.nii.gz'

In [5]:
# First, read nifti file
img, img_data, img_aff, img_shape = read_nifti(img_path)

In [65]:
img_aff

array([[  -0.39936084,    0.        ,    0.        ,  125.04445648],
       [   0.        ,    0.39936084,    0.        , -137.95605469],
       [   0.        ,    0.        ,    0.39911309,  -10.36294365],
       [   0.        ,    0.        ,    0.        ,    1.        ]])

In [66]:
img_shape

(626, 626, 451)

In [63]:
# Now get coordinates in CT space
ct_coord_arr = transform_points_ct_space(coord_path, ct_t1_trans)
ct_coord_arr[0:3,:]

array([[ 13.54638589,  -9.58834212,  34.12763581],
       [ 58.14983149,  -7.03882306,  26.09974994],
       [ 18.62660867, -18.41509008,  34.01957583]])

In [64]:
# Transform to voxels
transformed_coords = transform_coords_vox_space(ct_coord_arr, img_aff)
transformed_coords[0:3,:]

array([[279, 321, 111],
       [168, 328,  91],
       [266, 299, 111]])

In [61]:
# Create mask
final_mask = np.zeros(img_shape).astype(bool)
    
for i in range(0,len(transformed_coords),2):
    pointa = transformed_coords[i, :]
    pointb = transformed_coords[i+1, :]


    test_mask = create_line_mask(pointa, pointb, img_shape)

    footprint = ball(4)
    dilated = dilation(test_mask, footprint)
    result = gaussian_filter(dilated.astype(float), sigma=0.6)
    final_mask += result>0

clipped_img = nib.Nifti1Image(final_mask, img_aff, img.header)
nib.save(clipped_img, final_path)

In [71]:
img, img_data, img_aff, img_shape = read_nifti('./data/sub-P020_res-0p4mm_desc-contacts_mask.nii.gz')
img_aff, img_shape

(array([[  -0.39936084,    0.        ,    0.        ,  125.04445648],
        [   0.        ,    0.39936084,    0.        , -137.95605469],
        [   0.        ,    0.        ,    0.39911309,  -10.36294365],
        [   0.        ,    0.        ,    0.        ,    1.        ]]),
 (626, 626, 451))

In [74]:
img, img_data, img_aff, img_shape = read_nifti('./data/sub-P020_res-0p4mm_desc-z_norm_pred_ct.nii.gz')
img_aff, img_shape

(array([[-3.99116344e-01, -1.40737596e-02, -3.17369709e-03,
          1.23237984e+02],
        [ 1.36259463e-02, -3.96592552e-01,  4.50026941e-02,
          1.18835777e+02],
        [-4.74382046e-03,  4.49255352e-02,  3.96279473e-01,
         -2.73388290e+01],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          1.00000000e+00]]),
 (601, 601, 351))

## Post-process

### Functions

In [76]:
import numpy as np
from scipy.spatial import distance
from sklearn.linear_model import LinearRegression


def trace_line(entry_point, target_point):
    return (entry_point - target_point)/np.linalg.norm(target_point - entry_point)

def find_next_contact(current_point,
                      direction,
                      contacts,
                      max_distance_target,
                      min_distance_prior,
                      avg_distance,
                      entry_point,
                      distance_from_entry=2):
    next_contact = None
    next_contact_id = None
    min_distance = float('inf')
    # print(current_point, avg_distance, direction)
    theory_point = current_point + avg_distance*direction
    # Projection of entry on direction vector
    entry_proj = np.dot(entry_point, direction) / np.linalg.norm(direction)
    
    for ind, contact_array in enumerate(contacts):
        contact = contact_array[0:3]
#         vector_to_contact = contact - current_point
#         angle = calculate_angle(vector_to_contact, direction)

#         # If the angle is lower than 30 degrees (vectors heading the same direction)
#         if dot_product > 0:
#             allowed_distance = radius + avg_distance
#         else:
#             allowed_distance = radius

        distance_target = np.linalg.norm(theory_point - contact)
        distance_prior = np.linalg.norm(current_point - contact)
        distance_entry = np.linalg.norm(entry_point - contact)
        contact_proj = np.dot(contact, direction) / np.linalg.norm(direction)
        # print(contacts[ind,-1], distance_target, distance_prior, distance_entry)
        distance_metric = np.sqrt(distance_target**2+distance_prior**2) # balance between distance to target and prior
        # Check that it's the closest point to the target, 
        # that it's below a specific distance from the target
        # and above a certain distance from the current point,
        # that it is not too close to the entry
        # and that is not beyond the entry point
        if (distance_metric < min_distance and
            distance_target< max_distance_target*avg_distance and
            distance_prior >= min_distance_prior*avg_distance and
            (distance_entry >= distance_from_entry and contact_proj < entry_proj)):
            min_distance = distance_metric
            next_contact = contact
            next_contact_id = ind
    if next_contact_id is not None:
        print(contacts[next_contact_id,-1], '\n')
        contacts = np.delete(contacts, next_contact_id, axis=0) 
    return next_contact, contacts

def adjust_direction_using_all_contacts(contacts, original_direction):
    contacts = np.array(contacts)[-3:,:] # only use last 3 coords
    X = contacts[:, 0].reshape(-1, 1)
    y1 = contacts[:, 1]
    y2 = contacts[:, 2]

    reg1 = LinearRegression().fit(X, y1)
    reg2 = LinearRegression().fit(X, y2)

    direction_vector = np.array([1, reg1.coef_[0], reg2.coef_[0]])
    direction_vector = direction_vector / np.linalg.norm(direction_vector)

    # Ensure the direction vector points toward the initial direction
    if np.dot(direction_vector, original_direction) < 0:
        direction_vector = -direction_vector
        
    return direction_vector

def run_segmentation_qc(entry_point, target_point, contacts, max_distance_target, min_distance_prior):
    original_direction = trace_line(entry_point, target_point)
    direction = original_direction
    current_point = target_point
    avg_distance = 1 # mm
    
    # Initialize with the closest contact to the entry
    current_point, contacts = find_next_contact(current_point, direction, contacts, 2, 0, avg_distance, entry_point)
    found_contacts = [current_point]
    distances_between_contacts = []
    next_contact = current_point
    # Find next contact
    avg_distance = 3
    next_contact, contacts = find_next_contact(current_point, direction, contacts, max_distance_target, min_distance_prior, avg_distance, entry_point)
    # Calculate
    while next_contact is not None:   
        found_contacts.append(next_contact)
        distances_between_contacts.append(np.linalg.norm(next_contact - current_point))
        if len(found_contacts) >= 3:
            print('Cambio direccion \n')
            direction = adjust_direction_using_all_contacts(found_contacts, original_direction)
            print(direction)
        current_point = next_contact

        avg_distance = np.mean(distances_between_contacts)
        print(f"Adjusted Mean Distance: {avg_distance}")
        
        # Find next contact
        if len(contacts) > 0:
            next_contact, contacts = find_next_contact(current_point, direction, contacts, max_distance_target, min_distance_prior, avg_distance, entry_point)
        else:
            next_contact = None
    return found_contacts, avg_distance

### Usage

In [129]:
entry_target_path = './data/sub-P087_actual.fcsv'
entry_target = extract_coords(entry_target_path)
ind =4-4
entry_target_interest = entry_target[ind:ind+2,:]
# RAIn: 18-4
# LAIn: 28-4
# RPIn: 20-4
# RPHc: 8-4
entry_target_interest 

array([[-9.9107, 42.8761, 24.1209, 'LAm'],
       [-52.4179, 45.4588, 23.6232, 'LAm']], dtype=object)

In [128]:
coords = './data/sub-P087_space-T1w_desc-unet_pnms.fcsv'
coords = extract_coords(coords)
ind = 25-4
coords_interest = coords[ind:ind+11,:]
# RAIn: 69
# LAIn: 0
# RPIn: 140-4
# RPHc: 123-4
coords_interest

array([[-11.744, 43.34, 23.99, 'LAm1'],
       [-15.808, 43.595, 23.798, 'LAm2'],
       [-18.969, 43.687, 24.266, 'LAm3'],
       [-23.845, 43.572, 24.123, 'LAm4'],
       [-27.096, 44.198, 23.881, 'LAm5'],
       [-31.069, 43.919, 24.399, 'LAm6'],
       [-35.528, 44.186, 24.265, 'LAm7'],
       [-38.801, 44.418, 23.956, 'LAm8'],
       [-42.864, 44.674, 23.764, 'LAm9'],
       [-46.927, 44.929, 23.571, 'LAm10'],
       [-54.153, 45.277, 23.847, 'LAm11']], dtype=object)

### First apply masking


In [9]:
import mne

In [118]:
# ct = nib.load('./data/sub-D128_space-T1w_desc-rigid_ses-post_ct.nii.gz')
# ct = nib.load('./data/sub-P020_desc-rigid_space-T1w_ct.nii.gz')
ct = nib.load('./data/sub-P087_space-T1w_desc-rigid_ses-post_ct.nii.gz')
inv_affine = np.linalg.inv(ct.affine)
data = np.asarray(ct.dataobj)
inv_affine.shape, data.shape

((4, 4), (512, 512, 256))

In [130]:
# Transform to voxels
transformed_coords = np.round(mne.transforms.apply_trans(inv_affine, entry_target_interest[:,:-1].astype(float))).astype(int)
transformed_coords

array([[279, 257,  67],
       [361, 257,  56]])

In [131]:
# Create mask
img_shape = data.shape
final_mask = np.zeros(img_shape).astype(bool)
    
for i in range(0,len(transformed_coords),2):
    pointa = transformed_coords[i, :]
    pointb = transformed_coords[i+1, :]


    test_mask = create_line_mask(pointa, pointb, img_shape)

    footprint = ball(4)
    dilated = dilation(test_mask, footprint)
    result = gaussian_filter(dilated.astype(float), sigma=0.6)
    final_mask += result>0

In [132]:
# Mask based on CT intensity
mask_intensity = (data>2500).astype(bool)

# Merge the masks
merged_mask = final_mask*mask_intensity

# Dilate a bit
footprint = ball(2)
merged_mask = dilation(merged_mask, footprint)

np.unique(merged_mask), merged_mask.shape

(array([False,  True]), (512, 512, 256))

In [98]:
clipped_img = nib.Nifti1Image(merged_mask, ct.affine, ct.header)
nib.save(clipped_img, 'test_mask_sub-P020.nii.gz')

In [133]:
# Get contacts positions in voxel space
transformed_coords = np.round(mne.transforms.apply_trans(inv_affine, coords_interest[:,:-1].astype(float))).astype(int)
transformed_coords[:3,:]

array([[282, 257,  66],
       [290, 257,  65],
       [296, 257,  65]])

In [134]:
# Mask the contacts inside the mask
filtered_coords = coords_interest[merged_mask[transformed_coords[:,0], transformed_coords[:,1], transformed_coords[:,2]]]
filtered_coords.shape, merged_mask[transformed_coords[:,0], transformed_coords[:,1], transformed_coords[:,2]]

((11, 4),
 array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True]))

### Now line algorithm

In [135]:
entry_target_interest

array([[-9.9107, 42.8761, 24.1209, 'LAm'],
       [-52.4179, 45.4588, 23.6232, 'LAm']], dtype=object)

In [136]:
# Example usage
target_point = entry_target_interest[0,:-1]  # Replace with actual entry point coordinates
entry_point = entry_target_interest[1,:-1]  # Replace with actual target point coordinates
contacts = filtered_coords  # Replace with your file path
max_distance_target = 0.75  # Percentange from average 
min_distance_prior = 0.5

found_contacts, avg_distance = run_segmentation_qc(entry_point, target_point, contacts, max_distance_target, min_distance_prior)

print("Found Contacts:", found_contacts)
print("Average Distance:", avg_distance)


LAm1 

LAm2 

Adjusted Mean Distance: 4.076516282317538
LAm3 

Cambio direccion 

[-0.99823278  0.04862784  0.03415636]
Adjusted Mean Distance: 3.636648659774292
LAm4 

Cambio direccion 

[-0.9993945  -0.00462079  0.03448584]
Adjusted Mean Distance: 4.050916377673691
LAm5 

Cambio direccion 

[-0.9973837   0.05590617 -0.04582848]
Adjusted Mean Distance: 3.868075867644764
LAm6 

Cambio direccion 

[-0.99818531  0.04363586  0.041497  ]
Adjusted Mean Distance: 3.8977263786810115
LAm7 

Cambio direccion 

[-9.99033104e-01 -1.78704158e-04  4.39639178e-02]
Adjusted Mean Distance: 3.9929379962564155
LAm8 

Cambio direccion 

[-0.99640328  0.06403329 -0.05549996]
Adjusted Mean Distance: 3.893336805540281
LAm9 

Cambio direccion 

[-0.99555114  0.06608677 -0.06715996]
Adjusted Mean Distance: 3.916117460387997
LAm10 

Cambio direccion 

[-0.99691473  0.06269055 -0.04723261]
Adjusted Mean Distance: 3.933834029942431
Found Contacts: [array([-11.744, 43.34, 23.99], dtype=object), array([-15.808, 43

In [137]:
len(found_contacts), len(filtered_coords)

(10, 11)

In [69]:
found_contacts

[array([-15.172, 37.191, -6.681], dtype=object),
 array([-19.891, 38.252, -7.148], dtype=object),
 array([-24.733, 39.094, -6.858], dtype=object),
 array([-29.466, 39.388, -7.549], dtype=object),
 array([-35.097, 40.279, -7.375], dtype=object),
 array([-39.027, 41.29, -7.727], dtype=object),
 array([-44.279, 41.391, -7.719], dtype=object),
 array([-48.997, 42.451, -8.186], dtype=object),
 array([-54.628, 43.342, -8.012], dtype=object),
 array([-59.346, 44.402, -8.479], dtype=object),
 array([-66.152, 45.75, -8.365], dtype=object)]