Just trying stuff out
====

In [None]:
"""
Read in some WT jaws and make some greyscale images but only containing the voxels from the mask
"""

import tifffile
from tqdm.notebook import tqdm

from fishjaw.inference import read
from fishjaw.util import files

output_dir = files.script_out_dir() / "jaw_segmentations"

img_dir = output_dir / "imgs"
mask_dir = output_dir / "masks"

masked_out_dir = output_dir / "masked"
masked_out_dir.mkdir(exist_ok=True)

metadata_df = read.mastersheet()

wt_mdata_df = metadata_df[metadata_df["name"].str.contains("wt")]

imgs = []
masks = []
masked = []
img_ns = []

for n in tqdm(wt_mdata_df.index):
    try:
        new_img = tifffile.imread(img_dir / f"ak_{n}.tif")
        new_mask = tifffile.imread(mask_dir / f"ak_{n}.tif")
        new_masked = new_mask * new_img
    except FileNotFoundError:
        print(f"{n} doesn't exist")

    tifffile.imwrite(masked_out_dir / f"ak_{n}.tif", new_masked)

    imgs.append(new_img)
    masks.append(new_mask)
    masked.append(new_masked)
    img_ns.append(n)

In [None]:
"""
Plot some of them
"""

import numpy as np
import matplotlib.pyplot as plt

# hacky
from scripts.pipeline.plot_3d import _calculate_point_size


def plot_projections(imgs, masks):
    fig, axes = plt.subplots(5, 8, figsize=(24, 15), subplot_kw={"projection": "3d"})

    plot_kw = {
        "marker": "s",
        "cmap": "inferno",
        "vmin": 0,
        "vmax": 2**16,
        "s": _calculate_point_size(axes[0, 0], imgs[0].shape),
    }

    for img, mask, axis, n in zip(imgs, masks, tqdm(axes.flat), img_ns, strict=False):
        co_ords = np.argwhere(mask)
        greyscale_vals = img[co_ords[:, 0], co_ords[:, 1], co_ords[:, 2]]

        im = axis.scatter(
            co_ords[:, 0], co_ords[:, 1], co_ords[:, 2], c=greyscale_vals, **plot_kw
        )

        axis.view_init(elev=180, azim=30)
        axis.set_axis_off()
        axis.set_title(n)

    fig.colorbar(im)
    fig.tight_layout()


plot_projections(imgs, masks)

In [None]:
"""
Exclude segmentations that are broken, also quite a few of them look the same for some reason
"""

broken_n = [
    238,
    239,
    245,
    293,
    417,
    416,
    415,
    414,
    413,
    412,
    346,
    345,
    344,
    343,
    342,
    341,
    340,
]
keep = ~np.isin(img_ns, broken_n)
print(f"{np.sum(~keep)} broken segmentations")

imgs = np.array(imgs)[keep]
masks = np.array(masks)[keep]
masked = np.array(masked)[keep]
img_ns = np.array(img_ns)[keep]

In [None]:
"""
Uniform filter the masks and find maxima
"""
%load_ext autoreload
%autoreload 2

from fishjaw.density_analysis import find_attachments

n_loc = 6
exclusion_r = 25

filtered = [find_attachments.masked_smooth(i, m) for i, m in zip(tqdm(masked), masks)]
max_locations = [find_attachments.get_maxima(i, n_loc, removal_radius=25) for i in filtered]

In [None]:
"""
Plot the mask, original, smoothed + location of maximum
"""

n_rows = 10
fig, axes = plt.subplots(n_rows, 3, figsize=(9, 30))

for i in range(n_rows):
    # Get the location of the first maximum - i.e. the global max
    z, x, y = max_locations[i][0]
    for axis, name in zip(axes[i], ["masks", "masked", "filtered"]):
        axis.imshow(locals()[name][i][z], interpolation="none", cmap="grey")
        axis.set_axis_off()

        if not i:
            axis.set_title(name)

        axis.scatter(y, x, marker="x", color="r")

fig.tight_layout()
plt.close(fig)

