In [None]:
# Modules for deep learning and other operations
import os
from math import floor, ceil
from skimage.metrics import peak_signal_noise_ratio as PSNR
import numpy as np
import torch
from torchvision import transforms
import pytorch_lightning as pl
# Modules with networks
import sys
sys.path.append('../../')
from decoder.net import SDecoder
from N2V.models.lvae import N2V
from N2V.masks import RowMask, ColumnMask, CrossMask
from utils.datasets import TrainDatasetUnsupervised, PredictDatasetUNet
# Modules for loading and viewing images
import tifffile
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# Load the noisy and clean images
# These will be lists of PyTorch tensors with dimensions: [Channels, Height, Width]
top_dir = '../../data/stripes/low_snr'
files = os.listdir(top_dir)
low_snr = [tifffile.imread(os.path.join(top_dir, file)) for file in files if file.endswith('.tif')]
low_snr = [torch.from_numpy(image) for image in low_snr]

top_dir = '../../data/stripes/high_snr'
files = os.listdir(top_dir)
high_snr = [tifffile.imread(os.path.join(top_dir, file)) for file in files if file.endswith('.tif')]
high_snr = [torch.from_numpy(image) for image in high_snr]

In [None]:
# Calculate the mean and standard deviation of the noisy images to be used to normalise them
cumulative_sum = 0
cumulative_size = 0
for i in range(len(low_snr)):
    cumulative_sum += low_snr[i].sum()
    cumulative_size += low_snr[i].numel()

low_snr_mean = cumulative_sum / cumulative_size

cumulative_deviation = 0
for i in range(len(low_snr)):
    cumulative_deviation += ((low_snr[i] - low_snr_mean)**2).sum()

low_snr_std = (cumulative_deviation / cumulative_size)**0.5

In [None]:
# Check
idx = np.random.randint(len(low_snr))
fig, ax = plt.subplots(1, 2)
ax[0].imshow(low_snr[idx][0])
ax[1].imshow(high_snr[idx][0])
fig.show()

In [None]:
# Create training, validation and prediction dataloaders
# Although this is a self-supervised method, it is being compared to a supervised method, so should have the same prediction set.
predict_set = low_snr[-20:]
predict_set_gt = high_snr[-20:]
predict_set = [image[:, :256, :256] for image in predict_set]
predict_set_gt = [image[:, :256, :256] for image in predict_set_gt]

# Set batch size, size of the random crops and number of times to iterate over the dataset in an epoch
batch_size = 4
crop_size = 256
n_iters = max(low_snr[0].shape[0], low_snr[0].shape[1]) // crop_size
transform = transforms.RandomCrop(crop_size)

train_val_set = TrainDatasetUnsupervised(low_snr, n_iters=n_iters, transform=transform)
train_set, val_set = torch.utils.data.random_split(train_val_set, (floor(0.8*len(train_val_set)), ceil(0.2*len(train_val_set))))
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False)

predict_set = PredictDatasetUNet(predict_set)
predict_loader = torch.utils.data.DataLoader(predict_set, batch_size=batch_size, shuffle=False)

In [None]:
# Create the masking function. This will apply a median mask to random pixels in a cross shape
mask_func = CrossMask(n_masks=10, search_area=(5, 5)).forward

# The number of feature channels the clean signal code will have, the signal decoder will need to know this
s_code_features = 32

# Create the s decoder. This will learn to decode the signal code into a signal estimate
s_decoder = SDecoder(colour_channels=1, code_features=s_code_features)

# Create the N2V UNet
# Pass the masking function as an argument
n2v = N2V(colour_channels=1,
          data_mean=low_snr_mean,
          data_std=low_snr_std,
          img_shape=(crop_size, crop_size),
          s_decoder=s_decoder,
          n_filters=s_code_features,
          mask_func=mask_func)

In [None]:
# Where training logs and the trained parameters will be saved
checkpoint_path = '../../checkpoints/stripes/N2V'

trainer = pl.Trainer(default_root_dir=checkpoint_path,
                     accelerator='gpu',
                     devices=1,
                     max_epochs=500,
                     log_every_n_steps=len(train_set) // batch_size)

In [None]:
# Train and save model
trainer.fit(n2v, train_loader, val_loader)
trainer.save_checkpoint(os.path.join(checkpoint_path, 'final_params.ckpt'))

In [None]:
# Load the trained model
# Note the new hyperparameters: the loss does not need to be calculated so <mode_pred> is True. The image shape <img_shape> is the same this time but would need to be changed if we were to evaluate on images of a different size to the training ones.
n2v = N2V.load_from_checkpoint(os.path.join(checkpoint_path, 'final_params.ckpt'),
                               s_decoder=s_decoder,
                               img_shape=(256, 256),
                               mode_pred=True).eval()

# Evaluate the model on the test set
# The returned "predictions" will be a list with length equal to the number of noisy images divided by batch size.
# Each element of the list is a tensor of shape [batch_size, colours, height, width]
predictions = trainer.predict(n2v, predict_loader)

In [None]:
# Select a random image, its corresponding clean estimates and its ground truth
batch_idx = np.random.randint(len(predictions))
img_idx = np.random.randint(batch_size)

noisy_image = predict_set[batch_idx * batch_size + img_idx][0]
gt = predict_set_gt[batch_idx * batch_size + img_idx][0]
denoised = predictions[batch_idx][img_idx, 0]

# Display all three
fig, ax = plt.subplots(1, 3, figsize=[20, 20])
ax[0].imshow(noisy_image, cmap='Greys_r')
ax[0].axis('off')
ax[0].set_title('Input')

ax[1].imshow(gt, cmap='Greys_r')
ax[1].axis('off')
ax[1].set_title('Ground truth')

ax[2].imshow(denoised, cmap='Greys_r')
ax[2].axis('off')
ax[2].set_title('Denoised')

fig.show()

In [None]:
# Calculate PSNR
predictions_unbatched = []
for i in range(len(predictions)):
    for j in range(batch_size):
        try:
            predictions_unbatched.append(predictions[i][j])
        except:
            break

psnrs = []
for i in range(len(predict_set_gt)):
    gt = predict_set_gt[i][0].numpy()

    data_range = gt.max() - gt.min()

    noisy_image = predictions_unbatched[i][0].numpy()

    psnrs.append(PSNR(gt,
                      noisy_image,
                      data_range=data_range.item()))

print(f"PSNR: {np.mean(psnrs)}")