In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage as ndi
import open3d as o3d
import PIL
from PIL import ImageFilter
import cv2
import h5py

from skimage.segmentation import watershed
from skimage.feature import peak_local_max
import skimage.filters as skfilters

import perlin_numpy as perlin
import neuroglancer

import os 
import sys
import itertools
import importlib
lib_dir = os.path.dirname(os.path.realpath('.'))
print(lib_dir)
if lib_dir not in sys.path:
  sys.path.append(lib_dir)

In [None]:
def reload():
    import image_synthesis.utils as utils
    import image_synthesis.point_generation as pg
    importlib.reload(utils)
    importlib.reload(pg)

In [None]:
import image_synthesis.utils as utils
import image_synthesis.point_generation as pg
importlib.reload(utils)
importlib.reload(pg)

In [None]:
def remove_keymap_conflicts(new_keys_set):
    for prop in plt.rcParams:
        if prop.startswith('keymap.'):
            keys = plt.rcParams[prop]
            remove_list = set(keys) & new_keys_set
            for key in remove_list:
                keys.remove(key)
remove_keymap_conflicts({'j', 'k'})

# Watershed for 2d data

In [None]:
reload()
limits = (200, 200, 100)
radius_limit = 20
x, y = np.indices((limits[0], limits[1]))
point_count = None
noise_resolution = np.array([0.25, 0.25, 0.25]) # factor for multiplying the limit to get the resolution
assert (limits % (limits * noise_resolution).astype(int) == 0).all()
noise_intensity = 0.5

min_dist = 5
max_dist = 10
min_child_count = 3
max_child_count = 6
angle_noise = 0.01

# points = pg.generate_cell_centers(np.array([[0, limits[0]], [0, limit[1]]]), min_dist, max_dist, point_count)[:, :-1]
points = pg.generate_cell_centers(np.array([[0, limits[0]], [0, limits[1]]]), min_dist, max_dist, min_child_count, max_child_count, angle_noise)[:, :-1]
print(f'generated {points.shape[0]} points')

## Generating the starting points

In [None]:
%matplotlib qt

In [None]:
# points = pg.generate_cell_centers(np.array([[0, limits[0]], [0, limits[1]]]), min_dist, max_dist, min_child_count, max_child_count, angle_noise)[:, :-1]
points = pg.generate_cell_centers(np.array([[0, limits[0]], [0, limits[1]]]), 3, 30, 1, 2, angle_noise)[:, :-1]
print(f'generated {points.shape[0]} points')
point_count = points.shape[0]

points = points.astype(int)
vis = np.zeros(limits[:2])
vis[points[:, 0], points[:, 1]] = 1

fig, ax = plt.subplots()
ax.imshow(vis, cmap=plt.cm.gray)
ax.set_title('Starting points')

# ax.set_axis_off()
plt.show()

## Perform watershed

In [None]:
noise_intensity = 1

black_centers = points.astype(int)
image = np.ones(limits[:2], dtype=bool)
image[black_centers[:, 0], black_centers[:, 1]] = 0

# Now we want to separate the two objects in image
# Generate the markers as local maxima of the distance to the background
distance = ndi.distance_transform_edt(image)
noise = perlin.generate_perlin_noise_2d(limits[:2], (np.array(limits) * noise_resolution).astype(int))
distance += noise * noise_intensity

coords = peak_local_max(-distance, min_distance=min_dist, exclude_border=False)
mask = np.zeros(distance.shape, dtype=bool)
mask[tuple(coords.T)] = True
markers, _ = ndi.label(mask)
labels = watershed(distance, markers, watershed_line=True)

# fig, ax = plt.subplots()
# ax.imshow(labels, cmap=plt.cm.nipy_spectral)
# ax.set_title('Starting points')
# plt.show()

fig, axes = plt.subplots(ncols=4, figsize=(9, 4), sharex=True, sharey=True)
ax = axes.ravel()

ax[0].imshow(image, cmap=plt.cm.gray)
ax[0].set_title('Starting points')
ax[1].imshow(distance, cmap=plt.cm.gray)
ax[1].set_title('Distances with noise')
ax[2].imshow(mask, cmap=plt.cm.gray)
ax[2].set_title('Basin starts')
ax[3].imshow(labels, cmap=plt.cm.nipy_spectral)
ax[3].set_title('Separated objects')

