In [1]:
import os
import numpy as np
from functools import partial
import math
from tqdm import tqdm
import time as time

import torch

M1 = False

if M1:
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
else:
    os.environ["CUDA_VISIBLE_DEVICES"] = "2"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if torch.cuda.is_available():
        print(torch.cuda.is_available())
        print(torch.cuda.device_count())
        print(torch.cuda.current_device())
        print(torch.cuda.get_device_name(torch.cuda.current_device()))


from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

import skimage as ski

import large_scale_UQ as luq
from large_scale_UQ.utils import to_numpy, to_tensor
from convex_reg import utils as utils_cvx_reg

True
1
0
NVIDIA A100-PCIE-40GB
Using device: cuda


In [2]:
# Optimisation options for the MAP estimation
options = {"tol": 1e-5, "iter": 15000, "update_iter": 4999, "record_iters": False}
# Save param
repo_dir = "./../../.."
base_savedir = "/disk/xray99/tl3/proj-convex-UQ/outputs/new_UQ_results/CRR"
save_dir = base_savedir + "/vars/"
savefig_dir = base_savedir + "/figs/"

# Define my torch types (CRR requires torch.float32)
myType = torch.float32
myComplexType = torch.complex64

# CRR load parameters
sigma_training = 5
t_model = 5
CRR_dir_name = "./../../../trained_models/"
# CRR parameters
reg_params = [5e4]
mu = 20


# LCI params
alpha_prob = 0.01

# LCI algorithm parameters (bisection)
LCI_iters = 200
LCI_tol = 1e-4
LCI_bottom = -10
LCI_top = 10

# Compute the MAP-based UQ plots
superpix_MAP_sizes = [16, 8]  # [32, 16, 8, 4]
# Clipping values for MAP-based LCI. Set as None for no clipping
clip_high_val = 1.0
clip_low_val = 0.0

# Compute the sampling UQ plots
superpix_sizes = [32, 16, 8, 4, 1]

# Sampling alg params
frac_delta = 0.98
frac_burnin = 0.1
n_samples = np.int64(5e4)
thinning = np.int64(1e1)
maxit = np.int64(n_samples * thinning * (1.0 + frac_burnin))
# SKROCK params
nStages = 10
eta = 0.05
dt_perc = 0.99

# Plot parameters
cmap = "cubehelix"
nLags = 100


# Img name list
img_name_list = ["W28"]  # ['M31', 'W28', 'CYN', '3c288']
# Input noise level
input_snr = 30.0


save_fig_vals = False

