In [None]:
from pathlib import Path
from loguru import logger

from IPython.display import clear_output
import numpy as np
import matplotlib.pyplot as plt

from topostats.io import LoadScans
from topostats.filters import Filters
from topostats.plottingfuncs import Colormap

colormap = Colormap()
CMAP = colormap.get_cmap()
VMIN = None
VMAX = None

DATA_DIR = Path("/Users/sylvi/topo_data/perovskite/textured_silicon/")
files = list(DATA_DIR.glob("*"))
logger.info(f"Found {len(files)} files in {DATA_DIR}")
for file in files:
    logger.info(file)

loadscans = LoadScans(files, channel="Height Sensor")
loadscans.get_data()
image_dictionaries = loadscans.img_dict
logger.info(f"Loaded {len(image_dictionaries)} images")

In [None]:
flattened_images = {}
for filename, topostats_object in image_dictionaries.items():
    filename = Path(filename).name
    logger.info(f"Flattening {filename}")
    filters = Filters(
        image=topostats_object["image_original"],
        filename=filename,
        pixel_to_nm_scaling=topostats_object["pixel_to_nm_scaling"],
        row_alignment_quantile=None,
        threshold_method=None,
        gaussian_size=0.01,
        remove_scars={
            "run": False,
        },
    )
    filters.filter_image()
    flattened_images[filename] = {
        "original_image": topostats_object["image_original"],
        "flattened_image": filters.images["gaussian_filtered"],
        "pixel_to_nm_scaling": topostats_object["pixel_to_nm_scaling"],
    }
clear_output()
logger.info(f"Flattened {len(flattened_images)} images")

In [None]:
def plot_images(images: list, masks: list, px_to_nms: list, filenames: list, width=3, VMIN=VMIN, VMAX=VMAX, cmap=CMAP):
    num_images = len(images)
    num_rows = num_images // width + 1
    num_images_in_batch = 2
    fig, axes = plt.subplots(num_rows, width * num_images_in_batch, figsize=(width * 8, num_rows * 4))
    for i, (image, mask, p_to_nm, filename) in enumerate(zip(images, masks, px_to_nms, filenames)):
        # Plot image
        im_ax = axes[i // width, i % width * num_images_in_batch]
        im_ax.imshow(image, cmap=CMAP, vmin=VMIN, vmax=VMAX)
        im_ax.set_title(f"{p_to_nm:.3f} p/nm\n{filename}")
        im_ax.axis("off")
        # Plot mask
        mask_ax = axes[i // width, i % width * num_images_in_batch + 1]
        mask_ax.imshow(mask.astype(int))
        mask_ax.axis("off")

    fig.tight_layout()
    plt.show()

In [None]:
logger.info(f"{flattened_images.keys()}")
plot_images(
    images=[flattened_images[filename]["flattened_image"] for filename in flattened_images.keys()],
    masks=[np.zeros_like(flattened_images[filename]["flattened_image"]) for filename in flattened_images.keys()],
    px_to_nms=[flattened_images[filename]["pixel_to_nm_scaling"] for filename in flattened_images.keys()],
    filenames=[filename for filename in flattened_images.keys()],
)