In [None]:
# Modules for deep learning and other operations
import os
from math import floor, ceil
from skimage.metrics import peak_signal_noise_ratio as PSNR
import numpy as np
import torch
from torchvision import transforms
import pytorch_lightning as pl
# Modules with networks
import sys
sys.path.append('../../')
from decoder.net import SDecoder
from HDN.noise_model.gaussianMixtureNoiseModel import GaussianMixtureNoiseModel
from HDN.models.lvae import LadderVAE
from utils.datasets import TrainDatasetUnsupervised, PredictDatasetVAE
# Modules for loading and viewing images
import tifffile
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# Load the noisy images
# These will be a PyTorch tensor with dimensions: [Number, Channels, Height, Width]
low_snr = tifffile.imread('../../data/nucleus/low_snr.tif')
low_snr = torch.from_numpy(low_snr)

high_snr = tifffile.imread('../../data/nucleus/high_snr.tif')
high_snr = torch.from_numpy(high_snr)

In [None]:
# Scale high_snr to match low_snr by minimising mean square error between the two
def MSE(array1, array2):
    return torch.mean(((array1 - array2) ** 2))


def grid_search(search_range, step_size, high_snr, low_snr):
    search_values = torch.arange(start=0, end=search_range, step=step_size)

    MSEs = [MSE(high_snr * scale, low_snr) for scale in search_values]

    return search_values, MSEs


high_snr_scaled = []
low_snr_scaled = []
for i in range(len(low_snr)):
    gt = high_snr[i]
    gt = gt - torch.mean(gt)

    noisy = low_snr[i]
    noisy = noisy - torch.mean(noisy)

    search_values, MSEs = grid_search(search_range=1,
                                      step_size=0.001,
                                      high_snr=gt,
                                      low_snr=noisy)

    index = torch.argmin(torch.tensor(MSEs))
    optimal_scale = search_values[index]

    gt = gt * optimal_scale

    high_snr_scaled.append(gt)
    low_snr_scaled.append(noisy)

high_snr_scaled = torch.stack(high_snr_scaled, dim=0)
low_snr_scaled = torch.stack(low_snr_scaled, dim=0)

In [None]:
# Check
idx = np.random.randint(len(low_snr))
fig, ax = plt.subplots(1, 2)
ax[0].imshow(low_snr[idx, 0])
ax[1].imshow(high_snr[idx, 0])
fig.show()

In [None]:
# Create GMM noise model
min_signal = np.min(high_snr_scaled.numpy())
max_signal = np.max(high_snr_scaled.numpy())

nm_checkpoint_path = '../../checkpoints/Hagen/nucleus/HDN/noise_model/'
if not os.path.exists(nm_checkpoint_path):
    os.makedirs(nm_checkpoint_path)
n_gaussian = 3
n_coeff = 2
device = torch.device('cuda')
gaussianMixtureNoiseModel = GaussianMixtureNoiseModel(min_signal=min_signal,
                                                      max_signal=max_signal,
                                                      path=nm_checkpoint_path,
                                                      weight=None,
                                                      n_gaussian=n_gaussian,
                                                      n_coeff=n_coeff,
                                                      min_sigma=50,
                                                      device=device)

In [None]:
gaussianMixtureNoiseModel.train(high_snr_scaled.numpy(), low_snr_scaled.cpu().numpy(), batchSize=250000, n_epochs=2000, learning_rate=0.1, name='GMM_noise_model')

In [None]:
noise_model_params = np.load(os.path.join(nm_checkpoint_path, 'GMM_noise_model.npz'))
noise_model = GaussianMixtureNoiseModel(params=noise_model_params, device=device)

In [None]:
# Create training, validation and prediction dataloaders
# Although this is an usupervised method, it is being compared to a supervised method, so should have the same prediction set.
# Here we select the last 10 images as the prediction set.
predict_set = low_snr_scaled[-10:]
predict_set_gt = high_snr_scaled[-10:]

# Create training and validation dataloaders
batch_size = 4
crop_size = 256
n_iters = max(low_snr_scaled.shape[-2], low_snr_scaled.shape[-1]) // crop_size
transform = transforms.RandomCrop(crop_size)

train_val_set = TrainDatasetUnsupervised(low_snr_scaled, n_iters=n_iters, transform=transform)
train_set, val_set = torch.utils.data.random_split(train_val_set, (floor(0.8*len(train_val_set)), ceil(0.2*len(train_val_set))))
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False)

n_samples = 10 # The number of times the same image should be evaluated on
predict_set = PredictDatasetVAE(predict_set, n_samples=n_samples)
predict_loader = torch.utils.data.DataLoader(predict_set, batch_size=n_samples, shuffle=False)

