In [None]:
import datajoint as dj
import numpy as np
import time
from tqdm import tqdm

import matplotlib.pyplot as plt
import ipyvolume.pylab as p3

from sklearn.cluster import KMeans
from sklearn.decomposition import PCA

In [None]:
ta3p100 = dj.create_virtual_module('ta3p100', 'microns_ta3p100')

In [None]:
fetched_mesh = ta3p100.Decimation35.fetch(limit=1, as_dict=True)[0]

# Skeleton

In [None]:
# Once I make several, small voxels it would be quite easy to turn that into a skeleton structure by just connecting voxels
# that are next to each other.

# I could use the centroids of the voxels as bone points, and also downsample the bone structure by checking to see how far
# a bone vector strays from the centroids of adjacent voxels, or rather the downstream voxels from a starting voxel.

# If I can scan across the whole mesh and start building regression lines... and correcting and branching off... hmm that might be possible.
# Basically I would start a bone at a vertex, and continue that bone as I scan through the mesh across one axis. If I find a vertex far enough
# away from my current bone then I either start a new bone OR I branch off from the current bone.
# I think that I will basically grant membership of vertices to each bone in order to create the skeleton structure. So I need to keep track
# of which vertices are assigned to each bone. Probably using a dict, and from there I can just append vertices to the bone index. This will
# also allow me to compare 2 bones (and their vertices) with each other by looking at 2 bone indices in the dict.
# I should also make a verification function to see how close the vertices are to their assigned bones.
# Depending on how I merge bones, I might get spines also defined as bones, which might be beneficial.
# Essentially I'm going through a decision tree:
# 1. If I find a vertex on its own, I create a bone.
# 2. If I find a vertex near another vertex, I keep the bone growing.
# 

class Skeleton:
    _bones = list()
    
    def __init__(self, vertices, triangles):
        self._original_vertices = vertices
        self._original_triangles = triangles
        
        self._vertices = vertices
        self._triangles = triangles
    
    @property
    def vertices(self):
        return self._vertices
    
    @property
    def triangles(self):
        return self._triangles
    
    @property
    def _nonisolated_vertices(self):
        unique_vertex_idx = np.unique(self.triangles)
        return self.vertices[unique_vertex_idx]
    
    @property
    def _sorted_vertices(self):
        """
        Returns sorted nonisolated vertices.
        """
        verts = self._nonisolated_vertices
        sorted_idx = verts.T[2].argsort()
        return verts[sorted_idx]
    
    @property
    def bbox(self):
        return np.array([(np.min(axis), np.max(axis)) for axis in self.vertices.T]) # Should I use nonisolated vertices though?
    
    def plot_mesh(self, width=1024, height=1024, **kwargs):
        p3.figure(width=width, height=height)
        p3.plot_trisurf(*self.vertices.T/1000, self.triangles, **kwargs)
        p3.squarelim()
        p3.show()
    
    def plot_verts(self, width=1024, height=1024, targeted_verts=None, **kwargs):
        if targeted_verts is None:
            verts = self.vertices
        else:
            verts = targeted_verts
        
        p3.figure(width=width, height=height)
        p3.scatter(*verts.T/1000, **kwargs)
        p3.squarelim()
        p3.show()
        
    def thick_plane(self, num_planes=50):
        z_space = np.linspace(*self.bbox[2], num=num_planes)

        starting_idx = 0
        partition_edge_idx = list([starting_idx])
        verts = self._sorted_vertices
        for j, z_edge in enumerate(z_space[1:-1]):
            for i, vert in enumerate(verts.T[2][starting_idx:]):
                if vert > z_edge:
                    starting_idx += i
                    partition_edge_idx.append(starting_idx)
                    break
        partition_edge_idx.append(-1) #(len(verts) - 1)
        
        return np.array(partition_edge_idx)
    
    @property
    def bones(self):
        return self._bones
    
    def add_bone(self, starting_point, ending_point):
        # The bones should probably be connected in some way? Though a branching bone might not want to be connected.
        self._bones.append((starting_point, ending_point))
        
    def scan(self, axis='x'):        
        if axis.lower() == 'x':
            axis_idx = 0
        elif axis.lower() == 'y':
            axis_idx = 1
        elif axis.lower() == 'z':
            axis_idx = 2
        else:
            raise ValueError("Invalid value for axis, choose between 'x', 'y', and 'z'.")
            
        verts = self._sorted_vertices
        
        bone_membership = dict()
        for vert in verts:
            

