In [None]:
!pip install SimpleITK
!pip install skan 
!pip install tifffile

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import SimpleITK as sitk
from skimage.morphology import skeletonize, thin, medial_axis, skeletonize_3d
from scipy import ndimage
from skan import skeleton_to_csgraph, Skeleton
from skan import summarize
from skan import draw
import tifffile as tiff

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import cv2
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import glob
from tqdm.notebook import tnrange

In [None]:
def load_img_from_tiff(path2img):
    """
    Parameters
    ----------
    path2img: str 
        path to image Tiff file
    
    Returns
    -------
    img_array: np.array
        image data in numpy format
    """
    img = sitk.ReadImage(path2img)
    img_array = sitk.GetArrayFromImage(img)
    return(img_array)

In [None]:
!cp -r "/content/drive/MyDrive/mydata/sea_urchin_data/Galleried suture 1.zip" /content/
!unzip "/content/Galleried suture 1.zip"

In [None]:
def tiffs_to_3d_numpy(folderpath):
    '''
    Converts tiff files to numpy in bulk, also sorts based on filename

    args:
        folderpath (str): path to folder containing tiff files

    Returns:
        np.array: array containing data from all tiff files in folderpath
    '''

    filelist = glob.glob(folderpath+'*.tiff') 
    filename_prefix = filelist[0][:-8]
    filename_suffix = filelist[0][-5:]
    number_ids=[]

    for i in range(len(filelist)):
        number_ids.append(int(filelist[i][-8:-5]))
        filelist[i] = filelist[i][-8:-5]

    filelist.sort()
    sorted_filelist=[]

    for j in range(len(filelist)):
        sorted_filelist.append(filename_prefix+filelist[j]+filename_suffix)

    val_seg = np.array([plt.imread(fname) for fname in sorted_filelist])

    return(val_seg)

In [None]:
# show segmentation data
segmentation_data_np = tiffs_to_3d_numpy("/content/Galleried suture 1/")
images_to_show=np.array([segmentation_data_np])
fig = px.imshow(images_to_show, 
                facet_col=0,
                animation_frame=1,
                color_continuous_scale ="gray")
fig.show()

In [None]:
#compute 3d distance transfrom
dist_trans_3d=ndimage.distance_transform_edt(segmentation_data_np, return_distances=True)

#skeletonize
lee_3d_skimage_skeleton=skeletonize(segmentation_data_np, method='lee')
lee_3d_skan_skeleton=Skeleton(lee_3d_skimage_skeleton*dist_trans_3d)#*dist_trans_3d)
df_3d=summarize(lee_3d_skan_skeleton)

In [None]:
df_3d = df_3d.rename(columns={'image-coord-src-0': 'src-z', 
                        'image-coord-src-1': 'src-y',
                        'image-coord-src-2': 'src-x',
                        'image-coord-dst-0': 'dst-z',
                        'image-coord-dst-1': 'dst-y',
                        'image-coord-dst-2': 'dst-x',
                        'mean-pixel-value': 'thickness'})

#drop columns
df_3d.drop([#'mean-pixel-value',
         'stdev-pixel-value',
         'coord-src-0',
         'coord-src-1',
         'coord-dst-2',
         'coord-dst-1'], axis=1, inplace=True)

In [None]:
# branch avg thickness and endpoint node ids as csv
branches_3d = pd.DataFrame(df_3d['thickness'])
branches_3d['source_node_id']=df_3d['node-id-src']
branches_3d['destination_node_id']=df_3d['node-id-dst']
branches_3d.to_csv("3d_branches_lee.csv", index=False)
branches_3d.tail()

In [None]:
# csv with every branch endpoints
nodes_1_3d = pd.DataFrame(df_3d["node-id-src"])
nodes_2_3d=pd.DataFrame(df_3d["node-id-dst"])

nodes_1_3d = nodes_1_3d.rename(columns={'node-id-src': 'node_id'})
nodes_2_3d = nodes_2_3d.rename(columns={'node-id-dst': 'node_id'})

nodes_1_3d["node_coordinate_x"]=df_3d['src-x']
nodes_1_3d["node_coordinate_y"]=df_3d['src-y']
nodes_1_3d["node_coordinate_z"]=df_3d['src-z']

nodes_2_3d["node_coordinate_x"]=df_3d['dst-x']
nodes_2_3d["node_coordinate_y"]=df_3d['dst-y']
nodes_2_3d["node_coordinate_z"]=df_3d['dst-z']

final_nodes_3d=nodes_1_3d.append(nodes_2_3d)

final_nodes_3d=final_nodes_3d.drop_duplicates(subset=['node_id',], keep='last').reset_index()
final_nodes_3d = final_nodes_3d.drop(['index'], axis=1)
final_nodes_3d

final_nodes_3d.to_csv("3d_nodes_lee.csv", index=False)

In [None]:
# extract close nodes
close_node_pair_coordinates=[]

x1s_pre=np.array(final_nodes_3d['node_coordinate_x'])
y1s_pre=np.array(final_nodes_3d['node_coordinate_y'])
z1s_pre=np.array(final_nodes_3d['node_coordinate_z'])

