In [1]:
import os
import numpy as np
from functools import partial
import math
from tqdm import tqdm
import time as time

import torch
M1 = False

if M1:
    device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
else:
    os.environ["CUDA_VISIBLE_DEVICES"]="2"
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if torch.cuda.is_available():
        print(torch.cuda.is_available())
        print(torch.cuda.device_count())
        print(torch.cuda.current_device())
        print(torch.cuda.get_device_name(torch.cuda.current_device()))


from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from torchmetrics.functional import structural_similarity_index_measure 
from torchmetrics.functional import peak_signal_noise_ratio 

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

import scipy.io as sio
from astropy.io import fits
import skimage as ski

import large_scale_UQ as luq
from large_scale_UQ.utils import to_numpy, to_tensor
from convex_reg import utils as utils_cvx_reg


True
1
0
NVIDIA A100-PCIE-40GB
Using device: cuda


In [2]:
# Optimisation options for the MAP estimation
options = {"tol": 1e-5, "iter": 15000, "update_iter": 4999, "record_iters": False}
# Save param
repo_dir = '/disk/xray0/tl3/repos/large-scale-UQ'
base_savedir = '/disk/xray99/tl3/proj-convex-UQ/outputs/new_UQ_results/wavelets'
save_dir = base_savedir + '/vars/'
savefig_dir = base_savedir + '/figs/'

# Define my torch types (CRR requires torch.float32, wavelets require torch.float64)
myType = torch.float64
myComplexType = torch.complex128

# Wavelet parameters
reg_params = [5e2] # [5e2, 5e1, 1e3, 5e3, 1e4, 5e4]
wavs_list = ['db8']
levels = 4

# LCI params
alpha_prob = 0.01

# LCI algorithm parameters (bisection)
LCI_iters = 200
LCI_tol = 1e-4
LCI_bottom = -10
LCI_top = 10

# Compute the MAP-based UQ plots
superpix_MAP_sizes = [16, 8]  # [32, 16, 8, 4]
# Clipping values for MAP-based LCI. Set as None for no clipping
clip_high_val = 1.
clip_low_val = 0.

# Compute the sampling UQ plots
superpix_sizes = [32,16,8,4,1]

# Sampling alg params
frac_delta = 0.98
frac_burnin = 0.1
n_samples = np.int64(5e4)
thinning = np.int64(1e1)
maxit = np.int64(n_samples * thinning * (1. + frac_burnin))
# SKROCK params
nStages = 10
eta = 0.05
dt_perc = 0.99

# Plot parameters
cmap = 'cubehelix'
nLags = 100


# Img name list
img_name_list = ['W28']  # ['M31', 'W28', 'CYN', '3c288']
# Input noise level
input_snr = 30.


save_fig_vals = False



