In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import tifffile
from careamics_portfolio import PortfolioManager

from careamics.config.configuration_factories import create_n2v_configuration
from careamics.lightning.dataset_ng.lightning_module import N2VModule
from careamics.lightning.dataset_ng.train_data_module import TrainDataModule

In [None]:
# instantiate data portfolio manage
portfolio = PortfolioManager()

# and download the data
root_path = Path("./data")
files = portfolio.denoising.N2V_BSD68.download(root_path)

# create paths for the data
data_path = Path(root_path / "denoising-N2V_BSD68.unzip/BSD68_reproducibility_data")
train_path = data_path / "train"
val_path = data_path / "val"
test_path = data_path / "test" / "images"
gt_path = data_path / "test" / "gt"

# load training and validation image and show them side by side
train_files = sorted(train_path.rglob("*.tiff"))
val_files = sorted(val_path.rglob("*.tiff"))
test_files = sorted(test_path.rglob("*.tiff"))

single_train_image = tifffile.imread(train_files[0])[0]
single_val_image = tifffile.imread(val_files[0])[0]

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(single_train_image, cmap="gray")
ax[0].set_title("Training Image")
ax[1].imshow(single_val_image, cmap="gray")
ax[1].set_title("Validation Image")

In [None]:
config = create_n2v_configuration(
    experiment_name="bsd68_n2v",
    data_type="tiff",
    axes="SYX",
    patch_size=(64, 64),
    batch_size=64,
    num_epochs=100
)

config.data_config.set_means_and_stds([single_train_image.mean()], [single_train_image.std()] )

model = N2VModule(config.algorithm_config)


data_module = TrainDataModule(
    data_config=config.data_config,
    train_data=train_files,
    val_data=val_files,
)

In [None]:
train_batch = next(iter(data_module.train_dataloader()))
val_batch = next(iter(data_module.val_dataloader()))

fig, ax = plt.subplots(1, 8, figsize=(10, 5))

for i in range(8):
    ax[i].imshow(train_batch.data[0][i][0].numpy(), cmap="gray")


fig, ax = plt.subplots(1, 8, figsize=(10, 5))
for i in range(8):
    ax[i].imshow(val_batch.data[0][i][0].numpy(), cmap="gray")

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

root = Path("bsd68_n2v")
callbacks = [
    ModelCheckpoint(
        dirpath=root / "checkpoints",
        filename="bsd68_lightning_api",
        save_last=True,
    )
]

# Create a Lightning Trainer
trainer = Trainer(max_epochs=100, default_root_dir=root, callbacks=callbacks)

# Train the model
trainer.fit(model, datamodule=data_module)

In [None]:

from careamics.config.inference_model import InferenceConfig
from careamics.lightning.dataset_ng.predict_data_module import PredictDataModule

config = InferenceConfig(
    model_config=config,
    data_type="tiff",
    tile_size=(128, 128),
    tile_overlap=(32, 32),
    axes="YX",
    batch_size=1,
    image_means=data_module.train_dataset.input_stats.means,
    image_stds=data_module.train_dataset.input_stats.stds
)

inference_data_module = PredictDataModule(
    data_config=config,
    pred_data=test_files
)

In [None]:
predictions = trainer.predict(model, datamodule=inference_data_module)

In [None]:
def gather_data_samples(prediction):
    result = []

    items = []
    current_sample_id = 0
    for pred in prediction:
        if pred.region_spec['data_idx'] == current_sample_id:
            items.append(pred)
        else:
            result.append(items)
            items = []
            current_sample_id = pred.region_spec['data_idx']
            items.append(pred)
    return result


samples = gather_data_samples(predictions)

In [None]:
prediction = []

for sample in samples:
    input_shape = sample[0].data_shape
    predicted_image = np.zeros(input_shape, dtype=np.float32)
    for pred in sample:
        tile_array = pred.data
        tile_info = pred.region_spec
        data_idx = tile_info['data_idx']
        crop_coords = tile_info['crop_coords']
        crop_sizes = tile_info['crop_size']
        stitch_coords = tile_info['stitch_coords']

        crop_slices = []
        for coord, size in zip(crop_coords, crop_sizes):
            crop_slice = slice(coord, coord + size)
            crop_slices.append(crop_slice)

        crop_slices = (
            ...,
            *crop_slices,
        )

        cropped_tile = tile_array[crop_slices]

        stitch_slices = []
        for coord, size in zip(stitch_coords, crop_sizes):
            stitch_slice = slice(coord, coord + size)
            stitch_slices.append(stitch_slice)

        stitch_slices = (..., *stitch_slices)

        predicted_image[stitch_slices] = cropped_tile.astype(np.float32)

    prediction.append(predicted_image[0][0])



In [None]:
from careamics.utils.metrics import psnr, scale_invariant_psnr

# Show two images
noises = [tifffile.imread(f) for f in sorted(test_path.glob("*.tiff"))]
gts = [tifffile.imread(f) for f in sorted(gt_path.glob("*.tiff"))]

# images to show
images = np.random.choice(range(len(noises)), 3)

fig, ax = plt.subplots(3, 3, figsize=(15, 15))
fig.tight_layout()

for i in range(3):
    pred_image = prediction[images[i]].squeeze()
    psnr_noisy = psnr(gts[images[i]], noises[images[i]], data_range=gts[images[i]].max() - gts[images[i]].min())
    psnr_result = psnr(gts[images[i]], pred_image, data_range=gts[images[i]].max() - gts[images[i]].min())

    scale_invariant_psnr_result = scale_invariant_psnr(gts[images[i]], pred_image)

    ax[i, 0].imshow(noises[images[i]], cmap="gray")
    ax[i, 0].title.set_text(f"Noisy\nPSNR: {psnr_noisy:.2f}")

    ax[i, 1].imshow(pred_image, cmap="gray")
    ax[i, 1].title.set_text(f"Prediction\nPSNR: {psnr_result:.2f}\nScale invariant PSNR: {scale_invariant_psnr_result:.2f}")

    ax[i, 2].imshow(gts[images[i]], cmap="gray")
    ax[i, 2].title.set_text("Ground-truth")


In [None]:
psnrs = np.zeros((len(prediction), 1))
scale_invariant_psnrs = np.zeros((len(prediction), 1))

for i, (pred, gt) in enumerate(zip(prediction, gts)):
    psnrs[i] = psnr(gt, pred.squeeze(), data_range=gt.max() - gt.min())
    scale_invariant_psnrs[i] = scale_invariant_psnr(gt, pred.squeeze())

print(f"PSNR: {psnrs.mean():.2f} +/- {psnrs.std():.2f}")
print(f"Scale invariant PSNR: {scale_invariant_psnrs.mean():.2f} +/- {scale_invariant_psnrs.std():.2f}")
print("Reported PSNR: 27.71")