In [13]:
from pathlib import Path
import numpy as np
from aicsimageio import AICSImage
from aicsimageio.writers import OmeTiffWriter
from mmv_im2im.configs.config_base import ProgramConfig, configuration_validation
from mmv_im2im import ProjectTester
from skimage.morphology import remove_small_objects, remove_small_holes
from skimage.measure import label, regionprops
from aicssegmentation.core.utils import topology_preserving_thinning

first, we need to re-define a utility function for jupyter notebook. No need to change anything. Just run it.

In [14]:
from dataclasses import dataclass
from pathlib import Path
from pyrallis import field

import argparse
import dataclasses
import sys
import warnings
from argparse import HelpFormatter, Namespace
from collections import defaultdict
from logging import getLogger
from typing import Dict, List, Sequence, Text, Type, Union, TypeVar, Generic, Optional

from pyrallis import utils, cfgparsing
from pyrallis.help_formatter import SimpleHelpFormatter
from pyrallis.parsers import decoding
from pyrallis.utils import Dataclass, PyrallisException
from pyrallis.wrappers import DataclassWrapper

logger = getLogger(__name__)

T = TypeVar("T")


class ArgumentParser(Generic[T], argparse.ArgumentParser):
    def __init__(
        self,
        config_class: Type[T],
        config: Optional[str] = None,
        formatter_class: Type[HelpFormatter] = SimpleHelpFormatter,
        *args,
        **kwargs,
    ):
        """Creates an ArgumentParser instance."""
        kwargs["formatter_class"] = formatter_class
        super().__init__(*args, **kwargs)

        # constructor arguments for the dataclass instances.
        # (a Dict[dest, [attribute, value]])
        self.constructor_arguments: Dict[str, Dict] = defaultdict(dict)

        self._wrappers: List[DataclassWrapper] = []

        self.config = config
        self.config_class = config_class

        self._assert_no_conflicts()
        self.add_argument(
            f"--{utils.CONFIG_ARG}",
            type=str,
            help="Path for a config file to parse with pyrallis",
        )
        self.set_dataclass(config_class)

    def set_dataclass(
        self,
        dataclass: Union[Type[Dataclass], Dataclass],
        prefix: str = "",
        default: Union[Dataclass, Dict] = None,
        dataclass_wrapper_class: Type[DataclassWrapper] = DataclassWrapper,
    ):
        """Adds command-line arguments for the fields of `dataclass`."""
        if not isinstance(dataclass, type):
            default = dataclass if default is None else default
            dataclass = type(dataclass)

        new_wrapper = dataclass_wrapper_class(dataclass, prefix=prefix, default=default)
        self._wrappers.append(new_wrapper)
        self._wrappers += new_wrapper.descendants

        for wrapper in self._wrappers:
            logger.debug(
                f"Adding arguments for dataclass: {wrapper.dataclass} "
                f"at destination {wrapper.dest}"
            )
            wrapper.add_arguments(parser=self)

    def _assert_no_conflicts(self):
        """Checks for a field name that conflicts with utils.CONFIG_ARG"""
        if utils.CONFIG_ARG in [
            field.name for field in dataclasses.fields(self.config_class)
        ]:
            raise PyrallisException(
                f"{utils.CONFIG_ARG} is a reserved word for pyrallis"
            )

    def parse_args(self, args=None, namespace=None) -> T:
        return super().parse_args(args, namespace)

    def parse_known_args(
        self,
        args: Sequence[Text] = None,
        namespace: Namespace = None,
        attempt_to_reorder: bool = False,
    ):
        # NOTE: since the usual ArgumentParser.parse_args() calls
        # parse_known_args, we therefore just need to overload the
        # parse_known_args method to support both.
        if args is None:
            # args default to the system args
            args = sys.argv[1:]
        else:
            # make sure that args are mutable
            args = list(args)

        if "--help" not in args:
            for action in self._actions:
                # TODO: Find a better way to do that?
                action.default = (
                    argparse.SUPPRESS
                )  # To avoid setting of defaults in actual run
                action.type = (
                    str  # In practice, we want all processing to happen with yaml
                )
        parsed_args, unparsed_args = super().parse_known_args(args, namespace)

        parsed_args = self._postprocessing(parsed_args)
        return parsed_args, unparsed_args

    def print_help(self, file=None):
        return super().print_help(file)

    def _postprocessing(self, parsed_args: Namespace) -> T:
        logger.debug("\nPOST PROCESSING\n")
        logger.debug(f"(raw) parsed args: {parsed_args}")

        parsed_arg_values = vars(parsed_args)

        for key in parsed_arg_values:
            parsed_arg_values[key] = cfgparsing.parse_string(parsed_arg_values[key])

        config = self.config  # Could be NONE

        if utils.CONFIG_ARG in parsed_arg_values:
            new_config = parsed_arg_values[utils.CONFIG_ARG]
            if config is not None:
                warnings.warn(
                    UserWarning(f"Overriding default {config} with {new_config}")
                )
            ######################################################################
            # adapted from original implementation in pyrallis
            ######################################################################
            if Path(new_config).is_file():
                # pass in a absolute path
                config = new_config
            else:
                new_config = str(new_config)
                print(f"trying to locate preset config for {new_config} ...")

                config = Path(__file__).parent / f"preset_{new_config}.yaml"
            del parsed_arg_values[utils.CONFIG_ARG]

        if config is not None:
            print(f"loading configuration from {config} ...")
            file_args = cfgparsing.load_config(open(config, "r"))
            file_args = utils.flatten(file_args, sep=".")
            file_args.update(parsed_arg_values)
            parsed_arg_values = file_args
            print("configuration loading is completed")

        deflat_d = utils.deflatten(parsed_arg_values, sep=".")
        cfg = decoding.decode(self.config_class, deflat_d)

        return cfg