In [3]:
for img_name in img_name_list:

    optim_iters = []
    lci_uq_iters_arr = []

    # %%
    # Load image and mask
    img, mat_mask = luq.helpers.load_imgs(img_name, repo_dir)

    # Aliases
    x = img
    ground_truth = img

    torch_img = torch.tensor(
        np.copy(img), dtype=myType, device=device).reshape((1,1) + img.shape
    )

    phi = luq.operators.MaskedFourier_torch(
        shape=img.shape, 
        ratio=0.5 ,
        mask=mat_mask,
        norm='ortho',
        device=device
    )

    y = phi.dir_op(torch_img).detach().cpu().squeeze().numpy()

    # Define X Cai noise level
    eff_sigma = luq.helpers.compute_complex_sigma_noise(y, input_snr)
    sigma = eff_sigma * np.sqrt(2)

    # Generate noise
    rng = np.random.default_rng(seed=0)
    n_re = rng.normal(0, eff_sigma, y[y!=0].shape)
    n_im = rng.normal(0, eff_sigma, y[y!=0].shape)
    # Add noise
    y[y!=0] += (n_re + 1.j*n_im)

    # Observation
    torch_y = torch.tensor(np.copy(y), device=device, dtype=myComplexType).reshape((1,) + img.shape)
    x_init = torch.abs(phi.adj_op(torch_y))


    # %%
    # Define the likelihood
    g = luq.operators.L2Norm_torch(
        sigma=sigma,
        data=torch_y,
        Phi=phi,
    )
    # Lipschitz constant computed automatically by g, stored in g.beta

    # Define real prox
    f = luq.operators.RealProx_torch()

    # %%

    for it_1 in range(len(reg_params)):

        # Prior parameters
        reg_param = reg_params[it_1]

        # Define the wavelet dict
        # Define the l1 norm with dict psi
        psi = luq.operators.DictionaryWv_torch(wavs_list, levels)
        h = luq.operators.L1Norm_torch(1., psi, op_to_coeffs=True)
        h.gamma = reg_param

        # Compute stepsize
        alpha = 0.98 / g.beta

        # Effective threshold
        print('Threshold: ', h.gamma * alpha)

        # Run the optimisation
        x_hat, diagnostics = luq.optim.FB_torch(
            x_init,
            options=options,
            g=g,
            f=f,
            h=h,
            alpha=alpha,
            tau=alpha,
            viewer=None
        )

        # %%
        np_x_init = to_numpy(x_init)
        np_x = np.copy(x)
        np_x_hat = to_numpy(x_hat)


        # %%
        images = [np_x, np_x_init, np_x_hat, np_x - np.abs(np_x_hat)]
        labels = ["Truth", "Dirty", "Reconstruction", "Residual (x - x^hat)"]
        fig, axs = plt.subplots(1,4, figsize=(24,6), dpi=200)
        for i in range(4):
            im = axs[i].imshow(images[i], cmap='cubehelix', vmax=np.nanmax(images[i]), vmin=np.nanmin(images[i]))
            divider = make_axes_locatable(axs[i])
            cax = divider.append_axes('right', size='5%', pad=0.05)
            fig.colorbar(im, cax=cax, orientation='vertical')
            if i > 0:   
                stats_str = '\n(PSNR: {},\n SNR: {}, SSIM: {})'.format(
                    round(psnr(ground_truth, images[i], data_range=ground_truth.max()-ground_truth.min()), 2),
                    round(luq.utils.eval_snr(x, images[i]), 2),
                    round(ssim(ground_truth, images[i], data_range=ground_truth.max()-ground_truth.min()), 2),
                    )
                labels[i] += stats_str
                print(labels[i])
            axs[i].set_title(labels[i], fontsize=16)
            axs[i].axis('off')
        # plt.savefig('{:s}{:s}_SKROCK_wavelets_reg_param_{:.1e}_optim_MAP.pdf'.format(savefig_dir, img_name, reg_param))
        plt.close()


        ### MAP-based UQ

        # Define prior potential
        fun_prior = lambda _x : h._fun_coeffs(h.dir_op(_x))
        # Define posterior potential
        loss_fun_torch = lambda _x : g.fun(_x) +  fun_prior(_x)
        # Numpy version of the posterior potential
        loss_fun_np = lambda _x : g.fun(
            luq.utils.to_tensor(_x, dtype=myType)
        ).item() +  fun_prior(luq.utils.to_tensor(_x, dtype=myType)).item()

        # Compute HPD region bound
        N = np_x_hat.size
        tau_alpha = np.sqrt(16*np.log(3/alpha_prob))
        gamma_alpha = loss_fun_torch(x_hat).item() + tau_alpha*np.sqrt(N) + N


        # Compute the LCI
        error_p_arr = []
        error_m_arr = []
        mean_img_arr = []
        computing_time = []

        x_init_np = luq.utils.to_numpy(x_init)

        # Compute ground truth block 
        gt_mean_img_arr = []
        for superpix_size in superpix_MAP_sizes:
            mean_image = ski.measure.block_reduce(
                np.copy(img), block_size=(superpix_size, superpix_size), func=np.mean
            )
            gt_mean_img_arr.append(mean_image)

        # Define prefix
        save_MAP_prefix = '{:s}_wavelets_UQ_MAP_reg_param_{:.1e}'.format(img_name, reg_param)

        for it_pixs, superpix_size in enumerate(superpix_MAP_sizes):

            pr_time_1 = time.process_time()
            wall_time_1 = time.time()

            error_p, error_m, mean, lci_iters_cumul = luq.map_uncertainty.create_local_credible_interval(
                x_sol=np_x_hat,
                region_size=superpix_size,
                function=loss_fun_np,
                bound=gamma_alpha,
                iters=LCI_iters,
                tol=LCI_tol,
                bottom=LCI_bottom,
                top=LCI_top,
                return_iters=True,
            )
            pr_time_2 = time.process_time()
            wall_time_2 = time.time()

            # Save iteration number
            lci_uq_iters_arr.append(lci_iters_cumul)

            # Add values to array to save it later
            error_p_arr.append(np.copy(error_p))
            error_m_arr.append(np.copy(error_m))
            mean_img_arr.append(np.copy(mean))
            computing_time.append((
                pr_time_2 - pr_time_1, 
                wall_time_2 - wall_time_1
            ))
            # Clip plot values
            error_length = luq.utils.clip_matrix(
                np.copy(error_p), clip_low_val, clip_high_val
            ) - luq.utils.clip_matrix(
                np.copy(error_m), clip_low_val, clip_high_val
            )
            # Recover the ground truth mean
            gt_mean = gt_mean_img_arr[it_pixs]

            vmin = np.min((gt_mean, mean, error_length))
            vmax = np.max((gt_mean, mean, error_length))

            # Plot UQ
            fig = plt.figure(figsize=(24,5))

            plt.subplot(141)
            ax = plt.gca()
            ax.set_title('MAP estimation,\n superpix = {:d}'.format(superpix_size))
            im = ax.imshow(mean, cmap=cmap, vmin=vmin, vmax=vmax)
            divider = make_axes_locatable(ax)
            cax = divider.append_axes('right', size='5%', pad=0.05)
            fig.colorbar(im, cax=cax, orientation='vertical')
            ax.set_yticks([]);ax.set_xticks([])

            plt.subplot(142)
            ax = plt.gca()
            ax.set_title(
                'Residual (GT - MAP),\n RMSE = {:.3e}'.format(
                    np.sqrt(np.sum((gt_mean - mean)**2))
                )
            )
            im = ax.imshow(gt_mean - mean, cmap=cmap)
            divider = make_axes_locatable(ax)
            cax = divider.append_axes('right', size='5%', pad=0.05)
            fig.colorbar(im, cax=cax, orientation='vertical')
            ax.set_yticks([]);ax.set_xticks([])

            plt.subplot(143)
            ax = plt.gca()
            ax.set_title('LCI (max={:.5f})\n (<LCI>={:.5f})'.format(
                    np.max(error_length), np.mean(error_length)
                )
            )
            im = ax.imshow(error_length, cmap=cmap, vmin=vmin, vmax=vmax)
            divider = make_axes_locatable(ax)
            cax = divider.append_axes('right', size='5%', pad=0.05)
            fig.colorbar(im, cax=cax, orientation='vertical')
            ax.set_yticks([]);ax.set_xticks([])

            plt.subplot(144)
            ax = plt.gca()
            ax.set_title('LCI - min(LCI)')
            im = ax.imshow(error_length - np.min(error_length), cmap=cmap, vmin=vmin, vmax=vmax)
            divider = make_axes_locatable(ax)
            cax = divider.append_axes('right', size='5%', pad=0.05)
            fig.colorbar(im, cax=cax, orientation='vertical')
            ax.set_yticks([]);ax.set_xticks([])

            # plt.savefig(
            #     savefig_dir+save_MAP_prefix+'_UQ-MAP_pixel_size_{:d}.pdf'.format(superpix_size)
            # )
            plt.close()


        print(
            'f(x_map): ', g.fun(x_hat).item(),
            '\ng(x_map): ', fun_prior(x_hat).item(),
            '\ntau_alpha*np.sqrt(N): ', tau_alpha*np.sqrt(N),
            '\nN: ', N,
        )
        print('tau_alpha: ', tau_alpha)
        print('gamma_alpha: ', gamma_alpha.item())
        # 
        opt_params = {
            'wav': wavs_list,
            'levels': levels,
            'reg_param': reg_param,
            'sigma_noise': sigma,
            'opt_tol': options['tol'],
            'opt_max_iter': options['iter'],
        }
        hpd_results = {
            'alpha': alpha_prob,
            'gamma_alpha': gamma_alpha,
            'f_xmap': g.fun(x_hat).item(),
            'g_xmap': fun_prior(x_hat).item(),
            'h_alpha_N': tau_alpha*np.sqrt(N) + N,
        }
        LCI_params ={
            'iters': LCI_iters,
            'tol': LCI_tol,
            'bottom': LCI_bottom,
            'top': LCI_top,
            'top': LCI_top,
            'clip_low_val': clip_low_val,
            'clip_high_val': clip_high_val,
        }
        save_map_vars = {
            'x_ground_truth': img,
            'x_map': np_x_hat,
            'x_init': np_x_init,
            'opt_params': opt_params,
            'hpd_results': hpd_results,
            'error_p_arr': error_p_arr,
            'error_m_arr': error_m_arr,
            'mean_img_arr': mean_img_arr,
            'gt_mean_img_arr': gt_mean_img_arr,
            'computing_time': computing_time,
            'superpix_sizes': superpix_MAP_sizes,
            'LCI_params': LCI_params,
        }
        # We will overwrite the dict with new results
        # try:
        #     saving_map_path = save_dir + save_MAP_prefix + '_MAP_vars.npy'
        #     if os.path.isfile(saving_map_path):
        #         os.remove(saving_map_path)
        #     np.save(saving_map_path, save_map_vars, allow_pickle=True)
        # except Exception as e:
        #     print('Could not save vairables. Exception caught: ', e)    