In [None]:
skel = Skeleton()

# Voxel

In [249]:
# I can do the full bbox for each plane, then do it across the other planes to reduce the size of the voxels and see what I get.
# I can also do it across multiple restrictions pretty damn quickly with the method I created to partition it out.
# Yeah this will be a very fast and efficient method without worrying about hyperparameterizing the clustering methods.

# Create a 3D grid and scan through each block essentially. Can do it based on planes first, and then further break down the bboxes
# generated by those planes.

# Yeah I can literally just get the bboxes for several rectangular structures findable by the grid partitioning.

# I can turn this into a skeleton by merging touching voxels, or by simple connecting the centroids of voxels by some rule (basically make a tree structure
# starting from one centroid).

# I can merge voxels by checking to see how much their adjacent faces are overlapping (and I should also ensure they don't get too big, by keeping a threshold
# either on the volume/size it can be, or by restricting it by how much empty space I'd be adding in).

class Mesh:
    def __init__(self, vertices, triangles, ignore_isolated_vertices=True):
        self._original_vertices = vertices
        self._original_triangles = triangles
        
        self._vertices = vertices
        self._triangles = triangles
        self._ignore_isolated = ignore_isolated_vertices
        
        self._voxels = list()
    
    class Voxel: # Override some operators to allow direct manipulation to the bbox inside.
        """
        If I really want Voxels that are cubes, I can push the boundaries of the Mesh bounding box to fit what I need (to allow the cubes to fit edge to edge).
        """
        def __init__(self, bbox):
            self._bbox = bbox
            
        @property
        def bbox(self):
            return self._bbox
        
        @bbox.setter
        def bbox(self, bbox):
            if not isinstance(bbox, np.ndarray):
                bbox = np.array(bbox)
            
            if bbox.shape == (3, 2):
                self._bbox = bbox
            else:
                raise ValueError("Bounding box is not in required form which is: array-like with shape (3, 2).")
            
        @property
        def centroid(self):
            return self.bbox.T.mean(axis=0)
        
        def __str__(self):
            return str(self.bbox)
    
    @property
    def voxels(self):
        return self._voxels
    
    @voxels.setter
    def voxels(self, voxels):
        self._voxels = voxels
    
    @property
    def vertices(self): # Could "potentially" have the ignore_isolated_vertices_check in here... but it would mess up indices, so put in
        # another property entirely.
        return self._vertices
    
    @property
    def triangles(self):
        return self._triangles
    
    @property
    def _nonisolated_vertices(self):
        unique_vertex_idx = np.unique(self.triangles)
        return self.vertices[unique_vertex_idx]
    
    @property
    def bbox(self):
        return np.array([(np.min(axis), np.max(axis)) for axis in self.vertices.T])
    
    @staticmethod
    def get_bbox(vertices):
        return np.array([(np.min(axis), np.max(axis)) for axis in vertices.T])
    
    def merge_touching_voxels(self):
        raise NotImplementedError
            
        # Someone said that looping over numpy arrays should be avoided... can this be done more efficiently in some way?
#         corner_coords = list()
#         for x in x_split:
#             for y in y_split:
#                 for z in z_split:
#                     corner_coords.append((x, y, z))
#         corner_coords = np.array(corner_coords)
    
    # Turn these splits into bboxes so I can turn them into Voxels
    # I need pairwise bboxes for all of these
    # I should hmm. I was going to say to compute the centroids for these "rectangles" and then add the radius, but they're rectangles.
    # First thing I need is all of the points at which will be the corners for the rectangles.
    # Or I can potentially look at the walls formed and check from there.
    def grid_split(self, num_splits_each_axis=10):
        """
        :param num_splits_each_axis: 'automated' just means it'll be computed based on the volume of the mesh bbox.
        For now it just defaults to 10 per axis (so 1000 initial bboxes).
        """
        x_split, y_split, z_split = [np.linspace(minimum, maximum, num=num_splits_each_axis) for minimum, maximum in self.bbox]
                
        bboxes = list()
        x_pairs = np.array([(x_split[i], x_split[i+1]) for i in range(len(x_split) - 1)])
        y_pairs = np.array([(y_split[i], y_split[i+1]) for i in range(len(y_split) - 1)])
        z_pairs = np.array([(z_split[i], z_split[i+1]) for i in range(len(z_split) - 1)])
        for xs in x_pairs:
            for ys in y_pairs:
                for zs in z_pairs:
                    bboxes.append((xs, ys, zs))
        bboxes = np.array(bboxes)
        
        return bboxes
    
    def initialize_voxels(self, num_splits_each_axis=10):
        bboxes = self.grid_split(num_splits_each_axis=num_splits_each_axis)
        self._voxels = [self.Voxel(bbox) for bbox in bboxes]
        return self.voxels
    
    def plot_mesh(self, width=1024, height=1024, **kwargs):
        p3.figure(width=width, height=height)
        p3.plot_trisurf(*self.vertices.T/1000, self.triangles, **kwargs)
        p3.squarelim()
        p3.show()
        
    def plot_voxels(self, width=1024, height=1024):
        fig = p3.figure(width=width, heigh=height)
        p3.scatter(*np.array([voxel.centroid for voxel in self.voxels]).T/1000, marker='sphere')
        p3.squarelim()
        p3.show()
        
