In [None]:
from tqdm import tqdm
from array_lib import *
from point3d_lib import Point
from ply_creation_lib import create_ply
from skimage.morphology import skeletonize
import matplotlib.pyplot as plt
import pydicom as dicom
import numpy as np
import itertools
import struct
import pickle
import time
import os

In [None]:
"""
input_folder - must contain dcm files directly within it and only of one scan.
offsets - range of colors to include
seed - the entry coordinate to the aorta, must be the center.
threshold - normalized image color lower threshold
image_center - point between the arteries of interest
"""
input_folder = f'{os.getcwd()}\\20240923'
offsets = (-15, 20)
seed = (20, 310, 160)
image_center = (150, 280, 230)
min_skeleton_length = 100

In [None]:
def read_dicom(input_folder: str) -> np.ndarray:
    files: list[str] = os.listdir(input_folder)
    data = [dicom.dcmread(f'{input_folder}\\{file}') for file in files if file.endswith('.dcm')]
    image = np.array([dicom.pixel_array(datum) for datum in data])
    return image

def normalize_image_colors(image: np.ndarray) -> np.ndarray:
    min_val = np.min(image)
    max_val = np.max(image)
    image = (image - min_val) / (max_val - min_val) * 255
    return image

image = normalize_image_colors(read_dicom(input_folder))

In [None]:
filtered_mask = custom_floodfill_3d(image, seed_point=seed, new_value=-1, offsets=offsets)
eroded_mask = erode_3d(filtered_mask)
heartless_mask = remove_heart(eroded_mask)
trimmed_mask = distinguish_3d(heartless_mask)

In [None]:
skeleton_mask = skeletonize(trimmed_mask)
closest_skeletons = find_closest_skeletons(skeleton_mask, image_center)
skeleton_points = np.concatenate([i[0] for i in closest_skeletons])
filtered_skeleton_mask = np.zeros_like(skeleton_mask).astype(bool)
filtered_skeleton_mask[skeleton_points[:, 0], skeleton_points[:, 1], skeleton_points[:, 2]] = True

selected_skeletons = floodfill_nearby_skeletons(heartless_mask, closest_skeletons)

In [None]:
def add_skeleton_points(skeletons: list[np.ndarray], filtered_skeleton_mask: np.ndarray, nearby_pixels: list) -> list[Point]:
    """
    Convert the skeleton mask into skeleton structures.
    """
    skeleton: list[Point] = []
    for skeleton_point in skeletons:
        skeleton.append(Point(skeleton_point))

    for point in skeleton:
        point_surround = point.get_surround_points(nearby_pixels)
        nearby_points = [p for p in point_surround if filtered_skeleton_mask[p]]
        for p in nearby_points:
            if not filtered_skeleton_mask[p]:
                continue
            for another_point in skeleton:
                if np.array_equal(another_point.coordinates, p):
                    point.add_nearby(another_point)
                    break
        point.check_state()
    return skeleton

skeletons = [s[0] for s in closest_skeletons]

nearby_pixels = list(itertools.product([-1, 0, 1], repeat=3))
nearby_pixels.remove((0, 0, 0))

left_skeleton = add_skeleton_points(skeletons[0], filtered_skeleton_mask, nearby_pixels)
right_skeleton = add_skeleton_points(skeletons[1], filtered_skeleton_mask, nearby_pixels)

In [None]:
def remove_skeleton_close_ends(skeleton: list[Point], closeness: int = 20):
    """
    Remove skeleton ends if they are closer than some 20 points of distance to the nearest cross.
    """    
    removed = True
    while removed:
        removed = False
        skeleton_ends = [p for p in skeleton if p.end]
        for end in skeleton_ends:
            if end.is_cross_close(closeness):
                end.remove_point()
                removed = True
    skeleton = [p for p in skeleton if p.value > -1]
    return skeleton

right_skeleton = remove_skeleton_close_ends(right_skeleton)
left_skeleton = remove_skeleton_close_ends(left_skeleton)

In [None]:
def find_closest_point(skeleton: list[Point], center_point: Point) -> Point:
    """
    Find the closest points of a skeleton to another point.
    """
    closest_point = skeleton[0]
    min_distance = 999999
    
    for point in skeleton:
        dist = center_point.distance_to_point(another=point)
        if dist < min_distance:
            closest_point = point
            min_distance = dist
    
    return closest_point

center_point = Point(image_center)
right_head = find_closest_point(right_skeleton, center_point)
left_head = find_closest_point(left_skeleton, center_point)

In [None]:
def find_path_to_ends(skeletons: list[tuple[list[Point], Point]], min_skeleton_length: int):
    paths = []
    for skeleton, head, name in skeletons:
        ends = [p for p in skeleton if p.end]
        for end in ends:
            path = head.path_to_end(end)[1]
            if len(path) < min_skeleton_length:
                continue
            paths.append((name, path))
    return paths

skeletons = [(left_skeleton, left_head, 'left'), (right_skeleton, right_head, 'right')]
branches = find_path_to_ends(skeletons, min_skeleton_length)

In [None]:
def display_branches(branches: list[list[list[Point]]]):
    for i, named_branch in enumerate(branches):
        branch_name, branch = named_branch
        path_mask = np.zeros_like(image).astype(bool)
        for point in branch:
            path_mask[tuple(point.coordinates)] = True
        create_ply(path_mask, f'{branch_name}_branch_{i}.ply')
        
display_branches(branches)

In [None]:
def display_mpr(named_branch: tuple[str, list[Point]]):
    branch_name, branch = named_branch
    mpr_test_x = np.zeros((image.shape[0], len(branch)))
    mpr_test_y = np.zeros((image.shape[1], len(branch)))
    mpr_test_z = np.zeros((image.shape[2], len(branch)))

    for i, p in enumerate(branch):
        coords = p.coordinates
        x_pixels = image[:, coords[1], coords[2]]
        y_pixels = image[coords[0], :, coords[2]]
        z_pixels = image[coords[0], coords[1], :]
        
        mpr_test_x[:, i] = x_pixels
        mpr_test_y[:, i] = y_pixels
        mpr_test_z[:, i] = z_pixels

    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    axes[0].imshow(mpr_test_x, cmap='gray')
    axes[1].imshow(mpr_test_y, cmap='gray')
    axes[2].imshow(mpr_test_z, cmap='gray')
    fig.suptitle(branch_name)
    plt.show()
    
for branch in branches:
    display_mpr(branch)

In [None]:
create_ply(image, f'1.0_image.ply')
create_ply(filtered_mask, f'1.1_filtered.ply')
create_ply(eroded_mask, f'1.2_eroded.ply')
create_ply(heartless_mask, f'1.3_heartless.ply')
create_ply(trimmed_mask, f'1.4_trimmed.ply')
create_ply(skeleton_mask, f'1.5_skeleton.ply')
create_ply(filtered_skeleton_mask, f'1.6_closest_skeletons.ply')