In [None]:
from careamics.config.configuration_factories import (
    _create_ng_data_configuration,
    _list_spatial_augmentations,
)
from careamics.lightning.dataset_ng.data_module import CareamicsDataModule

In [None]:
%load_ext autoreload
%autoreload 2


In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import tifffile
from careamics_portfolio import PortfolioManager

# instantiate data portfolio manage
portfolio = PortfolioManager()

# and download the data
root_path = Path("./data")
files = portfolio.denoising.N2V_BSD68.download(root_path)

# create paths for the data
data_path = Path(root_path / "denoising-N2V_BSD68.unzip/BSD68_reproducibility_data")
train_path = data_path / "train"
val_path = data_path / "val"
test_path = data_path / "test" / "images"
gt_path = data_path / "test" / "gt"

In [None]:
image_std, image_mean = [], []
for file in train_path.glob("*.tiff"):
    image = tifffile.imread(file)
    image_std.append(image.std())
    image_mean.append(image.mean())
image_std, image_mean = np.mean(image_std), np.mean(image_mean)

In [None]:
# from path, train and val, no target

config = _create_ng_data_configuration(
    data_type="tiff",
    axes="SYX",
    patch_size=(64, 64),
    batch_size=64,
    augmentations=_list_spatial_augmentations()
)

config.set_means_and_stds([image_mean], [image_std])
config.val_dataloader_params = {"shuffle": False}

data_module = CareamicsDataModule(
    data_config=config,
    train_data=train_path,
    val_data=val_path,
)
data_module.setup('fit')
data_module.setup('validate')

train_batch = next(iter(data_module.train_dataloader()))
val_batch = next(iter(data_module.val_dataloader()))

fig, ax = plt.subplots(1, 8, figsize=(10, 5))

for i in range(8):
    ax[i].imshow(train_batch[0].data[i][0].numpy(), cmap="gray")


fig, ax = plt.subplots(1, 8, figsize=(10, 5))
for i in range(8):
    ax[i].imshow(val_batch[0].data[i][0].numpy(), cmap="gray")

In [None]:
# from path, only predict

from careamics.config.data import NGDataConfig

config = NGDataConfig(
    data_type="tiff",
    patching={
        "name": "tiled",
        "patch_size": (128, 128),
        "overlaps": (32, 32)
    },
    axes="YX",
    batch_size=8,
    image_means=[image_mean],
    image_stds=[image_std]
)

data_module = CareamicsDataModule(
    data_config=config,
    pred_data=test_path
)
data_module.setup('predict')

pred_batch = next(iter(data_module.predict_dataloader()))

fig, ax = plt.subplots(1, 8, figsize=(10, 5))

for i in range(8):
    ax[i].imshow(pred_batch[0].data[i][0].numpy(), cmap="gray")

In [None]:
# test from array

In [None]:
train_array = tifffile.imread(sorted(train_path.rglob('*'))[0])
val_array = tifffile.imread(sorted(val_path.rglob('*'))[0])
test_array = tifffile.imread(sorted(test_path.rglob('*'))[0])

In [None]:
# from array, train and val, no target

config = _create_ng_data_configuration(
    data_type="array",
    axes="SYX",
    patch_size=(64, 64),
    batch_size=64,
    augmentations=_list_spatial_augmentations()
)

config.set_means_and_stds([image_mean], [image_std])
config.val_dataloader_params = {"shuffle": False}

data_module = CareamicsDataModule(
    data_config=config,
    train_data=train_array,
    val_data=val_array,
)
data_module.setup('fit')
data_module.setup('validate')

train_batch = next(iter(data_module.train_dataloader()))
val_batch = next(iter(data_module.val_dataloader()))

fig, ax = plt.subplots(1, 8, figsize=(10, 5))

for i in range(8):
    ax[i].imshow(train_batch[0].data[i][0].numpy(), cmap="gray")


fig, ax = plt.subplots(1, 8, figsize=(10, 5))
for i in range(8):
    ax[i].imshow(val_batch[0].data[i][0].numpy(), cmap="gray")

In [None]:
# test with target

In [None]:
import skimage

example_data = skimage.data.human_mitosis()

markers = np.zeros_like(example_data)
markers[example_data < 25] = 1
markers[example_data > 50] = 2

elevation_map = skimage.filters.sobel(example_data)
segmentation = skimage.segmentation.watershed(elevation_map, markers)

