In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import zarr

from careamics.config.configuration_factories import (
    _create_ng_data_configuration,
    create_n2v_configuration,
)
from careamics.config.data import NGDataConfig
from careamics.lightning.callbacks import HyperParametersCallback
from careamics.lightning.dataset_ng.data_module import CareamicsDataModule
from careamics.lightning.dataset_ng.lightning_modules import N2VModule

In [None]:
# Set seeds for reproducibility
from pytorch_lightning import seed_everything

seed = 42
seed_everything(seed)

In [None]:
import platform

import torch

print(platform.processor() in ('arm', 'arm64') and torch.backends.mps.is_available())

## Create zarr dataset

In [None]:
# # instantiate data portfolio manage and download the data
# root_path = Path("data")

# portfolio = PortfolioManager()
# files = portfolio.denoising.N2V_BSD68.download(root_path)

# # create paths for the data
# data_path = root_path / "denoising-N2V_BSD68.unzip/BSD68_reproducibility_data"
# train_path = data_path / "train" / "DCNN400_train_gaussian25.tiff"
# val_path = data_path / "val" / "DCNN400_validation_gaussian25.tiff"

# root = zarr.create_group(root_path / "bsd68.zarr")
# print(f"Creating zarr at: {root.store_path}")

# # add train images to train group
# train = root.create_group('train')
# train_img = tifffile.imread(train_path)

# for i in range(train_img.shape[0]):

#     img = train_img[i]
#     name = f"img_{i:04d}"
#     train.create_array(name=name, data=img, chunks=(128, 128))

# # add validation images to train group
# val = root.create_group('val')
# val_img = tifffile.imread(val_path)

# for i in range(val_img.shape[0]):

#     img = val_img[i]
#     name = f"img_{i:04d}"
#     val.create_array(name=name, data=img, chunks=(128, 128))


# # add test gt to zarr
# test_gt_f = sorted([f for f in (data_path / "test" / "gt").glob("*.tiff")])
# test_gt_z = root.create_group('test_gt')
# for i in test_gt_f:
#     img = tifffile.imread(i)
#     name = i.stem
#     test_gt_z.create_array(name=name, data=img, chunks=(128, 128))


# test_raw_f = sorted([f for f in (data_path / "test" / "images").glob("*.tiff")])
# test_raw_z = root.create_group('test_raw')
# for i in test_raw_f:
#     img = tifffile.imread(i)
#     name = i.stem
#     test_raw_z.create_array(name=name, data=img, chunks=(128, 128))

In [None]:
# data path
g = zarr.open(Path("data") / "bsd68.zarr")
train_path = str(g["train"].store_path)
val_path = str(g["val"].store_path)

### Visualize a single train and val image

In [None]:
# load training and validation image and show them side by side
list_train_arrays = list(g["train"].array_keys())
list_val_arrays = list(g["val"].array_keys())

single_train_image = g["train"][list_train_arrays[0]]
single_val_image = g["val"][list_val_arrays[0]]
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(single_train_image, cmap="gray")
ax[0].set_title("Training Image")
ax[1].imshow(single_val_image, cmap="gray")
ax[1].set_title("Validation Image")

### Create config

In [None]:
epochs = 1
steps = 50
batch_size = 1

config = create_n2v_configuration(
    experiment_name="bsd68_n2v",
    data_type="custom",
    axes="YX",
    patch_size=(64, 64),
    batch_size=batch_size,
    num_epochs=epochs,
    num_steps=steps,
)

# TODO until the NGDataConfig is accepted by the Configuration, these are separate
ng_data_config = _create_ng_data_configuration(
    data_type="zarr", # specific to NG Dataset
    axes=config.data_config.axes,
    patch_size=config.data_config.patch_size,
    batch_size=config.data_config.batch_size,
    augmentations=config.data_config.transforms,
    train_dataloader_params=config.data_config.train_dataloader_params,
    val_dataloader_params=config.data_config.val_dataloader_params,
    seed=seed,
)

ng_data_config.set_means_and_stds(
    image_means=[0], image_stds=[1]
)


### Create Lightning datamodule and model

In [None]:
train_data_module = CareamicsDataModule(
    data_config=ng_data_config,
    train_data=[train_path],
    val_data=[val_path],
)

model = N2VModule(config.algorithm_config)

### Manually initialize the datamodule and visualize single train and val batches

In [None]:
train_data_module.setup("fit")
train_data_module.setup("validate")