#         return fig
        
    def restrict_bboxes(self, num_splits_each_axis=10):
        voxels = self.initialize_voxels(num_splits_each_axis=num_splits_each_axis)
        nonisolated_vertices = self._nonisolated_vertices
        t_nonisolated_vertices = nonisolated_vertices.T
#         x_sorted_idx = t_nonisolated_vertices[0].argsort()
#         y_sorted_idx = t_nonisolated_vertices[1].argsort()
#         z_sorted_idx = t_nonisolated_vertices[2].argsort()
        
#         x_verts = nonisolated_vertices[x_sorted_idx]
#         y_verts = nonisolated_vertices[y_sorted_idx]
#         z_verts = nonisolated_vertices[z_sorted_idx]
        
        x_axis, y_axis, z_axis = t_nonisolated_vertices
        
        # I can make this even faster still! Which is extremely necessary as the time complexity of the number of Voxels to make increases.
        # Less voxels might still be what I want though.
        
        masks = list()
        not_empty_voxels = list()
        for i, voxel in tqdm(enumerate(voxels)):
            # Now I need to use each voxel to retrieve the vertices within its bbox, then reduce the voxel's bbox to the bbox of the vertices present within it.
            # First retrive the vertices:
            (x_min, x_max), (y_min, y_max), (z_min, z_max) = voxel.bbox
            x_mask = np.all(((x_axis >= x_min), (x_axis <= x_max)), axis=0)
            y_mask = np.all(((y_axis >= y_min), (y_axis <= y_max)), axis=0)
            z_mask = np.all(((z_axis >= z_min), (z_axis <= z_max)), axis=0)
            mask = np.all((x_mask, y_mask, z_mask), axis=0)
            if mask.sum() > 0:
                # Next get the bbox of the vertices.
                relevant_verts = nonisolated_vertices[mask]
                new_bbox = self.get_bbox(relevant_verts)
                # Now insert the new bbox into the voxel object.
                voxel.bbox = new_bbox
                not_empty_voxels.append(i)
                
        self.voxels = np.array(voxels)[not_empty_voxels]
        
"""
    def thick_plane(self, num_planes=50):
        z_space = np.linspace(*self.bbox[2], num=num_planes)

        starting_idx = 0
        partition_edge_idx = list([starting_idx])
        verts = self._sorted_vertices
        for j, z_edge in enumerate(z_space[1:-1]):
            for i, vert in enumerate(verts.T[2][starting_idx:]):
                if vert > z_edge:
                    starting_idx += i
                    partition_edge_idx.append(starting_idx)
                    break
        partition_edge_idx.append(-1) #(len(verts) - 1)
        
        return np.array(partition_edge_idx)
"""
pass

'\n    def thick_plane(self, num_planes=50):\n        z_space = np.linspace(*self.bbox[2], num=num_planes)\n\n        starting_idx = 0\n        partition_edge_idx = list([starting_idx])\n        verts = self._sorted_vertices\n        for j, z_edge in enumerate(z_space[1:-1]):\n            for i, vert in enumerate(verts.T[2][starting_idx:]):\n                if vert > z_edge:\n                    starting_idx += i\n                    partition_edge_idx.append(starting_idx)\n                    break\n        partition_edge_idx.append(-1) #(len(verts) - 1)\n        \n        return np.array(partition_edge_idx)\n'

