In [None]:
from brainlit.utils.Neuron_trace import NeuronTrace
from brainlit.viz.visualize import napari_viewer
import numpy as np
from skimage import io
from scipy.ndimage.morphology import distance_transform_edt
from pathlib import Path
from brainlit.algorithms.image_processing import Bresenham3D
from brainlit.utils.benchmarking_params import (
    brain_offsets,
    vol_offsets,
    scales,
    type_to_date,
)

In [None]:
# loading all the benchmarking images from local paths
# all the paths of gfp images are saved in variable gfp_files
# the folder of output masks is in the same folder where the folder of benchmarking data is
base_dir = Path("D:/Study/Nuero Data Design/brainlit")
data_dir = base_dir / "benchmarking_datasets"
im_dir = data_dir / "Images"
mask_dir = base_dir / "benchmarking_masks"
gfp_files = list(im_dir.glob("**/*.tif"))
swc_base_path = data_dir / "Manual-GT"
save = True

for im_num, im_path in enumerate(gfp_files):
    # loading one gfp image
    print(str(im_path))
    im = io.imread(im_path, plugin="tifffile")
    im = np.swapaxes(im, 0, 2)

    file_name = im_path.parts[-1][:-8]

    f = im_path.parts[-1][:-8].split("_")
    image = f[0]
    date = type_to_date[image]
    num = int(f[1])

    scale = scales[date]
    brain_offset = brain_offsets[date]
    vol_offset = vol_offsets[date][num]
    im_offset = np.add(brain_offset, vol_offset)

    # loading all the .swc files corresponding to the image
    # all the paths of .swc files are saved in variable swc_files
    lower = int(np.floor((num - 1) / 5) * 5 + 1)
    upper = int(np.floor((num - 1) / 5) * 5 + 5)
    dir1 = date + "_" + image + "_" + str(lower) + "-" + str(upper)
    dir2 = date + "_" + image + "_" + str(num)
    swc_path = swc_base_path / dir1 / dir2
    swc_files = list(swc_path.glob("**/*.swc"))

    paths_total = []
    labels_total = np.zeros(im.shape)

    # generate paths and save them into paths_total
    for swc_num, swc in enumerate(swc_files):
        if "cube" in swc.parts[-1]:
            # skip the bounding box swc
            continue
        print(swc)

        swc_trace = NeuronTrace(path=swc)
        paths = swc_trace.get_paths()
        swc_offset, _, _, _ = swc_trace.get_df_arguments()
        offset_diff = np.subtract(swc_offset, im_offset)

        # df, swc_offset, _, _, _ = read_swc(swc)
        # offset_diff = np.subtract(swc_offset, im_offset)
        # G = df_to_graph(df)
        # paths = graph_to_paths(G)

        # for every path in that swc
        for path_num, p in enumerate(paths):
            pvox = (p + offset_diff) / (scale) * 1000
            paths_total.append(pvox)

    # generate labels by using paths
    for path_voxel in paths_total:
        for voxel_num, voxel in enumerate(path_voxel):
            if voxel_num == 0:
                continue
            voxel_prev = path_voxel[voxel_num - 1, :]
            xs, ys, zs = Bresenham3D(
                int(voxel_prev[0]),
                int(voxel_prev[1]),
                int(voxel_prev[2]),
                int(voxel[0]),
                int(voxel[1]),
                int(voxel[2]),
            )
            for x, y, z in zip(xs, ys, zs):
                vox = np.array((x, y, z))
                if (vox >= 0).all() and (vox < im.shape).all():
                    labels_total[x, y, z] = 1

    label_flipped = labels_total * 0
    label_flipped[labels_total == 0] = 1
    dists = distance_transform_edt(label_flipped, sampling=scale)
    labels_total[dists <= 1000] = 1

    if save:
        im_file_name = file_name + "_mask.tif"
        out_file = mask_dir / im_file_name
        io.imsave(out_file, labels_total, plugin="tifffile")

In [None]:
# checking whether masks can be loaded
show_napari = False
mask_files = list(mask_dir.glob("**/*.tif"))

for im_num, im_path in enumerate(gfp_files):
    im = io.imread(im_path, plugin="tifffile")
    im = np.swapaxes(im, 0, 2)

    file_name = im_path.parts[-1][:-8]
    mask_file = file_name + "_mask.tif"
    mask_path = mask_dir / mask_file
    mask = io.imread(mask_path, plugin="tifffile")

    print("loading the mask of", file_name, "...")
    if show_napari:
        napari_viewer(im, labels=mask, label_name="mask")