In [1]:
import os
import numpy as np
from functools import partial
import math
from tqdm import tqdm
import time as time

import torch
M1 = False

if M1:
    device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
else:
    os.environ["CUDA_VISIBLE_DEVICES"]="0"
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if torch.cuda.is_available():
        print(torch.cuda.is_available())
        print(torch.cuda.device_count())
        print(torch.cuda.current_device())
        print(torch.cuda.get_device_name(torch.cuda.current_device()))


from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from torchmetrics.functional import structural_similarity_index_measure 
from torchmetrics.functional import peak_signal_noise_ratio 

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.ticker as tick

import scipy.io as sio
from astropy.io import fits
import skimage as ski

import large_scale_UQ as luq
from large_scale_UQ.utils import to_numpy, to_tensor
from convex_reg import utils as utils_cvx_reg



True
1
0
NVIDIA A100-PCIE-40GB
Using device: cuda


In [4]:

repo_dir = '/disk/xray0/tl3/repos/large-scale-UQ'
CRR_save_dir = '/disk/xray0/tl3/outputs/large-scale-UQ/def_UQ_results/CRR/new_pixel_UQ/'
wav_save_dir = '/disk/xray0/tl3/outputs/large-scale-UQ/def_UQ_results/wavelets/paper_figs_new_UQ/'
load_var_dir = '/disk/xray0/tl3/outputs/large-scale-UQ/def_UQ_results/CRR/vars/'
wav_load_var_dir = '/disk/xray0/tl3/outputs/large-scale-UQ/def_UQ_results/wavelets/vars/'

cmap = 'cubehelix'
cbar_font_size = 18

map_vars_path_arr = [
    load_var_dir+'CYN_CRR_UQ_MAP_lmbd_5.0e+04_MAP_vars.npy',
    load_var_dir+'M31_CRR_UQ_MAP_lmbd_5.0e+04_MAP_vars.npy',
    load_var_dir+'3c288_CRR_UQ_MAP_lmbd_5.0e+04_MAP_vars.npy',
    load_var_dir+'W28_CRR_UQ_MAP_lmbd_5.0e+04_MAP_vars.npy',
]
samp_vars_path_arr = [
    load_var_dir+'CYN_SKROCK_CRR_lmbd_5.0e+04_mu_2.0e+01_nsamples_5.0e+04_thinning_1.0e+01_vars.npy',
    load_var_dir+'M31_SKROCK_CRR_lmbd_5.0e+04_mu_2.0e+01_nsamples_5.0e+04_thinning_1.0e+01_vars.npy',
    load_var_dir+'3c288_SKROCK_CRR_lmbd_5.0e+04_mu_2.0e+01_nsamples_5.0e+04_thinning_1.0e+01_vars.npy',
    load_var_dir+'W28_SKROCK_CRR_lmbd_5.0e+04_mu_2.0e+01_nsamples_5.0e+04_thinning_1.0e+01_vars.npy',    
]

img_name_arr = [
    'CYN',
    'M31',
    '3c288',
    'W28',
]
vmin_log_arr = [
    -3.,
    -2.,
    -2.,
    -2.,
]

options = {"tol": 1e-5, "iter": 15000, "update_iter": 4999, "record_iters": False}

# CRR load parameters
sigma_training = 5
t_model = 5
CRR_dir_name = '/disk/xray0/tl3/repos/convex_ridge_regularizers/trained_models/'
# CRR parameters
reg_param = 5e4
mu = 20

# Define my torch types (CRR requires torch.float32)
myType = torch.float32
myComplexType = torch.complex64

# Parameters
alpha_prob = 0.01

# Define the wavelet parameters for UQ maps
wavs_list = ['db8']
levels = 4
# Parameters for UQ map
start_interval = [0, 10]
iters = 5e2
tol = 1e-2

model_prefix = '-CRR'


In [None]:

