In [14]:
from pathlib import Path
from typing import Union

import numpy as np
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader, Dataset

from careamics.lightning import CAREamicsModel
from careamics.config.algorithm import Algorithm
from careamics.config.data import Data


In [6]:
# Algorithm configuration
config_dict = {
    "algorithm_type": "n2v",
    "model": {
        "architecture": "UNet",
        "is_3D": False,
        "parameters": {}
    },
    "loss": "n2v",
    "optimizer": {
        "name": "Adam",
    },
    "lr_scheduler": {
        "name": "ReduceLROnPlateau"
    },
}

# validate configuration
config = Algorithm(**config_dict)

In [8]:
# instantiate model
model = CAREamicsModel(config)

In [9]:
# create trainer
trainer = Trainer(max_epochs=1)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/joran.deschamps/miniconda3/envs/careamics/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


### Possibility 1: Subclass CAREamics Dataset

In [None]:
# declare dataset using CAREamics (reader function passed as argument)
def read_czi(path) -> np.ndarray:
    return ...

class CZIDataset(CAREamicsDataset):
 
    def __init__(self, path: Union[str, Path], config: Data) -> None:
        super().__init__(path, config, read_func=read_czi)

# create data configuration
data_config_dict = {
    "in_memory": True,
    "dataloader": "Custom", # as opposed to Zarr or Tif
    "patch_size": [68, 68],
    "axes": "YX",
    "transforms": [
        # any albumentations transform accepted
        {
            "name": "RandomRotate90",
        },
        {
            "name": "RandomFlip",
        },
        # CAREamics N2V transform should come last
        {
            "name": "ManipulateN2V",
            "parameters": {
                "roi_size": 11,
                "masked_pixel_percentage": 0.198,
            }
        },
    ],
}
data_config = Data(**data_config_dict) # validate the configuration

# instantiate dataloaders
path_to_train_data = ...
path_to_val_data = ...
train_czi_dataloader = DataLoader(
    CZIDataset(path_to_train_data, data_config_dict),
    batch_size=64,
    num_workers=4,
)
val_czi_dataloader = DataLoader(
    CZIDataset(path_to_val_data, data_config_dict),
    batch_size=1,
    num_workers=0,
)

### Possibility 2: Pass your own CAREamics Dataset

In [None]:
# write your own Dataset class, ouput must be SC(Z)YX
from careamics.augmentation import ManipulateN2V

class CZIDataset(Dataset):

    ... # call default_manipulate on your data

path_to_train_data = ...
path_to_val_data = ...
train_czi_dataloader = DataLoader(
    CZIDataset(...),
    batch_size=64,
    num_workers=4,
)
val_czi_dataloader = DataLoader(
    CZIDataset(...),
    batch_size=1,
    num_workers=0,
)

### Train and predict using Lightning

In [None]:
# train model
trainer.fit(model, train_czi_dataloader, val_czi_dataloader)

In [None]:
# predict
path_to_pred_data = ...
pred_czi_dataloader = DataLoader(CZIDataset(...))
pred = trainer.predict(model, pred_czi_dataloader)