In [None]:
# Modules for deep learning and other operations
import os
from math import floor, ceil
import numpy as np
import torch
from torchvision import transforms
import pytorch_lightning as pl
# Modules with networks
import sys
sys.path.append('../../')
from noise_models.cnn import PixelCNN
from decoder.net import SDecoder
from LVAE.models.lvae import LadderVAE
from utils.datasets import TrainDatasetUnsupervised, PredictDatasetVAE
# Modules for loading and viewing images
import tifffile
import matplotlib.pyplot as plt
%matplotlib inline

Benchmark denoised images are provided by:<br> Henninen, T.R., Bon, M., Wang, F., Passerone, D. and Erni, R., 2020. The Structure of Sub‐nm Platinum Clusters at Elevated Temperatures. Angewandte Chemie International Edition, 59(2), pp.839-845.

In [None]:
# Load the noisy images
# These will be a PyTorch tensor with dimensions: [Number, Colours, Height, Width]
low_snr = tifffile.imread('../../data/SEM/low_snr.tif')
low_snr = torch.from_numpy(low_snr)

# Load the benchmark results
benchmark = tifffile.imread('../../data/SEM/benchmark.tif')
benchmark = torch.from_numpy(benchmark)

In [None]:
# Check
idx = np.random.randint(len(low_snr))
fig, ax = plt.subplots(1, 2)
ax[0].imshow(low_snr[idx, 0])
ax[1].imshow(benchmark[idx, 0])
fig.show()

In [None]:
# Create training, validation and prediction dataloaders
batch_size = 4
crop_size = 256
n_iters = max(low_snr.shape[-2], low_snr.shape[-1]) // crop_size
transform = transforms.RandomCrop(crop_size)

train_val_set = TrainDatasetUnsupervised(low_snr, n_iters=n_iters, transform=transform)
train_set, val_set = torch.utils.data.random_split(train_val_set, (floor(0.8*len(train_val_set)), ceil(0.2*len(train_val_set))))
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False)

n_samples = 10 # The number of times the same image should be evaluated on
predict_set = PredictDatasetVAE(low_snr, n_samples=n_samples)
predict_loader = torch.utils.data.DataLoader(predict_set, batch_size=n_samples, shuffle=False)

In [None]:
# The number of features the clean signal code will have, both the noise model and the signal decoder will need to know this
s_code_features = 32

# Create the kernel mask for the noise model.
# This mask will allow 15 horizontally adjacent pixels into the receptive field.
kernel_mask = torch.zeros((1, 16))
kernel_mask[-1, :] = 1
kernel_mask[:, -1] = 1
kernel_mask[-1, -1] = 0

# Create the noise model.
noise_model = PixelCNN(colour_channels=1,
                       code_features=s_code_features,
                       kernel_mask=kernel_mask,
                       n_filters=32,
                       n_layers=4,
                       n_gaussians=3)

# Create the signal decoder.
s_decoder = SDecoder(colour_channels=1, code_features=s_code_features)

# Create the VAE
# Pass the noise model and signal decoder as an argument
vae = LadderVAE(colour_channels=1,
                data_mean=low_snr.mean(),
                data_std=low_snr.std(),
                img_shape=(crop_size, crop_size),
                noise_model=noise_model,
                s_decoder=s_decoder,
                n_filters=s_code_features)

In [None]:
# Where training logs and the trained parameters will be saved
checkpoint_path = '../../checkpoints/SEM/DVLAE'

# Create the pytorch lightning trainer
trainer = pl.Trainer(default_root_dir=checkpoint_path,
                     accelerator='gpu',
                     devices=1,
                     max_epochs=500,
                     log_every_n_steps=len(train_set) // batch_size)

In [None]:
# Train and save model
trainer.fit(vae, train_loader, val_loader)
trainer.save_checkpoint(os.path.join(checkpoint_path, 'final_params.ckpt'))

In [None]:
# Load the trained model
# Note the new hyperparameters: the noise model is no longer necessary and the loss does not need to be calculated so <mode_pred> is set True. The image shape <img_shape> is the same this time but would need to be changed if we were to evaluate on images of a different size to the training ones.
vae = LadderVAE.load_from_checkpoint(os.path.join(checkpoint_path, 'final_params.ckpt'),
                                     noise_model=None,
                                     s_decoder=s_decoder,
                                     mode_pred=True,
                                     img_shape=(450, 512),
                                     strict=False).eval()

# Evaluate the model on the test set
# The returned "predictions" will be a list with length equal to the number of noisy images.
# Each element of the list is a tensor of shape [n_samples, colours, height, width] and is n_samples many signal estimates for a noisy image
predictions = trainer.predict(vae, predict_loader)

# To get an MMSE, we take the mean of each batch returned by predict()
MMSEs = [samples.mean(dim=0) for samples in predictions]

In [None]:
# Save results to be used as pseudo ground truth for training the noise model of HDN36
MMSEs_numpy = torch.stack(MMSEs, dim=0).numpy()
tifffile.imwrite('../../data/SEM/DVLAE_results.tif', MMSEs_numpy)

In [None]:
# Peak-of-histogram background subtraction
# Code taken from: https://github.com/hentr/peak-of-histogram
from numpy import argmax, histogram, zeros
def peak_of_histogram(imgdata,individual_frames=True):
    nbins = 100 #number of bins for the histogram, 100 should be precise enough
    # if len(imgdata.shape) == 2: individual_frames=False

    if individual_frames == False: #removes image background from peak of the entire image series' histogram
        h = histogram(imgdata,bins=nbins)
        hmax = h[1][argmax(h[0])]
        subimg = imgdata-hmax
        subimg[subimg < 1] = 1 #set negative and 0 values to 1 to avoid divide-by-zero errors in later processing
        return(subimg)

    elif individual_frames == True: #removes image background from each frame individually, using peak of the image's histogram
        newstack = zeros(imgdata.shape)
        for i,img in enumerate(imgdata):
            h = histogram(img,bins=nbins)
            hmax = h[1][argmax(h[0][1:])]  # max of the histogram (ignoring the first bin of the histogram, e.g. black pixels)
            subimg = img-hmax
            subimg[subimg < 1] = 1  #set negative and 0 values to 1 to avoid divide-by-zero errors in later processing
            newstack[i,:,:] = subimg
        return(newstack)
    else:
        raise('"individual_frames" must be True or False')

predictions_no_background = []
for i in range(len(predictions)):
    predictions_no_background.append(peak_of_histogram(predictions[i]))

MMSEs_no_background = [samples.mean(axis=0) for samples in predictions_no_background]

In [None]:
# Select a random image, its corresponding clean estimates and its ground truth
idx = np.random.randint(len(predictions))

noisy_image = predict_set[idx * n_samples, 0]
bench = benchmark[idx, 0]
MMSE = MMSEs_no_background[idx][0]
samples = predictions_no_background[idx]

# Select a random sample
sample_idx = np.random.randint(n_samples)
sample = samples[sample_idx, 0]

# Display all three
fig, ax = plt.subplots(2, 2, figsize=[15, 15])
ax[0, 0].imshow(noisy_image, cmap='inferno')
ax[0, 0].axis('off')
ax[0, 0].set_title('Input')

ax[0, 1].imshow(bench, cmap='inferno')
ax[0, 1].axis('off')
ax[0, 1].set_title('Benchmark')

ax[1, 0].imshow(MMSE, cmap='inferno')
ax[1, 0].axis('off')
ax[1, 0].set_title('MMSE')

ax[1, 1].imshow(sample, cmap='inferno')
ax[1, 1].axis('off')
ax[1, 1].set_title('Sample')

fig.show()