tdm = train_data_module.train_dataloader()
vdm = train_data_module.val_dataloader()
train_batch = next(iter(tdm))
val_batch = next(iter(vdm))

fig, ax = plt.subplots(1, batch_size, figsize=(10, 5))
ax[0].set_title("Training Batch")
for i in range(batch_size):
    ax[i].imshow(train_batch.data[i][0].numpy(), cmap="gray")

fig, ax = plt.subplots(1, batch_size, figsize=(10, 5))
ax[0].set_title("Validation Batch")
for i in range(batch_size):
    ax[i].imshow(val_batch.data[i][0].numpy(), cmap="gray")

### Train the model

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

root = Path("bsd68_n2v")
callbacks = [
    ModelCheckpoint(
        dirpath=root / "checkpoints",
        filename="bsd68_new_lightning_module",
        save_last=True,
        monitor="val_loss",
        mode="min",
    ),
    HyperParametersCallback(config),
]

trainer = Trainer(
    max_epochs=epochs,
    default_root_dir=root,
    callbacks=callbacks,
    limit_train_batches=steps
)
trainer.fit(model, datamodule=train_data_module)

### Create an inference config and datamodule

In [None]:
from careamics.dataset_ng.legacy_interoperability import imageregions_to_tileinfos
from careamics.prediction_utils import convert_outputs

test_files = [str(g["test_raw"].store_path)]


config = NGDataConfig(
    data_type="zarr",
    patching={
        "name": "tiled",
        "patch_size": (128, 128),
        "overlaps": (48, 48),
    },
    axes="YX",
    batch_size=1,
    image_means=train_data_module.train_dataset.input_stats.means,
    image_stds=train_data_module.train_dataset.input_stats.stds,
)

inf_data_module = CareamicsDataModule(data_config=config, pred_data=test_files)

### Convert outputs to the legacy format and stitch the tiles

In [None]:
predictions = trainer.predict(model, datamodule=inf_data_module)
tile_infos = imageregions_to_tileinfos(predictions)
predictions = convert_outputs(tile_infos, tiled=True)

### Visualize predictions and count metrics

In [None]:
from careamics.utils.metrics import psnr, scale_invariant_psnr

noises_str = sorted(g["test_raw"].array_keys())
gts_str = sorted(g["test_gt"].array_keys())

noises = [
    g["test_raw"][arr] for arr in noises_str
]
gts = [
    g["test_gt"][arr] for arr in gts_str
]

images = [0, 1, 2]
fig, ax = plt.subplots(3, 3, figsize=(15, 15))
fig.tight_layout()

for i in range(3):
    gts_arrs = gts[images[i]][...]
    noises_arrs = noises[images[i]][...]

    pred_image = predictions[images[i]].squeeze()
    psnr_noisy = psnr(
        gts_arrs,
        noises_arrs,
        data_range=gts_arrs.max() - gts_arrs.min(),
    )
    psnr_result = psnr(
        gts_arrs,
        pred_image,
        data_range=gts_arrs.max() - gts_arrs.min(),
    )

    scale_invariant_psnr_result = scale_invariant_psnr(gts_arrs, pred_image)

    ax[i, 0].imshow(noises[images[i]], cmap="gray")
    ax[i, 0].title.set_text(f"Noisy\nPSNR: {psnr_noisy:.2f}")

    ax[i, 1].imshow(pred_image, cmap="gray")
    ax[i, 1].title.set_text(
        f"Prediction\nPSNR: {psnr_result:.2f}\n"
        f"Scale invariant PSNR: {scale_invariant_psnr_result:.2f}"
    )

    ax[i, 2].imshow(gts_arrs, cmap="gray")
    ax[i, 2].title.set_text("Ground-truth")

In [None]:
psnrs = np.zeros((len(predictions), 1))
scale_invariant_psnrs = np.zeros((len(predictions), 1))

for i, (pred, gt) in enumerate(zip(predictions, gts, strict=False)):
    gt = gt[...]
    psnrs[i] = psnr(gt, pred.squeeze(), data_range=gt.max() - gt.min())
    scale_invariant_psnrs[i] = scale_invariant_psnr(gt, pred.squeeze())

print(f"PSNR: {psnrs.mean():.2f} +/- {psnrs.std():.2f}")
print(
    f"Scale invariant PSNR: "
    f"{scale_invariant_psnrs.mean():.2f} +/- {scale_invariant_psnrs.std():.2f}"
)
print("Reported PSNR: 27.71")