In [None]:
from pathlib import Path
import pickle

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

from topostats.plottingfuncs import Colormap
from topostats.measure.feret import get_feret_from_mask
from topostats.hariboscripts import calculate_contour_length_from_points

colormap = Colormap()
cmap = colormap.get_cmap()

In [None]:
SAMPLE_TYPE = "cas9_ON_SC"  # options: unbound, bound, on_sc, ot1_rel etc.
DATE_TO_LOAD_FROM = "2024-05-21"
DATA_DIR = Path(f"/Users/sylvi/topo_data/hariborings/extracted_grains/{SAMPLE_TYPE}/date_{DATE_TO_LOAD_FROM}")

# Load the data
FILE_PATH = DATA_DIR / "ferets_dict.pkl"
with open(FILE_PATH, "rb") as f:
    grain_dicts = pickle.load(f)

if "unbound" in SAMPLE_TYPE:
    VMIN = -3
    VMAX = 4
elif "cas9" in SAMPLE_TYPE:
    VMIN = -0.5
    VMAX = 7.0
else:
    raise ValueError("Sample is neither unbound or bound")

print(f"number of images for [{SAMPLE_TYPE}]: {len(grain_dicts.keys())}")
print(f"keys in grain dicts: {grain_dicts[list(grain_dicts.keys())[0]].keys()}")

