In [None]:
!pip install plotly==5.3.1
!pip install SimpleITK
!pip install skan #older version

In [None]:
!pip install tifffile

In [None]:
#copy and extract valentinas seg data
!cp -r /content/drive/MyDrive/mydata/sea_urchin_data/3D/val_segmentation.zip ./val_segmentation.zip
!unzip ./val_segmentation.zip

In [4]:
!mkdir euc_skels #for storing euclidean skeleton figs

In [5]:
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import os
import cv2
import SimpleITK as sitk
import plotly.express as px
import plotly.graph_objects as go
import math
from skimage.morphology import skeletonize, thin, medial_axis
from skan import skeleton_to_csgraph, Skeleton
from skan import summarize
from skan import draw
from skimage import measure
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import plotly.figure_factory as ff
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
import tifffile as tiff
import glob
from tqdm import tqdm
from IPython.display import clear_output

In [6]:
def load_img_from_tiff(path2img):
    img = sitk.ReadImage(path2img)
    img_array = sitk.GetArrayFromImage(img)
    return(img_array)

In [7]:
path = "/content/drive/MyDrive/mydata/sea_urchin_data/3D/Galleria Piastra Modello 1 100x100 Echi1-10x.tif"
data = load_img_from_tiff(path)
for i in range(len(data)):
    data[i] = data[i]*(255.0/data[i].max())

In [8]:
blurred_data = np.zeros_like(data)
for i in range(len(blurred_data)):
    blurred_data[i] = cv2.GaussianBlur(data[i], (3,3),3)

thresholded_data = np.zeros_like(data)
for i in range(len(thresholded_data)):
    bin, thresholded_data[i] = cv2.threshold(blurred_data[i],175, 255, cv2.THRESH_BINARY)

