In [None]:
# Imports necessary to execute the code
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pooch
import tifffile
import torch
from careamics import CAREamist
from careamics.config import GaussianMixtureNMConfig, create_microsplit_configuration, create_n2v_configuration
from careamics.lightning import (
    create_microsplit_predict_datamodule,
    create_microsplit_train_datamodule,
)
from careamics.lightning.callbacks import DataStatsCallback
from careamics.lightning.lightning_module import VAEModule
from careamics.lvae_training.dataset import DataSplitType
from careamics.lvae_training.eval_utils import get_device
from careamics.models.lvae.noise_models import (
    GaussianMixtureNoiseModel,
    create_histogram,
)
from careamics.prediction_utils import convert_outputs_microsplit
from careamics.utils.metrics import psnr
from PIL import Image
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from utils import get_train_val_data

## Import the dataset

The dataset can be directly downloaded using the `careamics-portfolio` package, which
uses `pooch` to download the data.

In [None]:
# TODO replace with PortfolioManager
DATA = pooch.create(
    path="./data/",
    base_url="https://download.fht.org/jug/msplit/ht_lif24/data_tiff/",
    registry={"ht_lif24_5ms_reduced.zip": None},
)
for fname in DATA.registry:
    DATA.fetch(fname, processor=pooch.Unzip(), progressbar=True)

DATA_PATH = DATA.abspath / (DATA.registry_files[0] + ".unzip/5ms/data/")

In [None]:
nm_input = input_data = get_train_val_data(
    datadir=DATA_PATH,
    datasplit_type=DataSplitType.Train,
    val_fraction=0.1,
    test_fraction=0.1,
)

In [None]:
NM_PATH = Path("./noise_models/")

## Visualize data

In [None]:
_, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(nm_input[0, ..., 0])
ax[0].set_title("Input channel 1")
ax[1].imshow(nm_input[0, ..., 1])
ax[1].set_title("Input channel 2")
plt.show()

## Train with the CAREamics Lightning API

Using the Lightning API of CAREamics, you need to instantiate the lightning module, the 
data module and the trainer yourself.

## Create Noise Models

Please note that for this step we'll use the high-level CAREamics API. 

### Configure N2V


In [None]:
config = create_n2v_configuration(
    experiment_name="my_data_noise_models_n2v",
    data_type="array",
    axes="SYXC",
    n_channels=2,
    patch_size=(64, 64),
    batch_size=64,
    num_epochs=5,  # We set the training to 5 epochs, but you can change this to a higher number if you want a better Noise Model.
)

print("N2V configuration generated.")

### Train N2V on the data we prepared
This might take a while, mainly if you changed `num_epochs` above or if you do not have a quick GPU.

In [None]:
careamist = CAREamist(source=config, work_dir="noise_models")
careamist.train(train_source=nm_input, val_minimum_split=5)

### Denoise loaded data with the N2V model we just trained

In [None]:
prediction = careamist.predict(nm_input, tile_size=(256, 256))

In [None]:
# Make your choice. If 'False', the entire image will be shown...
do_crop = True

xfrom = yfrom = 0
xto = yto = -1
strcrop = ""
if do_crop:
    strcrop = " (crop)"
    yfrom = 200
    yto = 600
    xfrom = 800
    xto = 1200

_, ax = plt.subplots(2, 2, figsize=(10, 10))
ax[0][0].imshow(nm_input[0, ..., 0][yfrom:yto, xfrom:xto])
ax[0][0].set_title("Input channel 1" + strcrop)
ax[0][1].imshow(prediction[0].squeeze()[0][yfrom:yto, xfrom:xto])
ax[0][1].set_title("Denoised channel 1" + strcrop)
ax[1][0].imshow(nm_input[0, ..., 1][yfrom:yto, xfrom:xto])
ax[1][0].set_title("Input channel 2" + strcrop)
ax[1][1].imshow(prediction[0].squeeze()[1][yfrom:yto, xfrom:xto])
ax[1][1].set_title("Denoised channel 2" + strcrop)
plt.show()

