In [None]:
from pathlib import Path
from typing import Union

import numpy as np
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader, Dataset

from careamics_portfolio import PortfolioManager

from careamics import CAREamicsModule
from careamics.config.algorithm import AlgorithmModel
from careamics.config.data import DataModel
from careamics.dataset.in_memory_dataset import InMemoryDataset
from careamics.dataset.iterable_dataset import IterableDataset


In [None]:
# Algorithm configuration
config_dict = {
    "algorithm_type": "n2v",
    "model": {
        "architecture": "UNet",
        "is_3D": False,
        "parameters": {}
    },
    "loss": "n2v",
    "optimizer": {
        "name": "Adam",
    },
    "lr_scheduler": {
        "name": "ReduceLROnPlateau"
    },
}

# validate configuration
config = AlgorithmModel(**config_dict)

In [None]:
# instantiate model
model = CAREamicsModule(config)

In [None]:
# create trainer
trainer = Trainer(max_epochs=1)

### Possibility 1: Subclass CAREamics Dataset

In [None]:
# declare dataset using CAREamics (reader function passed as argument)
from careamics.dataset.dataset_utils import read_tiff

def read_czi(path) -> np.ndarray:
    return ...

class CZIDataset(InMemoryDataset):
 
    def __init__(self, path: Union[str, Path], config: DataModel) -> None:
        super().__init__(data_path=path, data=config, read_source_func=read_tiff)

# create data configuration
data_config_dict = {
    "in_memory": True,
    "data_format": "tiff", # as opposed to Zarr or Tif
    "patch_size": [64, 64],
    "axes": "SYX",
    "transforms": [
        # any albumentations transform accepted
        {
            "name": "RandomRotate90",
        },
        {
            "name": "Flip",
        },
        {   
            "name": "Normalize",
        },
        # CAREamics N2V transform should come last
        {
            "name": "ManipulateN2V",
            "parameters": {
                "roi_size": 11,
                "masked_pixel_percentage": 0.198,
            }
        },
    ],
}
data_config = DataModel(**data_config_dict) # validate the configuration


In [None]:
# Explore portfolio
portfolio = PortfolioManager()
print(portfolio.denoising)

# Download and unzip the files
root_path = Path("data")
files = portfolio.denoising.N2V_BSD68.download(root_path)
print(f"List of downloaded files: {files}")

data_path = Path(root_path / "denoising-N2V_BSD68.unzip/BSD68_reproducibility_data")
train_path = data_path / "train"
val_path = data_path / "val"

train_path.mkdir(parents=True, exist_ok=True)
val_path.mkdir(parents=True, exist_ok=True)

In [None]:
# instantiate dataloaders

train_czi_dataloader = DataLoader(
    CZIDataset(train_path, data_config),
    batch_size=64,
    num_workers=4,
)
val_czi_dataloader = DataLoader(
    CZIDataset(val_path, data_config),
    batch_size=1,
    num_workers=0,
)

### Possibility 2: Pass your own CAREamics Dataset

In [None]:
# write your own Dataset class, ouput must be SC(Z)YX
from careamics.transforms import ManipulateN2V

class CZIDataset(Dataset):

    ... # call default_manipulate on your data

path_to_train_data = ...
path_to_val_data = ...
train_czi_dataloader = DataLoader(
    CZIDataset(...),
    batch_size=64,
    num_workers=4,
)
val_czi_dataloader = DataLoader(
    CZIDataset(...),
    batch_size=1,
    num_workers=0,
)

### Train and predict using Lightning

In [None]:
# train model
trainer.fit(model, train_czi_dataloader, val_czi_dataloader)

In [None]:
# predict
path_to_pred_data = ...


class CZIPredictionDataset(IterableDataset):

    def __init__(self, path: Union[str, Path], config: DataModel) -> None:
        super().__init__(data_path=path, data=config, read_source_func=read_czi)


pred_czi_dataloader = DataLoader(
    CZIPredictionDataset(path_to_pred_data, data_config),
    batch_size=1,
    num_workers=0,)

In [None]:
pred = trainer.predict(model, pred_czi_dataloader)

In [None]:
from careamics.prediction.prediction_utils import stitch_prediction


# stitch prediction
stitched_prediction = stitch_prediction(pred)