In [None]:
from pathlib import Path
import pickle as pkl
import re

from datetime import datetime

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
DATA_DIR = Path("/Users/sylvi/topo_data/hariborings/extracted_grains/")
FIG_SAVE_DIR = Path(f"/Volumes/shared-3/pyne_group/Shared/Papers/cas9_minicircles/figure_1/")
LOAD_DATE = "2024-05-21"
TODAY_DATE = datetime.today().strftime("%Y-%m-%d")
assert DATA_DIR.exists()

MAX_P_TO_NM = 10.0

SAMPLES = [
    "unbound_ON_REL",
    "unbound_ON_SC",
    "unbound_OT1_REL",
    "unbound_OT1_SC",
    "unbound_OT2_REL",
    "unbound_OT2_SC",
    "cas9_ON_SC",
    "cas9_OT1_SC",
    "cas9_OT2_SC",
]

# Create a big dataframe holding:
# - sample type
# - p_to_nm
# - min_feret

min_feret_lower_threshold = 4
min_feret_upper_threshold = 20

data_list = []
bad_feret_list = []

for sample_type in SAMPLES:
    print(f"loading {sample_type}")
    # Load the data from pickle
    with open(DATA_DIR / sample_type / f"date_{LOAD_DATE}" / f"feret_grain_dict_fig1.pkl", "rb") as f:
        feret_data = pkl.load(f)

    for grain_index, grain in feret_data.items():
        image = grain["image"]
        if "cas9_" in sample_type:
            mask = grain["predicted_mask"]
        elif "unbound_" in sample_type:
            mask = grain["mask"]
        else:
            raise ValueError()
        p_to_nm = grain["p_to_nm"]
        min_feret = grain["min_feret"]
        max_feret = grain["max_feret"]

        min_feret_coords = grain["min_feret_coords"]
        max_feret_coords = grain["max_feret_coords"]

        if p_to_nm <= MAX_P_TO_NM:
            if min_feret < min_feret_lower_threshold or min_feret > min_feret_upper_threshold:
                bad_feret_list.append(
                    {
                        "sample_type": sample_type,
                        "image": image,
                        "mask": mask,
                        "p_to_nm": p_to_nm,
                        "min_feret": min_feret,
                        "max_feret": max_feret,
                        "min_feret_coords": min_feret_coords,
                        "max_feret_coords": max_feret_coords,
                    }
                )
            else:
                data_list.append(
                    {
                        "sample_type": sample_type,
                        "p_to_nm": p_to_nm,
                        "min_feret": min_feret,
                        "max_feret": max_feret,
                    }
                )

print(f"num grains: {len(data_list)}")

df = pd.DataFrame(data_list)

print(df.head())

print(f"num bad feret: {len(bad_feret_list)}")

# Save the dataframe
# df.to_csv(DATA_DIR / f"feret_data_{TODAY_DATE}_max_p_to_nm_{MAX_P_TO_NM}.csv", index=False)

In [None]:
# plot violin plots of the min ferets for each group

# colours = ["#D81B60", "#CE5782", "#1E88E5", "#6396C3", "#FFC107", "#F3D16D", "#C1879C", "#99ACBD", "#ECDFB6"]

x_ticks = [
    "ON SC",
    "OT1 SC",
    "OT2 SC",
    "ON REL",
    "OT1 REL",
    "OT2 REL",
    "ON SC +dCas9",
    "OT1 SC +dCas9",
    "OT2 SC +dCas9",
]

# Previous ordering
sample_order = [
    "unbound_ON_SC",
    "unbound_OT1_SC",
    "unbound_OT2_SC",
    "unbound_ON_REL",
    "unbound_OT1_REL",
    "unbound_OT2_REL",
    "cas9_ON_SC",
    "cas9_OT1_SC",
    "cas9_OT2_SC",
]

# sample_order = [
#     "unbound_ON_REL",
#     "unbound_OT1_REL",
#     "unbound_OT2_REL",
#     "unbound_ON_SC",
#     "unbound_OT1_SC",
#     "unbound_OT2_SC",
#     "cas9_ON_SC",
#     "cas9_OT1_SC",
#     "cas9_OT2_SC",
# ]