x2s_pre=x1s_pre
y2s_pre=y1s_pre
z2s_pre=z1s_pre

for i in tnrange(len(final_nodes_3d)):
    for j in range(len(final_nodes_3d)):
        x1, y1, z1, x2, y2, z2 = (x1s_pre[i],
                                  y1s_pre[i],
                                  z1s_pre[i],
                                  x2s_pre[j],
                                  y2s_pre[j],
                                  z2s_pre[j])

        # print(x1, y1, z1, x2, y2, z2)

        p1 = np.array([x1, y1, z1])
        p2 = np.array([x2, y2, z2])

        #compute distance between nodes
        squared_dist = np.sum((p1-p2)**2, axis=0)
        dist = np.sqrt(squared_dist)

        if dist<2.0 and dist>0.0:
            thickness_2=dist_trans_3d[int(z2)][int(y2)][int(x2)]
            thickness_1=dist_trans_3d[int(z1)][int(y1)][int(x1)]
            avg_thickness=np.mean([thickness_1, thickness_2])
            close_node_pair_coordinates.append([(x1, y1, z1), (x2, y2, z2), avg_thickness])
            #close_node_pair_coordinates -> [[(src coords), (dst coords), dst], [(),(),t], ..]

            if thickness_1==0 or thickness_2==0:
                print("check", thickness_1, thickness_2)
                print("coords ", x1, y1, z1, x2, y2, z2)

In [None]:
# close nodes as pandas df
fine_branches=pd.DataFrame(close_node_pair_coordinates, columns=["src", "dst", "thickness"])

fine_branches[['x1', 'y1', 'z1']] = pd.DataFrame(fine_branches['src'].tolist(), 
                                                 index=fine_branches.index)

fine_branches[['x2', 'y2', 'z2']] = pd.DataFrame(fine_branches['dst'].tolist(), 
                                                 index=fine_branches.index)

fine_branches.drop('src', axis=1, inplace=True)
fine_branches.drop('dst', axis=1, inplace=True)

fine_branches
#now add intermediate nodes from branches

In [None]:
# treat each pair of adjacent skeleton point as a branch
dummy_=lee_3d_skan_skeleton.path_means()
fine_branches_2=fine_branches
for i in tnrange(len(dummy_)):
    path_coordinates=lee_3d_skan_skeleton.path_coordinates(i)
    # print(path_coordinates)

    for c in range(len(path_coordinates)-1):
        x1, y1, z1 = path_coordinates[c]
        x2, y2, z2 = path_coordinates[c+1]
        # print("src",x1, y1, z1, "dst",x2, y2, z2)

        thickness_2=dist_trans_3d[int(x2)][int(y2)][int(z2)]
        thickness_1=dist_trans_3d[int(x1)][int(y1)][int(z1)]
        avg_thickness=np.mean([thickness_1, thickness_2])

        temp_dict={'thickness':avg_thickness,
                   'x1':x1,
                   'y1':y1,
                   'z1':z1,
                   'x2':x2,
                   'y2':y2,
                   'z2':z2}

        fine_branches_2=fine_branches_2.append(temp_dict, ignore_index=True)

fine_branches_2.to_csv("finer_branches.csv", index=False)

In [None]:
#compute 3d node image
xs_3d=np.array(final_nodes_3d['node_coordinate_x'])
ys_3d=np.array(final_nodes_3d['node_coordinate_y'])
zs_3d=np.array(final_nodes_3d['node_coordinate_z'])
node_image_3d = np.zeros([100,100,100], dtype=np.uint8)
node_image_3d[zs_3d.astype(np.uint16), ys_3d.astype(np.uint16), xs_3d.astype(np.uint16)] = 4.

In [None]:
images_to_show=np.array([dist_trans_3d, segmentation_data_np])
fig = px.imshow(images_to_show, 
                facet_col=0,
                animation_frame=1)
# for i, label in enumerate(['2d distance transform', '3d distance transform', "Difference"]):
#     fig.layout.annotations[i]['text'] = label
fig.show()

In [None]:
# show thresh + nodes + skel as 2d slices
images_to_show=np.array([lee_3d_skimage_skeleton+node_image_3d+segmentation_data_np, dist_trans_3d])
fig = px.imshow(images_to_show, 
                facet_col=0,
                animation_frame=1,
                color_continuous_scale ="inferno")
fig.show()

In [None]:
# matplotlib static 2d projection of 3d skel
plt.rcParams["figure.figsize"] = (10,10)
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

size = 21
m = lee_3d_skimage_skeleton[:100]

pos = np.where(m==1)
ax.scatter(pos[0], pos[1], pos[2], c='green', alpha=0.3)
plt.show()

In [None]:
# 3d skel as plotly scatter
fig = go.Figure(data=[go.Scatter3d(
    x=pos[0],
    y=pos[1],
    z=pos[2],
    mode='markers',
    marker=dict(
        size=2,
        color=np.ones_like(pos[0]),         
        colorscale='Viridis',  
        opacity=0.6
    )
)])

# tight layout
fig.update_layout(margin=dict(l=0, r=0, b=0, t=0))
fig.show()