for it in range(len(img_name_arr)):
    # Set paths
    if model_prefix == '-CRR':
        img_name = img_name_arr[it]
        # map_vars_path = map_vars_path_arr[it]
        # samp_vars_path = samp_vars_path_arr[it]
        vmin_log = vmin_log_arr[it]
        save_dir = CRR_save_dir
        save_var_dir = load_var_dir


    # Load image and mask
    img, mat_mask = luq.helpers.load_imgs(img_name, repo_dir)
    # Aliases
    x = img
    ground_truth = img
    # Convert Torch
    torch_img = torch.tensor(
        np.copy(img), dtype=myType, device=device).reshape((1,1) + img.shape
    )
    # Init Fourier masl op
    phi = luq.operators.MaskedFourier_torch(
        shape=img.shape, 
        ratio=0.5 ,
        mask=mat_mask,
        norm='ortho',
        device=device
    )
    # Define X Cai noise level
    sigma = 0.0024
    y = phi.dir_op(torch_img).detach().cpu().squeeze().numpy()
    # Generate noise
    rng = np.random.default_rng(seed=0)
    n = rng.normal(0, sigma, y[y!=0].shape)
    # Add noise
    y[y!=0] += n
    # Observation
    torch_y = torch.tensor(np.copy(y), device=device, dtype=myComplexType).reshape((1,) + img.shape)
    x_init = torch.abs(phi.adj_op(torch_y))
    # Define the likelihood
    g = luq.operators.L2Norm_torch(
        sigma=sigma,
        data=torch_y,
        Phi=phi,
    )
    # Lipschitz constant computed automatically by g, stored in g.beta
    # Define real prox
    f = luq.operators.RealProx_torch()


    # Load CRR model
    torch.set_grad_enabled(False)
    torch.set_num_threads(4)

    exp_name = f'Sigma_{sigma_training}_t_{t_model}/'
    model = utils_cvx_reg.load_model(CRR_dir_name + exp_name, 'cuda:0', device_type='gpu')

    print(f'Numbers of parameters before prunning: {model.num_params}')
    model.prune()
    print(f'Numbers of parameters after prunning: {model.num_params}')

    # L_CRR = model.L.detach().cpu().squeeze().numpy()
    # print(f"Lipschitz bound {L_CRR:.3f}")

    # [not required] intialize the eigen vector of dimension (size, size) associated to the largest eigen value
    model.initializeEigen(size=100)
    # compute bound via a power iteration which couples the activations and the convolutions
    model.precise_lipschitz_bound(n_iter=100)
    # the bound is stored in the model
    L_CRR = model.L.data.item()
    print(f"Lipschitz bound {L_CRR:.3f}")
    

    ## Compute MAP solution
    # Prior parameters
    lmbd = reg_param

    # Compute stepsize
    alpha = 0.98 / (g.beta + mu * lmbd * L_CRR)

    # initialization
    x_hat = torch.clone(x_init)
    z = torch.clone(x_init)
    t = 1

    # Accelerated gradient descend
    for it_2 in range(options['iter']):
        x_hat_old = torch.clone(x_hat)
        
        x_hat = z - alpha *(
            g.grad(z) + lmbd * model(mu * z)
        )
        # Reality constraint
        x_hat = f.prox(x_hat)
        
        t_old = t 
        t = 0.5 * (1 + math.sqrt(1 + 4*t**2))
        z = x_hat + (t_old - 1)/t * (x_hat - x_hat_old)

        # relative change of norm for terminating
        res = (torch.norm(x_hat_old - x_hat)/torch.norm(x_hat_old)).item()

        if res < options['tol']:
            print("[GD] converged in %d iterations"%(it_2))
            break

        if it_2 % options['update_iter'] == 0:
            print(
                "[GD] %d out of %d iterations, tol = %f" %(            
                    it_2,
                    options['iter'],
                    res,
                )
            )


    # Save MAP
    np_x_hat = to_numpy(x_hat)
    np_x = np.copy(x)
    # Evaluate performance
    print(img_name, ' PSNR: ', psnr(np_x, np_x_hat, data_range=np_x.max()-np_x.min()))
    print(img_name, ' SNR: ', luq.utils.eval_snr(x, np_x_hat))


    # Function handle for the potential
    def _fun(_x, model, mu, lmbd):
        return (lmbd / mu) * model.cost(mu * _x) + g.fun(_x)

    # Evaluation of the potential
    fun = partial(_fun, model=model, mu=mu, lmbd=lmbd)
    # Evaluation of the potential in numpy
    fun_np = lambda _x : fun(luq.utils.to_tensor(_x, dtype=myType)).item()

    # Compute HPD region bound
    N = np_x_hat.size
    tau_alpha = np.sqrt(16*np.log(3/alpha_prob))
    gamma_alpha = fun(x_hat).item() + tau_alpha*np.sqrt(N) + N

    
    # Define the wavelet dict
    # Define the l1 norm with dict psi
    Psi = luq.operators.DictionaryWv_torch(wavs_list, levels)
    oper2wavelet = luq.operators.Operation2WaveletCoeffs_torch(Psi=Psi)

    # Clone MAP estimation and cast type for wavelet operations
    torch_map = torch.clone(x_hat).to(torch.float64)
    torch_x = to_tensor(np_x).to(torch.float64)

    
    def _potential_to_bisect(thresh, fun_np, oper2wavelet, torch_map):

        thresh_img = oper2wavelet.full_op_threshold_img(torch_map, thresh)

        return gamma_alpha - fun_np(thresh_img)

    # Evaluation of the potential
    potential_to_bisect = partial(
        _potential_to_bisect,
        fun_np=fun_np,
        oper2wavelet=oper2wavelet,
        torch_map=torch_map
    )


    selected_thresh = luq.map_uncertainty.bisection_method(
        potential_to_bisect, start_interval, iters, tol
    )
    select_thresh_img = oper2wavelet.full_op_threshold_img(
        torch_map, selected_thresh
    )
    print('selected_thresh: ', selected_thresh)
    print('gamma_alpha: ', gamma_alpha)
    print('MAP image: ', fun_np(torch_map.squeeze()))
    print('thresholded image: ', fun_np(select_thresh_img))

    # Plot MAP
    fig = plt.figure(figsize=(5,5), dpi=200)
    axs = plt.gca()
    im_log = np.log10(np.abs(np_x_hat))
    plt_im = axs.imshow(im_log, cmap=cmap, vmin=vmin_log, vmax=0)
    divider = make_axes_locatable(axs)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    cbar = fig.colorbar(plt_im, cax=cax)
    cbar.ax.yaxis.set_major_formatter(tick.FormatStrFormatter('%.2f'))
    cbar.ax.tick_params(labelsize=cbar_font_size)
    axs.set_yticks([]);axs.set_xticks([])
    plt.tight_layout()
    plt.savefig(
        '{:s}{:s}{:s}{:s}'.format(
            save_dir, img_name, model_prefix, '-newPixelUQ-MAP.pdf'
        ),
        bbox_inches='tight',
        dpi=200
    )
    plt.show()

    # Plot Thresholded image
    fig = plt.figure(figsize=(5,5), dpi=200)
    axs = plt.gca()
    im_log = np.log10(np.abs(to_numpy(select_thresh_img)))
    plt_im = axs.imshow(im_log, cmap=cmap, vmin=vmin_log, vmax=0)
    divider = make_axes_locatable(axs)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    cbar = fig.colorbar(plt_im, cax=cax)
    cbar.ax.yaxis.set_major_formatter(tick.FormatStrFormatter('%.2f'))
    cbar.ax.tick_params(labelsize=cbar_font_size)
    axs.set_yticks([]);axs.set_xticks([])
    plt.tight_layout()
    plt.savefig(
        '{:s}{:s}{:s}{:s}'.format(
            save_dir, img_name, model_prefix, '-newPixelUQ-ThresholdedImage.pdf'
        ),
        bbox_inches='tight',
        dpi=200
    )
    plt.show()

    # Plot MAP - Thresholded error
    fig = plt.figure(figsize=(5,5), dpi=200)
    axs = plt.gca()
    im_log = np.log10(np.abs(to_numpy(torch_map - select_thresh_img)))
    plt_im = axs.imshow(im_log, cmap=cmap, vmin=vmin_log-2, vmax=0)
    divider = make_axes_locatable(axs)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    cbar = fig.colorbar(plt_im, cax=cax)
    cbar.ax.yaxis.set_major_formatter(tick.FormatStrFormatter('%.2f'))
    cbar.ax.tick_params(labelsize=cbar_font_size)
    axs.set_yticks([]);axs.set_xticks([])
    plt.tight_layout()
    plt.savefig(
        '{:s}{:s}{:s}{:s}'.format(
            save_dir, img_name, model_prefix, '-newPixelUQ-MAP_thresholded_error.pdf'
        ),
        bbox_inches='tight',
        dpi=200
    )
    plt.show()

    # Plot MAP - Thresholded error
    fig = plt.figure(figsize=(5,5), dpi=200)
    axs = plt.gca()
    im_log = np.log10(np.abs(np_x - np_x_hat))
    plt_im = axs.imshow(im_log, cmap=cmap, vmin=vmin_log-2, vmax=0)
    divider = make_axes_locatable(axs)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    cbar = fig.colorbar(plt_im, cax=cax)
    cbar.ax.yaxis.set_major_formatter(tick.FormatStrFormatter('%.2f'))
    cbar.ax.tick_params(labelsize=cbar_font_size)
    axs.set_yticks([]);axs.set_xticks([])
    plt.tight_layout()
    plt.savefig(
        '{:s}{:s}{:s}{:s}'.format(
            save_dir, img_name, model_prefix, '-newPixelUQ-GT_MAP_error.pdf'
        ),
        bbox_inches='tight',
        dpi=200
    )
    plt.show()


    modif_img_list = []
    GT_modif_img_list = []
    SNR_at_lvl_list = []

    for modif_level in range(levels+1):

        op = lambda x1, x2: x2

        modif_img = oper2wavelet.full_op_two_img(
            torch.clone(torch_map),
            torch.clone(select_thresh_img),
            op,
            level=modif_level
        )
        GT_modif_img = oper2wavelet.full_op_two_img(
            torch.clone(torch_x),
            torch.clone(torch_map),
            op,
            level=modif_level
        )
        print('SNR at lvl {:d}: {:f}'.format(
            modif_level, luq.utils.eval_snr(to_numpy(torch_map), to_numpy(modif_img)))
        )
        modif_img_list.append(to_numpy(modif_img))
        GT_modif_img_list.append(to_numpy(GT_modif_img))
        SNR_at_lvl_list.append(luq.utils.eval_snr(to_numpy(torch_map), to_numpy(modif_img)))    

        # Plot MAP - Thresholded error
        fig = plt.figure(figsize=(5,5), dpi=200)
        axs = plt.gca()
        im_log = np.log10(np.abs(to_numpy(torch_map - modif_img)))
        plt_im = axs.imshow(im_log, cmap=cmap, vmin=vmin_log-2, vmax=0)
        divider = make_axes_locatable(axs)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cbar = fig.colorbar(plt_im, cax=cax)
        cbar.ax.yaxis.set_major_formatter(tick.FormatStrFormatter('%.2f'))
        cbar.ax.tick_params(labelsize=cbar_font_size)
        axs.set_yticks([]);axs.set_xticks([])
        plt.tight_layout()
        plt.savefig(
            '{:s}{:s}{:s}{:s}{:d}{:s}'.format(
                save_dir,
                img_name,
                model_prefix,
                '-newPixelUQ-MAP_thresholded_error_level_',
                modif_level,
                '.pdf'
            ),
            bbox_inches='tight',
            dpi=200
        )
        plt.show()

        # Plot GT - MAP error
        fig = plt.figure(figsize=(5,5), dpi=200)
        axs = plt.gca()
        im_log = np.log10(np.abs(np_x - to_numpy(GT_modif_img)))
        plt_im = axs.imshow(im_log, cmap=cmap, vmin=vmin_log-2, vmax=0)
        divider = make_axes_locatable(axs)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cbar = fig.colorbar(plt_im, cax=cax)
        cbar.ax.yaxis.set_major_formatter(tick.FormatStrFormatter('%.2f'))
        cbar.ax.tick_params(labelsize=cbar_font_size)
        axs.set_yticks([]);axs.set_xticks([])
        plt.tight_layout()
        plt.savefig(
            '{:s}{:s}{:s}{:s}{:d}{:s}'.format(
                save_dir,
                img_name,
                model_prefix,
                '-newPixelUQ-GT_MAP_error_level_',
                modif_level,
                '.pdf'
            ),
            bbox_inches='tight',
            dpi=200
        )
        plt.show()

    config_dict = {
        'sigma_training': sigma_training,
        't_model': t_model,
        'reg_param': reg_param,
        'mu': mu,
        'alpha_prob': alpha_prob,
        'wavs_list': wavs_list,
        'levels': levels,
        'start_interval': start_interval,
        'iters': iters,
        'tol': tol,
        'optim_options': options,
    }
    save_dict = {
        'gt': np_x,
        'map': np_x_hat,
        'thresholded_img': to_numpy(select_thresh_img),
        'map_thresh_error_at_level': np.array(modif_img_list),
        'gt_map_error_at_level': np.array(GT_modif_img_list),
        'SNR_at_level': np.array(SNR_at_lvl_list),
        'config_dict': config_dict,
    }

    # We will overwrite the dict with new results
    try:
        saving_var_path = '{:s}{:s}{:s}{:s}'.format(
                save_var_dir,
                img_name,
                model_prefix,
                '-new_pixel_UQ_vars.npy',
            )
        if os.path.isfile(saving_var_path):
            os.remove(saving_var_path)
        np.save(saving_var_path, save_dict, allow_pickle=True)
    except Exception as e:
        print('Could not save vairables. Exception caught: ', e)    


