In [None]:
from pathlib import Path
import importlib.resources as pkg_resources
import yaml
from collections import defaultdict
import numpy as np
import pandas as pd
from typing import Dict, Union

from topostats.run_topostats import run_topostats, find_files
from topostats.grain_finding_cats_unet import test_GPU, predict_unet
from topostats.io import read_yaml, LoadScans
from topostats.processing import process_scan, get_out_path, save_array
from topostats.utils import update_config, update_plotting_config, create_empty_dataframe
from topostats.validation import validate_config, DEFAULT_CONFIG_SCHEMA, PLOTTING_SCHEMA
from topostats.filters import Filters
from topostats.grains import Grains
from topostats.grainstats import GrainStats
from topostats.tracing.dnatracing import trace_image, dnaTrace
from topostats.plotting import plot_crossing_linetrace_halfmax
from topostats.plottingfuncs import Images

import argparse as arg

test_GPU()

In [None]:
# CONFIG_PATH = Path("/Users/sylvi/topo_data/cats/catsconf.yaml")
# args =  arg.Namespace(config_file=str(CONFIG_PATH), create_config_file=False)
# run_topostats(args=args)

In [None]:
CONFIG_PATH = Path("/Users/sylvi/topo_data/cats/catsconf.yaml")
config = read_yaml(CONFIG_PATH)
config = update_config(config, args={})
validate_config(config, schema=DEFAULT_CONFIG_SCHEMA, config_type="YAML configuration file")

config["output_dir"].mkdir(parents=True, exist_ok=True)
# Load plotting_dictionary and validate
plotting_dictionary = pkg_resources.open_text("topostats", "plotting_dictionary.yaml")
config["plotting"]["plot_dict"] = yaml.safe_load(plotting_dictionary.read())
validate_config(config["plotting"]["plot_dict"], schema=PLOTTING_SCHEMA, config_type="YAML plotting configuration file")
config["plotting"] = update_plotting_config(config["plotting"])
plotting_config = config["plotting"]
core_out_path = Path("./output")
core_out_path.mkdir(exist_ok=True)
grain_out_path = Path("./output_grains")
grain_out_path.mkdir(exist_ok=True)
dna_tracing_out_path = Path("./output_dnatracing")
dna_tracing_out_path.mkdir(exist_ok=True)

img_files = find_files(config["base_dir"], file_ext=config["file_ext"])
all_scan_data = LoadScans(img_files, **config["loading"])
all_scan_data.get_data()
scan_data_dict = all_scan_data.img_dict

# ==================================================================

results = defaultdict()
node_results = defaultdict()
for img_path_px2nm in scan_data_dict.values():
    image_path, result, node_result = process_scan(
        img_path_px2nm=img_path_px2nm,
        base_dir=config["base_dir"],
        filter_config=config["filter"],
        grains_config=config["grains"],
        grainstats_config=config["grainstats"],
        dnatracing_config=config["dnatracing"],
        plotting_config=config["plotting"],
        output_dir=config["output_dir"],
    )

    results[str(image_path)] = result
    node_results[str(image_path)] = node_result
try:
    results = pd.concat(results.values())
except ValueError as error:
    print("No grains found in any images, consider adjusting your thresholds.")
    print(error)


# ===========================================================================================