In [None]:
import albumentations as Aug
from pytorch_lightning import Trainer

from careamics import (
    CAREamicsModule,
    CAREamicsTrainDataModule,
)
from careamics.transforms import ManipulateN2V

In [None]:
# Instantiate ligthning module
model = CAREamicsModule(
    algorithm="n2v",
    loss="n2v",
    architecture="UNet",
    model_parameters={
        # parameters such as depth, n2v2, etc. See UNet definition.
    },
    optimizer="Adam",  # see SupportedOptimizer
    optimizer_parameters={
        "lr": 1e-4,
        # parameters from torch.optim
    },
    lr_scheduler="ReduceLROnPlateau",  # see SupportedLRScheduler
    lr_scheduler_parameters={
        # parameters from torch.optim.lr_scheduler
    },
)

In [None]:
# Create trainer
trainer = Trainer(max_epochs=50)

### Possibility 1: Use CAREamics data module

In [None]:
# define function to read data


def read_my_data_type(file):
    pass


# Create your transforms using albumentations
transforms = Aug.Compose(
    [Aug.Flip(), Aug.RandomRotate90(), Aug.Normalize(), ManipulateN2V()],
)

# Instantiate data module
train_path = ...
val_path = ...

train_data_module = CAREamicsTrainDataModule(
    train_path=train_path,
    val_path=val_path,
    data_type="custom",  # this forces read_source_func to be specified
    patch_size=(64, 64),
    axes="SYX",
    batch_size=128,
    transforms=transforms,
    num_workers=4,
    read_source_func=read_my_data_type,  # function to read data
)

### Possibility 2: Pass your own Dataset

In [None]:
# Write youn own lightning data module
train_data_module = ...

### Train and predict using Lightning

In [None]:
# train model
trainer.fit(model, datamodule=train_data_module)

In [None]:
# predict
path_to_pred_data = ...
pred_dataloader = ...

pred = trainer.predict(model, pred_dataloader)