## Code for training the noise model

Pixel noise models are a collection of probability distributions mapping from a true pixel intensity to observed noisy pixel measurements (and vice-versa).
They have been effectively used for unsupervised denoising, which we aim to replicate in our setup. Given a noisy pixel intensity, these models provide a distribution of potential clean signal values and their probabilities.

Physical Availability of Microscope
When the microscope is available, the same content is imaged 
𝑁 times to obtain 
𝑁 noisy versions of the specimen. A high-SNR version is generated by averaging the noisy samples pixel-wise, creating data pairs of clean signal values and 
𝑁 noisy intensities. This data is used to train a Gaussian-mixture noise model. To simulate this, noise (Gaussian and/or Poisson) is repeatedly added to intensity values in the range 
[0, 65535], [0,65535].

Physical Unavailability of Microscope
When the microscope is unavailable, as with publicly available datasets, DivNoising proposes a bootstrap noise model. Noisy data is denoised using unsupervised techniques, and the noisy/denoised pairs are used to train the noise model. To simulate this, noise is added to clean training data to produce noisy-clean pairs.

To address errors introduced by the denoising process, an additional method is tested: noise is added to clean data, and an off-the-shelf denoiser like N2V is used. The resulting noisy-denoised pairs are then used to train the noise model.

### Important ! 

This step can be skipped! Only run this notebook if you want to train the noise models from scratch. Pretrained noise models are available 

In [None]:
import os
import pooch
import matplotlib.pyplot as plt
from careamics import CAREamist
from careamics.models.lvae.noise_models import GaussianMixtureNoiseModel, create_histogram
from careamics.lvae_training.dataset import DataSplitType
from careamics.config import GaussianMixtureNMConfig
from careamics.config import create_n2v_configuration

from microsplit_reproducibility.configs.data.HT_LIF24 import get_data_configs
from microsplit_reproducibility.datasets.HT_LIF24 import get_train_val_data
from microsplit_reproducibility.utils.utils import plot_probability_distribution



### Load and prepare data

Data preparation, dataset specific or load from function in this repo



In [None]:
DATA = pooch.create(
    path="./data",
    base_url="https://download.fht.org/jug/ht_lif24",
    registry={"ht_lif24.zip": None},
)

In [None]:
fp = DATA.fetch("ht_lif24.zip", processor=pooch.Unzip())

In [None]:
train_data_config, val_data_config, test_data_config = get_data_configs(dset_type="20ms", channel_idx_list=[1, 2, 3])

In [None]:
input_data = get_train_val_data(data_config=train_data_config,
                                datadir=DATA.path / "ht_lif24.zip.unzip/ht_lif24",
                                datasplit_type=DataSplitType.Train,
                                val_fraction=0.1,
                                test_fraction=0.1,)

In [None]:
# We use a subset of the data for training the noise model
train_data = input_data[0:-1:10].squeeze()

### Create N2V configuration

In [None]:
config = create_n2v_configuration(
    experiment_name="ht_lif24",
    data_type="array",
    axes="SYXC",
    n_channels=3,
    patch_size=(64, 64),
    batch_size=64,
    num_epochs=1,
)

print(config)

### Train N2V 

In [None]:
# Check if all the noise models are loaded(3 is the number of channels for this dataset)
os.path.exists("noise_models") and len(os.listdir("noise_models")) == 3

In [None]:
if not os.path.exists("noise_models"):
    # instantiate a CAREamist
    careamist = CAREamist(source=config)

    # train
    careamist.train(
        train_source=train_data,
        val_minimum_split=5,
    )
else:
    print("Noise model already exists.")

In [None]:
predicition = careamist.predict(train_data, tile_size=(256, 256))

### Visualize N2V predictions

In [None]:
_, ax = plt.subplots(2, 2, figsize=(20, 30))
ax[0][0].imshow(train_data[..., 0])
ax[0][0].set_title("Input channel 1")
ax[0][1].imshow(predicition[0].squeeze()[0])
ax[0][1].set_title("Denoised channel 1")
ax[1][0].imshow(train_data[..., 1])
ax[1][0].set_title("Input channel 2")
ax[1][1].imshow(predicition[0].squeeze()[1])
ax[1][1].set_title("Denoised channel 2")
plt.show()

### Train the noise model and visualize the results

Here we train a noise model using the denoised images aquired from the N2V model.

We train a separate noise model for each channel.


In [None]:
for channel_idx in range(train_data.shape[-1]):
    channel_data = train_data[..., channel_idx]
    print(f"Training noise model for channel {channel_idx}")
    noise_model_config = GaussianMixtureNMConfig(model_type="GaussianMixtureNoiseModel",
                                             min_signal=channel_data.min(),
                                                max_signal=channel_data.max(),
)
    noise_model = GaussianMixtureNoiseModel(noise_model_config)
    noise_model.fit(signal=channel_data, observation=predicition[0].squeeze()[channel_idx], n_epochs=100) # TODO change n_epochs
    noise_model.save(path="noise_models", name=f"noise_model_pavia_p24_channel_{channel_idx}")
    histogram = create_histogram(bins=100, min_val=channel_data.min(), max_val=channel_data.max(), signal=channel_data, observation=predicition[0].squeeze()[channel_idx])
    plot_probability_distribution(noise_model, signalBinIndex=50, histogram=histogram[0], channel=channel_idx)

    