In [None]:

from pathlib import Path

import zarr
import matplotlib.pyplot as plt
import numpy as np
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

from careamics_portfolio import PortfolioManager

from careamics.config.data import NGDataConfig
from careamics.config.architectures import UNetModel
from careamics.lightning.dataset_ng.callbacks.prediction_writer import (
    PredictionWriterCallback,
)
from careamics.lightning.dataset_ng.data_module import CareamicsDataModule


from careamics_seg.configuration import SegAlgorithm
from careamics_seg.model import SegModule


In [None]:
# dataset
portfolio = PortfolioManager()
portfolio.denoiseg.DSB2018_n0.download(Path(".") / "data")

In [None]:
train_data_path = '/Users/joran.deschamps/git/uv_tests/careamics/data/denoiseg-DSB2018_n0.unzip/DSB2018_n0/train/train_data.npz'
archive = np.load(train_data_path)
x_train = archive['X_train']
y_train = archive['Y_train']
x_val = archive['X_val']
y_val = archive['Y_val']


# threshold Y
y_train = (y_train > 0).astype(np.int8)
y_val = (y_val > 0).astype(np.int8)

print(x_train.shape)

# show two training images and the corresponding labels
fig, axs = plt.subplots(1, 2, figsize=(5, 8))
axs[0].imshow(x_train[0], cmap='gray')
axs[0].set_title('Training image 1')
axs[1].imshow(y_train[0], cmap='gray')
axs[1].set_title('Label image 1')
plt.show()



In [None]:
# save in zarr
if not Path('denoiseg_dsb2018_n0.zarr').exists():
    z = zarr.open('denoiseg_dsb2018_n0.zarr', mode='w')
    train_g = z.create_group('train')
    label_g = z.create_group('train_labels')

    # create arrays
    for i in range(x_train.shape[0]):
        img = x_train[i]
        label = y_train[i]
        train_g.create_array(f"array{i}", data=img, chunks=(64, 64))
        label_g.create_array(f"array{i}", data=label, chunks=(64, 64))


In [None]:
# save uris
z = zarr.open('denoiseg_dsb2018_n0.zarr', mode='r')

train_array_keys = list(z['train'].array_keys())

train_uris = [
    str(z['train'][arr].store_path)
    for arr in train_array_keys
]
assert len(train_uris) > 0

val_uris = train_uris[-10:]  # last 10 for validation
val_target_uris = [
    str(z['train_labels'][arr].store_path)
    for arr in train_array_keys[-10:]
]
train_uris = train_uris[:-10]  # rest for training
train_target_uris = [
    str(z['train_labels'][arr].store_path)
    for arr in train_array_keys[:-10]
]

print(f"Number of training samples: {len(train_uris)}")

In [None]:
n_classes = 1
is_2d= True

# configuration
algorithm_config = SegAlgorithm(
    loss="dice",
    model=UNetModel(
        architecture="UNet",
        conv_dims=2 if is_2d else 3,
        n_classes=n_classes,
        independent_channels=False,
    )
)

data_config = NGDataConfig(
    data_type="zarr",
    axes="YX" if is_2d else "ZYX",
    patching={
        "name": "random",
        "patch_size": (64, 64) if is_2d else (32, 64, 64),
    },
    batch_size=8,
    image_means=[13.587576],
    image_stds=[18.4636317],
    target_means=[0],
    target_stds=[1],
    train_dataloader_params={
        "num_workers": 0,
        "shuffle": True
    },
    val_dataloader_params={
        "num_workers": 0
    }
)


In [None]:
# Dataset
data = CareamicsDataModule(
    data_config=data_config,
    train_data=train_uris,
    val_data=val_uris,
    train_data_target=train_target_uris,
    val_data_target=val_target_uris
)


In [None]:
# Model
model = SegModule(
    algorithm_config=algorithm_config
)

In [None]:
# create prediction writer callback params
predict_writer = PredictionWriterCallback(dirpath=Path("predict_output"))

# create trainer
trainer = Trainer(
    max_epochs=10,
    limit_train_batches=100,
    default_root_dir=Path("data"),
    callbacks=[
        ModelCheckpoint(
            dirpath=Path("data/checkpoints"),
            filename="test_seg",
        ),
        predict_writer,
    ],
)

# train
trainer.fit(model, datamodule=data)

In [None]:
means = data.train_dataset.target_stats.means
stds = data.train_dataset.target_stats.stds
means, stds

In [None]:

# predict
predict_writer.set_writing_strategy(write_type="zarr", tiled=True)
means = data.train_dataset.input_stats.means
stds = data.train_dataset.input_stats.stds

pred_dataset_cfg = NGDataConfig(
    data_type="zarr",
    axes="YX",
    batch_size=4,
    patching={
        "name": "tiled",
        "patch_size": (64, 64),
        "overlaps": (32, 32),
    },
    transforms=[],
    image_means=means,
    image_stds=stds,
    test_dataloader_params={
        "num_workers": 0
    }
)

predict_data = CareamicsDataModule(
    data_config=pred_dataset_cfg,
    pred_data=[
        arr
        for i, arr in enumerate(train_uris)
        if i < 20
    ],
)

# predict
trainer.predict(model, datamodule=predict_data, return_predictions=False)




In [None]:
# show predictions
original = train_g

prediction = zarr.open_group(Path("predict_output/denoiseg_dsb2018_n0_output.zarr"), path="train", mode="r")
preds = [
    p for p in prediction.array_keys()
]


# plot original and prediction for first 3 images
for i in range(3):
    name = preds[i]

    fig, axs = plt.subplots(1, 2, figsize=(8, 4))
    axs[0].imshow(original[name], cmap='gray')
    axs[0].set_title(f'Original image {name}')
    axs[1].imshow(prediction[name], cmap='gray')
    axs[1].set_title(f'Predicted image {name}')
    plt.show()
