In [None]:
# Modules for deep learning and other operations
import os
from math import floor, ceil
import numpy as np
import torch
from torchvision import transforms
import pytorch_lightning as pl
# Modules with networks
import sys
sys.path.append('../../')
from decoder.net import SDecoder
from N2V.models.lvae import N2V
from N2V.masks import RowMask, ColumnMask, CrossMask
from utils.datasets import TrainDatasetUnsupervised, PredictDatasetUNet
# Modules for loading and viewing images
import tifffile
import matplotlib.pyplot as plt
%matplotlib inline

Benchmark denoised images are provided by:<br> Henninen, T.R., Bon, M., Wang, F., Passerone, D. and Erni, R., 2020. The Structure of Sub‐nm Platinum Clusters at Elevated Temperatures. Angewandte Chemie International Edition, 59(2), pp.839-845.

In [None]:
# Load the noisy images
# These will be a PyTorch tensor with dimensions: [Number, Colours, Height, Width]
low_snr = tifffile.imread('../../data/SEM/low_snr.tif')
low_snr = torch.from_numpy(low_snr)

# Load the benchmark results
benchmark = tifffile.imread('../../data/SEM/benchmark.tif')
benchmark = torch.from_numpy(benchmark)

In [None]:
# Check
idx = np.random.randint(len(low_snr))
fig, ax = plt.subplots(1, 2)
ax[0].imshow(low_snr[idx, 0])
ax[1].imshow(benchmark[idx, 0])
fig.show()

In [None]:
# Create training, validation and prediction dataloaders
batch_size = 4
crop_size = 256
n_iters = max(low_snr.shape[-2], low_snr.shape[-1]) // crop_size
transform = transforms.RandomCrop(crop_size)

train_val_set = TrainDatasetUnsupervised(low_snr, n_iters=n_iters, transform=transform)
train_set, val_set = torch.utils.data.random_split(train_val_set, (floor(0.8*len(train_val_set)), ceil(0.2*len(train_val_set))))
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False)

predict_set = PredictDatasetUNet(low_snr)
predict_loader = torch.utils.data.DataLoader(predict_set, batch_size=batch_size, shuffle=False)

In [None]:
# Create the masking function. This will apply a median mask to the random pixels
mask_func = RowMask(n_masks=10, search_area=(5, 5)).forward

# The number of features the clean signal code will have, the signal decoder will need to know this
s_code_features = 32

# Create the s decoder. This will learn to decode the signal code into a signal estimate
s_decoder = SDecoder(colour_channels=1, code_features=s_code_features)

# Create the UNet
# Pass the masking function as an argument
n2v = N2V(colour_channels=1,
          data_mean=low_snr.mean(),
          data_std=low_snr.std(),
          img_shape=(crop_size, crop_size),
          s_decoder=s_decoder,
          n_filters=s_code_features,
          mask_func=mask_func)

In [None]:
# Where training logs and the trained parameters will be saved
checkpoint_path = '../../checkpoints/SEM/N2V'

# Create the pytorch lightning trainer
trainer = pl.Trainer(default_root_dir=checkpoint_path,
                     accelerator='gpu',
                     devices=1,
                     max_epochs=500,
                     log_every_n_steps=len(train_set) // batch_size)

In [None]:
# Train and save model
trainer.fit(n2v, train_loader, val_loader)
trainer.save_checkpoint(os.path.join(checkpoint_path, 'final_params.ckpt'))

In [None]:
# Load the trained model
# Note the new hyperparameters: the loss does not need to be calculated so <mode_pred> is set True. The image shape <img_shape> is the same this time but would need to be changed if we were to evaluate on images of a different size to the training ones.
n2v = N2V.load_from_checkpoint(os.path.join(checkpoint_path, 'final_params.ckpt'),
                               s_decoder=s_decoder,
                               img_shape=(450, 512),
                               mode_pred=True).eval()

# Evaluate the model on the test set
# The returned "predictions" will be a list with length equal to the number of noisy images divided by batch size.
# Each element of the list is a tensor of shape [batch_size, colours, height, width]
predictions = trainer.predict(n2v, predict_loader)

In [None]:
# Peak-of-histogram background subtraction
# Code taken from: https://github.com/hentr/peak-of-histogram
from numpy import argmax, histogram, zeros
def peak_of_histogram(imgdata,individual_frames=True):
    nbins = 100 #number of bins for the histogram, 100 should be precise enough
    # if len(imgdata.shape) == 2: individual_frames=False

    if individual_frames == False: #removes image background from peak of the entire image series' histogram
        h = histogram(imgdata,bins=nbins)
        hmax = h[1][argmax(h[0])]
        subimg = imgdata-hmax
        subimg[subimg < 1] = 1 #set negative and 0 values to 1 to avoid divide-by-zero errors in later processing
        return(subimg)

    elif individual_frames == True: #removes image background from each frame individually, using peak of the image's histogram
        newstack = zeros(imgdata.shape)
        for i,img in enumerate(imgdata):
            h = histogram(img,bins=nbins)
            hmax = h[1][argmax(h[0][1:])]  # max of the histogram (ignoring the first bin of the histogram, e.g. black pixels)
            subimg = img-hmax
            subimg[subimg < 1] = 1  #set negative and 0 values to 1 to avoid divide-by-zero errors in later processing
            newstack[i,:,:] = subimg
        return(newstack)
    else:
        raise('"individual_frames" must be True or False')

predictions_no_background = []
for i in range(len(predictions)):
    predictions_no_background.append(peak_of_histogram(predictions[i]))

In [None]:
vmin = np.percentile(low_snr.numpy(), 1)
vmax = np.percentile(low_snr.numpy(), 99)

In [None]:
# Select a random image, its corresponding clean estimates and its ground truth
batch_idx = np.random.randint(len(predictions))
img_idx = np.random.randint(batch_size)

noisy_image = predict_set[batch_idx * batch_size + img_idx, 0]
bench = benchmark[batch_idx * batch_size + img_idx, 0]
denoised = predictions_no_background[batch_idx][img_idx, 0]

# Display all three
fig, ax = plt.subplots(1, 3, figsize=[15, 15])
ax[0].imshow(noisy_image, cmap='inferno')
ax[0].axis('off')
ax[0].set_title('Input')

ax[1].imshow(bench, cmap='inferno')
ax[1].axis('off')
ax[1].set_title('Benchmark')

ax[2].imshow(denoised, cmap='inferno')
ax[2].axis('off')
ax[2].set_title('Denoised')

fig.show()