In [3]:
for img_name in img_name_list:
    optim_iters = []
    lci_uq_iters_arr = []

    # %%
    # Load image and mask
    img, mat_mask = luq.helpers.load_imgs(img_name, repo_dir)

    # Aliases
    x = img
    ground_truth = img

    torch_img = torch.tensor(np.copy(img), dtype=myType, device=device).reshape(
        (1, 1) + img.shape
    )

    phi = luq.operators.MaskedFourier_torch(
        shape=img.shape, ratio=0.5, mask=mat_mask, norm="ortho", device=device
    )

    y = phi.dir_op(torch_img).detach().cpu().squeeze().numpy()

    # Define X Cai noise level
    eff_sigma = luq.helpers.compute_complex_sigma_noise(y, input_snr)
    sigma = eff_sigma * np.sqrt(2)

    # Generate noise
    rng = np.random.default_rng(seed=0)
    n_re = rng.normal(0, eff_sigma, y[y != 0].shape)
    n_im = rng.normal(0, eff_sigma, y[y != 0].shape)
    # Add noise
    y[y != 0] += n_re + 1.0j * n_im

    # Observation
    torch_y = torch.tensor(np.copy(y), device=device, dtype=myComplexType).reshape(
        (1,) + img.shape
    )
    x_init = torch.abs(phi.adj_op(torch_y))

    # %%
    # Define the likelihood
    likelihood = luq.operators.L2Norm_torch(
        sigma=sigma,
        data=torch_y,
        Phi=phi,
    )
    # Lipschitz constant computed automatically by likelihood, stored in likelihood.beta

    # Define real prox
    cvx_set_prox_op = luq.operators.RealProx_torch()

    # %%
    # Load CRR model
    torch.set_grad_enabled(False)
    torch.set_num_threads(4)

    exp_name = f"Sigma_{sigma_training}_t_{t_model}/"
    model = utils_cvx_reg.load_model(
        CRR_dir_name + exp_name, "cuda:0", device_type="gpu"
    )

    print(f"Numbers of parameters before prunning: {model.num_params}")
    model.prune()
    print(f"Numbers of parameters after prunning: {model.num_params}")

    # L_CRR = model.L.detach().cpu().squeeze().numpy()
    # print(f"Lipschitz bound {L_CRR:.3f}")

    # [not required] intialize the eigen vector of dimension (size, size) associated to the largest eigen value
    model.initializeEigen(size=100)
    # compute bound via a power iteration which couples the activations and the convolutions
    model.precise_lipschitz_bound(n_iter=100)
    # the bound is stored in the model
    L_CRR = model.L.data.item()
    print(f"Lipschitz bound {L_CRR:.3f}")

    # %
    for it_1 in range(len(reg_params)):
        # Prior parameters
        lmbd = reg_params[it_1]

        # Compute stepsize
        alpha = 0.98 / (likelihood.beta + mu * lmbd * L_CRR)

        # initialization
        x_hat = torch.clone(x_init)
        z = torch.clone(x_init)
        t = 1

        for it_2 in range(options["iter"]):
            x_hat_old = torch.clone(x_hat)
            x_hat = z - alpha * (likelihood.grad(z) + lmbd * model(mu * z))
            # Reality constraint
            x_hat = cvx_set_prox_op.prox(x_hat)
            # Positivity constraint
            # x = torch.clamp(x, 0, None)

            t_old = t
            t = 0.5 * (1 + math.sqrt(1 + 4 * t**2))
            z = x_hat + (t_old - 1) / t * (x_hat - x_hat_old)

            # relative change of norm for terminating
            res = (torch.norm(x_hat_old - x_hat) / torch.norm(x_hat_old)).item()

            if res < options["tol"]:
                print("[GD] converged in %d iterations" % (it_2))
                break

            if it_2 % options["update_iter"] == 0:
                print(
                    "[GD] %d out of %d iterations, tol = %f"
                    % (
                        it_2,
                        options["iter"],
                        res,
                    )
                )

        optim_iters.append(it_2)

        # %%
        np_x_init = to_numpy(x_init)
        np_x = np.copy(x)
        np_x_hat = to_numpy(x_hat)

        images = [np_x, np_x_init, np_x_hat, np_x - np.abs(np_x_hat)]

        # %%
        labels = ["Truth", "Dirty", "Reconstruction", "Residual (x - x^hat)"]
        fig, axs = plt.subplots(1, 4, figsize=(20, 8), dpi=200)
        for i in range(4):
            im = axs[i].imshow(
                images[i],
                cmap=cmap,
                vmax=np.nanmax(images[i]),
                vmin=np.nanmin(images[i]),
            )
            divider = make_axes_locatable(axs[i])
            cax = divider.append_axes("right", size="5%", pad=0.05)
            fig.colorbar(im, cax=cax, orientation="vertical")
            if i == 0:
                stats_str = "\nRegCost {:.3f}".format(
                    model.cost(to_tensor(mu * images[i], device=device))[0].item()
                )
            if i > 0:
                stats_str = "\n(PSNR: {:.2f}, SNR: {:.2f},\nSSIM: {:.2f}, RegCost: {:.3f})".format(
                    psnr(np_x, images[i], data_range=np_x.max() - np_x.min()),
                    luq.utils.eval_snr(x, images[i]),
                    ssim(np_x, images[i], data_range=np_x.max() - np_x.min()),
                    model.cost(to_tensor(mu * images[i], device=device))[0].item(),
                )
            labels[i] += stats_str
            axs[i].set_title(labels[i], fontsize=16)
            axs[i].axis("off")
        if save_fig_vals:
            plt.savefig(
                "{:s}{:s}_lmbd_{:.1e}_optim_MAP.pdf".format(savefig_dir, img_name, lmbd)
            )
        plt.close()

        ### MAP-based UQ

        # function handles to used for ULA
        def _fun(_x, model, mu, lmbd):
            return (lmbd / mu) * model.cost(mu * _x) + likelihood.fun(_x)

        def _grad_fun(_x, likelihood, model, mu, lmbd):
            return torch.real(likelihood.grad(_x) + lmbd * model(mu * _x))

        def _prior_fun(_x, model, mu, lmbd):
            return (lmbd / mu) * model.cost(mu * _x)

        # Evaluation of the potentials
        fun = partial(_fun, model=model, mu=mu, lmbd=lmbd)
        prior_fun = partial(_prior_fun, model=model, mu=mu, lmbd=lmbd)
        # Evaluation of the gradient
        grad_f = partial(
            _grad_fun, likelihood=likelihood, model=model, mu=mu, lmbd=lmbd
        )
        # Evaluation of the potential in numpy
        fun_np = lambda _x: fun(luq.utils.to_tensor(_x, dtype=myType)).item()

        # Compute HPD region bound
        N = np_x_hat.size
        tau_alpha = np.sqrt(16 * np.log(3 / alpha_prob))
        gamma_alpha = fun(x_hat).item() + tau_alpha * np.sqrt(N) + N

        error_p_arr = []
        error_m_arr = []
        mean_img_arr = []
        computing_time = []

        x_init_np = luq.utils.to_numpy(x_init)

        # Compute ground truth block
        gt_mean_img_arr = []
        for superpix_size in superpix_MAP_sizes:
            mean_image = ski.measure.block_reduce(
                np.copy(img), block_size=(superpix_size, superpix_size), func=np.mean
            )
            gt_mean_img_arr.append(mean_image)

        # Define prefix
        save_MAP_prefix = "{:s}_CRR_UQ_MAP_lmbd_{:.1e}".format(img_name, lmbd)

        for it_pixs, superpix_size in enumerate(superpix_MAP_sizes):
            pr_time_1 = time.process_time()
            wall_time_1 = time.time()

            (
                error_p,
                error_m,
                mean,
                lci_iters_cumul,
            ) = luq.map_uncertainty.create_local_credible_interval(
                x_sol=np_x_hat,
                region_size=superpix_size,
                function=fun_np,
                bound=gamma_alpha,
                iters=LCI_iters,
                tol=LCI_tol,
                bottom=LCI_bottom,
                top=LCI_top,
                return_iters=True,
            )
            pr_time_2 = time.process_time()
            wall_time_2 = time.time()

            # Save iteration number
            lci_uq_iters_arr.append(lci_iters_cumul)

            # Add values to array to save it later
            error_p_arr.append(np.copy(error_p))
            error_m_arr.append(np.copy(error_m))
            mean_img_arr.append(np.copy(mean))
            computing_time.append((pr_time_2 - pr_time_1, wall_time_2 - wall_time_1))
            # Clip plot values
            error_length = luq.utils.clip_matrix(
                np.copy(error_p), clip_low_val, clip_high_val
            ) - luq.utils.clip_matrix(np.copy(error_m), clip_low_val, clip_high_val)
            # Recover the ground truth mean
            gt_mean = gt_mean_img_arr[it_pixs]

            vmin = np.min((gt_mean, mean, error_length))
            vmax = np.max((gt_mean, mean, error_length))
            # err_vmax= 0.6

            # Plot UQ
            fig = plt.figure(figsize=(24, 5))

            plt.subplot(141)
            ax = plt.gca()
            ax.set_title("MAP estimation,\n superpix = {:d}".format(superpix_size))
            im = ax.imshow(mean, cmap=cmap, vmin=vmin, vmax=vmax)
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            fig.colorbar(im, cax=cax, orientation="vertical")
            ax.set_yticks([])
            ax.set_xticks([])

            plt.subplot(142)
            ax = plt.gca()
            ax.set_title(
                "Residual (GT - MAP),\n RMSE = {:.3e}".format(
                    np.sqrt(np.sum((gt_mean - mean) ** 2))
                )
            )
            im = ax.imshow(gt_mean - mean, cmap=cmap)
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            fig.colorbar(im, cax=cax, orientation="vertical")
            ax.set_yticks([])
            ax.set_xticks([])

            plt.subplot(143)
            ax = plt.gca()
            ax.set_title(
                "LCI (max={:.5f})\n (<LCI>={:.5f})".format(
                    np.max(error_length), np.mean(error_length)
                )
            )
            im = ax.imshow(error_length, cmap=cmap, vmin=vmin, vmax=vmax)
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            fig.colorbar(im, cax=cax, orientation="vertical")
            ax.set_yticks([])
            ax.set_xticks([])

            plt.subplot(144)
            ax = plt.gca()
            ax.set_title("LCI - min(LCI)")
            im = ax.imshow(
                error_length - np.min(error_length), cmap=cmap, vmin=vmin, vmax=vmax
            )
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            fig.colorbar(im, cax=cax, orientation="vertical")
            ax.set_yticks([])
            ax.set_xticks([])
            if save_fig_vals:
                plt.savefig(
                    savefig_dir
                    + save_MAP_prefix
                    + "_UQ-MAP_pixel_size_{:d}.pdf".format(superpix_size)
                )
            plt.close()

        print(
            "f(x_map): ",
            likelihood.fun(x_hat).item(),
            "\ng(x_map): ",
            prior_fun(x_hat).item(),
            "\ntau_alpha*np.sqrt(N): ",
            tau_alpha * np.sqrt(N),
            "\nN: ",
            N,
        )
        print("tau_alpha: ", tau_alpha)
        print("gamma_alpha: ", gamma_alpha.item())
        #
        opt_params = {
            "lmbd": lmbd,
            "mu": mu,
            "sigma_training": sigma_training,
            "t_model": t_model,
            "sigma_noise": sigma,
            "eff_sigma_noise": eff_sigma,
            "opt_tol": options["tol"],
            "opt_max_iter": options["iter"],
        }
        hpd_results = {
            "alpha": alpha_prob,
            "gamma_alpha": gamma_alpha,
            "f_xmap": likelihood.fun(x_hat).item(),
            "g_xmap": prior_fun(x_hat).item(),
            "h_alpha_N": tau_alpha * np.sqrt(N) + N,
        }
        LCI_params = {
            "iters": LCI_iters,
            "tol": LCI_tol,
            "bottom": LCI_bottom,
            "top": LCI_top,
            "clip_low_val": clip_low_val,
            "clip_high_val": clip_high_val,
        }
        save_map_vars = {
            "x_ground_truth": img,
            "x_map": np_x_hat,
            "x_init": np_x_init,
            "opt_params": opt_params,
            "hpd_results": hpd_results,
            "error_p_arr": error_p_arr,
            "error_m_arr": error_m_arr,
            "mean_img_arr": mean_img_arr,
            "gt_mean_img_arr": gt_mean_img_arr,
            "computing_time": computing_time,
            "superpix_sizes": superpix_MAP_sizes,
            "LCI_params": LCI_params,
        }