In [None]:
def plot_images(
    images: list,
    masks: list,
    grain_indexes: list,
    min_ferets: list,
    max_ferets: list,
    px_to_nms: list,
    width=5,
    VMIN=VMIN,
    VMAX=VMAX,
    cmap=cmap,
):
    num_images = len(images)
    num_rows = num_images // width + 1
    num_images_in_batch = 2
    fig, axes = plt.subplots(num_rows, width * num_images_in_batch, figsize=(width * 4, num_rows * 4))
    for i, (image, mask, grain_index, min_feret, max_feret, p_to_nm) in enumerate(
        zip(images, masks, grain_indexes, min_ferets, max_ferets, px_to_nms)
    ):
        # Plot image
        im_ax = axes[i // width, i % width * num_images_in_batch]
        im_ax.imshow(image, cmap=cmap, vmin=VMIN, vmax=VMAX)
        # print(f"min feret coords: {min_feret_coords}")
        # im_ax.scatter(min_feret_coords[0, 0], min_feret_coords[0, 1], color="blue")
        # im_ax.scatter(min_feret_coords[1, 0], min_feret_coords[1, 1], color="blue")
        # im_ax.scatter(max_feret_coords[0, 0], max_feret_coords[0, 1], color="orange")
        # im_ax.scatter(max_feret_coords[1, 0], max_feret_coords[1, 1], color="orange")
        im_ax.set_title(
            f"Grain {grain_index} {p_to_nm:.2f} p/nm\nmin feret: {min_feret:.4f} max feret: {max_feret:.4f}"
        )
        im_ax.axis("off")
        # Plot mask
        mask_ax = axes[i // width, i % width * num_images_in_batch + 1]
        mask_ax.imshow(mask.astype(int))
        mask_ax.axis("off")

    fig.tight_layout()
    plt.show()

In [None]:
image_dict_ferets = {}

for index, image_dict_item in grain_dicts.items():
    # print(image_dict_item.keys())
    image = image_dict_item["image"]
    if "cas9_" in SAMPLE_TYPE:
        mask = image_dict_item["predicted_mask"]
    elif "unbound" in SAMPLE_TYPE:
        mask = image_dict_item["mask"]
    else:
        raise ValueError("Sample type not got cas9 or unbound in it.")
    p_to_nm = image_dict_item["p_to_nm"]

    if "cas9_" in SAMPLE_TYPE:
        # print(f"Using bound data, grabbing ring mask only for ferets")
        feret_mask = (mask == 1).copy()
    elif "unbound_" in SAMPLE_TYPE:
        # print(f"Using unbound data, using whole mask for ferets")
        feret_mask = mask.copy()
    else:
        raise ValueError("Sample type not got cas9 or unbound in it.")

    # get feret from the trace instead
    # Turn the trace into a binary mask
    if "cas9_" in SAMPLE_TYPE:
        trace = image_dict_item["path"]
        pooled_trace = image_dict_item["pooled_path"]
    elif "unbound_" in SAMPLE_TYPE:
        trace = image_dict_item["trace"]
        pooled_trace = image_dict_item["pooled_trace"]
    trace_mask = np.zeros_like(image, dtype=bool)
    for x, y in trace:
        trace_mask[int(y), int(x)] = 1

    results = get_feret_from_mask(mask_im=trace_mask)

    min_feret = results["min_feret"] * p_to_nm
    max_feret = results["max_feret"] * p_to_nm

    feret_ratio = max_feret / min_feret

    min_feret_coords = results["min_feret_coords"]
    max_feret_coords = results["max_feret_coords"]

    # contour length
    # if "cas9_" in SAMPLE_TYPE:
    #     circular = False
    # if "unbound_" in SAMPLE_TYPE:
    #     circular = True
    contour_length = calculate_contour_length_from_points(
        points=pooled_trace, pixel_to_nm_scaling=p_to_nm, circular=True
    )

    if contour_length > 50:
        print(f"large contour length for {index}: {contour_length} nm")

    image_dict_ferets[index] = image_dict_item
    image_dict_ferets[index]["min_feret"] = min_feret
    image_dict_ferets[index]["max_feret"] = max_feret
    image_dict_ferets[index]["feret_ratio"] = feret_ratio
    image_dict_ferets[index]["min_feret_coords"] = min_feret_coords
    image_dict_ferets[index]["max_feret_coords"] = max_feret_coords
    image_dict_ferets[index]["contour_length"] = contour_length

    if "cas9_" in SAMPLE_TYPE:
        image_dict_ferets[index]["protein_area"] = np.sum(mask == 2) * p_to_nm**2
        # Calculate volume of protein by summing the heights of the protein pixels
        image_dict_ferets[index]["protein_volume"] = np.sum(image[mask == 2]) * p_to_nm**2
    else:
        image_dict_ferets[index]["protein_area"] = np.nan
        image_dict_ferets[index]["protein_volume"] = np.nan


# Plot kde of ferets
sns.kdeplot([image_dict_ferets[i]["min_feret"] for i in image_dict_ferets], label="min feret")
sns.kdeplot([image_dict_ferets[i]["max_feret"] for i in image_dict_ferets], label="max feret")
plt.title(f"Feret diameters for {SAMPLE_TYPE} (n = {len(image_dict_ferets)})")
plt.legend()
plt.show()

# Plot kde of contour lengths
sns.kdeplot([image_dict_ferets[i]["contour_length"] for i in image_dict_ferets])
plt.title(f"Contour lengths for {SAMPLE_TYPE} (n = {len(image_dict_ferets)})")
plt.show()

# kde plot of protein areas and volumes
if "cas9_" in SAMPLE_TYPE:
    sns.kdeplot([image_dict_ferets[i]["protein_area"] for i in image_dict_ferets])
    plt.title(f"Protein areas for {SAMPLE_TYPE} (n = {len(image_dict_ferets)})")
    plt.show()
    sns.kdeplot([image_dict_ferets[i]["protein_volume"] for i in image_dict_ferets])
    plt.title(f"Protein volumes for {SAMPLE_TYPE} (n = {len(image_dict_ferets)})")
    plt.show()

In [None]:
# # # plot an image of the sample

# # print keys of grain_dicts
# # print(f"keys in grain_dicts: {grain_dicts.keys()}")
# # image_index = list(grain_dicts.keys())[0]
# image_index = 21
# print(f"image index: {image_index}")
# image_dict = grain_dicts[image_index]
# print(f"image dict keys: {image_dict.keys()}")
# image = image_dict["image"]
# mask = image_dict["predicted_mask"]
# p_to_nm = image_dict["p_to_nm"]
# min_feret = image_dict_ferets[image_index]["min_feret"]
# max_feret = image_dict_ferets[image_index]["max_feret"]
# min_feret_coords = image_dict_ferets[image_index]["min_feret_coords"]
# max_feret_coords = image_dict_ferets[image_index]["max_feret_coords"]
# contour_length = image_dict_ferets[image_index]["contour_length"]
# trace = image_dict["path"]
# pooled_trace = image_dict["pooled_path"]

# plt.imshow(image, cmap=cmap, vmin=VMIN, vmax=VMAX)
# plt.title(f"Grain {image_index} {p_to_nm:.2f} p/nm\nmin feret: {min_feret:.4f} max feret: {max_feret:.4f}")
# plt.show()
# plt.imshow(image, cmap=cmap, vmin=VMIN, vmax=VMAX)
# plt.plot()
# plt.imshow(mask, alpha=0.5)
# plt.plot(trace[:, 1], trace[:, 0])
# plt.plot(pooled_trace[:, 1], pooled_trace[:, 0])

# # Draw scale bar on image of 10nm
# scale_bar_nm = 10
# scale_bar_px = scale_bar_nm / p_to_nm
# pos_scale = 0.05
# plt.plot(
#     [pos_scale * image.shape[0], pos_scale * image.shape[0] + scale_bar_px],
#     [pos_scale * image.shape[1], pos_scale * image.shape[1]],
#     color="white",
#     lw=2,
# )
# plt.title(f"Contour length: {contour_length:.2f} nm | scale bar: {scale_bar_nm} nm")
# plt.show()

In [None]:
# Save dictionary to the same folder the data was grabbed from

with open(FILE_PATH.parent / "feret_grain_dict_fig1_with_contour_length.pkl", "wb") as f:
    pickle.dump(image_dict_ferets, f)

In [None]:
raise ValueError("stop here")

In [None]:
plot_until_index = 40
plot_images(
    [image_dict_ferets[i]["image"] for i in image_dict_ferets if i < plot_until_index],
    [image_dict_ferets[i]["mask"] for i in image_dict_ferets if i < plot_until_index],
    [i for i in image_dict_ferets if i < plot_until_index],
    [image_dict_ferets[i]["min_feret"] for i in image_dict_ferets if i < plot_until_index],
    [image_dict_ferets[i]["max_feret"] for i in image_dict_ferets if i < plot_until_index],
    [image_dict_ferets[i]["p_to_nm"] for i in image_dict_ferets if i < plot_until_index],
)

In [None]:
i = list(image_dict_ferets.keys())[0]
thisdict = image_dict_ferets[i]
print(thisdict["min_feret_coords"])
print(thisdict["min_feret_coords"][1, 0])

plt.imshow(thisdict["image"], cmap=cmap, vmin=VMIN, vmax=VMAX)
plt.imshow(thisdict["mask"], alpha=0.5)
plt.scatter(thisdict["min_feret_coords"][0, 1], thisdict["min_feret_coords"][0, 0], color="blue")
plt.scatter(thisdict["min_feret_coords"][1, 1], thisdict["min_feret_coords"][1, 0], color="blue")
plt.scatter(thisdict["max_feret_coords"][0, 1], thisdict["max_feret_coords"][0, 0], color="orange")
plt.scatter(thisdict["max_feret_coords"][1, 1], thisdict["max_feret_coords"][1, 0], color="orange")
plt.title(f"grain {i} | min feret: {thisdict['min_feret']:.2f} | max feret: {thisdict['max_feret']:.2f}")

In [None]:
def plot_images_with_min_feret_coords(
    images: list,
    masks: list,
    grain_indexes: list,
    min_ferets: list,
    max_ferets: list,
    px_to_nms: list,
    min_feret_coord_x_1s: list,
    min_feret_coord_y_1s: list,
    min_feret_coord_x_2s: list,
    min_feret_coord_y_2s: list,
    width=5,
    VMIN=VMIN,
    VMAX=VMAX,
    cmap=cmap,
    title="",
):
    num_images = len(images)
    num_rows = num_images // width + 1
    num_images_in_batch = 2
    fig, axes = plt.subplots(num_rows, width * num_images_in_batch, figsize=(width * 4, num_rows * 4))
    for i, (
        image,
        mask,
        grain_index,
        min_feret,
        max_feret,
        p_to_nm,
        min_feret_coord_x_1,
        min_feret_coord_y_1,
        min_feret_coord_x_2,
        min_feret_coord_y_2,
    ) in enumerate(
        zip(
            images,
            masks,
            grain_indexes,
            min_ferets,
            max_ferets,
            px_to_nms,
            min_feret_coord_x_1s,
            min_feret_coord_y_1s,
            min_feret_coord_x_2s,
            min_feret_coord_y_2s,
        )
    ):
        # Plot image
        im_ax = axes[i // width, i % width * num_images_in_batch]
        im_ax.imshow(image, cmap=cmap, vmin=VMIN, vmax=VMAX, interpolation="none")
        im_ax.scatter(min_feret_coord_y_1, min_feret_coord_x_1, color="cornflowerblue")
        im_ax.scatter(min_feret_coord_y_2, min_feret_coord_x_2, color="cornflowerblue")
        im_ax.set_title(
            f"Grain {grain_index} {p_to_nm:.2f} p/nm\nmin feret: {min_feret:.4f} max feret: {max_feret:.4f}"
        )
        im_ax.axis("off")
        # Plot mask
        mask_ax = axes[i // width, i % width * num_images_in_batch + 1]
        mask_ax.imshow(mask.astype(int) == 1, interpolation="none")
        mask_ax.scatter(min_feret_coord_y_1, min_feret_coord_x_1, color="cornflowerblue")
        mask_ax.scatter(min_feret_coord_y_2, min_feret_coord_x_2, color="cornflowerblue")
        mask_ax.plot(
            [min_feret_coord_y_1, min_feret_coord_y_2],
            [min_feret_coord_x_1, min_feret_coord_x_2],
            color="cornflowerblue",
        )
        mask_ax.axis("off")

    plt.suptitle(title)
    fig.tight_layout()
    plt.show()

In [None]:
# large_feret_threshold = 14
# large_feret_grains = [i for i in image_dict_ferets if image_dict_ferets[i]["min_feret"] > large_feret_threshold]

# plot_images_with_min_feret_coords(
#     [image_dict_ferets[i]["image"] for i in large_feret_grains],
#     [image_dict_ferets[i]["mask"] for i in large_feret_grains],
#     [i for i in large_feret_grains],
#     [image_dict_ferets[i]["min_feret"] for i in large_feret_grains],
#     [image_dict_ferets[i]["max_feret"] for i in large_feret_grains],
#     [image_dict_ferets[i]["p_to_nm"] for i in large_feret_grains],
#     [image_dict_ferets[i]["min_feret_coords"][0, 0] for i in large_feret_grains],
#     [image_dict_ferets[i]["min_feret_coords"][0, 1] for i in large_feret_grains],
#     [image_dict_ferets[i]["min_feret_coords"][1, 0] for i in large_feret_grains],
#     [image_dict_ferets[i]["min_feret_coords"][1, 1] for i in large_feret_grains],
#     title="ot1 rel grains with large min ferets",
# )