In [None]:
# The number of features the clean signal code will have, the signal decoder will need to know this
s_code_features = 32

# Create the signal decoder.
s_decoder = SDecoder(colour_channels=1, code_features=s_code_features)

# Create the VAE
# Pass the noise model and signal decoder as arguments
# use_uncond_mode_at is a list of layers for which the posterior distribution will not be conditioned on x, with 0 being the final layer
vae = LadderVAE(colour_channels=1,
                data_mean=low_snr_scaled.mean(),
                data_std=low_snr_scaled.std(),
                img_shape=(crop_size, crop_size),
                noise_model=noise_model,
                s_decoder=s_decoder,
                n_filters=s_code_features,
                use_uncond_mode_at=[0, 1])

In [None]:
# Where training logs and the trained parameters will be saved
checkpoint_path = '../../checkpoints/Hagen/nucleus/HDN'

In [None]:
# Create the pytorch lightning trainer
trainer = pl.Trainer(default_root_dir=checkpoint_path,
                     accelerator='gpu',
                     devices=1,
                     max_epochs=500,
                     log_every_n_steps=len(train_set) // batch_size)

In [None]:
# Train and save model
trainer.fit(vae, train_loader, val_loader)
trainer.save_checkpoint(os.path.join(checkpoint_path, 'final_params.ckpt'))

In [None]:
# Load the trained model
# Note the new hyperparameters: the noise model is no longer necessary and the loss does not need to be calculated. The image shape (img_shape) is the same this time but would need to be changed if we were to evaluate on images of a different size to the training ones.
vae = LadderVAE.load_from_checkpoint(os.path.join(checkpoint_path, 'final_params.ckpt'),
                                     noise_model=None, 
                                     s_decoder=s_decoder,
                                     mode_pred=True,
                                     img_shape=(512, 512),
                                     strict=False).eval()

# Evaluate the model on the test set
# The returned "predictions" will be a list with length equal to the number of noisy images.
# Each element of the list is a tensor of shape [n_samples, colours, height, width] and is n_samples many signal estimates for a noisy image
predictions = trainer.predict(vae, predict_loader)

# To get an MMSE, we take the mean of each batch returned by predict()
MMSEs = [samples.mean(dim=0) for samples in predictions]

In [None]:
# Select a random image, its corresponding clean estimates and its ground truth
idx = np.random.randint(len(predict_set_gt))

noisy_image = predict_set[idx * n_samples, 0]
gt = predict_set_gt[idx, 0]
MMSE = MMSEs[idx][0]
samples = predictions[idx]
# Select a random sample
sample_idx = np.random.randint(n_samples)
sample = samples[sample_idx, 0]

# Display all three
fig, ax = plt.subplots(2, 2, figsize=[10, 10])
ax[0, 0].imshow(noisy_image, cmap='inferno')
ax[0, 0].axis('off')
ax[0, 0].set_title('Input')

ax[0, 1].imshow(gt, cmap='inferno')
ax[0, 1].axis('off')
ax[0, 1].set_title('Ground truth')

ax[1, 0].imshow(MMSE, cmap='inferno')
ax[1, 0].axis('off')
ax[1, 0].set_title('MMSE')

ax[1, 1].imshow(sample, cmap='inferno')
ax[1, 1].axis('off')
ax[1, 1].set_title('Sample')

fig.show()

In [None]:
# Scale gt to match training data by minimising mean square error between the two
def MSE(array1, array2):
    return np.mean(((array1 - array2) ** 2))

def grid_search(search_range, step_size, high_snr, low_snr):
    search_values = np.arange(start=0, stop=search_range, step=step_size)

    MSEs = [MSE(high_snr * scale, low_snr) for scale in search_values]

    return search_values, MSEs

from tqdm import tqdm

psnrs = []
for i in tqdm(range(len(predict_set_gt))):
    gt = predict_set_gt[i, 0].numpy()
    gt = gt - np.mean(gt)

    noisy = MMSEs[i][0].numpy()
    noisy = noisy - np.mean(noisy)

    search_values, MSEs = grid_search(search_range=1.5,
                                      step_size=0.001,
                                      high_snr=gt,
                                      low_snr=noisy)

    index = np.argmin(MSEs)
    optimal_scale = search_values[index]

    gt = gt * optimal_scale

    data_range = gt.max() - gt.min()

    psnrs.append(PSNR(gt,
                      noisy,
                      data_range=data_range))

print(np.mean(psnrs))