In [90]:
%matplotlib inline
import matplotlib.pyplot as plt

def plot(*args,**kwargs):
    ax = plt.figure(figsize=(5,5)).add_subplot(111)
    getattr(ax, 'plot')(*args,**kwargs)
    
def imshow(*args,**kwargs):
    ax = plt.figure(figsize=(10,10)).add_subplot(111)
    getattr(ax, 'imshow')(*args,**kwargs)

In [91]:
import os
import h5py
import math
import scipy
import numpy as np
from stl import mesh
import mahotas as mh
import scipy.spatial as kd
from skimage import measure

NEURON_ID = 3036
X_SHAPE=1024
Y_SHAPE=1024
Z_SHAPE = (0,75)
SPLINE_RESOLUTION = 1/16.
OUT_FOLDER = '/home/john/data/2017/winter/3dxp/3dxp_data/hohoho/'
DATA = '/home/d/data/ac3x75/mojo/ids/tiles/w=00000000/'

def threshold(arr, val):
    out = np.zeros((arr.shape[0], arr.shape[1]), dtype=np.bool)
    out[arr == val] = 1
    return out


thresholded_3d = np.zeros((Z_SHAPE[1], Y_SHAPE, X_SHAPE), dtype=np.bool)

for SLICE in range(Z_SHAPE[0], Z_SHAPE[1]):

    img = np.zeros((Y_SHAPE,X_SHAPE), dtype=np.uint64)
    tiles = sorted(os.listdir(os.path.join(DATA, 'z='+str(SLICE).zfill(8))))


    for t in tiles:

        if t.startswith('.'):
            continue

        filepath = os.path.join(DATA, 'z='+str(SLICE).zfill(8), t)
        y = int(t.split(',')[0].split('=')[1])
        x = int(t.split(',')[1].split('=')[1].split('.')[0])
        with h5py.File(filepath, 'r') as f:
            data = f.get('IdMap')
            img[y*512:y*512+512, x*512:x*512+512] = data

    # now threshold this bad boy
    thresholded_slice = threshold(img, NEURON_ID)
    thresholded_3d[SLICE] = thresholded_slice
    
upsampled = thresholded_3d.repeat(10, axis=0)
volume = upsampled.swapaxes(0,1)

In [108]:
class KidTree:
    def __init__(self,pairs):
        leafsize = 2*int(math.sqrt(len(pairs)))
        self.tree = kd.KDTree(pairs, leafsize)
    def nearest(self,now):
        pixel = self.tree.query(now,1)[1]
        return self.tree.data[pixel]

class Edger:
    def __init__(self, spots):

        # Generate edge_image output and edges input 
        self.edge_image = np.zeros(spots.shape,dtype=int)
        self.max_shape = np.array(self.edge_image.shape)-1
        self.edges = measure.find_contours(spots, 0)
        self.edges.sort(self.sortAll)

    def run(self, edgen, old_interp):
        y,x = zip(*edgen)
        # get the cumulative distance along the contour
        dist = np.sqrt((np.diff(x))**2 + (np.diff(y))**2).cumsum()[-1]
        # build a spline representation of the contour
        spline, u = scipy.interpolate.splprep([x, y])
        res =  int(SPLINE_RESOLUTION * dist)
        sampler = np.linspace(0, u[-1], res)

        # resample it at smaller distance intervals
        interp_x, interp_y = scipy.interpolate.splev(sampler, spline)
        iy,ix = [[int(math.floor(ii)) for ii in i] for i in [interp_x,interp_y]]
        interp = [np.clip(point,[0,0],self.max_shape) for point in zip(ix,iy)]

        for j in range(1, len(interp)):
            mh.polygon.line(interp[j-1], interp[j], self.edge_image)
            
        if len(old_interp):
            # Option 1
            polygo = old_interp[::-1]+interp
            mh.polygon.fill_polygon(polygo,self.edge_image)
                
            return interp
            # Option 2
            sides = [old_interp, interp]
            short_i = np.argmin([len(s) for s in sides])
            sl_sides = [sides[i] for i in [short_i,1-short_i]]
            branches = KidTree(sl_sides[1])
            
            canvas = np.zeros(self.edge_image.shape)
            old_line = []
            for px in sl_sides[0]:
                new_line = [px, branches.nearest(px)]
                if len(old_line):
                    polygo = old_line+new_line[::-1]
                    mh.polygon.fill_polygon(polygo,self.edge_image)
                old_line = list(np.copy(new_line))
        # Return spline for next slice
        return interp

    def sortAll(self,a,b):
        xylists = [zip(*a),zip(*b)]
        da,db = [np.array([max(v)-min(v) for v in l]) for l in xylists]
        return 2*int((da-db < 0).all())-1

    def runAll(self,old_interp):
        new_interp = self.run(self.edges[0], old_interp)
        return new_interp

class Mesher:
    old_interp = []
    def __init__(self,volume):
        self.volume = volume
        self.slice_run = range(self.volume.shape[0])
        self.edge_vol = np.zeros(volume.shape)        
        self.runAll()
    def run(self,k):
        edgy = Edger(self.volume[k])
        self.old_interp = edgy.runAll(self.old_interp)
        self.edge_vol[k] = edgy.edge_image
        print 'k ',k
    def runAll(self):
#         self.run(21)
#         self.run(22)
#         return 0
        for sli in self.slice_run:
            self.run(sli)

In [None]:
meshed = Mesher(volume).edge_vol

k  0
k  1
k  2
k  3
k  4
k  5
k  6
k  7
k  8
k  9
k  10
k  11
k  12
k  13
k  14
k  15
k  16
k  17
k  18
k  19
k  20
k  21
k  22
k  23
k  24
k  25
k  26
k  27
k  28
k  29
k  30
k  31


In [None]:
def store_mesh(arr, filename):

    verts, faces = measure.marching_cubes(arr, 0, spacing=(1.,1.,1.),gradient_direction='ascent')
    applied_verts = verts[faces]

    mesh_data = np.zeros(applied_verts.shape[0], dtype=mesh.Mesh.dtype)

    for i, v in enumerate(applied_verts):
        mesh_data[i][1][0] = v[0]
        mesh_data[i][1][1] = v[1]
        mesh_data[i][1][2] = v[2]

    m = mesh.Mesh(mesh_data)
    with open(filename, 'w') as f:
        m.save(filename, f)

    return m

m1 = store_mesh(meshed, OUT_FOLDER+str(NEURON_ID)+'_smooth.stl')