In [None]:
"""
For all the jaws plot the location of all of the maxima
"""

for i in range(n_loc):
    fig, axes = plt.subplots(3, 7, figsize=(21, 15))
    bins = np.linspace(1, 60000, 50)
    for max_loc, img, axs, n in zip(
        [m[i] for m in max_locations], filtered, axes.flat, img_ns, strict=False
    ):
        z, x, y = max_loc

        im = axs.imshow(img[z], cmap="gist_ncar")
        axs.set_axis_off()
        axs.scatter(y, x, marker=".", color="k", s=10)

        axs.set_title(n)

    fig.tight_layout()

    cax = fig.add_axes([1.02, 0.1, 0.02, 0.8])
    cbar = fig.colorbar(im, cax=cax)
    cbar.set_ticks([])

In [None]:
"""
Plot the distance between the two locations - they should be much more than the exclusion radius
"""
pairwise_distances = [find_attachments.get_pairwise_distances(l) for l in tqdm(max_locations)]
pairwise_distances = np.array(pairwise_distances).ravel()

fig, axis = plt.subplots()

axis.hist(pairwise_distances, bins=100, color="k")
axis.set_xlabel("Distance between maxima")

axis.axvline(exclusion_r, color="r")
axis.text(exclusion_r * 1.05, axis.get_ylim()[1] * 0.96, "Exclusion radius", color="red")
fig.suptitle(
    "Hist shouldn't push up against the red line\nIt does but we'll deal with it later"
)
fig.tight_layout()

In [None]:
fig, axes = plt.subplots(3, 7, figsize=(21, 9), subplot_kw={"projection": "3d"})

plot_kw = {
    "marker": "s",
    "cmap": "grey",
    "vmin": 0,
    "vmax": 2**16,
    "s": _calculate_point_size(axes[0, 0], imgs[0].shape),
    "alpha": 0.005,
}

colours = [plt.cm.tab10(i) for i in range(6)]
markers = ["x", "+", "o", "^", "v", "s"]

for (img, mask, max_locs, axis) in zip(imgs, masks, max_locations, tqdm(axes.flat), strict=True):
    co_ords = np.argwhere(mask)
    greyscale_vals = img[co_ords[:, 0], co_ords[:, 1], co_ords[:, 2]]

    im = axis.scatter(
        co_ords[:, 0],
        co_ords[:, 1],
        co_ords[:, 2],
        c=greyscale_vals,
        **plot_kw,
        zorder=2,
    )

    axis.view_init(elev=180, azim=30)
    axis.set_axis_off()
    axis.set_title(img_ns[i])

    for arr, colour, marker in (
        zip(max_locs, colours, markers)
    ):
        axis.scatter(*np.array(arr).T, s=81, color=colour, marker=marker)

fig.tight_layout()

In [None]:
"""
Get the greyscale values at each point of interest

"""

# Get the overall average greyscale
avg_greyscales = [np.median(i[m]) for i, m in zip(filtered, masks)]

# Get the average greyscales in a region around each point of interest
roi_greyscales = []
with tqdm(total=len(imgs) * n_loc) as pbar:
    for i, (img, max_locs) in enumerate(zip(filtered, max_locations)):
        roi_greyscales.append(
            [find_attachments.ball_median(img, x, exclusion_r) for x in max_locs]
        )
        pbar.update(n_loc)

In [None]:
fig, axes = plt.subplots(1, 2, sharey=True, figsize=(10, 5))

ages = metadata_df.loc[img_ns, "age"]

for colour, marker, params in zip(
    colours, markers, np.array(roi_greyscales).T, strict=True
):
    for i, n in enumerate(img_ns):
        median = params[i]
        age = ages[n]
        axes[0].scatter(age, median, marker=marker, color=colour)

axes[0].set_xlabel("Age (months)")
axes[1].set_xlabel("Age (months)")
axes[0].set_ylabel("Density")

for i, n in enumerate(img_ns):
    axes[1].scatter(ages[n], avg_greyscales[i], color="gray")

axes[0].set_title("Regions of interest")
axes[1].set_title("Jaw medians")