In [11]:

for it in range(len(img_name_arr)):
    # Set paths
    if model_prefix == '-CRR':
        img_name = img_name_arr[it]
        # map_vars_path = map_vars_path_arr[it]
        # samp_vars_path = samp_vars_path_arr[it]
        # vmin_log = vmin_log_arr[it]
        save_dir = CRR_save_dir
        save_var_dir = load_var_dir

    saving_var_path = '{:s}{:s}{:s}{:s}'.format(
        save_var_dir,
        img_name,
        model_prefix,
        '-new_pixel_UQ_vars.npy',
    )

    data = np.load(saving_var_path, allow_pickle=True)[()]

    print('\n\n', img_name)
    print('SNR (MAP wrt GT): \t\t', luq.utils.eval_snr(data['gt'], data['map']))
    print('SNR (thresholded wrt MAP): \t', luq.utils.eval_snr(data['map'], data['thresholded_img']))
    for modif_level in range(levels+1):
        print('SNR at lvl {:d}: \t\t\t{:.2f}'.format(
            modif_level, data['SNR_at_level'][modif_level]
        ))





 CYN
SNR (MAP wrt GT): 		 28.13
SNR (thresholded wrt MAP): 	 12.28
SNR at lvl 0: 			15.80
SNR at lvl 1: 			20.07
SNR at lvl 2: 			19.34
SNR at lvl 3: 			20.39
SNR at lvl 4: 			26.14


 M31
SNR (MAP wrt GT): 		 32.82
SNR (thresholded wrt MAP): 	 22.61
SNR at lvl 0: 			34.17
SNR at lvl 1: 			29.91
SNR at lvl 2: 			26.90
SNR at lvl 3: 			27.20
SNR at lvl 4: 			38.93


 3c288
SNR (MAP wrt GT): 		 25.89
SNR (thresholded wrt MAP): 	 23.55
SNR at lvl 0: 			31.88
SNR at lvl 1: 			30.49
SNR at lvl 2: 			27.91
SNR at lvl 3: 			29.44
SNR at lvl 4: 			39.76


 W28
SNR (MAP wrt GT): 		 26.59
SNR (thresholded wrt MAP): 	 15.57
SNR at lvl 0: 			28.14
SNR at lvl 1: 			23.04
SNR at lvl 2: 			21.21
SNR at lvl 3: 			20.38
SNR at lvl 4: 			23.46
