Just trying stuff out
====

In [None]:
"""
Read in some WT jaws
"""

import tifffile
from tqdm.notebook import tqdm

from fishjaw.inference import read
from fishjaw.util import files

output_dir = files.script_out_dir() / "jaw_segmentations"

img_dir = output_dir / "imgs"
mask_dir = output_dir / "masks"

metadata_df = read.mastersheet()

wt_mdata_df = metadata_df[metadata_df["name"].str.contains("wt")]

imgs = []
masks = []
img_ns = []

for n in tqdm(wt_mdata_df.index):
    try:
        new_img = tifffile.imread(img_dir / f"ak_{n}.tif")
        new_mask = tifffile.imread(mask_dir / f"ak_{n}.tif")
    except FileNotFoundError:
        print(f"{n} doesn't exist")

    imgs.append(new_img)
    masks.append(new_mask)
    img_ns.append(n)

In [None]:
"""
Plot some of them
"""

import numpy as np
import matplotlib.pyplot as plt

# hacky
from scripts.pipeline.plot_3d import _calculate_point_size

fig, axes = plt.subplots(3, 6, figsize=(12, 6), subplot_kw={"projection": "3d"})

plot_kw = {
    "marker": "s",
    "cmap": "inferno",
    "vmin": 0,
    "vmax": 2**16,
    "s": _calculate_point_size(axes[0, 0], imgs[0].shape),
}

for img, mask, axis, n in zip(imgs, masks, tqdm(axes.flat), img_ns, strict=False):
    co_ords = np.argwhere(mask)
    greyscale_vals = img[co_ords[:, 0], co_ords[:, 1], co_ords[:, 2]]

    im = axis.scatter(
        co_ords[:, 0], co_ords[:, 1], co_ords[:, 2], c=greyscale_vals, **plot_kw
    )

    axis.view_init(elev=180, azim=30)
    axis.set_axis_off()
    axis.set_title(n)

fig.colorbar(im)
fig.tight_layout()

In [None]:
"""
Use a kernel to average/smooth it
"""

from scipy.ndimage import convolve
from skimage.morphology import ball

kernel_size = 7
kernel = ball(kernel_size)

n_plots = len(axes.flat)

smoothed = []
for img, mask, n in zip(
    tqdm(imgs[:n_plots]), masks[:n_plots], img_ns[:n_plots], strict=True
):
    co_ords = np.argwhere(mask)
    zmin, ymin, xmin = co_ords.min(axis=0)
    zmax, ymax, xmax = co_ords.max(axis=0)

    masked = img[zmin:zmax, ymin:ymax, xmin:xmax].copy()
    masked[~mask[zmin:zmax, ymin:ymax, xmin:xmax]] = 0

    out = np.zeros_like(img, dtype=np.uint64)
    out[zmin:zmax, ymin:ymax, xmin:xmax] = convolve(masked, kernel, mode="constant", cval=0)

    smoothed.append(out)
    break

In [None]:
fig, axes = plt.subplots(1, 2, subplot_kw={"projection": "3d"}, figsize=(16, 8))

for img, mask, axis in zip(
    [smoothed[0], imgs[0]], masks, tqdm(axes.flat), strict=False
):
    co_ords = np.argwhere(mask)
    greyscale_vals = img[co_ords[:, 0], co_ords[:, 1], co_ords[:, 2]]

    im = axis.scatter(
        co_ords[:, 0], co_ords[:, 1], co_ords[:, 2], c=greyscale_vals, **plot_kw
    )

    axis.view_init(elev=180, azim=30)
    axis.set_axis_off()
    axis.set_title(n)

fig.colorbar(im)