--- loading checkpoint from epoch 10 ---
---------------------
Building a CRR-NN model with 
 - [1, 8, 32] channels 
 - linear_spline activation functions
  (LinearSpline(mode=conv, num_activations=32, init=zero, size=21, grid=0.010, monotonic_constraint=True.))
---------------------
Numbers of parameters before prunning: 13610
---------------------
 PRUNNING 
 Found 22 filters with non-vanishing potential functions
---------------------
Numbers of parameters after prunning: 4183
Lipschitz bound 0.770
[GD] 0 out of 15000 iterations, tol = 0.103799
[GD] converged in 541 iterations
-----------------------
Updating spline coefficients for the reg cost
 (the gradient-step model is trained and intergration is required to compute the regularization cost)
-----------------------
Calculating credible interval for superpxiel:  (256, 256)
[Bisection Method] There is no root in this range.
[Bisection Method] There is no root in this range.
[Bisection Method] There is no root in this range.
[Bisec

In [4]:
print("Iteration number for W28")

print("Optimisation iterations: ", optim_iters[0])
print("LCI iterations 16x16 super pixe: ", lci_uq_iters_arr[0])
print("LCI iterations 8x8 super pixe: ", lci_uq_iters_arr[1])

Iteration number for W28
Optimisation iterations:  541
LCI iterations 16x16 super pixe:  21188
LCI iterations 8x8 super pixe:  81540