for a in ax:
    a.set_axis_off()

fig.tight_layout()
plt.show()

# Watershed for 3d data

In [None]:
limits = (200, 200, 100)
noise_resolution_factor = 20
noise_resolution = (np.array(limits) / noise_resolution_factor).astype(int)
noise_intensity = 10
assert ((limits % noise_resolution).astype(int) == 0).all()

membrane_threshold = 2

In [None]:
np.random.seed(0)
points = pg.generate_3d_centers(
    limits=np.array([[0, limits[0]], [0, limits[1]], [0, limits[2]]]),
    min_dist=25,
    max_dist=60,
    min_child_count=3,
    max_child_count=5,
    angle_noise=angle_noise,
    plane_distance=50,
    max_offset_from_plane=20,
    first_plane_offset=20,
    max_center_offset=[10, 10],
)
print(f'generated {points.shape[0]} many points')

In [None]:
black_centers = points.astype(int)
image = np.ones(limits, dtype=bool)
image[black_centers[:, 0], black_centers[:, 1], black_centers[:, 2]] = 0
distance = ndi.distance_transform_edt(image)

In [None]:
noise = perlin.generate_perlin_noise_3d(limits, noise_resolution) * noise_intensity
noisy_distance = distance + noise

In [None]:
coords = peak_local_max(-noisy_distance, min_distance=min_dist, exclude_border=False)
mask = np.zeros(noisy_distance.shape, dtype=bool)
mask[tuple(coords.T)] = True
markers, _ = ndi.label(mask)
labels = watershed(noisy_distance, markers, watershed_line=False)
labels_with_lines = watershed(noisy_distance, markers, watershed_line=True)
cell_image = np.array([0, 255]).astype(np.uint8)[(labels_with_lines != 0).astype(int)]

In [None]:
# make membrane thicker
post_distance = ndi.distance_transform_edt(cell_image)
post_distance = skfilters.gaussian(post_distance, sigma=2)

cell_image[post_distance <= membrane_threshold] = 0

In [None]:
fig, axes = plt.subplots(ncols=3, figsize=(19, 10), sharex=True, sharey=True)
ax = axes.ravel()

slice_dim = 0
slice_axis = 2
mask_image = cell_image

max_slice_dim = labels.shape[slice_axis] - 1
axis_string = ['x', 'y', 'z'][slice_axis]

def process_key(event):
    global slice_dim
    if event.key == 'j':
        slice_dim = max(0, slice_dim - 1)
        ax[0].images[0].set_array(noise[index()])
        ax[1].images[0].set_array(mask_image[index()])
        ax[2].images[0].set_array(labels[index()])
    elif event.key == 'k':
        slice_dim = min(max_slice_dim, slice_dim + 1)
        ax[0].images[0].set_array(noise[index()])
        ax[1].images[0].set_array(mask_image[index()])
        ax[2].images[0].set_array(labels[index()])
    ax[2].set_title(f'Label ({axis_string} = {slice_dim})')
    fig.canvas.draw()

def index():
    i = [slice(None)] * 3
    i[slice_axis] = slice_dim
    return tuple(i)

ax[0].imshow(noise[index()], cmap=plt.cm.gray)
ax[0].set_title('noise')
ax[1].imshow(mask_image[index()], cmap=plt.cm.gray)
ax[1].set_title('Separated objects')
ax[2].imshow(labels[index()], cmap=plt.cm.nipy_spectral)
ax[2].set_title(f'Label ({axis_string} = {slice_dim})')

for a in ax:
    a.set_axis_off()

fig.canvas.mpl_connect('key_press_event', process_key)
fig.tight_layout()
plt.show()

#### Plotting the different parts

In [None]:
# the points
pcd_points = o3d.utility.Vector3dVector(points)
pcd = o3d.geometry.PointCloud(pcd_points)
pcd.paint_uniform_color([0, 0, 0])
o3d.visualization.draw_geometries([pcd])

In [None]:
# the noise
noise = perlin.generate_perlin_noise_3d((100, 100, 5),noise_resolution)
plt.imshow(noise[:, :, 0], cmap=plt.cm.gray)