In [1]:
from pathlib import Path
from typing import Union

import numpy as np
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader, Dataset

from careamics_portfolio import PortfolioManager

from careamics import CAREamicsModule
from careamics.config.data import DataModel
from careamics.dataset.in_memory_dataset import InMemoryDataset
from careamics.dataset.iterable_dataset import IterableDataset




In [7]:
# Instantiate ligthning module
model = CAREamicsModule(
    algorithm="n2v",
    loss="n2v", 
    architecture="UNet",
    model_parameters={
        # parameters such as depth, n2v2, etc. See UNet definition.
    },
    optimizer="Adam", # see SupportedOptimizer
    optimizer_parameters={
        "lr": 1e-4,
        # parameters from torch.optim
    },
    lr_scheduler="ReduceLROnPlateau", # see SupportedLRScheduler
    lr_scheduler_parameters={
        # parameters from torch.optim.lr_scheduler
    }
)

In [8]:
# Create trainer
trainer = Trainer(max_epochs=1)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/joran.deschamps/miniconda3/envs/careamics/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


### Possibility 1: Subclass CAREamics Dataset

In [None]:
# declare dataset using CAREamics (reader function passed as argument)
from careamics.dataset.dataset_utils import read_tiff

def read_czi(path) -> np.ndarray:
    return ...

class CZIDataset(InMemoryDataset):
 
    def __init__(self, path: Union[str, Path], config: DataModel) -> None:
        super().__init__(data_path=path, data=config, read_source_func=read_tiff)

# create data configuration
data_config_dict = {
    "in_memory": True,
    "data_format": "tiff", # currently only one working
    "patch_size": [64, 64],
    "axes": "SYX",
    "transforms": [
        # any albumentations transform accepted
        {
            "name": "RandomRotate90",
        },
        {
            "name": "Flip",
        },
        {   
            "name": "Normalize",
        },
        # CAREamics N2V transform should come last
        {
            "name": "ManipulateN2V",
            "parameters": {
                "roi_size": 11,
                "masked_pixel_percentage": 0.198,
            }
        },
    ],
}
data_config = DataModel(**data_config_dict) # validate the configuration


In [10]:
# We use the portfolio to download example data
portfolio = PortfolioManager()

root_path = Path("data")
files = portfolio.denoising.N2V_BSD68.download(root_path)
print(f"List of downloaded files: {files}")

data_path = Path(root_path / "denoising-N2V_BSD68.unzip/BSD68_reproducibility_data")
train_path = data_path / "train"
val_path = data_path / "val"

List of downloaded files: ['/Users/joran.deschamps/git/careamics/careamics/examples/data/denoising-N2V_BSD68.unzip/BSD68_reproducibility_data/test/bsd68_gaussian25.npy', '/Users/joran.deschamps/git/careamics/careamics/examples/data/denoising-N2V_BSD68.unzip/BSD68_reproducibility_data/test/bsd68_groundtruth.npy', '/Users/joran.deschamps/git/careamics/careamics/examples/data/denoising-N2V_BSD68.unzip/BSD68_reproducibility_data/train/DCNN400_train_gaussian25.npy', '/Users/joran.deschamps/git/careamics/careamics/examples/data/denoising-N2V_BSD68.unzip/BSD68_reproducibility_data/val/DCNN400_validation_gaussian25.npy']


In [None]:
# instantiate dataloaders

train_czi_dataloader = DataLoader(
    CZIDataset(train_path, data_config),
    batch_size=64,
    num_workers=4,
)
val_czi_dataloader = DataLoader(
    CZIDataset(val_path, data_config),
    batch_size=1,
    num_workers=0,
)

### Possibility 2: Pass your own Dataset

In [None]:
# write your own Dataset class, ouput must be SC(Z)YX
from careamics.transforms import ManipulateN2V

class CZIDataset(Dataset):

    ... # call default_manipulate on your data

path_to_train_data = ...
path_to_val_data = ...
train_czi_dataloader = DataLoader(
    CZIDataset(...),
    batch_size=64,
    num_workers=4,
)
val_czi_dataloader = DataLoader(
    CZIDataset(...),
    batch_size=1,
    num_workers=0,
)

### Train and predict using Lightning

In [None]:
# train model
trainer.fit(model, train_czi_dataloader, val_czi_dataloader)

In [None]:
# predict
path_to_pred_data = ...


class CZIPredictionDataset(IterableDataset):

    def __init__(self, path: Union[str, Path], config: DataModel) -> None:
        super().__init__(data_path=path, data=config, read_source_func=read_czi)


pred_czi_dataloader = DataLoader(
    CZIPredictionDataset(path_to_pred_data, data_config),
    batch_size=1,
    num_workers=0,)

In [None]:
pred = trainer.predict(model, pred_czi_dataloader)