In [None]:
# Modules for deep learning and other operations
import os
from math import floor, ceil
from skimage.metrics import peak_signal_noise_ratio as PSNR
import numpy as np
import torch
from torchvision import transforms
import pytorch_lightning as pl
# Modules with networks
import sys
sys.path.append('../../')
from decoder.net import SDecoder
from UNet.models.lvae import UNet
from utils.datasets import TrainDatasetSupervised, PredictDatasetUNet
# Modules for loading and viewing images
import tifffile
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# Load the noisy images
# These will be a PyTorch tensor with dimensions: [Number, Channels, Height, Width]
low_snr = tifffile.imread('../../data/nucleus/low_snr.tif')
low_snr = torch.from_numpy(low_snr)

high_snr = tifffile.imread('../../data/nucleus/high_snr.tif')
high_snr = torch.from_numpy(high_snr)

In [None]:
# Check
idx = np.random.randint(len(low_snr))
fig, ax = plt.subplots(1, 2)
ax[0].imshow(low_snr[idx, 0])
ax[1].imshow(high_snr[idx, 0])
fig.show()

In [None]:
# Create training, validation and prediction dataloaders
# Here we select the last 10 images as the prediction set.
train_val_set = low_snr[:-10]
train_val_set_gt = high_snr[:-10]
predict_set = low_snr[-10:]
predict_set_gt = high_snr[-10:]

batch_size = 4
crop_size = 256
n_iters = max(train_val_set.shape[-2], train_val_set.shape[-1]) // crop_size
transform = transforms.RandomCrop(crop_size)

train_val_set = TrainDatasetSupervised(train_val_set, train_val_set_gt, n_iters=n_iters, transform=transform)
train_set, val_set = torch.utils.data.random_split(train_val_set, (floor(0.8*len(train_val_set)), ceil(0.2*len(train_val_set))))
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False)

predict_set = PredictDatasetUNet(predict_set)
predict_loader = torch.utils.data.DataLoader(predict_set, batch_size=batch_size, shuffle=False)

In [None]:
# The number of features the clean signal code will have, the signal decoder will need to know this
s_code_features = 32

# Create the s decoder. This will learn to decode the signal code into a signal estimate
s_decoder = SDecoder(colour_channels=1, code_features=s_code_features)

# Where training logs and the trained parameters will be saved
checkpoint_path = '../../checkpoints/Hagen/nucleus/CARE'

# Create the VAE
# Pass the noise model as an argument
unet = UNet(colour_channels=1,
           data_mean=low_snr.mean(),
           data_std=low_snr.std(),
           img_shape=(crop_size, crop_size),
           s_decoder=s_decoder,
           n_filters=s_code_features)

trainer = pl.Trainer(default_root_dir=checkpoint_path,
                     accelerator='gpu',
                     devices=1,
                     max_epochs=500,
                     log_every_n_steps=len(train_set) // batch_size)

In [None]:
# Train and save model
trainer.fit(unet, train_loader, val_loader)
trainer.save_checkpoint(os.path.join(checkpoint_path, 'final_params.ckpt'))

In [None]:
# Load the trained model
# Note the new hyperparameters: the noise model is no longer necessary and the loss does not need to be calculated so <mode_pred> is set True. The image shape <img_shape> is the same this time but would need to be changed if we were to evaluate on images of a different size to the training ones.
unet = UNet.load_from_checkpoint(os.path.join(checkpoint_path, 'final_params.ckpt'), s_decoder=s_decoder, img_shape=(512, 512), mode_pred=True).eval()

# Evaluate the model on the test set
# The returned "predictions" will be a list with length equal to the number of noisy images divided by batch size.
# Each element of the list is a tensor of shape [batch_size, colours, height, width]
predictions = trainer.predict(unet, predict_loader)

In [None]:
# Select a random image, its corresponding clean estimates and its ground truth
batch_idx = np.random.randint(len(predictions))
img_idx = np.random.randint(batch_size)

noisy_image = predict_set[batch_idx * batch_size + img_idx, 0]
gt = predict_set_gt[batch_idx * batch_size + img_idx, 0]
denoised = predictions[batch_idx][img_idx, 0]

# Display all three
fig, ax = plt.subplots(1, 3, figsize=[15, 15])
ax[0].imshow(noisy_image, cmap='inferno')
ax[0].axis('off')
ax[0].set_title('Input')

ax[1].imshow(gt, cmap='inferno')
ax[1].axis('off')
ax[1].set_title('Ground truth')

ax[2].imshow(denoised, cmap='inferno')
ax[2].axis('off')
ax[2].set_title('Denoised')

fig.show()

In [None]:
# Calculate PSNR
predictions_unbatched = []
for i in range(len(predictions)):
    for j in range(batch_size):
        try:
            predictions_unbatched.append(predictions[i][j])
        except:
            break

# Scale gt to match training data by minimising mean square error between the two
def MSE(array1, array2):
    return np.mean(((array1 - array2) ** 2))

def grid_search(search_range, step_size, high_snr, low_snr):
    search_values = np.arange(start=0, stop=search_range, step=step_size)

    MSEs = [MSE(high_snr * scale, low_snr) for scale in search_values]

    return search_values, MSEs

from tqdm import tqdm

psnrs = []
for i in tqdm(range(len(predict_set_gt))):
    gt = predict_set_gt[i, 0].numpy()
    gt = gt - np.mean(gt)

    noisy = predictions_unbatched[i][0].numpy()
    noisy = noisy - np.mean(noisy)

    search_values, MSEs = grid_search(search_range=1.5,
                                      step_size=0.0001,
                                      high_snr=gt,
                                      low_snr=noisy)

    index = np.argmin(MSEs)
    optimal_scale = search_values[index]

    gt = gt * optimal_scale

    data_range = gt.max() - gt.min()

    psnrs.append(PSNR(gt,
                      noisy,
                      data_range=data_range))

print(np.mean(psnrs))