In [250]:
mesh = Mesh(fetched_mesh['vertices'], fetched_mesh['triangles'])

In [256]:
start = time.time()
mesh.restrict_bboxes(50)
time.time() - start

117649it [05:06, 384.37it/s]


306.6842346191406

In [257]:
mesh.plot_voxels()

VBox(children=(Figure(camera=PerspectiveCamera(fov=46.0, position=(0.0, 0.0, 2.0), quaternion=(0.0, 0.0, 0.0, …

In [258]:
mesh.plot_mesh()

VBox(children=(Figure(camera=PerspectiveCamera(fov=46.0, position=(0.0, 0.0, 2.0), quaternion=(0.0, 0.0, 0.0, …

In [246]:
# What I really want to be able to do, is to shrink the bbox of the Mesh down to fit the (nonisolated) vertices.
# From there it should split the resulting complex shape into cubes..

# So to start, I need to get bboxes on a chosen plane at each layer (like the whole bbox).
# So do that several times for, say, the xy plane at each z block.
# Then I can start looking at each "slice" and get inner bounding boxes.

# Due to the odd shape of some of these neurons and their parts, I might really just want to use rectangular shapes that can be rotated...
# but at that point wouldn't it make more sense to create skeletons? If I figure out how to do that it would make my life (and getting
# accurate minimum distance measures) much easier. Though it would be difficult to extract the info needed for those skeletons.

# I can do the full bbox for each plane, then do it across the other planes to reduce the size of the voxels and see what I get.
# I can also do it across multiple restrictions pretty damn quickly with the method I created to partition it out.
# Yeah this will be a very fast and efficient method without worrying about hyperparameterizing the clustering methods.

# Create a 3D grid and scan through each block essentially. Can do it based on planes first, and then further break down the bboxes
# generated by those planes.

# Yeah I can literally just get the bboxes for several rectangular structures findable by the grid partitioning.

class Mesh:
    def __init__(self, vertices, triangles):
        self._original_vertices = vertices
        self._original_triangles = triangles
        
        self._vertices = vertices
        self._triangles = triangles
    
    @staticmethod
    def get_bbox(vertices):
        np.array([(np.min(axis), np.max(axis)) for axis in vertices.T])
    
    def grid_split(self, num_splits_each_axis=50):
        splits = [np.linspace(minimum, maximum, num=num_splits_each_axis) for minimum, maximum in self.bbox]
        return splits        
        
    def merge_small_voxels(self):
        pass
    
    class Voxel:
        def __init__(self, centroid, radius):
            self._x, self._y, self._z = centroid
            self._radius = radius
            
        @property
        def centroid(self):
            return np.array((self._x, self._y, self._z))
        
        @property
        def radius(self):
            return self._radius
        
        @radius.setter
        def radius(self, radius):
            self._radius = radius
            
        @property
        def bbox(self):
            # Just add and subtract the radius from the centroid vector
            return np.vstack((self.centroid - self.radius, self.centroid + self.radius)).T
    
    @property
    def vertices(self):
        return self._vertices
    
    @property
    def triangles(self):
        return self._triangles
    
    @property
    def _nonisolated_vertices(self):
        unique_vertex_idx = np.unique(self.triangles)
        return self.vertices[unique_vertex_idx]
    
    @property
    def _sorted_vertices(self):
        """
        Returns sorted nonisolated vertices. This is best used when only dealing with vertices and not triangles too.
        """
        raise NotImplementedError
        
        verts = self._nonisolated_vertices
        sorted_idx = verts.T[2].argsort()
        return verts[sorted_idx]
    
    @property
    def bbox(self):
        return np.array([(np.min(axis), np.max(axis)) for axis in self.vertices.T]) # Should I use nonisolated vertices though?
    
    def plot_mesh(self, width=1024, height=1024, **kwargs):
        p3.figure(width=width, height=height)
        p3.plot_trisurf(*self.vertices.T/1000, self.triangles, **kwargs)
        p3.squarelim()
        p3.show()
    
    def plot_verts(self, width=1024, height=1024, targeted_verts=None, **kwargs):
        if targeted_verts is None:
            verts = self.vertices
        else:
            verts = targeted_verts
        
        p3.figure(width=width, height=height)
        p3.scatter(*verts.T/1000, **kwargs)
        p3.squarelim()
        p3.show()
        
    def get_adjacent_voxels(self):
        
        # I just need to look at how the Voxel bboxes touch other Voxel bboxes
        
        raise NotImplementedError
    
    # Should I make this a lazy generator?
    def partition(self, axis, bbox_view=None, num_partitions=50):
        """
        Splits up the vertices along an axis.
        :param axis: The axis along which to partition the vertices.
        :param view: BBox view which will be used to look at a specific portion of the mesh. Leave as None to partition the entire mesh.
        :param num_partitions: Number of partitions/slices to be returned.
        """
        
        if bbox_view is None:
            axis_space = np.linspace(self.bbox[axis], num=num_planes)
            starting_idx = 0
            partition_edge_idx = list([starting_idx])
        else:
            raise NotImplementedError("bbox_view hasn't been implemented yet, bitch")
            axis_space = np.linspace(bbox_view[axis], num=num_planes)
            starting_idx = 0 # Would have to figure out if I want to keep indices based on the total _sorted_vertices or...
            partition_edge_idx = list([starting_idx])
            
        nonisolated_verts = self._nonisolated_vertices
        sorted_idx = nonisolated_verts.T[axis].argsort()
        axis_verts = nonisolated_verts[sorted_idx].T[axis]
        
        
        
    def thick_plane(self, num_planes=50):
        z_space = np.linspace(*self.bbox[2], num=num_planes)

        starting_idx = 0
        partition_edge_idx = list([starting_idx])
        verts = self._sorted_vertices
        for j, z_edge in enumerate(z_space[1:-1]):
            for i, vert in enumerate(verts.T[2][starting_idx:]):
                if vert > z_edge:
                    starting_idx += i
                    partition_edge_idx.append(starting_idx)
                    break
        partition_edge_idx.append(-1) #(len(verts) - 1)
        
        return np.array(partition_edge_idx)
    
    # Should still go through with this approach regardless.
    # Might either want to look into density or hierarchical clustering methods. But try to do a simple (note: efficient) version.
    # Could be accomplished by choosing a voxel center, then expanding outwards until there is a density drought. Then choose another voxel
    # center nearby.
    # Also the center of these Voxel bboxes should be between the 2 edges being looked at.
    def get_bbox_of_plane(self, num_planes=50):
        part_idx = self.thick_plane(num_planes)
        verts = self._sorted_vertices
        
        for i in range(len(part_idx) - 1):
            start = time.time()
            lower_idx, upper_idx = part_idx[i], part_idx[i+1]
            targets = verts[lower_idx:upper_idx]
            
            plt.scatter(*targets.mean(axis=0)[0:2])
            
#             find_n_clusters = PCA()
#             find_n_clusters.fit(targets)
#             return find_n_clusters.components_
        
#             clusterer = KMeans(n_clusters=4)
#             clusterer.fit(targets)
            
#             print(time.time() - start)
#             plt.scatter(*targets.T[0:2], c=clusterer.labels_, cmap="Set1")
#             plt.hist(verts[lower_idx:upper_idx].T[0], bins=50, color=(i/50, i/100, i/50, 0.9))
#             plt.show()
#             print(time.time() - start)
            
#         self.plot_verts(targeted_verts=verts, size=0.1, marker='sphere')
        
#         plt.hist(verts.T[0])
#         plt.show()
#         plt.hist(verts.T[1])
#         plt.show()
#         plt.hist(verts.T[2])
#         plt.show()

In [247]:
mesh = Mesh(fetched_mesh['vertices'], fetched_mesh['triangles'])

In [248]:
mesh.get_bbox_of_plane()

NotImplementedError: 

In [None]:
mesh.get_bbox_of_plane()

In [None]:
mesh.plot_ipv()

In [None]:
part_idx = mesh.thick_plane()

In [None]:
part_idx

In [None]:
mesh._sorted_vertices[part_idx[-1]], mesh.bbox[-1][-1]

In [None]:
mesh._sorted_vertices[-1]

In [None]:
voxel = mesh.Voxel((1, 3, 5), 10)

In [None]:
print(voxel.centroid)
print(voxel.radius)
print(voxel.bbox)

In [None]:
(ta3p100.AllenSomaClass() & 'cell_class="glia"') - ta3p100.SegmentExclude

In [None]:
ta3p100.AllenSomaClass()