In [1]:
import os
import nibabel as nib
import numpy as np
import pyvista as pv
from datetime import datetime

from utils import *

##### \/ \/ \/ Inputs
input_dir = "D:\QTFE_local\Python\ElectrodesLocalization\sub11\in"
ct_path              = os.path.join(input_dir, "CT.nii.gz")
ct_mask_path         = os.path.join(input_dir, "CTMask.nii.gz")
electrodes_info_path = os.path.join(input_dir, "entry_points.txt")
##### /\ /\ /\

##### \/ \/ \/ HYPERPARAMETERS
# Masking electrodes in CT scan (Houndsfield units)
ELECTRODE_THRESHOLD = 2500

# Nb of attempts to find c_1
NB_ATTEMPTS_C1 = 50

# Max number of contacts computed in an electrode (without reaching tail) 
# before raising an exception (due to lack of convergence)
MAX_ITER = 50
##### /\ /\ /\


##### Fetching inputs
raw_ct, electrodes = get_inputs(ct_path, ct_mask_path, electrodes_info_path)

##### Preprocessing: applying threshold on CT
raw_ct.mask &= raw_ct.ct > ELECTRODE_THRESHOLD

In [2]:
#from scipy.ndimage import binary_erosion, binary_dilation
from skimage.morphology import binary_erosion, reconstruction
from scipy.ndimage import label, center_of_mass

"""
# TODO debug remove
plots = []
"""

def binary_ultimate_erosion(image: np.ndarray, ks=3):
    # Ref: Proakis, G. John

    
    """ 
    # TODO debug remove
    global plots
    """

    result = np.zeros_like(image)
    _ndims = len(image.shape)
    struct = np.ones((ks,)*_ndims, dtype=np.int32)

    while image.sum() > 0:
        """
        # TODO debug remove
        plot_line = []
        """

        eroded        = binary_erosion(image, struct)
        reconstructed = reconstruction(eroded, image, 'dilation', struct)
        
        # The xor is equivalent to image - opened if they only consist of 0's and 1's
        # knowing that opened < image
        result |= np.logical_xor(image, reconstructed)

        """
        # TODO debug remove
        plot_line.append(image)
        plot_line.append(eroded)
        plot_line.append(reconstructed)
        plot_line.append(np.copy(result))
        plots.append(plot_line)
        """

        image = eroded
    return result, struct

In [None]:
# Toy: plot iterations
"""
import matplotlib.pyplot as plt

def plot_lines():
    nrows = len(plots)
    ncols = len(plots[0])
    fig, axs = plt.subplots(nrows, ncols)
    for i, line in enumerate(plots[:3]):
        for j, plot in enumerate(line):
            axs[i][j].imshow(plot)
    plt.show()
"""


In [None]:
# TODO: add multiprocessing because slow
# TODO thesis: plot steps

ue, struct = binary_ultimate_erosion(raw_ct.mask, 3)
labels, n_contacts = label(ue, struct)

contacts_com = []
for i in range(1, n_contacts+1):
    contacts_com.append(center_of_mass(raw_ct.ct, labels, i))

contacts_com = np.stack(contacts_com)

In [2]:
def find_closest(
        target: np.ndarray, 
        coords: np.ndarray, 
        indices: set, 
        update_indices: bool = True
) -> np.ndarray:
    """Returns the coordinates in 'coords' the closest to 'target'.
    
    Inputs:
    - target: the target coordinates. Must be of shape (3,).
    - coords: all the candidate coordinates among which we want the closest to 
    'target'. Must be of shape (N, 3).
    - indices: the set of indices of the candidates in 'coords' to consider 
    (in range {0, ..., N-1}). In 'coords', only the rows present in 'indices'
    are searched. If 'update_indices' is True, at the end of the execution, 
    the best index is removed from the set.
    - update_indices: if True, the set 'indices' is modified during the execution
    to remove the best index found. If False, 'indices' is left untouched.
    
    Output:
    - best (np.ndarray): the closest point in 'target' found, of shape (3,)."""

    best_idx, best_dist = None, 1e20

    for i in indices:
        dist = np.linalg.norm(target-coords[i]) 
        if dist < best_dist:
            best_idx  = i
            best_dist = dist
    if update_indices:
        indices.remove(best_idx)
    return coords[best_idx]

def are_same_contact(contact_a, contact_b):
    return np.linalg.norm(contact_a - contact_b) < 1e-6

In [None]:
def segment_electrode(electrode: Electrode, contacts_com: np.ndarray, indices: set) -> None:
    """Returns the center of mass of all contacts of the electrode that spans
    between coordinates H (head) and T (tail).
    
    Inputs:
    - raw_ct (numpy.array): the CT image in which the contacts are segmented. Must be of shape (X, Y, Z).
    - H (numpy.array): The coordinates of the first contact of the electrode. Must be of shape (3,).
    - T (numpy.array): The coordinates of the last contact of the electrode. Must be of shape (3,).
    
    Returns:
    contacts (numpy.array): an array of shape (N, 3) of the coordinates of all N contacts
    found in the electrode. The coordinates contacts[0] (resp. contacts[-1])
    refer to the center of mass of the contact closest to H (resp. T).
    
    Ref: Arnulfo et. al."""

    # Finding the first contact
    c_0 = find_closest(electrode.head, contacts_com, indices, update_indices=False)
    electrode.add_contact(c_0)

    # Estimating the position of the second contact by starting from first contact
    # and moving along the vector T-H
    for j in range(1, NB_ATTEMPTS_C1+1):
        c_1_approx = c_0 + (j/NB_ATTEMPTS_C1) * (electrode.tail - electrode.head)

        c_1 = find_closest(c_1_approx, contacts_com, indices, update_indices=False)
        if not are_same_contact(c_0, c_1):
            electrode.add_contact(c_1)
            break
    
    # Now that c_0 and c_1 are known, estimate the rest of the contacts iteratively
    # by following vector c_i - c_(i-1) and finding closest contact
    c_end = find_closest(electrode.tail, contacts_com, indices, update_indices=False)
    iter = 0
    while not are_same_contact(electrode.contacts[-1], c_end) and iter < MAX_ITER:
        c_pp, c_p = electrode.contacts[-2], electrode.contacts[-1]
        c_i_approx = c_p + (c_p - c_pp)
        c_i = find_closest(c_i_approx, contacts_com, indices)
        electrode.add_contact(c_i)

        iter += 1
    
        if iter == MAX_ITER:
            pass # TODO handle case
            #raise RuntimeError("Max number of iterations/contacts reached but tail contact has not been found")

In [4]:
def segment_all_electrodes(raw_ct: RawCT, electrodes: List[Electrode], contacts_com: np.ndarray) -> None:
    # Segmenting one electrode at a time, for all electrodes
    n_contacts = contacts_com.shape[0]
    indices = set(range(n_contacts))
    
    for e in electrodes:
        segment_electrode(e, contacts_com, indices)

In [7]:
# Toy: resetting the electrodes
contacts_com = np.loadtxt("centers_of_mass.txt", dtype=np.float32)

electrodes = electrodes_copy
electrodes_copy = []
for e in electrodes:
    electrodes_copy.append(Electrode(e.name, e.head, e.tail))


In [10]:
# Toy: testing the algo
segment_all_electrodes(raw_ct, electrodes, contacts_com)
for e in electrodes:
    print(len(e.contacts), end='  ')



16  6  52  19  4  4  7  11  