def parse_adaptor_jpnb(
    config_class: Type[T],
    config: Optional[Union[Path, str]] = None,
    args: Optional[Sequence[str]] = None,
) -> T:
    parser = ArgumentParser(config_class=config_class, config=config)
    return parser.parse_args(args=[])


Now, we start to do some real processing.

<b>Note:</b>
<br>
<br>
Some changes are required depending on the model you intend to use for prediction:

To change the model's configuration, simply replace the existing config parameter with the desired one in the line of the code below:

```python 
   cfg = parse_adaptor_jpnb(config_class=ProgramConfig, config="./semantic_seg_2d_inference_2class.yaml")
```

 For the retarining of the existing model with improvements no changes are needed here:

```python
   cfg = parse_adaptor_jpnb(config_class=ProgramConfig, config="./semantic_seg_2d_inference_2class.yaml") 
```
   
 For the Probabilistic model use:  
 
```python
   cfg = parse_adaptor_jpnb(config_class=ProgramConfig, config="./probabilistic_semantic_seg2D_inference_2Class.yaml")  
```
 
Depending on your model selection, you'll need to choose the correct weights.<br>
This means changing the input in the following line of the code below:

```python 
   cfg.model.checkpoint = Path("./version_2023_06.ckpt")
```

If you desire to test the retraining of the last model, you need to use: 

```python 
   cfg.model.checkpoint = Path("./ClassicUnet_16_06_2025.ckpt")
```

If you desire to test the probabilistic model you need to use:

```python 
   cfg.model.checkpoint = Path("./ProbUnet_16_06_2025.ckpt")
```

These .ckpt files are now available via the same download [LINK](https://ambiomcloud.isas.de/index.php/s/CwcfFRt8eQ9gKWj).

In [None]:

cfg = parse_adaptor_jpnb(config_class=ProgramConfig, config="./semantic_seg_2d_inference_2class.yaml")
cfg = configuration_validation(cfg)

# select which model to use
# the old one (which used for genrating the results evaluated by Jan): version_2023_06.ckpt
# the recent one (slightly improved for antigen data): version_2023_09.ckpt
cfg.model.checkpoint = Path("./version_2023_06.ckpt")

# define the executor for inference (no need to change anything)
executor = ProjectTester(cfg)
executor.setup_model()
executor.setup_data_processing()

In [None]:
# get the data, run inference, and save the result (set your path)

# root path, should be the same as "out_path_base" in the data wrangling notebook
path_base = Path("/path/to/splitted/3d/files")

# input 3D tiff files. Note. we assume data were generated by the data_wrangling notebook
# the tiff file has two channels, first is CD31, second is Coll IV
input_path = path_base / Path("split_3d")
out_p = path_base / Path("pred_2class")
out_p.mkdir(parents=True, exist_ok=True)

# get all files to be processed
filenames = sorted(input_path .glob("*.tiff"))

num = 0
for fn in filenames:
    num = num + 1
    print("--Predicting: ", num,'/',len(filenames), "...")
    img = AICSImage(fn).get_image_data("CZYX", T=0)

    out_list = []
    for zz in range(img.shape[1]):
        # channel order: CD31, Coll IV
        im_input = img[:, zz, :, :]
        seg = executor.process_one_image(im_input)
        out_list.append(np.squeeze(seg))
    seg_full = np.stack(out_list, axis=0)

    ###################################################
    # attempt to remove pericytes by post-processing
    ###################################################
    # remove small objects (size<64)
    seg_2 = remove_small_objects(seg_full == 2, min_size=64)

    # for all mid-size objects (64<s<300), check slice by slice for circles
    seg_2_mid = np.logical_xor(seg_2, remove_small_objects(seg_2, min_size=300))
    for zz in range(seg_2_mid.shape[0]):
        seg_label, num_obj = label(seg_2_mid[zz, :, :], return_num=True)
        if num_obj > 0:
            stats = regionprops(seg_label)
            for ii in range(num_obj):
                # low eccentricity (closer to circle) and not too concave
                if stats[ii].eccentricity < 0.88 and stats[ii].solidity > 0.85 and stats[ii].area < 150:
                    seg_z = seg_2[zz, :, :]
                    seg_z[seg_label == (ii+1)] = 0
                    seg_2[zz, :, :] = seg_z

    seg_full[seg_full == 2] = 1
    seg_full[seg_2 > 0] = 2

    ###################################################################
    # another minor fix: remove small holes due to segmentation errors
    ###################################################################
    hole_size_threshold = 15
    seg_1 = remove_small_objects(seg_full==1, min_size=50)
    seg_2 = seg_full == 2
    for zz in range(seg_full.shape[0]):
        s_v = remove_small_holes(seg_1[zz, :, :], area_threshold=hole_size_threshold)
        seg_1[zz, :, :] = s_v[:, :]

        a_v = remove_small_holes(seg_2[zz, :, :], area_threshold=hole_size_threshold)
        seg_2[zz, :, :] = a_v[:, :]

    ##################################################
    # thickness adjustment
    ##################################################
    # perform thinning on segmentation without breaking topology
    # this is only done for string vessels. Currently, we only removing 1 pixel
    # from the outer laye of the segmented string vessels. If this is not enough,
    # one can set the paramter thin=1 to thin=2 or higher.
    seg_string = topology_preserving_thinning(seg_full == 2, min_thickness=1, thin=1)
    seg_thin = np.zeros_like(seg_full)
    seg_thin[seg_string > 0] = 2
    seg_thin[seg_full == 1] = 1  # the normal vessels are unchanged

    out_fn = out_p / fn.name
    OmeTiffWriter.save(seg_full, out_fn, dim_order="ZYX")