### Train the Noise Model

In [None]:
for channel_idx in range(nm_input.shape[-1]):

    # train Noise Model for current channel
    print(f"Training noise model for channel {channel_idx}")
    channel_data = nm_input[..., channel_idx]
    channel_prediction = np.concatenate(prediction)[:, channel_idx]
    noise_model_config = GaussianMixtureNMConfig(
        model_type="GaussianMixtureNoiseModel",
        min_signal=channel_data.min(),
        max_signal=channel_data.max(),
        n_coeff=3,
        n_gaussian=3,
    )
    noise_model = GaussianMixtureNoiseModel(noise_model_config)
    noise_model.fit(signal=channel_data, observation=channel_prediction, n_epochs=1000)

    # save result on disk for later re-use
    noise_model.save(path="noise_models/", name=f"noise_model_Ch{channel_idx}")

    # show the result
    histogram = create_histogram(
        bins=100,
        min_val=channel_data.min(),
        max_val=channel_data.max(),
        signal=channel_data,
        observation=channel_prediction,
    )

# TODO NM preparation needs refactoring

### Create the Lightning module

In [None]:
config = create_microsplit_configuration(
    experiment_name="ht_lif24_5ms_reduced",
    data_type="tiff", # TODO: after refactoring should mean data extension. Originally reffered to Dataset( e.g. ht_lif24)
    axes="SYX",
    z_dims=[32] * 4,
    patch_size=(64, 64),
    grid_size=32,
    output_channels=2,
    multiscale_count=3,
    batch_size=64,
    num_epochs=10,
    predict_logvar="pixelwise",
    nm_paths=["/home/igor.zubarev/projects/microSplit-reproducibility/examples/2D/custom_dataset/noise_models/noise_model_Ch0.npz", "/home/igor.zubarev/projects/microSplit-reproducibility/examples/2D/custom_dataset/noise_models/noise_model_Ch1.npz"],
    # TODO path to be changed after NM refactoring
    train_dataloader_params={"num_workers": 0},
    val_dataloader_params={"num_workers": 0},
    logger=None,
)

print(config)

### Create the data module

In [None]:
model = VAEModule(config.algorithm_config)

#### Optional: Load model from checkpoint

In [None]:
# TODO: move to somewhere
def load_pretrained_model(model: VAEModule, ckpt_path):
    device = get_device()
    ckpt_dict = torch.load(ckpt_path, map_location=device, weights_only=True)
    model.load_state_dict(ckpt_dict['state_dict'], strict=False)
    print(f"Loaded model from {ckpt_path}")

In [None]:
# load_pretrained_model(model, ckpt_path) # TODO optional, temporary

In [None]:
train_data_module = create_microsplit_train_datamodule(
    train_data=DATA_PATH,
    data_type=config.data_config.data_type,
    patch_size=config.data_config.image_size, # TODO, it's not patch size because of ugly duplication
    grid_size=config.data_config.grid_size,
    multiscale_count=config.data_config.multiscale_lowres_count,# TODO, same, because of ugly duplication
    axes=config.data_config.axes,
    batch_size=config.data_config.batch_size, # TODO, should be inside dataloader params?
    transforms=[],
    train_dataloader_params=config.data_config.train_dataloader_params,
    val_dataloader_params=config.data_config.val_dataloader_params,
)

### Create the trainer

Note that here we modify the prediction loop, but this will be  changed in the near
future.

In [None]:
# Create Callbacks
root = Path("ht_lif24")
callbacks = [
    ModelCheckpoint(
        dirpath=root / "checkpoints",
        filename="ht_lif24_lightning_api",
        save_last=True,
    ),
    DataStatsCallback()
]

# Create a Lightning Trainer
trainer = Trainer(
    max_epochs=config.training_config.lightning_trainer_config["max_epochs"],
    default_root_dir=root,
    callbacks=callbacks,
)


