In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import tifffile
from careamics_portfolio import PortfolioManager

from careamics.config.configuration_factories import (
    _create_ng_data_configuration,
    create_n2v_configuration,
)
from careamics.config.data import NGDataConfig
from careamics.lightning.callbacks import HyperParametersCallback
from careamics.lightning.dataset_ng.data_module import CareamicsDataModule
from careamics.lightning.dataset_ng.lightning_modules import N2VModule

In [None]:
# Set seeds for reproducibility
from pytorch_lightning import seed_everything

seed = 42
seed_everything(seed)

### Load data and set paths to it

In [None]:
# instantiate data portfolio manage and download the data
root_path = Path("./data")

portfolio = PortfolioManager()
files = portfolio.denoising.N2V_BSD68.download(root_path)

# create paths for the data
data_path = Path(root_path / "denoising-N2V_BSD68.unzip/BSD68_reproducibility_data")
train_path = data_path / "train"
val_path = data_path / "val"
test_path = data_path / "test" / "images"
gt_path = data_path / "test" / "gt"

# list train, val and test files
train_files = sorted(train_path.rglob("*.tiff"))
val_files = sorted(val_path.rglob("*.tiff"))
test_files = sorted(test_path.rglob("*.tiff"))

### Visualize a single train and val image

In [None]:
# load training and validation image and show them side by side
single_train_image = tifffile.imread(train_files[0])[0]
single_val_image = tifffile.imread(val_files[0])[0]

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(single_train_image, cmap="gray")
ax[0].set_title("Training Image")
ax[1].imshow(single_val_image, cmap="gray")
ax[1].set_title("Validation Image")

### Create config

In [None]:
config = create_n2v_configuration(
    experiment_name="bsd68_n2v",
    data_type="tiff",
    axes="SYX",
    patch_size=(64, 64),
    batch_size=64,
    num_epochs=100,
)

# TODO until the NGDataConfig is accepted by the Confiugration, these are separte
ng_data_config = _create_ng_data_configuration(
    data_type=config.data_config.data_type,
    axes=config.data_config.axes,
    patch_size=config.data_config.patch_size,
    batch_size=config.data_config.batch_size,
    augmentations=config.data_config.transforms,
    train_dataloader_params=config.data_config.train_dataloader_params,
    val_dataloader_params=config.data_config.val_dataloader_params,
    seed=seed,
)


### Create Lightning datamodule and model

In [None]:
train_data_module = CareamicsDataModule(
    data_config=ng_data_config,
    train_data=train_files,
    val_data=val_files,
)

model = N2VModule(config.algorithm_config)

### Manually initialize the datamodule and visualize single train and val batches

In [None]:
train_data_module.setup("fit")
train_data_module.setup("validate")

train_batch = next(iter(train_data_module.train_dataloader()))
val_batch = next(iter(train_data_module.val_dataloader()))

fig, ax = plt.subplots(1, 8, figsize=(10, 5))
ax[0].set_title("Training Batch")
for i in range(8):
    ax[i].imshow(train_batch[0].data[i][0].numpy(), cmap="gray")

fig, ax = plt.subplots(1, 8, figsize=(10, 5))
ax[0].set_title("Validation Batch")
for i in range(8):
    ax[i].imshow(val_batch[0].data[i][0].numpy(), cmap="gray")

### Train the model

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger

root = Path("bsd68_n2v")
callbacks = [
    ModelCheckpoint(
        dirpath=root / "checkpoints",
        filename="bsd68_new_lightning_module",
        save_last=True,
        monitor="val_loss",
        mode="min",
    ),
    HyperParametersCallback(config),
]
logger = WandbLogger(project="bsd68-n2v", name="bsd68_new_lightning_module")

trainer = Trainer(
    max_epochs=50, default_root_dir=root, callbacks=callbacks, logger=logger
)
trainer.fit(model, datamodule=train_data_module)

### Create an inference config and datamodule

In [None]:
from careamics.dataset_ng.legacy_interoperability import imageregions_to_tileinfos
from careamics.prediction_utils import convert_outputs

config = NGDataConfig(
    data_type="tiff",
    patching={
        "name": "tiled",
        "patch_size": (128, 128),
        "overlaps": (32, 32),
    },
    axes="YX",
    batch_size=1,
    image_means=train_data_module.train_dataset.input_stats.means,
    image_stds=train_data_module.train_dataset.input_stats.stds,
)

inf_data_module = CareamicsDataModule(data_config=config, pred_data=test_files)

### Convert outputs to the legacy format and stitch the tiles

In [None]:
predictions = trainer.predict(model, datamodule=inf_data_module)
tile_infos = imageregions_to_tileinfos(predictions)
predictions = convert_outputs(tile_infos, tiled=True)

### Visualize predictions and count metrics

In [None]:
from careamics.utils.metrics import psnr, scale_invariant_psnr

noises = [tifffile.imread(f) for f in sorted(test_path.glob("*.tiff"))]
gts = [tifffile.imread(f) for f in sorted(gt_path.glob("*.tiff"))]

images = [0, 1, 2]
fig, ax = plt.subplots(3, 3, figsize=(15, 15))
fig.tight_layout()

for i in range(3):
    pred_image = predictions[images[i]].squeeze()
    psnr_noisy = psnr(
        gts[images[i]],
        noises[images[i]],
        data_range=gts[images[i]].max() - gts[images[i]].min(),
    )
    psnr_result = psnr(
        gts[images[i]],
        pred_image,
        data_range=gts[images[i]].max() - gts[images[i]].min(),
    )

    scale_invariant_psnr_result = scale_invariant_psnr(gts[images[i]], pred_image)

    ax[i, 0].imshow(noises[images[i]], cmap="gray")
    ax[i, 0].title.set_text(f"Noisy\nPSNR: {psnr_noisy:.2f}")

    ax[i, 1].imshow(pred_image, cmap="gray")
    ax[i, 1].title.set_text(
        f"Prediction\nPSNR: {psnr_result:.2f}\n"
        f"Scale invariant PSNR: {scale_invariant_psnr_result:.2f}"
    )

    ax[i, 2].imshow(gts[images[i]], cmap="gray")
    ax[i, 2].title.set_text("Ground-truth")

In [None]:
psnrs = np.zeros((len(predictions), 1))
scale_invariant_psnrs = np.zeros((len(predictions), 1))

for i, (pred, gt) in enumerate(zip(predictions, gts, strict=False)):
    psnrs[i] = psnr(gt, pred.squeeze(), data_range=gt.max() - gt.min())
    scale_invariant_psnrs[i] = scale_invariant_psnr(gt, pred.squeeze())

print(f"PSNR: {psnrs.mean():.2f} +/- {psnrs.std():.2f}")
print(
    f"Scale invariant PSNR: "
    f"{scale_invariant_psnrs.mean():.2f} +/- {scale_invariant_psnrs.std():.2f}"
)
print("Reported PSNR: 27.71")