Threshold:  0.002599902946466276
Running Base Forward Backward
[Forward Backward] 0 out of 15000 iterations, tol = 5.18e-01
[Forward Backward] converged in 207 iterations
Dirty
(PSNR: 32.25,
 SNR: 3.39, SSIM: 0.58)
Reconstruction
(PSNR: 49.35,
 SNR: 20.49, SSIM: 0.99)
Residual (x - x^hat)
(PSNR: 29.07,
 SNR: 0.21, SSIM: 0.93)
Calculating credible interval for superpxiel:  (256, 256)
[Bisection Method] There is no root in this range.
[Bisection Method] There is no root in this range.
[Bisection Method] There is no root in this range.
[Bisection Method] There is no root in this range.
[Bisection Method] There is no root in this range.
[Bisection Method] There is no root in this range.
[Bisection Method] There is no root in this range.
[Bisection Method] There is no root in this range.
Calculating credible interval for superpxiel:  (256, 256)
[Bisection Method] There is no root in this range.
[Bisection Method] There is no root in this range.
[Bisection Method] There is no root in this ra

In [4]:


print('Iteration number for W28')

# print('Optimisation iterations: ', optim_iters[0])
print('LCI iterations 16x16 super pixe: ', lci_uq_iters_arr[0])
print('LCI iterations 8x8 super pixe: ', lci_uq_iters_arr[1])



Iteration number for W28
LCI iterations 16x16 super pixe:  21202
LCI iterations 8x8 super pixe:  81576
