In [1]:
from dask_image.imread import imread

In [2]:
#filter_list = (
#    filters.Gaussian,
#    filters.LaplacianOfGaussian,
#    filters.GaussianGradientMagnitude,
#    filters.DifferenceOfGaussians,
#    filters.StructureTensorEigenvalues,
#   filters.HessianOfGaussianEigenvalues,)

In [None]:
from typing import Sequence

import dask.array as da
from dask_image.ndfilters import (
    gaussian_filter,
    gaussian_gradient_magnitude,
    gaussian_laplace,
    laplace,
)

from ilastik.napari.filters import GaussianDask, FilterSet

features=FilterSet( filters=[GaussianDask( scale=0.3 ),GaussianDask( scale=0.7 ) ] ) # -> we use dask filters
features

In [None]:
import numpy as np
import dask
from sklearn.pipeline import Pipeline

from spatialdata import read_zarr

dask.config.set({'distributed.worker.daemon': False})

path = "/Users/arnedf/VIB/DATA/test_data_ilastik/output" # change this

sdata = read_zarr("/Users/arnedf/VIB/DATA/test_data_ilastik/sdata_multi_channel.zarr")

#image=imread( "/Users/arnedf/VIB/DATA/test_data_ilastik/fov0/*.tiff" )

image=sdata[ "raw_image" ].data
image

In [None]:
from matplotlib.pyplot import imshow

imshow( image[0] )  # we only plot the first channel

In [5]:
import os
import joblib

def preprocessing_dask(image, estimators, preprocessing_path=None):
    pipe = Pipeline(estimators)
    feature_map_lazy = pipe.transform(image)
    feature_map_lazy.to_zarr( os.path.join( preprocessing_path , "array.zarr") ) # this could be large, so we write to zarr store
    joblib.dump(pipe, os.path.join( preprocessing_path, "preprocessing_pipe.pkl" ))

features=FilterSet( filters=[GaussianDask( scale=0.3 ),GaussianDask( scale=0.7 ) ] )
estimators = [("features", features)]

preprocessing_path = path

preprocessing_dask( image[0], estimators=estimators, preprocessing_path = preprocessing_path )

In [None]:
preprocessed_image=da.from_zarr( os.path.join( preprocessing_path, "array.zarr" ) )

preprocessed_image # preprocessed image

In [None]:
# create some dummy annotations

labels = np.random.choice([0, 1, 2], size=image.shape[1:], p=[0.8, 0.1, 0.1])

In [None]:
from dask.distributed import Client

from ilastik.napari.classifier import NDSparseDaskClassifier
from sklearn.ensemble import RandomForestClassifier

import loguru

logger = loguru.logger

def pixel_training_dask(
    X,labels, model_path=None, **client_kwargs,
):
    clf = NDSparseDaskClassifier(RandomForestClassifier(n_jobs=-1))
    # add the classifier to the pipe, and then dump it
    client = Client(**client_kwargs)
    logger.info(f"Client dashboard link {client.dashboard_link}")

    with joblib.parallel_backend(
        "dask"
    ):  # note, NDSparseDaskClassifier with dask backend will still load data that was annotated in memory (although not the full dataset, only non-zero labels)
        clf.fit(X, labels)

    if model_path is not None:
        joblib.dump(clf, os.path.join(model_path))

# load features from the zarr store
image =  da.from_zarr( os.path.join( preprocessing_path, "array.zarr" ) )
pixel_training_dask( X=image, labels=labels, model_path=os.path.join( path, "model.pkl" ), n_workers=1, threads_per_worker=10  )

In [9]:
def pixel_classification_dask(
    image: da.Array | None,
    preprocessing_path,
    model_path,
    tmp_path,
    **client_kwargs,
):
    # WIP
    if image is None:
        # case where we train and run inference on same image
        image = da.from_zarr( os.path.join( preprocessing_path,  "array.zarr" ) )
    else:
        # load the preprocessing pipe from the path, do the preprocessing on image, and then do the classification
        # this should be used if we have a new image coming in, that we want to preprocesses and classify using pretrained model and the same preprocessing pipe.
        preprocessing_pipe = joblib.load(preprocessing_path / "pipe.pkl")
        image = preprocessing_pipe.transform(image)
        # image could be large
        image.to_zarr(tmp_path)
        image=da.from_zarr(tmp_path)
    clf = joblib.load(model_path)
    client = Client(**client_kwargs)

    clf_scatter = client.scatter(
        clf
    )  # scatter the model otherwise issues with large task graph

    def _predict_clf(arr, model):
        arr = model.predict(arr)
        return arr.squeeze(-1)

    # probably need to use map_overlap instead of map_blocks here
    array_result = da.map_blocks(
        _predict_clf,
        image,
        dtype=image.dtype,
        drop_axis=-1,
        chunks=image.chunks[:-1],
        model=clf_scatter,
        # TODO output dtype not correct, need to fix via meta
    )

    results = array_result.compute()
    return results

In [None]:
results=pixel_classification_dask(image = None, preprocessing_path=path, model_path=os.path.join( path, "model.pkl" ), tmp_path = None,  n_workers=1, threads_per_worker=10  )

In [None]:
results # ->predicted labels