In [None]:
from tip_finding import tip_finding
from drive import segment_points
import numpy as np
import pandas as pd


In [None]:
root_id = 864691135552841025
endpoints = []
good_tips_thick, good_tips_thin, good_tips_bad_thick, good_tips_bad_thin, just_tips, just_means, skel_mp, mesh_obj  = tip_finding.endpoints_from_rid(root_id)
for i in good_tips_thick:
    endpoints.append(i)
for i in good_tips_thin:
    endpoints.append(i)

In [None]:
flat_regions = tip_finding.get_flat_regions(mesh_obj, minsize=10, n_iter=3)
just_means = flat_regions[-1]

In [None]:
#visualizing mesh
from meshparty import trimesh_io, trimesh_vtk, skeletonize, mesh_filters
syn_actor = trimesh_vtk.point_cloud_actor(just_tips, size=400, color=(0.2, 0.9, 0.9))
syn_actor2 = trimesh_vtk.point_cloud_actor(just_means, size=400, color=(0.9, 0.2, 0.9))
mesh_actor = trimesh_vtk.mesh_actor(mesh_obj, opacity=1, color=(0.7, 0.7, 0.7))
trimesh_vtk.render_actors([mesh_actor, syn_actor, syn_actor2])

In [None]:
import trimesh
import networkx as nx
from tqdm import tqdm
minsize = 10
mesh = mesh_obj
n_iter = 3

mesh_obj = trimesh.Trimesh(np.divide(mesh.vertices, np.array([1,1,1])), mesh.faces)
mesh_coords = mesh_obj.vertices[mesh_obj.faces]
normals = mesh_obj.face_normals

# diff = mesh_coords[:,:,2] - mesh_coords[:, 0, 2][:, None]

# one_pts = np.argwhere(np.sum(np.abs(diff),axis=1) <= 1)
# one_mask = np.full(mesh_obj.faces.shape[0], False)
# one_mask[one_pts] = True

# Make sure to fix normals in mesh
one_mask = np.abs(normals[:, 2]) > .95

adjacency_mask =  one_mask[mesh_obj.face_adjacency]
adjacency_mask_both = adjacency_mask[:,0]*adjacency_mask[:,1]
E = mesh_obj.face_adjacency[adjacency_mask_both]

face_areas = mesh_obj.area_faces
area_dict = {int(i): face_areas[i] for i in range(face_areas.shape[0])}
face_centers = np.mean(mesh_coords, axis=1)

loc_dict = {int(i): face_centers[i] for i in range(face_centers.shape[0])}

G = nx.from_edgelist(E)
nx.set_node_attributes(G, area_dict, name='area')
nx.set_node_attributes(G, loc_dict, name='mean_loc')

# degs = G.degree()
# to_remove = [n for n, d in degs if d != 3 and G.nodes()[n]['area'] < a_min]
# G.remove_nodes_from(to_remove)

graphs = [G.subgraph(c).copy() for c in nx.connected_components(G)]
sums = np.array([np.sum(list(nx.get_node_attributes(g, 'area').values())) for g in graphs])
locs = np.array([np.array(list(nx.get_node_attributes(g, 'mean_loc').values())) for g in graphs])
# face_angles = trimesh.geometry.vector_angle(normals[mesh_obj.face_adjacency])
# s_min = 100000
# s_max = 2000000
sums_mask = np.squeeze(np.argwhere(sums > minsize))

if len(sums_mask.shape) == 0:
    sums_mask = np.array([sums_mask])

# sums_mask = np.squeeze(np.argwhere((sums > s_min)))
# sums_mask.shape[0], len(graphs), sums.shape
n_aboves = np.zeros(sums_mask.shape[0])
n_belows = np.zeros(sums_mask.shape[0])
sums_mask = np.squeeze(np.argwhere(sums > minsize))
flat_mask = np.full(sums_mask.shape[0], True)

