In [None]:
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

from topostats.io import LoadScans
from topostats.plotting import Colormap

colormap = Colormap()
cmap = colormap.get_cmap()
blu = colormap.blue()
vmin = -3
vmax = 4

In [None]:
base_dir = Path("/Users/sylvi/topo_data/topostats_2/datasets/topology-plasmids")
assert base_dir.exists()
dir_height_09 = base_dir / "output_height_09" / "plasmid_sup" / "processed"
dir_conf_03 = base_dir / "output_old_catsnet_c_03" / "plasmid_sup" / "processed"
dir_conf_05 = base_dir / "output_old_catsnet_c_05" / "plasmid_sup" / "processed"
dir_conf_07 = base_dir / "output_old_catsnet_c_07" / "plasmid_sup" / "processed"
dir_conf_09 = base_dir / "output_old_catsnet_c_09" / "plasmid_sup" / "processed"
files = {
    "height_09": dir_height_09,
    "unet_03": dir_conf_03,
    "unet_05": dir_conf_05,
    "unet_07": dir_conf_07,
    "unet_09": dir_conf_09,
}

figure_dir = Path("/Users/sylvi/topo_data/topostats_2/figures/fig-unet")
assert figure_dir.exists()

filename = Path("20230929_unknot_SC_4ng_mgni_eph.0_00054.topostats")

images_data_by_type = {}
for image_id, filedir in files.items():
    filepath = filedir / filename
    print(filepath)
    assert filepath.exists()
    loadscans = LoadScans([filedir / filename], channel="dummy")
    loadscans.get_data()
    loadscans_img_dict = loadscans.img_dict
    image_data = loadscans_img_dict[filename.stem]
    images_data_by_type[image_id] = image_data

print("\n---")

for image_id, image_data in images_data_by_type.items():
    print(image_id)
    print(image_data.keys())
    image = image_data["image"]
    p2nm = image_data["pixel_to_nm_scaling"]
    plt.imshow(image, vmin=vmin, vmax=vmax, cmap=cmap)
    plt.savefig(figure_dir / f"fig-unet-masking-grain-mask-comparison-image.png")
    plt.show()

    mask = image_data["grain_tensors"]["above"][:, :, 1]
    plt.imshow(mask)
    plt.show()

    bbox_topleft = (270, 90)
    bbox_size = 400
    bbox = [bbox_topleft[0], bbox_topleft[1], bbox_topleft[0] + bbox_size, bbox_topleft[1] + bbox_size]
    mask_crop = mask[bbox[0]:bbox[2], bbox[1]:bbox[3]]
    image_crop = image[bbox[0]:bbox[2], bbox[1]:bbox[3]]
    print(f"image crop size: {image_crop.shape[0]} px, {image_crop.shape[0] * p2nm} nm")
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.imshow(image_crop, vmin=vmin, vmax=vmax, cmap=cmap)
    ax.imshow(np.ma.masked_where(mask_crop==0, mask_crop), cmap="cool", alpha=0.8)
    plt.savefig(figure_dir / f"fig-unet-masking-grain-mask-comparision-{image_id}.png")
    plt.show()

    fig, ax = plt.subplots(figsize=(10, 10))
    ax.imshow(image_crop, cmap=cmap, vmin=vmin, vmax=vmax)
    plt.savefig(figure_dir / f"fig-unet-masking-grain-mask-comparison-image-crop.png")
    plt.show()