In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import skimage
import tifffile

from careamics.config.configuration_factories import (
    _create_ng_data_configuration,
    _list_spatial_augmentations,
)
from careamics.dataset_ng.dataset import Mode
from careamics.dataset_ng.factory import create_dataset

In [None]:
example_data = skimage.data.human_mitosis()

markers = np.zeros_like(example_data)
markers[example_data < 25] = 1
markers[example_data > 50] = 2

elevation_map = skimage.filters.sobel(example_data)
segmentation = skimage.segmentation.watershed(elevation_map, markers)

fig, ax = plt.subplots(1, 2)
ax[0].imshow(example_data)
ax[1].imshow(segmentation)
plt.show()

### 1. From an array 

In [None]:
# 1. Train val from an array

train_data_config = _create_ng_data_configuration(
    data_type="array",
    axes="YX",
    patch_size=(32, 32),
    batch_size=1,
    augmentations=_list_spatial_augmentations()
)

val_data_config = _create_ng_data_configuration(
    data_type="array",
    axes="YX",
    patch_size=(32, 32),
    batch_size=1,
    augmentations=[],
)


train_dataset = create_dataset(
    config=train_data_config,
    mode=Mode.TRAINING,
    inputs=[example_data],
    targets=[segmentation],
    in_memory=True,
)
val_dataset = create_dataset(
    config=val_data_config,
    mode=Mode.VALIDATING,
    inputs=[example_data],
    targets=[segmentation],
    in_memory=True,
)

fig, ax = plt.subplots(2, 5, figsize=(10, 5))
ax[0, 0].set_title("Train input")
ax[1, 0].set_title("Train target")
for i in range(5):
    sample, target = train_dataset[i]
    ax[0, i].imshow(sample.data[0])
    ax[1, i].imshow(target.data[0])

### 2. From tiff 

In [None]:
tifffile.imwrite("example_data1.tiff", example_data)
tifffile.imwrite("example_target1.tiff", segmentation)
tifffile.imwrite("example_data2.tiff", example_data[:256, :256])
tifffile.imwrite("example_target2.tiff", segmentation[:256, :256])

train_data_config = _create_ng_data_configuration(
    data_type="tiff",
    axes="YX",
    patch_size=(32, 32),
    batch_size=1,
    augmentations=_list_spatial_augmentations()
)

val_data_config = _create_ng_data_configuration(
    data_type="tiff",
    axes="YX",
    patch_size=(32, 32),
    batch_size=1,
    augmentations=[],
)

data = sorted(Path("./").glob("example_data*.tiff"))
targets = sorted(Path("./").glob("example_target*.tiff"))
train_dataset = create_dataset(
    config=train_data_config,
    mode=Mode.TRAINING,
    inputs=data,
    targets=targets,
    in_memory=True,
)
val_dataset = create_dataset(
    config=val_data_config,
    mode=Mode.VALIDATING,
    inputs=data,
    targets=targets,
    in_memory=True,
)

fig, ax = plt.subplots(2, 5, figsize=(10, 5))
ax[0, 0].set_title("Train input")
ax[1, 0].set_title("Train target")
for i in range(5):
    sample, target = train_dataset[i]
    ax[0, i].imshow(sample.data[0])
    ax[1, i].imshow(target.data[0])

### 3. Prediction from array

In [None]:
from careamics.config.data import NGDataConfig

prediction_config = NGDataConfig(
    data_type="array",
    patching={
        "name": "tiled",
        "patch_size": (32, 32),
        "overlaps": (16, 16),
    },
    axes="YX",
    batch_size=1,
    image_means=[example_data.mean()],
    image_stds=[example_data.std()],
)

prediction_dataset = create_dataset(
    config=prediction_config,
    mode=Mode.PREDICTING,
    inputs=[example_data],
    targets=None,
    in_memory=True,
)

fig, ax = plt.subplots(1, 5, figsize=(10, 5))
ax[0].set_title("Prediction input")
for i in range(5):
    sample, *_ = prediction_dataset[i]
    ax[i].imshow(sample.data[0])

### 4. From custom data type 

In [None]:
train_data_config = _create_ng_data_configuration(
    data_type="custom",
    axes="YX",
    patch_size=(32, 32),
    batch_size=1,
    augmentations=_list_spatial_augmentations(),
)


def read_data_func_test(example_data):
    return 255 - example_data


fig, ax = plt.subplots(1, 5, figsize=(10, 5))
train_dataset = create_dataset(
    config=train_data_config,
    mode=Mode.TRAINING,
    inputs=[example_data],
    targets=[segmentation],
    in_memory=True,
    read_func=read_data_func_test,
    read_kwargs={}
)

for i in range(5):
    sample, _ = train_dataset[i]
    ax[i].imshow(sample.data[0])