mean_locs = []
mean_locs_all = []
mean_locs_bad = []
inds_good = []
inds_bad = []

all_nodes = set()

for s in range(sums_mask.shape[0]):
    all_nodes.update(list(graphs[sums_mask[s]].nodes()))

# Precomputing the adjacency matrix saves a lot of time
for _ in range(n_iter):
    masked_ad = np.isin(mesh_obj.face_adjacency, list(all_nodes))
    masked_ad_rows = np.sum(masked_ad,axis=1)>0
    face_ad_sub = mesh_obj.face_adjacency[masked_ad_rows]
    face_ad_sub_flat = mesh_obj.face_adjacency[masked_ad_rows].flatten()
    all_nodes.update(face_ad_sub_flat)
    
for s in tqdm(range(sums_mask.shape[0])):
    curr_nodes = list(graphs[sums_mask[s]].nodes())

    original_nodes = list(graphs[sums_mask[s]].nodes())
    tris_list = []
    for _ in range(n_iter):
        new_tris, f = tip_finding.get_next(curr_nodes, face_ad_sub, face_ad_sub_flat)
        curr_nodes.extend(new_tris)
        tris_list.extend(new_tris)
    if len(tris_list)  == 0:
        continue
    new_tris =  np.array(tris_list)
    
    flat_mask_faces = np.abs(normals[original_nodes][:, 2]) == 1
    face_z = mesh_coords[original_nodes][flat_mask_faces]

    
    if face_z.shape[0] == 0:
        flat_mask[s] = False
        continue
    else:
        face_z = face_z[0, 0, 2]
    
    diffs = mesh_coords[list(new_tris)][:, :, 2] - face_z
    #could move further...change the "5"
    n_above = np.sum(diffs > 5)
    n_below = np.sum(diffs < -5)
    n_aboves[s] = n_above
    n_belows[s] = n_below
    direc = np.round(np.mean(normals[original_nodes][:, 2]))
    af = mesh_obj.area_faces[new_tris]
    af = af / np.sum(af)
    angs = np.arccos(np.dot(normals[new_tris], np.array([0, 0, direc])))
    mloc = np.mean(np.mean(mesh_coords[original_nodes][flat_mask_faces], axis=1),axis=0)
    #can change the zeros to params 
    threshold = 0
    if (n_above > threshold) and (n_below > threshold):
        flat_mask[s] = False
    elif (n_above > 0 and direc == 1) or (n_below > 0 and direc == -1):
        flat_mask[s] = False
    elif not f:
        flat_mask[s] = False
    elif np.sum(af*(angs < np.pi/4)) < .35:
        mean_locs_all.append(mloc)
        mean_locs.append(mloc)
        inds_good.append(s)
    else:
        mean_locs_all.append(mloc)
        mean_locs_bad.append(mloc)
        inds_bad.append(s)
#         print(s,np.mean(angs), np.divide(mloc, np.array([4,4,40])))

mean_locs = np.array(mean_locs)
mean_locs_bad = np.array(mean_locs_bad)


In [None]:
#print(graphs)
#print(locs)

l = [np.mean(x,axis = 0) for x in locs]
print(len(locs))
print(len(l))
print(len(graphs))





In [None]:
#visualizing mesh
from meshparty import trimesh_io, trimesh_vtk, skeletonize, mesh_filters
syn_actor = trimesh_vtk.point_cloud_actor(just_tips, size=400, color=(0.2, 0.9, 0.9))
syn_actor2 = trimesh_vtk.point_cloud_actor(l, size=400, color=(0.9, 0.2, 0.9))
mesh_actor = trimesh_vtk.mesh_actor(mesh_obj, opacity=1, color=(0.7, 0.7, 0.7))
trimesh_vtk.render_actors([mesh_actor, syn_actor, syn_actor2])

In [None]:
for endpoint in endpoints: 
    endpoint = np.array(endpoint)
    print(endpoint / [4,4,40])

    