Just trying stuff out
====

In [None]:
"""
Read in some WT jaws and make some greyscale images but only containing the voxels from the mask
"""

import tifffile
from tqdm.notebook import tqdm

from fishjaw.inference import read
from fishjaw.util import files

output_dir = files.script_out_dir() / "jaw_segmentations"

img_dir = output_dir / "imgs"
mask_dir = output_dir / "masks"

masked_out_dir = output_dir / "masked"
masked_out_dir.mkdir(exist_ok=True)

metadata_df = read.mastersheet()

wt_mdata_df = metadata_df[metadata_df["name"].str.contains("wt")]

imgs = []
masks = []
masked = []
img_ns = []

for n in tqdm(wt_mdata_df.index):
    try:
        new_img = tifffile.imread(img_dir / f"ak_{n}.tif")
        new_mask = tifffile.imread(mask_dir / f"ak_{n}.tif")
        new_masked = new_mask * new_img
    except FileNotFoundError:
        print(f"{n} doesn't exist")

    tifffile.imwrite(masked_out_dir / f"ak_{n}.tif", new_masked)

    imgs.append(new_img)
    masks.append(new_mask)
    masked.append(new_masked)
    img_ns.append(n)

In [None]:
"""
Plot some of them
"""

import numpy as np
import matplotlib.pyplot as plt

# hacky
from scripts.pipeline.plot_3d import _calculate_point_size


def plot_projections(imgs, masks):
    fig, axes = plt.subplots(5, 8, figsize=(24, 15), subplot_kw={"projection": "3d"})

    plot_kw = {
        "marker": "s",
        "cmap": "inferno",
        "vmin": 0,
        "vmax": 2**16,
        "s": _calculate_point_size(axes[0, 0], imgs[0].shape),
    }

    for img, mask, axis, n in zip(imgs, masks, tqdm(axes.flat), img_ns, strict=False):
        co_ords = np.argwhere(mask)
        greyscale_vals = img[co_ords[:, 0], co_ords[:, 1], co_ords[:, 2]]

        im = axis.scatter(
            co_ords[:, 0], co_ords[:, 1], co_ords[:, 2], c=greyscale_vals, **plot_kw
        )

        axis.view_init(elev=180, azim=30)
        axis.set_axis_off()
        axis.set_title(n)

    fig.colorbar(im)
    fig.tight_layout()


plot_projections(imgs, masks)

In [None]:
"""
Exclude segmentations that are broken, also quite a few of them look the same for some reason
"""

broken_n = [238, 239, 245, 293, 417, 416, 415, 414, 413, 412, 346, 345, 344, 343, 342, 341, 340]
keep = ~np.isin(img_ns, broken_n)
print(f"{np.sum(~keep)} broken segmentations")

imgs = np.array(imgs)[keep]
masks = np.array(masks)[keep]
masked = np.array(masked)[keep]
img_ns = np.array(img_ns)[keep]

In [None]:
"""
Uniform filter the mask
"""

from scipy.ndimage import uniform_filter


def masked_mean_filter(
    img: np.ndarray, mask: np.ndarray, filter_size: int = 5
) -> np.ndarray:
    """
    Mean-filter an image; the denominator is the number of mask pixels in the mean region
    """
    if not (img[~mask] == 0).all():
        raise ValueError("Img pixels outside the mask must be set to 0")

    sum = uniform_filter(img, size=filter_size)
    count = uniform_filter(mask.astype(np.float32), size=filter_size)
    # return sum, count

    # Avoid div by 0
    retval = np.divide(sum, count, where=count > 0)
    return retval * mask


def get_max_loc(img):
    return np.unravel_index(np.argmax(img, axis=None), img.shape)


filtered = [masked_mean_filter(i, m) for i, m in zip(tqdm(masked), masks)]
max_locations = [get_max_loc(i) for i in filtered]

In [None]:
"""
Plot the mask, original, smoothed + location of maximum
"""
n_rows = 10
fig, axes = plt.subplots(n_rows, 3, figsize=(9, 30))

for i in range(n_rows):
    z, x, y = max_locations[i]
    for axis, name in zip(axes[i], ["masks", "masked", "filtered"]):
        axis.imshow(locals()[name][i][z], interpolation="none", cmap="grey")
        axis.set_axis_off()

        if not i:
            axis.set_title(name)

        axis.scatter(y, x, marker="x", color="r")

fig.tight_layout()
plt.close(fig)

In [None]:
"""
For all the jaws plot the location of the maximum as well as its greyscale val in the smoothed hist
"""

fig, axes = plt.subplots(3, 7, figsize=(21, 15))

bins = np.linspace(1, 60000, 50)
for max_loc, img, axs, n in zip(
    max_locations, filtered, axes.flat, img_ns, strict=False
):
    z, x, y = max_loc

    axs.imshow(img[z], cmap="gist_ncar")
    axs.set_axis_off()
    axs.scatter(y, x, marker=".", color="k", s=10)

    axs.set_title(n)

fig.tight_layout()

In [None]:
"""
Exclude a radius around the muscle attachment, find the maximum again and plot
"""


def remove_ball(img: np.ndarray, centre: np.ndarray, radius: int):
    z, y, x = np.ogrid[: img.shape[0], : img.shape[1], : img.shape[2]]

    mask = (
        (z - centre[0]) ** 2 + (y - centre[1]) ** 2 + (x - centre[2]) ** 2
    ) <= radius**2

    return np.where(mask, 0, img)


exclusion_r = 10
removed1 = [
    remove_ball(i, c, exclusion_r) for i, c in zip(tqdm(filtered), max_locations)
]
second_loc = [get_max_loc(i) for i in tqdm(removed1)]

In [None]:
fig, axes = plt.subplots(3, 7, figsize=(24, 15))

for loc, img, ax, n in zip(second_loc, removed1, axes.flat, img_ns, strict=True):
    z, x, y = loc

    ax.imshow(img[z], cmap="gist_ncar")

    ax.scatter(y, x, marker=".", color="k", s=16)
    ax.scatter(y, x, marker="o", facecolor="none", edgecolor="k", s=49)
    ax.set_title(n)
    ax.set_axis_off()

fig.tight_layout()

In [None]:
"""
Plot the distance between the two locations - they should be much more than the exclusion radius
"""

delta = np.array(
    [np.linalg.norm(np.subtract(a, b)) for a, b in zip(tqdm(max_locations), second_loc)]
)

fig, axis = plt.subplots()

axis.hist(delta, bins=50, color="k")
axis.axvline(exclusion_r, color="r")
fig.suptitle(
    "Hist shouldn't push up against the red line\nIt does but we'll deal with it later"
)
fig.tight_layout()

# Find the ones that are close
close = (delta - exclusion_r) < 10

img_ns[close]