## Code for training the noise model

Description of nm theory and why we need it

In [None]:
import pooch
import matplotlib.pyplot as plt
from careamics import CAREamist
from careamics.models.lvae.noise_models import GaussianMixtureNoiseModel, create_histogram
from careamics.config import GaussianMixtureNMConfig
from careamics.config import create_n2v_configuration

from microsplit_reproducibility.configs.data.HT_LIF24 import get_data_configs
from microsplit_reproducibility.datasets.HT_LIF24 import get_train_val_data
from microsplit_reproducibility.utils.utils import plot_probability_distribution

### Load and prepare data

Data preparation, dataset specific or load from function in this repo



In [None]:
tmp_local_path = "/localscratch/data/pavia3_sequential_cropped"

In [None]:
DATA = pooch.create(
    # path=pooch.os_cache("microsplit_reproducibility_pavia_p24"), # TODO should be downloaded and stored locally
    path=tmp_local_path,
    base_url="",
    registry={"":""},
)

In [None]:
train_data_config, val_data_config, test_data_config = get_data_configs()

In [None]:
input_data = get_train_val_data(
    data_config=train_data_config,
    datadir=DATA.path, # TODO replace with actual local path after downloading
    val_fraction=0.1
)

In [None]:
train_data = input_data._data[0].squeeze()

### Create N2V configuration

In [None]:
config = create_n2v_configuration(
    experiment_name="bla",
    data_type="array",
    axes="YXC",
    n_channels=2,
    patch_size=(64, 64),
    batch_size=64,
    num_epochs=1,
)

print(config)

### Train N2V 

In [None]:
# instantiate a CAREamist
careamist = CAREamist(source=config)

# train
careamist.train(
    train_source=train_data,
    val_minimum_split=5,
)

In [None]:
predicition = careamist.predict(train_data, tile_size=(256, 256))

### Visualize N2V predictions

In [None]:
_, ax = plt.subplots(2, 2, figsize=(20, 30))
ax[0][0].imshow(train_data[..., 0])
ax[0][0].set_title("Input channel 1")
ax[0][1].imshow(predicition[0].squeeze()[0])
ax[0][1].set_title("Denoised channel 1")
ax[1][0].imshow(train_data[..., 1])
ax[1][0].set_title("Input channel 2")
ax[1][1].imshow(predicition[0].squeeze()[1])
ax[1][1].set_title("Denoised channel 2")
plt.show()

### Train the noise model and visualize the results

Here we train a noise model using the denoised images aquired from the N2V model.

We train a separate noise model for each channel.


In [None]:
for channel_idx in range(train_data.shape[-1]):
    channel_data = train_data[..., channel_idx]
    print(f"Training noise model for channel {channel_idx}")
    noise_model_config = GaussianMixtureNMConfig(
        model_type="GaussianMixtureNoiseModel",
        min_signal=channel_data.min(),
        max_signal=channel_data.max(),
    )
    noise_model = GaussianMixtureNoiseModel(noise_model_config)
    noise_model.fit(
        signal=channel_data, 
        observation=predicition[0].squeeze()[channel_idx], 
        n_epochs=100
    ) # TODO change n_epochs
    noise_model.save(
        path="noise_models", name=f"noise_model_pavia_p24_channel_{channel_idx}"
    )
    histogram = create_histogram(
        bins=100,
        min_val=channel_data.min(),
        max_val=channel_data.max(),
        signal=channel_data,
        observation=predicition[0].squeeze()[channel_idx]
    )
    plot_probability_distribution(
        noise_model, signalBinIndex=50, histogram=histogram[0], channel=channel_idx
    )

### Plot the results

#### Plot the noise model for each channel

In [None]:
nm_path = "/group/jug/ashesh/training/noise_model/2404/94/GMMNoiseModel_pavia3_sequential_singlefiles-Cond_1__6_4_Clip0.0-1.0_Sig0.125_UpNone_Norm0_bootstrap.npz"
# TODO remove after testing

In [None]:
noise_model_config = GaussianMixtureNMConfig(path=nm_path, model_type="GaussianMixtureNoiseModel")
loaded_noise_model = GaussianMixtureNoiseModel(noise_model_config)

In [None]:
plot_probability_distribution(signalBinIndex=50, histogram=histogram[0])