fig, ax = plt.subplots(1, 2)
ax[0].imshow(example_data)
ax[1].imshow(segmentation)
plt.show()

In [None]:
config = _create_ng_data_configuration(
    data_type="array",
    axes="YX",
    patch_size=(64, 64),
    batch_size=64,
    augmentations=_list_spatial_augmentations()
)
config.set_means_and_stds(
    [example_data.mean()],
    [example_data.std()],
    [segmentation.mean()],
    [segmentation.std()]
)

data_module = CareamicsDataModule(
    data_config=config,
    train_data=[example_data],
    train_data_target=[segmentation],
    val_data=[example_data],
    val_data_target=[segmentation]
)
data_module.setup('fit')
data_module.setup('validate')

train_batch = next(iter(data_module.train_dataloader()))
val_batch = next(iter(data_module.val_dataloader()))

fig, ax = plt.subplots(2, 8, figsize=(10, 3))

for i in range(8):
    ax[0][i].imshow(train_batch[0].data[i][0].numpy(), cmap="gray")
    ax[1][i].imshow(train_batch[1].data[i][0].numpy())


fig, ax = plt.subplots(2, 8, figsize=(10, 3))
for i in range(8):
    ax[0][i].imshow(val_batch[0].data[i][0].numpy(), cmap="gray")
    ax[1][i].imshow(val_batch[1].data[i][0].numpy())

In [None]:
# from array, only predict, with target

from careamics.config.data import NGDataConfig

config = NGDataConfig(
    data_type="array",
    patching={
        "name": "tiled",
        "patch_size": (128, 128),
        "overlaps": (32, 32)
    },
    axes="YX",
    batch_size=8,
    image_means=[image_mean],
    image_stds=[image_std]
)

data_module = CareamicsDataModule(
    data_config=config,
    pred_data=example_data,
    pred_data_target=segmentation
)
data_module.setup('predict')

pred_batch = next(iter(data_module.predict_dataloader()))

fig, ax = plt.subplots(1, 8, figsize=(10, 5))

for i in range(8):
    ax[i].imshow(pred_batch[0].data[i][0].numpy(), cmap="gray")

In [None]:
# from list of paths

In [None]:
config = _create_ng_data_configuration(
    data_type="tiff",
    axes="SYX",
    patch_size=(64, 64),
    batch_size=64,
    augmentations=_list_spatial_augmentations()
)

config.set_means_and_stds([image_mean], [image_std])
config.val_dataloader_params = {"shuffle": False}

data_module = CareamicsDataModule(
    data_config=config,
    train_data=sorted(train_path.glob("*.tiff")),
    val_data=sorted(val_path.glob("*.tiff")),
)
data_module.setup('fit')
data_module.setup('validate')

train_batch = next(iter(data_module.train_dataloader()))
val_batch = next(iter(data_module.val_dataloader()))

fig, ax = plt.subplots(1, 8, figsize=(10, 5))

for i in range(8):
    ax[i].imshow(train_batch[0].data[i][0].numpy(), cmap="gray")


fig, ax = plt.subplots(1, 8, figsize=(10, 5))
for i in range(8):
    ax[i].imshow(val_batch[0].data[i][0].numpy(), cmap="gray")

In [None]:
# from custom

In [None]:
config = _create_ng_data_configuration(
    data_type="custom",
    axes="SYX",
    patch_size=(64, 64),
    batch_size=64,
    augmentations=_list_spatial_augmentations()
)

config.set_means_and_stds([image_mean], [image_std])
config.val_dataloader_params = {"shuffle": False}

def read_source_func(path):
    image = tifffile.imread(path)
    image = 255 - image
    return image

data_module = CareamicsDataModule(
    data_config=config,
    train_data=sorted(train_path.glob("*.tiff")),
    val_data=sorted(val_path.glob("*.tiff")),
    read_source_func=read_source_func
)
data_module.setup('fit')
data_module.setup('validate')

train_batch = next(iter(data_module.train_dataloader()))
val_batch = next(iter(data_module.val_dataloader()))

fig, ax = plt.subplots(1, 8, figsize=(10, 5))

for i in range(8):
    ax[i].imshow(train_batch[0].data[i][0].numpy(), cmap="gray")


fig, ax = plt.subplots(1, 8, figsize=(10, 5))
for i in range(8):
    ax[i].imshow(val_batch[0].data[i][0].numpy(), cmap="gray")