In [None]:

from pathlib import Path

import zarr
import matplotlib.pyplot as plt
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

from careamics.config.data import NGDataConfig
from careamics.config.architectures import UNetModel
from careamics.lightning.dataset_ng.callbacks.prediction_writer import (
    PredictionWriterCallback,
)
from careamics.lightning.dataset_ng.data_module import CareamicsDataModule

from careamics_seg.configuration import SegAlgorithm
from careamics_seg.model import SegModule


In [None]:
# parameters
n_classes = 1
is_2d = True

In [None]:
# configuration
algorithm_config = SegAlgorithm(
    loss="dice",
    model=UNetModel(
        architecture="UNet",
        conv_dims=2 if is_2d else 3,
        n_classes=n_classes,
        independent_channels=False,
    )
)

data_config = NGDataConfig(
    data_type="tiff",  # only comptible with tiff files!
    axes="YX" if is_2d else "ZYX",
    patching={
        "name": "random",
        "patch_size": (64, 64) if is_2d else (32, 64, 64), # adjust patch sizes
    },
    batch_size=8, # adjust batch size
    target_means=[0], # hack to not compute target normalization
    target_stds=[1],
    train_dataloader_params={
        "num_workers": 0, # can be changed on VDI or HPC systems
        "shuffle": True
    },
    val_dataloader_params={
        "num_workers": 0
    }
)


In [None]:
train_data_dir = ...
val_data_dir = ...
train_target_data_dir = ...
val_target_data_dir = ...


# Dataset
data = CareamicsDataModule(
    data_config=data_config,
    train_data=train_data_dir,
    val_data=val_data_dir,
    train_data_target=train_target_data_dir,
    val_data_target=val_target_data_dir
)


In [None]:
# model
model = SegModule(
    algorithm_config=algorithm_config
)

# create prediction writer callback params
predict_writer = PredictionWriterCallback(dirpath=Path("predict_output"))

# create trainer
trainer = Trainer(
    max_epochs=10, # change number of epochs as needed
    limit_train_batches=100, # change number of steps as needed
    default_root_dir=Path("experiment"),
    callbacks=[
        ModelCheckpoint(
            dirpath=Path("experiment/checkpoints"),
            filename="test_seg",
        ),
        predict_writer,
    ],
)

# train
trainer.fit(model, datamodule=data)

In [None]:
pred_data_dir = ...


# predict
predict_writer.set_writing_strategy(write_type="tiff", tiled=False) 

pred_dataset_cfg = data_config.convert_mode(
    new_mode="predicting",
    new_dataloader_params={
        "num_workers": 0
    }
)

predict_data = CareamicsDataModule(
    data_config=pred_dataset_cfg,
    pred_data=pred_data_dir,
)

# predict
trainer.predict(model, datamodule=predict_data, return_predictions=False)


