In [None]:
!pip install plotly==5.3.1
!pip install SimpleITK
!pip install skan

In [3]:
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import os
import cv2
import SimpleITK as sitk
import plotly.express as px
import plotly.graph_objects as go
import math
from skimage.morphology import skeletonize, thin, medial_axis
from skan import skeleton_to_csgraph
from skimage import measure
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import plotly.figure_factory as ff
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot

In [4]:
def load_img_from_tiff(path2img):
    img = sitk.ReadImage(path2img)
    img_array = sitk.GetArrayFromImage(img)
    return(img_array)

In [5]:
path = "/content/drive/MyDrive/mydata/sea_urchin_data/3D/Galleria Piastra Modello 1 100x100 Echi1-10x.tif"
data = load_img_from_tiff(path)
for i in range(len(data)):
    data[i] = data[i]*(255.0/data[i].max())

In [6]:
blurred_data = np.zeros_like(data)
for i in range(len(blurred_data)):
    blurred_data[i] = cv2.GaussianBlur(data[i], (3,3),3)

thresholded_data = np.zeros_like(data)
for i in range(len(thresholded_data)):
    bin, thresholded_data[i] = cv2.threshold(blurred_data[i],185, 255, cv2.THRESH_BINARY)

medial_axis_skel = np.zeros_like(data)
for i in range(len(thresholded_data)):
    skel, distance = medial_axis(thresholded_data[i], return_distance=True)
    dist_on_skel = distance * skel
    medial_axis_skel[i]=skel

In [None]:
pixel_graph, coordinates, degrees = Z(medial_axis_skel)
im_3 = np.array([degrees])
fig = px.imshow(im_3, 
                animation_frame=1, 
                facet_col=0, 
                color_continuous_scale ="gray",
                title="Skeletonized")
fig.show()

In [14]:
def make_mesh(image, step_size=1):

    print("Transposing surface")
    p = image.transpose(2,1,0)
    
    print("Calculating surface")
    verts, faces, norm, val = measure.marching_cubes_lewiner(p,step_size=step_size, allow_degenerate=True)
    return verts, faces


def plotly_3d(verts, faces):
    x,y,z = zip(*verts) 
    
    print("Drawing")
    
    # Make the colormap single color since the axes are positional not intensity. 
#    colormap=['rgb(255,105,180)','rgb(255,255,51)','rgb(0,191,255)']
    colormap=['rgb(128,255,128)','rgb(255,128,128)']
    
    fig = ff.create_trisurf(x=x, y=y, z=z, plot_edges=False,
                        colormap=colormap,
                        simplices=faces,
                        #backgroundcolor='rgb(64, 64, 64)',
                        title="3D mesh")
    fig.update_layout(scene = dict(zaxis = dict(nticks=4, range=[-1,100])))
    iplot(fig)

def plt_3d(verts, faces):
    print("Drawing")
    x,y,z = zip(*verts) 
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111, projection='3d')

    # Fancy indexing: `verts[faces]` to generate a collection of triangles
    mesh = Poly3DCollection(verts[faces], linewidths=0.01, alpha=1)
    face_color = [1, 1, 1]
    mesh.set_facecolor(face_color)
    ax.add_collection3d(mesh)

    ax.set_xlim(0, max(x))
    ax.set_ylim(0, max(y))
    ax.set_zlim(0, max(z))
#     ax.set_axis_bgcolor((0.7, 0.7, 0.7))
    ax.set_facecolor((0.7,0.7,0.7))
    plt.show()

In [None]:
v, f = make_mesh(medial_axis_skel[:10])
plotly_3d(v, f)

In [None]:
!mkdir animation3

In [None]:
#sbs matplotlib, large output warning!
plt.rcParams['figure.figsize'] =14, 5

for i in range(len(data)):
    fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4)
    ax1.imshow(data[i], cmap='gray')
    ax1.set_xlabel("Raw image  z-axis slice: "+str(i+1))
    ax2.imshow(blurred_data[i], cmap='gray')
    ax2.set_xlabel("Gaussian blurred")
    ax3.imshow(thresholded_data[i], cmap='gray')
    ax3.set_xlabel("Thresholded")
    ax4.imshow(medial_axis_skel[i]*255, cmap='gray')
    ax4.set_xlabel("Medial axis")
    savename="/content/animation3/"+str(i)+".png"
    fig.savefig(savename)

In [None]:
!zip -r animation3.zip animation3/