medial_axis_skel = np.zeros_like(data)
for i in range(len(thresholded_data)):
    # skel, distance = medial_axis(thresholded_data[i], return_distance=True)
    # dist_on_skel = distance * skel
    # medial_axis_skel[i]=skel
    skel = skeletonize(thresholded_data[i]//255)
    medial_axis_skel[i]=skel

pixel_graph, coordinates =  skeleton_to_csgraph(medial_axis_skel)
#rescue degrees!
#https://jni.github.io/skan/getting_started.html

In [None]:
#save numpy arr as tiff file
tiff.imsave("threshold.tiff", thresholded_data)
test_data = load_img_from_tiff("threshold.tiff")
plt.imshow(test_data[0])

In [None]:
#compare seg data sbs
filelist = glob.glob('Segmentation/*.tiff') 
filename_prefix = filelist[0][:-8]
filename_suffix = filelist[0][-5:]

number_ids=[]

for i in range(len(filelist)):
    
    number_ids.append(int(filelist[i][-8:-5]))
    filelist[i] = filelist[i][-8:-5]

filelist.sort()

sorted_filelist=[]

for j in range(len(filelist)):
    sorted_filelist.append(filename_prefix+filelist[j]+filename_suffix)

val_seg = np.array([plt.imread(fname) for fname in sorted_filelist])
seg_images = np.array([val_seg*255, thresholded_data, data])
fig = px.imshow(seg_images, 
                facet_col=0,
                animation_frame=1,
                color_continuous_scale ="gray")

for i, label in enumerate(['Seg3D', 'OpenCV (Python)', "Raw"]):
    fig.layout.annotations[i]['text'] = label

fig.layout.template = 'plotly_dark'
fig.show()

In [38]:
#contour heatmap
contour_arr=np.zeros_like(data)
perimeter_coordinates=[]
contour_area_thresh=0
for i in range(len(contour_arr)):
    contours, hierarcy = cv2.findContours(thresholded_data[i].astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
    drawn_contours=cv2.drawContours(contour_arr[i].copy(), contours, -1, (255), 1)
    
    peri_sub=[]
    for cnt in contours:
        area = cv2.contourArea(cnt)
        if area>contour_area_thresh:
            contour_arr[i]=drawn_contours
            peri_sub.append(np.array(cnt))
    perimeter_coordinates.append(peri_sub)

In [13]:
def define_circle(p1, p2, p3):
    """
    Returns the center, radius and curvature of the circle passing the given 3 points.
    In case the 3 points form a line, returns (None, infinity).
    """
    temp = p2[0] * p2[0] + p2[1] * p2[1]
    bc = (p1[0] * p1[0] + p1[1] * p1[1] - temp) / 2
    cd = (temp - p3[0] * p3[0] - p3[1] * p3[1]) / 2
    det = (p1[0] - p2[0]) * (p2[1] - p3[1]) - (p2[0] - p3[0]) * (p1[1] - p2[1])

    if abs(det) < 1.0e-6:
        return (None, np.inf, 0)

    # Center of circle
    cx = (bc*(p2[1] - p3[1]) - cd*(p1[1] - p2[1])) / det
    cy = ((p1[0] - p2[0]) * cd - (p2[0] - p3[0]) * bc) / det

    radius = np.sqrt((cx - p1[0])**2 + (cy - p1[1])**2)
    curvature=1/radius
    return ((cx, cy), radius, curvature)

def frame_2_contour_heatmap(frame, perimeter_coordinates, curvature_scale_factor):
    curvature_arr=[]
    corresp_coordinate=[]
    kernel=3
    curvature_heatmap=np.zeros_like(frame)

    for list_of_coordinates in perimeter_coordinates:
        for i in range(len(list_of_coordinates)):
            low, center, high = i-kernel, i, i+kernel
            center, radius, curvature=define_circle(tuple(list_of_coordinates[low][0]), 
                                                    tuple(list_of_coordinates[center][0]),
                                                    tuple(list_of_coordinates[high%len(list_of_coordinates)][0])) #index rolls over at end of arr
            curvature_arr.append(curvature)
            # print(curvature)
            corresp_coordinate.append(list(list_of_coordinates[i][0]))
    for i in range(len(corresp_coordinate)):
        c1, c2=corresp_coordinate[i][0],corresp_coordinate[i][1]
        pixel_value=curvature_arr[i]
        curvature_heatmap[c2,c1]=pixel_value*curvature_scale_factor
    # px.imshow(frame)
    # fig.show()
    return(curvature_heatmap)

In [14]:
curvature_heatmap_arr=np.zeros_like(data)
for i in range(len(curvature_heatmap_arr)):
    heatmap=frame_2_contour_heatmap(frame=contour_arr[i], 
                                    perimeter_coordinates=perimeter_coordinates[i],
                                    curvature_scale_factor=255)
    curvature_heatmap_arr[i]=heatmap

# im_3 = np.array([curvature_heatmap_arr])
# fig = px.imshow(im_3, 
#                 facet_col=0,
#                 animation_frame=1,
#                 color_continuous_scale ="turbo")

# for i, label in enumerate(["Contour Curvatures"]):
#     fig.layout.annotations[i]['text'] = label

# fig.layout.template = 'plotly_dark'
# fig.show()

#note: small area contours are ignored

In [37]:
#conv numpy skeleton to skl to summarize
skan_skel = Skeleton(medial_axis_skel)
skl=summarize(skan_skel)
skl.to_csv("summary.csv")

In [32]:
!rm -rf animation
!mkdir animation

In [None]:
#euclidean skel plot for animation
plt.rcParams['figure.figsize'] =15, 5

for i in tqdm(range(len(data))):
    fig, axarr = plt.subplots(1,4)
    temp_skan_skel = Skeleton(medial_axis_skel[i])
    temp_summary=summarize(temp_skan_skel)
    draw.overlay_euclidean_skeleton_2d(medial_axis_skel[i], 
                                    temp_summary, 
                                    skeleton_color_source='branch-distance', 
                                    axes=axarr[3])
    axarr[3].set_title("Euclidean Skeleton", fontsize=12)
    axarr[2].imshow(medial_axis_skel[i], cmap='gray')
    axarr[2].set_title("Medial Axis Skeleton", fontsize=12)
    axarr[2].axis("off")

    # axarr[2].imshow(curvature_heatmap_arr[0], cmap='jet')
    # axarr[2].set_title("Curvature Heatmap", fontsize=12)
    # axarr[2].axis("off")

    axarr[1].imshow(thresholded_data[i], cmap='gray')
    axarr[1].set_title("Thresholded Image", fontsize=12)
    axarr[1].axis("off")
    axarr[0].imshow(data[i], cmap='gray')
    axarr[0].set_title("Raw Image - slice "+str(i), fontsize=12)
    axarr[0].axis("off")
    clear_output(wait=True)

    savename="/content/animation/"+str(i)+".png"
    fig.savefig(savename)
# plt.savefig("euc_skels/"+str(i)+".png")

In [None]:
!zip -r animation.zip animation/

In [None]:
def make_mesh(image, step_size=1):
    print("Transposing surface")
    p = image.transpose(2,1,0)
    print("Calculating surface")
    verts, faces, norm, val = measure.marching_cubes_lewiner(p,step_size=step_size, allow_degenerate=True)
    return verts, faces


def plotly_3d(verts, faces):
    x,y,z = zip(*verts) 
    
    print("Drawing")
    
    # Make the colormap single color since the axes are positional not intensity. 
#    colormap=['rgb(255,105,180)','rgb(255,255,51)','rgb(0,191,255)']
    colormap=['rgb(128,255,128)','rgb(255,128,128)']
    
    fig = ff.create_trisurf(x=x, y=y, z=z, plot_edges=False,
                        colormap=colormap,
                        simplices=faces,
                        #backgroundcolor='rgb(64, 64, 64)',
                        title="3D mesh")
    fig.update_layout(scene = dict(zaxis = dict(nticks=4, range=[-1,100])))
    iplot(fig)

def plt_3d(verts, faces):
    print("Drawing")
    x,y,z = zip(*verts) 
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111, projection='3d')

    # Fancy indexing: `verts[faces]` to generate a collection of triangles
    mesh = Poly3DCollection(verts[faces], linewidths=0.01, alpha=1)
    face_color = [1, 1, 1]
    mesh.set_facecolor(face_color)
    ax.add_collection3d(mesh)

    ax.set_xlim(0, max(x))
    ax.set_ylim(0, max(y))
    ax.set_zlim(0, max(z))
#     ax.set_axis_bgcolor((0.7, 0.7, 0.7))
    ax.set_facecolor((0.7,0.7,0.7))
    plt.show()

In [None]:
v, f = make_mesh(medial_axis_skel[:10])
plotly_3d(v, f)