In [None]:
# Train the model
trainer.fit(model, datamodule=train_data_module)

## Predict with CAREamics Lightning API

### Define the prediction datamodule

In [None]:
# TODO ugly that we need to initiliaze train_data_module to get data stats
data_stats, max_val = train_data_module.get_data_stats()


In [None]:
pred_data_module = create_microsplit_predict_datamodule(
    pred_data=DATA_PATH,
    data_type="tiff", # TODO see train dm
    axes="YX",
    batch_size=64,
    multiscale_count=config.data_config.multiscale_lowres_count,
    data_stats=data_stats,
    max_val=max_val, # TODO should be in the config?
    tile_size=(64, 64),
    grid_size=32, # TODO rename to overlap
)

### Predict

In [None]:
# Predict
predicted_tiles = trainer.predict(model, datamodule=pred_data_module)


In [None]:
predicted_tiles[0][1].shape

In [None]:
plt.imshow(predicted_tiles[0][0][1][0])

In [None]:
# Convert the outputs to the original format, mostly useful if tiling is used
predictions, _ = convert_outputs_microsplit(predicted_tiles, pred_data_module.predict_dataset)

In [None]:
plt.imshow(predictions[10, 200:700, 200:700, 1])

### Visualize the prediction

In [None]:
# Show two images
noises = [tifffile.imread(f) for f in sorted(test_path.glob("*.tiff"))]
gts = [tifffile.imread(f) for f in sorted(gt_path.glob("*.tiff"))]

# images to show
images = np.random.choice(range(len(noises)), 3)

fig, ax = plt.subplots(3, 3, figsize=(15, 15))
fig.tight_layout()

for i in range(3):
    pred_image = prediction[images[i]].squeeze()
    psnr_noisy = psnr(
        gts[images[i]], noises[images[i]], data_range=noises[images[i]].max()
    )
    psnr_result = psnr(gts[images[i]], pred_image, pred_image.max())

    ax[i, 0].imshow(noises[images[i]], cmap="gray")
    ax[i, 0].title.set_text(f"Noisy\nPSNR: {psnr_noisy:.2f}")

    ax[i, 1].imshow(pred_image, cmap="gray")
    ax[i, 1].title.set_text(f"Prediction\nPSNR: {psnr_result:.2f}")

    ax[i, 2].imshow(gts[images[i]], cmap="gray")
    ax[i, 2].title.set_text("Ground-truth")

### Compute metrics

In [None]:
psnrs = np.zeros((len(prediction), 1))

for i, (pred, gt) in enumerate(zip(prediction, gts)):
    psnrs[i] = psnr(gt, pred.squeeze(), pred.max())

print(f"PSNR: {psnrs.mean():.2f} +/- {psnrs.std():.2f}")
print("Reported PSNR: 27.71")

## Create cover

In [None]:
# create a cover image
im_idx = 3
cv_image_noisy = noises[im_idx]
cv_image_pred = prediction[im_idx].squeeze()

# create image
cover = np.zeros((256, 256))
(height, width) = cv_image_noisy.shape
assert height > 256
assert width > 256

# normalize train and prediction
norm_noise = (cv_image_noisy - cv_image_noisy.min()) / (
    cv_image_noisy.max() - cv_image_noisy.min()
)
norm_pred = (cv_image_pred - cv_image_pred.min()) / (
    cv_image_pred.max() - cv_image_pred.min()
)

# fill in halves
cover[:, : 256 // 2] = norm_noise[
    height // 2 - 256 // 2 : height // 2 + 256 // 2, width // 2 - 256 // 2 : width // 2
]
cover[:, 256 // 2 :] = norm_pred[
    height // 2 - 256 // 2 : height // 2 + 256 // 2, width // 2 : width // 2 + 256 // 2
]

# plot the single image
plt.imshow(cover, cmap="gray")

# save the image
im = Image.fromarray(cover * 255)
im = im.convert("L")
im.save("BSD68_Noise2Void_lightning_api.jpeg")