In [None]:
# Modules for deep learning and other operations
import os
from math import floor, ceil
from skimage.metrics import peak_signal_noise_ratio as PSNR
import numpy as np
import torch
from torchvision import transforms
import pytorch_lightning as pl
# Modules with networks
import sys
sys.path.append('../../')
from noise_models.cnn import PixelCNN
from decoder.net import SDecoder
from LVAE.models.lvae import LadderVAE
from utils.datasets import TrainDatasetUnsupervised, PredictDatasetVAE
# Modules for loading and viewing images
import tifffile
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# Load the noisy and clean images
# These will be lists of PyTorch tensors with dimensions: [Channels, Height, Width]
top_dir = '../../data/stripes/low_snr'
files = os.listdir(top_dir)
low_snr = [tifffile.imread(os.path.join(top_dir, file)) for file in files if file.endswith('.tif')]
low_snr = [torch.from_numpy(image) for image in low_snr]

top_dir = '../../data/stripes/high_snr'
files = os.listdir(top_dir)
high_snr = [tifffile.imread(os.path.join(top_dir, file)) for file in files if file.endswith('.tif')]
high_snr = [torch.from_numpy(image) for image in high_snr]

In [None]:
# Calculate the mean and standard deviation of the noisy images to be used to normalise them
cumulative_sum = 0
cumulative_size = 0
for i in range(len(low_snr)):
    cumulative_sum += low_snr[i].sum()
    cumulative_size += low_snr[i].numel()

low_snr_mean = cumulative_sum / cumulative_size

cumulative_deviation = 0
for i in range(len(low_snr)):
    cumulative_deviation += ((low_snr[i] - low_snr_mean)**2).sum()

low_snr_std = (cumulative_deviation / cumulative_size)**0.5

In [None]:
# Check
idx = np.random.randint(len(low_snr))
fig, ax = plt.subplots(1, 2)
ax[0].imshow(low_snr[idx][0])
ax[1].imshow(high_snr[idx][0])
fig.show()

In [None]:
# Create training, validation and prediction dataloaders
# Although this is an unsupervised method, it is being compared to a supervised method, so should have the same prediction set.
predict_set = low_snr[-20:]
predict_set_gt = high_snr[-20:]
predict_set = [image[:, :256, :256] for image in predict_set]
predict_set_gt = [image[:, :256, :256] for image in predict_set_gt]

# Set batch size, size of the random crops and number of times to iterate over the dataset in an epoch
batch_size = 4
crop_size = 256
n_iters = max(low_snr[0].shape[0], low_snr[0].shape[1]) // crop_size
transform = transforms.RandomCrop(crop_size)

train_val_set = TrainDatasetUnsupervised(low_snr, n_iters=n_iters, transform=transform)
train_set, val_set = torch.utils.data.random_split(train_val_set, (floor(0.8*len(train_val_set)), ceil(0.2*len(train_val_set))))
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False)

n_samples = 10 # The number of times the same image should be evaluated on
predict_set = PredictDatasetVAE(predict_set, n_samples=n_samples)
predict_loader = torch.utils.data.DataLoader(predict_set, batch_size=n_samples, shuffle=False)

In [None]:
# The number of features the clean signal code will have, both the noise model and the signal decoder will need to know this
s_code_features = 32

# Create the kernel mask for the noise model.
# This mask will allow 15 horizontally adjacent pixels and 15 vertically adjacent into the receptive field.
kernel_mask = torch.zeros((16, 16))
kernel_mask[-1, :] = 1
kernel_mask[:, -1] = 1
kernel_mask[-1, -1] = 0

# Create the noise model.
noise_model = PixelCNN(colour_channels=1,
                       code_features=s_code_features,
                       kernel_mask=kernel_mask,
                       n_filters=32,
                       n_layers=4,
                       n_gaussians=5)

# Create the signal decoder.
s_decoder = SDecoder(colour_channels=1, code_features=s_code_features)

# Create the VAE
# Pass the noise model and signal decoder as arguments
vae = LadderVAE(colour_channels=1,
                data_mean=low_snr_mean,
                data_std=low_snr_std,
                img_shape=(crop_size, crop_size),
                noise_model=noise_model,
                s_decoder=s_decoder,
                n_filters=s_code_features)

In [None]:
# Where training logs and the trained parameters will be saved
checkpoint_path = '../../checkpoints/stripes/DVLAE'

# Create the pytorch lightning trainer
trainer = pl.Trainer(default_root_dir=checkpoint_path,
                     accelerator='gpu',
                     devices=1,
                     max_epochs=500,
                     log_every_n_steps=len(train_set) // batch_size)

In [None]:
# Train and save model
trainer.fit(vae, train_loader, val_loader)
trainer.save_checkpoint(os.path.join(checkpoint_path, 'final_params.ckpt'))

In [None]:
# Load the trained model
# Note the new hyperparameters: the noise model is no longer necessary and the loss does not need to be calculated so <mode_pred> is set True. The image shape <img_shape> is the same this time but would need to be changed if we were to evaluate on images of a different size to the training ones.
vae = LadderVAE.load_from_checkpoint(os.path.join(checkpoint_path, 'final_params.ckpt'),
                                     noise_model=None,
                                     s_decoder=s_decoder,
                                     mode_pred=True,
                                     img_shape=(256, 256),
                                     strict=False).eval()

# Evaluate the model on the test set
# The returned "predictions" will be a list with length equal to the number of noisy images.
# Each element of the list is a tensor of shape [n_samples, colours, height, width] and is n_samples many signal estimates for a noisy image
predictions = trainer.predict(vae, predict_loader)

# To get an MMSE, we take the mean of each batch returned by predict()
MMSEs = [samples.mean(dim=0) for samples in predictions]

In [None]:
# Select a random image, its corresponding clean estimates and its ground truth
idx = np.random.randint(len(predictions))

noisy_image = predict_set[idx * n_samples][0]
gt = predict_set_gt[idx][0]
MMSE = MMSEs[idx][0]
samples = predictions[idx]

# Select a random sample
sample_idx = np.random.randint(n_samples)
sample = samples[sample_idx, 0]

# Display all three
fig, ax = plt.subplots(2, 2, figsize=[10, 10])
ax[0, 0].imshow(noisy_image, cmap='Greys_r')
ax[0, 0].axis('off')
ax[0, 0].set_title('Input')

ax[0, 1].imshow(gt, cmap='Greys_r')
ax[0, 1].axis('off')
ax[0, 1].set_title('Ground truth')

ax[1, 0].imshow(MMSE, cmap='Greys_r')
ax[1, 0].axis('off')
ax[1, 0].set_title('MMSE')

ax[1, 1].imshow(sample, cmap='Greys_r')
ax[1, 1].axis('off')
ax[1, 1].set_title('Sample')

fig.show()

In [None]:
# Calculate PSNR
psnrs = []
for i in range(len(predict_set_gt)):
    gt = predict_set_gt[i][0].numpy()

    data_range = gt.max() - gt.min()

    noisy = MMSEs[i][0].numpy()

    psnrs.append(PSNR(gt,
                      noisy,
                      data_range=data_range.item()))

print(f"PSNR: {np.mean(psnrs)}")