The BSD68 dataset was adapted from K. Zhang et al (TIP, 2017) and is composed of natural
images. The noise was artificially added, allowing for quantitative comparisons with the
ground truth, one of the benchmark used in many denoising publications. Here, we check 
the performances of Noise2Void using the Lightning API of CAREamics.

This API gives you more freedom to customize the training by using wrappers around the
main elements of CAREamics: the datasets and the lightning module.

In [None]:
# Imports necessary to execute the code
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import tifffile
from PIL import Image
from careamics.lightning import (
    create_predict_datamodule,
    create_train_datamodule,
)
from careamics.lightning.lightning_module import VAEModule
from careamics.config import create_hdn_configuration
from careamics.config.support import SupportedTransform
from careamics.prediction_utils import convert_outputs
from careamics.utils.metrics import psnr
from careamics_portfolio import PortfolioManager
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

## Import the dataset

The dataset can be directly downloaded using the `careamics-portfolio` package, which
uses `pooch` to download the data.

In [None]:
# instantiate data portfolio manage
portfolio = PortfolioManager()

# and download the data
root_path = Path("./data")
files = portfolio.denoising.N2V_BSD68.download(root_path)

# create paths for the data
data_path = Path(root_path / "denoising-N2V_BSD68.unzip/BSD68_reproducibility_data")
train_path = data_path / "train"
val_path = data_path / "val"
test_path = data_path / "test" / "images"
gt_path = data_path / "test" / "gt"

## Visualize data

In [None]:
# load training and validation image and show them side by side
single_train_image = tifffile.imread(next(iter(train_path.rglob("*.tiff"))))[0]
single_val_image = tifffile.imread(next(iter(val_path.rglob("*.tiff"))))[0]

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(single_train_image, cmap="gray")
ax[0].set_title("Training Image")
ax[1].imshow(single_val_image, cmap="gray")
ax[1].set_title("Validation Image")

## Train with the CAREamics Lightning API

Using the Lightning API of CAREamics, you need to instantiate the lightning module, the 
data module and the trainer yourself.

### Create the Lightning module

In [None]:
# loss_config = get_loss_config(**experiment_params)
# model_config = get_model_config(**experiment_params)
# gaussian_lik_config, noise_model_config, nm_lik_config = get_likelihood_config(
#     **experiment_params
# )
# training_config = get_training_config(**experiment_params)

# # setting up learning rate scheduler and optimizer (using default parameters)
# lr_scheduler_config = get_lr_scheduler_config(**experiment_params)
# optimizer_config = get_optimizer_config(**experiment_params)

In [None]:
config = create_hdn_configuration(
    experiment_name="bsd68_hdn",
    data_type="tiff",
    axes="SYX",
    z_dims=[32] * 4,
    patch_size=(128, 128),
    batch_size=64,
    num_epochs=5,
    predict_logvar="pixelwise",
    train_dataloader_params={"num_workers": 4},
    val_dataloader_params={"num_workers": 4},
    logger=None
)

print(config)


### Create the data module

In [None]:
model = VAEModule(config.algorithm_config)

In [None]:
train_data_module = create_train_datamodule(
    train_data=train_path,
    val_data=val_path,
    data_type=config.data_config.data_type,
    patch_size=config.data_config.patch_size,
    axes=config.data_config.axes,
    batch_size=config.data_config.batch_size,
    transforms=[],
    train_dataloader_params=config.data_config.train_dataloader_params,
    val_dataloader_params=config.data_config.val_dataloader_params,
)

### Create the trainer

Note that here we modify the prediction loop, but this will be  changed in the near
future.

In [None]:
# Create Callbacks
root = Path("bsd68_n2v")
callbacks = [
    ModelCheckpoint(
        dirpath=root / "checkpoints",
        filename="bsd68_lightning_api",
        save_last=True,
    )
]

# Create a Lightning Trainer
trainer = Trainer(max_epochs=config.training_config.num_epochs, default_root_dir=root, callbacks=callbacks)

# Train the model
trainer.fit(model, datamodule=train_data_module)

## Predict with CAREamics Lightning API

### Define the prediction datamodule

In [None]:
means, stds = train_data_module.get_data_statistics()
pred_data_module = create_predict_datamodule(
    pred_data=test_path,
    data_type="tiff",
    axes="YX",
    batch_size=1,
    tta_transforms=False, #not implemented
    image_means=means,
    image_stds=stds,
    tile_size=(128, 128),
    tile_overlap=(32, 32),
)

### Predict

In [None]:
# Predict
prediction = trainer.predict(model, datamodule=pred_data_module)

# Convert the outputs to the original format, mostly useful if tiling is used
prediction = convert_outputs(prediction, tiled=True)

### Visualize the prediction

In [None]:
# Show two images
noises = [tifffile.imread(f) for f in sorted(test_path.glob("*.tiff"))]
gts = [tifffile.imread(f) for f in sorted(gt_path.glob("*.tiff"))]

# images to show
images = np.random.choice(range(len(noises)), 3)

fig, ax = plt.subplots(3, 3, figsize=(15, 15))
fig.tight_layout()

for i in range(3):
    pred_image = prediction[images[i]].squeeze()
    psnr_noisy = psnr(gts[images[i]], noises[images[i]], data_range=noises[images[i]].max())
    psnr_result = psnr(gts[images[i]], pred_image, pred_image.max())

    ax[i, 0].imshow(noises[images[i]], cmap="gray")
    ax[i, 0].title.set_text(f"Noisy\nPSNR: {psnr_noisy:.2f}")

    ax[i, 1].imshow(pred_image, cmap="gray")
    ax[i, 1].title.set_text(f"Prediction\nPSNR: {psnr_result:.2f}")

    ax[i, 2].imshow(gts[images[i]], cmap="gray")
    ax[i, 2].title.set_text("Ground-truth")

### Compute metrics

In [None]:
psnrs = np.zeros((len(prediction), 1))

for i, (pred, gt) in enumerate(zip(prediction, gts)):
    psnrs[i] = psnr(gt, pred.squeeze(), pred.max())

print(f"PSNR: {psnrs.mean():.2f} +/- {psnrs.std():.2f}")
print("Reported PSNR: 27.71")

## Create cover

In [None]:
# create a cover image
im_idx = 3
cv_image_noisy = noises[im_idx]
cv_image_pred = prediction[im_idx].squeeze()

# create image
cover = np.zeros((256, 256))   
(height, width) = cv_image_noisy.shape
assert height > 256
assert width > 256

# normalize train and prediction
norm_noise = (cv_image_noisy - cv_image_noisy.min()) / (cv_image_noisy.max() - cv_image_noisy.min())
norm_pred = (cv_image_pred - cv_image_pred.min()) / (cv_image_pred.max() - cv_image_pred.min())

# fill in halves
cover[:, :256 // 2] = norm_noise[height // 2 - 256 // 2:height // 2 + 256 // 2, width // 2 - 256 // 2:width // 2]
cover[:, 256 // 2:] = norm_pred[height // 2 - 256 // 2:height // 2 + 256 // 2, width // 2:width // 2 + 256 // 2]

# plot the single image
plt.imshow(cover, cmap="gray")

# save the image
im = Image.fromarray(cover * 255)
im = im.convert('L')
im.save("BSD68_Noise2Void_lightning_api.jpeg")