Just trying stuff out
====

In [None]:
"""
Read in some WT jaws and make some greyscale images but only containing the voxels from the mask
"""

import tifffile
from tqdm.notebook import tqdm

from fishjaw.inference import read
from fishjaw.util import files

output_dir = files.script_out_dir() / "jaw_segmentations"

img_dir = output_dir / "imgs"
mask_dir = output_dir / "masks"

masked_out_dir = output_dir / "masked"
masked_out_dir.mkdir(exist_ok=True)

metadata_df = read.mastersheet()

wt_mdata_df = metadata_df[metadata_df["name"].str.contains("wt")]

imgs = []
masks = []
masked = []
img_ns = []

for n in tqdm(wt_mdata_df.index):
    try:
        new_img = tifffile.imread(img_dir / f"ak_{n}.tif")
        new_mask = tifffile.imread(mask_dir / f"ak_{n}.tif")
        new_masked = new_mask * new_img
    except FileNotFoundError:
        print(f"{n} doesn't exist")

    tifffile.imwrite(masked_out_dir / f"ak_{n}.tif", new_masked)

    imgs.append(new_img)
    masks.append(new_mask)
    masked.append(new_masked)
    img_ns.append(n)

In [None]:
"""
Plot some of them
"""

import numpy as np
import matplotlib.pyplot as plt

# hacky
from scripts.pipeline.plot_3d import _calculate_point_size


def plot_projections(imgs, masks):
    fig, axes = plt.subplots(3, 6, figsize=(12, 6), subplot_kw={"projection": "3d"})

    plot_kw = {
        "marker": "s",
        "cmap": "inferno",
        "vmin": 0,
        "vmax": 2**16,
        "s": _calculate_point_size(axes[0, 0], imgs[0].shape),
    }

    for img, mask, axis, n in zip(imgs, masks, tqdm(axes.flat), img_ns, strict=False):
        co_ords = np.argwhere(mask)
        greyscale_vals = img[co_ords[:, 0], co_ords[:, 1], co_ords[:, 2]]

        im = axis.scatter(
            co_ords[:, 0], co_ords[:, 1], co_ords[:, 2], c=greyscale_vals, **plot_kw
        )

        axis.view_init(elev=180, azim=30)
        axis.set_axis_off()
        axis.set_title(n)

    fig.colorbar(im)
    fig.tight_layout()


plot_projections(imgs, masks)

In [None]:
"""
Uniform filter the mask
"""

from scipy.ndimage import uniform_filter


def masked_mean_filter(
    img: np.ndarray, mask: np.ndarray, filter_size: int = 5
) -> np.ndarray:
    """
    Mean-filter an image; the denominator is the number of mask pixels in the mean region
    """
    if not (img[~mask] == 0).all():
        raise ValueError("Img pixels outside the mask must be set to 0")

    sum = uniform_filter(img, size=filter_size)
    count = uniform_filter(mask.astype(np.float32), size=filter_size)
    # return sum, count

    # Avoid div by 0
    retval = np.divide(sum, count, where=count > 0)
    return retval * mask


filtered = [masked_mean_filter(i, m) for i, m in zip(tqdm(masked), masks)]

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(9, 3))

n = 100
for axis, name in zip(axes, ["masks", "masked", "filtered"]):
    axis.imshow(locals()[name][0][n], interpolation="none")
    axis.set_title(name)
    axis.set_axis_off()

fig.tight_layout()