# Prevous ordering
colour_dict = {
    "unbound_ON_SC": "#C1879C",
    "unbound_OT1_SC": "#99ACBD",
    "unbound_OT2_SC": "#ECDFB6",
    "unbound_ON_REL": "#CE5782",
    "unbound_OT1_REL": "#6396C3",
    "unbound_OT2_REL": "#F3D16D",
    "cas9_ON_SC": "#D81B60",
    "cas9_OT1_SC": "#1E88E5",
    "cas9_OT2_SC": "#FFC107",
}

# colour_dict = {
#     "unbound_ON_SC": "#C1879C",
#     "unbound_OT1_SC": "#99ACBD",
#     "unbound_OT2_SC": "#ECDFB6",
#     "unbound_ON_REL": "#CE5782",
#     "unbound_OT1_REL": "#6396C3",
#     "unbound_OT2_REL": "#F3D16D",
#     "cas9_ON_SC": "#D81B60",
#     "cas9_OT1_SC": "#1E88E5",
#     "cas9_OT2_SC": "#FFC107",
# }

# Print ns for each sample
for sample in sample_order:
    print(f"{sample}: {len(df[df['sample_type'] == sample])}")

fig, ax = plt.subplots(figsize=(12, 8))

sns.violinplot(
    data=df, ax=ax, x="sample_type", y="min_feret", hue="sample_type", palette=colour_dict, order=sample_order
)

ax.set_ylabel("Mininum width (nm)", fontsize=20)
ax.set_xlabel("Sample type", fontsize=20)
# reformat x ticks
# plt.xticks(rotation=45, ha="right")
# manualy set x ticks with font size
ticks = ax.get_xticks()
ax.set_yticklabels(ax.get_yticks(), fontsize=18)
# plt.set_xticks(ticks, np.arange(9), x_ticks, fontsize=20, rotation=45, ha="right")
# convert this to be used with axes, setting the font size and rotation and ha
ax.set_xticks(ticks)
ax.set_xticklabels(x_ticks, fontsize=18, rotation=45, ha="right")
# ax.set_title(f"Min Feret width for grains with p_to_nm < {MAX_P_TO_NM}", fontsize=20)
# ax.set_ylim(0, 20)
fig.tight_layout()
# plt.savefig(FIG_SAVE_DIR / f"min_feret_violin_plot_{TODAY_DATE}_max_p_to_nm_{MAX_P_TO_NM}.png", dpi=500)
plt.show()

In [None]:
# plot the bad ferets

for grain in bad_feret_list:
    sample_type = grain["sample_type"]

    min_feret_coords = grain["min_feret_coords"]
    max_feret_coords = grain["max_feret_coords"]

    print(sample_type)
    fig, ax = plt.subplots(1, 2, figsize=(12, 6))
    ax[0].imshow(grain["image"])
    ax[0].set_title(f"sample: {sample_type} image")

    multiplied_ticks = np.array(ax[0].get_xticks()) * grain["p_to_nm"]
    multiplied_ticks = [f"{x:.1f}" for x in multiplied_ticks]
    ax[0].set_xticklabels(multiplied_ticks)

    multiplied_ticks = np.array(ax[0].get_yticks()) * grain["p_to_nm"]
    multiplied_ticks = [f"{x:.1f}" for x in multiplied_ticks]
    ax[0].set_yticklabels(multiplied_ticks)

    ax[1].imshow(grain["mask"])
    ax[1].set_title(f"sample: {sample_type} mask")
    ax[1].scatter(min_feret_coords[0, 0], min_feret_coords[0, 1], color="r")
    ax[1].scatter(min_feret_coords[1, 0], min_feret_coords[1, 1], color="r")

    multiplied_ticks = np.array(ax[1].get_xticks()) * grain["p_to_nm"]
    multiplied_ticks = [f"{x:.1f}" for x in multiplied_ticks]
    ax[1].set_xticklabels(multiplied_ticks)

    multiplied_ticks = np.array(ax[1].get_yticks()) * grain["p_to_nm"]
    multiplied_ticks = [f"{x:.1f}" for x in multiplied_ticks]
    ax[1].set_yticklabels(multiplied_ticks)

    plt.suptitle(f"sample: {grain['sample_type']} image: {grain['image']} p_to_nm: {grain['p_to_nm']}")
    plt.show()