In [1]:
import os
import numpy as np
from functools import partial
import math
from tqdm import tqdm
import time as time

import torch
M1 = False

if M1:
    device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
else:
    os.environ["CUDA_VISIBLE_DEVICES"]="0"
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if torch.cuda.is_available():
        print(torch.cuda.is_available())
        print(torch.cuda.device_count())
        print(torch.cuda.current_device())
        print(torch.cuda.get_device_name(torch.cuda.current_device()))

from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim


import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.patches as patches
import matplotlib.ticker as tick

import skimage as ski
import scipy as sp
import large_scale_UQ as luq
from large_scale_UQ.utils import to_numpy, to_tensor
from convex_reg import utils as utils_cvx_reg


True
1
0
NVIDIA A100-PCIE-40GB
Using device: cuda


In [2]:
# Save param
repo_dir = '/disk/xray0/tl3/repos/large-scale-UQ'

# save_dir = '/disk/xray0/tl3/outputs/large-scale-UQ/def_UQ_results/wavelets/hypothesis_test_paper_figs/'
# load_var_dir = '/disk/xray0/tl3/outputs/large-scale-UQ/def_UQ_results/wavelets/vars/'

save_dir = '/disk/xray0/tl3/outputs/large-scale-UQ/def_UQ_results/v2/wavelets/hypothesis_test_paper_figs/'
load_var_dir = '/disk/xray99/tl3/proj-convex-UQ/outputs/new_UQ_results/wavelets/vars/'


# Confidence value
alpha_prob = 0.01
# Blurring Gaussian St Dev
G_sigma = 3.5 # 1.02

# Inpatinting params
inptaint_options = {
    "tol": 5e-6,
    "iter": 15000,
    "update_iter": 4999,
    "record_iters": False
}

map_potential_list = []
likelihood_map_potential_list = []
prior_map_potential_list = []
surrogate_potential_list = []
likelihood_surrogate_potential_list = []
prior_surrogate_potential_list = []
gamma_alpha_list = []
Hnot_reject_list = []
potential_blurring_list = []
Hnot_reject_blurring_list = []
SNR_list = []
PSNR_list = []

cmap = 'cubehelix'
model_prefix = '-CRR'
input_snr = 30.
cbar_font_size = 18
box_font_size = 18


# CRR load parameters
sigma_training = 5
t_model = 5
CRR_dir_name = '/disk/xray0/tl3/repos/convex_ridge_regularizers/trained_models/'
# CRR parameters
reg_params = [5e4]
mu = 20

# Wavelet parameters
# reg_param = 1e4  # 5e2 #
wavs_list = ['db8']
levels = 4

map_vars_path_arr = [
    load_var_dir+'CYN_wavelets_UQ_MAP_reg_param_5.0e+02_MAP_vars.npy',
    load_var_dir+'M31_wavelets_UQ_MAP_reg_param_5.0e+02_MAP_vars.npy',
    load_var_dir+'3c288_wavelets_UQ_MAP_reg_param_5.0e+02_MAP_vars.npy',
    load_var_dir+'3c288_wavelets_UQ_MAP_reg_param_5.0e+02_MAP_vars.npy',
    load_var_dir+'W28_wavelets_UQ_MAP_reg_param_5.0e+02_MAP_vars.npy',
]
samp_vars_path_arr = [
    load_var_dir+'CYN_SKROCK_wavelets_reg_param_5.0e+02_nsamples_5.0e+04_thinning_1.0e+01_vars.npy',
    load_var_dir+'M31_SKROCK_wavelets_reg_param_5.0e+02_nsamples_5.0e+04_thinning_1.0e+01_vars.npy',
    load_var_dir+'3c288_SKROCK_wavelets_reg_param_5.0e+02_nsamples_5.0e+04_thinning_1.0e+01_vars.npy',
    load_var_dir+'3c288_SKROCK_wavelets_reg_param_5.0e+02_nsamples_5.0e+04_thinning_1.0e+01_vars.npy',
    load_var_dir+'W28_SKROCK_wavelets_reg_param_5.0e+02_nsamples_5.0e+04_thinning_1.0e+01_vars.npy',    
]
img_name_list = ['CYN', 'M31', '3c288', '3c288','W28']
pysiscal_list = [True, True, True, False, True]
vmin_log_arr = [-3., -2., -2., -2., -2.]
text_str_arr = [r'$1$',r'$1$',r'$1$',r'$2$',r'$1$']
saving_text_str_arr = ['1', '1', '1', '2', '1']
text_pos_arr = [
    [0, 0.12],
    [0, 0.06],
    [0, 0.06],
    [-0.05, -0.01],
    [0, 0.06],
]


# %%
# Load CRR model
torch.set_grad_enabled(False)
torch.set_num_threads(4)

exp_name = f'Sigma_{sigma_training}_t_{t_model}/'
model = utils_cvx_reg.load_model(CRR_dir_name + exp_name, 'cuda:0', device_type='gpu')

print(f'Numbers of parameters before prunning: {model.num_params}')
model.prune()
print(f'Numbers of parameters after prunning: {model.num_params}')

# L_CRR = model.L.detach().cpu().squeeze().numpy()
# print(f"Lipschitz bound {L_CRR:.3f}")

# [not required] intialize the eigen vector of dimension (size, size) associated to the largest eigen value
model.initializeEigen(size=100)
# compute bound via a power iteration which couples the activations and the convolutions
model.precise_lipschitz_bound(n_iter=100)
# the bound is stored in the model
L_CRR = model.L.data.item()
print(f"Lipschitz bound {L_CRR:.3f}")



save_results = False



--- loading checkpoint from epoch 10 ---
---------------------
Building a CRR-NN model with 
 - [1, 8, 32] channels 
 - linear_spline activation functions
  (LinearSpline(mode=conv, num_activations=32, init=zero, size=21, grid=0.010, monotonic_constraint=True.))
---------------------
Numbers of parameters before prunning: 13610
---------------------
 PRUNNING 
 Found 22 filters with non-vanishing potential functions
---------------------
Numbers of parameters after prunning: 4183
Lipschitz bound 0.770


In [3]:
reg_param_list = np.logspace(np.log10(100), np.log10(1e6), num=30, endpoint=True, base=10.0)
# reg_param_list[0] = 7e1 
# reg_param_list[1] = 8e1 
# reg_param_list[2] = 9e1
# reg_param_list[3] = 1e2
# reg_param_list[4] = 1.1e2
# reg_param_list[25] = 1e4


In [4]:
reg_param_list

array([1.00000000e+02, 1.37382380e+02, 1.88739182e+02, 2.59294380e+02,
       3.56224789e+02, 4.89390092e+02, 6.72335754e+02, 9.23670857e+02,
       1.26896100e+03, 1.74332882e+03, 2.39502662e+03, 3.29034456e+03,
       4.52035366e+03, 6.21016942e+03, 8.53167852e+03, 1.17210230e+04,
       1.61026203e+04, 2.21221629e+04, 3.03919538e+04, 4.17531894e+04,
       5.73615251e+04, 7.88046282e+04, 1.08263673e+05, 1.48735211e+05,
       2.04335972e+05, 2.80721620e+05, 3.85662042e+05, 5.29831691e+05,
       7.27895384e+05, 1.00000000e+06])

In [5]:

# for it_img in range(len(img_name_list)):

for reg_param in reg_param_list:

    # Prior parameters
    lmbd = reg_param
    # reg_param = 2e2
    it_img = 1  # M31
    img_name = img_name_list[it_img]
    map_vars_path = map_vars_path_arr[it_img]
    samp_vars_path = samp_vars_path_arr[it_img]
    vmin_log = vmin_log_arr[it_img]

    text_pos = text_pos_arr[it_img]
    textstr = text_str_arr[it_img]
    saving_text_str = saving_text_str_arr[it_img]

    # Load variables
    map_vars = np.load(map_vars_path, allow_pickle=True)[()]
    samp_vars = np.load(samp_vars_path, allow_pickle=True)[()]

    # Load image and mask
    img, mat_mask = luq.helpers.load_imgs(img_name, repo_dir)

    # Define my torch types (CRR requires torch.float32)
    myType = torch.float32
    myComplexType = torch.complex64


    # Aliases
    x = img
    ground_truth = img
    # Prepare inputs and functions
    torch_img = torch.tensor(
        np.copy(img), dtype=myType, device=device).reshape((1,1) + img.shape
    )
    phi = luq.operators.MaskedFourier_torch(
        shape=img.shape, 
        ratio=0.5 ,
        mask=mat_mask,
        norm='ortho',
        device=device
    )
    y = phi.dir_op(torch_img).detach().cpu().squeeze().numpy()

    # Define X Cai noise level
    eff_sigma = luq.helpers.compute_complex_sigma_noise(y, input_snr)
    sigma = eff_sigma * np.sqrt(2)

    # Generate noise
    rng = np.random.default_rng(seed=0)
    n_re = rng.normal(0, eff_sigma, y[y!=0].shape)
    n_im = rng.normal(0, eff_sigma, y[y!=0].shape)
    # Add noise
    y[y!=0] += (n_re + 1.j*n_im)

    # Observation
    torch_y = torch.tensor(np.copy(y), device=device, dtype=myComplexType).reshape((1,) + img.shape)
    x_init = torch.abs(phi.adj_op(torch_y))

    # %%
    # Define the likelihood
    g = luq.operators.L2Norm_torch(
        sigma=sigma,
        data=torch_y,
        Phi=phi,
    )
    # Lipschitz constant computed automatically by g, stored in g.beta

    # Define real prox
    f = luq.operators.RealProx_torch()


    # Optimisation options for the MAP estimation
    # options = {"tol": 1e-4, "iter": 500, "update_iter": 4999, "record_iters": False}
    options = {"tol": 1e-5, "iter": 15000, "update_iter": 4999, "record_iters": False}

    # Compute stepsize
    alpha = 0.98 / (g.beta + mu * lmbd * L_CRR)

    # initialization
    x_hat = torch.clone(x_init)
    z = torch.clone(x_init)
    t = 1

    for it_2 in range(options['iter']):
        x_hat_old = torch.clone(x_hat)
        # grad = g.grad(z.squeeze()) +  lmbd * model(mu * z)
        x_hat = z - alpha *(
            g.grad(z) + lmbd * model(mu * z)
        )
        # Reality constraint
        x_hat = f.prox(x_hat)
        # Positivity constraint
        # x = torch.clamp(x, 0, None)
        
        t_old = t 
        t = 0.5 * (1 + math.sqrt(1 + 4*t**2))
        z = x_hat + (t_old - 1)/t * (x_hat - x_hat_old)

        # relative change of norm for terminating
        res = (torch.norm(x_hat_old - x_hat)/torch.norm(x_hat_old)).item()

        if res < options['tol']:
            print("[GD] converged in %d iterations"%(it_2))
            break

        if it_2 % options['update_iter'] == 0:
            print(
                "[GD] %d out of %d iterations, tol = %f" %(            
                    it_2,
                    options['iter'],
                    res,
                )
            )

    # %%
    np_x_init = to_numpy(x_init)
    np_x = np.copy(x)
    np_x_hat = to_numpy(x_hat)


    print('\nMAP reg_param: ', reg_param)
    print('Image: ', img_name)
    print('PSNR: {},\n SNR: {}, SSIM: {}'.format(
        round(psnr(np_x, np_x_hat, data_range=np_x.max()-np_x.min()), 2),
        round(luq.utils.eval_snr(np_x, np_x_hat), 2),
        round(ssim(np_x, np_x_hat, data_range=np_x.max()-np_x.min()), 2),
    ))

    SNR_list.append(luq.utils.eval_snr(np_x, np_x_hat))
    PSNR_list.append(psnr(np_x, np_x_hat, data_range=np_x.max()-np_x.min()))


    # Extract variables
    x_gt = np.copy(x) #  samp_vars['X_ground_truth']
    x_dirty = np.copy(np_x_init)  # samp_vars['X_dirty']
    x_map = np.copy(np_x_hat)  # samp_vars['X_MAP']



    # # Compute SNR
    # map_snr = luq.utils.eval_snr(x_gt, x_map)
    # # Plot MAP
    # fig = plt.figure(figsize=(5,5), dpi=200)
    # axs = plt.gca()
    # plt_im = axs.imshow(np.log10(np.abs(x_map)), cmap=cmap, vmin=vmin_log, vmax=0)
    # divider = make_axes_locatable(axs)
    # cax = divider.append_axes("right", size="5%", pad=0.05)
    # cbar = fig.colorbar(plt_im, cax=cax)
    # cbar.ax.yaxis.set_major_formatter(tick.FormatStrFormatter('%.1f'))
    # cbar.ax.tick_params(labelsize=cbar_font_size)
    # axs.set_yticks([]);axs.set_xticks([])
    # textstr = r'$\mathrm{SNR}=%.2f$ dB'%(np.mean(map_snr))
    # props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
    # axs.text(
    #     0.05, 0.95, textstr, transform=axs.transAxes,
    #     fontsize=box_font_size, verticalalignment='top', bbox=props
    # )
    # plt.tight_layout()
    # plt.show()


    # To tensor
    x_map_torch = to_tensor(x_map)

    # Compute stepsize
    alpha = 0.98 / g.beta

    x_map_torch = to_tensor(x_map)


    #function handles for the hypothesis test

    #function handles to used for ULA
    def _fun(_x, model, mu, lmbd):
        return (lmbd / mu) * model.cost(mu * _x) + g.fun(_x)

    def _grad_fun(_x, g, model, mu, lmbd):
        return  torch.real(g.grad(_x) + lmbd * model(mu * _x))
    
    def _prior_fun(_x, model, mu, lmbd):
        return (lmbd / mu) * model.cost(mu * _x)

    # Evaluation of the potentials
    fun = partial(_fun, model=model, mu=mu, lmbd=lmbd)
    prior_fun = partial(_prior_fun, model=model, mu=mu, lmbd=lmbd)
    # Evaluation of the gradient
    grad_f = partial(_grad_fun, g=g, model=model, mu=mu, lmbd=lmbd)
    # Evaluation of the potential in numpy
    fun_np = lambda _x : fun(luq.utils.to_tensor(_x, dtype=myType)).item()


    # Compute HPD region bound
    N = x_map.size
    tau_alpha = np.sqrt(16*np.log(3/alpha_prob))
    gamma_alpha = fun(x_map_torch).item() + tau_alpha*np.sqrt(N) + N

    print('gamma_alpha: ', gamma_alpha)
    print('fun(x_map).item(): ', fun(x_map_torch).item())
    print('tau_alpha*np.sqrt(N) + N: ', tau_alpha*np.sqrt(N) + N)

    # Compute potential
    map_potential = fun(x_map_torch).item()

    # Decompose potentials
    map_likelihood_potential = g.fun(x_map_torch).item()
    map_prior_potential = prior_fun(x_map_torch).item()

    # Print values
    print(img_name, '_gamma_alpha: ', gamma_alpha)
    print(img_name, '-MAP_potential: ', map_potential)
    # Save values
    map_potential_list.append(map_potential)
    gamma_alpha_list.append(gamma_alpha)

    # Save decomposed potentials
    likelihood_map_potential_list.append(map_likelihood_potential)
    prior_map_potential_list.append(map_prior_potential)




save_dict = {
    'alpha_prob': alpha_prob,
    'reg_param_list': reg_param_list,
    'SNR_list': SNR_list,
    'PSNR_list': PSNR_list,
    'map_potential_list': map_potential_list,
    'likelihood_map_potential_list': likelihood_map_potential_list,
    'prior_map_potential_list': prior_map_potential_list,
    'optim_options': options,
}


# Save variables
# # if save_results:
# try:
#     save_path = '{:s}{:s}{:s}{:s}'.format(
#         save_dir, 'test_reg_stregth', model_prefix, '_vars.npy'
#     )
#     if os.path.isfile(save_path):
#         os.remove(save_path)
#     np.save(save_path, save_dict, allow_pickle=True)

# except Exception as e:
#     print('Could not save vairables. Exception caught: ', e)




INSTRUME                                                                         [astropy.io.fits.card]


[GD] 0 out of 15000 iterations, tol = 0.386035
[GD] 4999 out of 15000 iterations, tol = 0.000012
[GD] converged in 5137 iterations

MAP reg_param:  100.0
Image:  M31
PSNR: 48.97,
 SNR: 26.53, SSIM: 0.98
-----------------------
Updating spline coefficients for the reg cost
 (the gradient-step model is trained and intergration is required to compute the regularization cost)
-----------------------
gamma_alpha:  68077.71238599616
fun(x_map).item():  96.1348648071289
tau_alpha*np.sqrt(N) + N:  67981.57752118903
M31 _gamma_alpha:  68077.71238599616
M31 -MAP_potential:  96.1348648071289
[GD] 0 out of 15000 iterations, tol = 0.379874
[GD] converged in 4443 iterations

MAP reg_param:  137.38237958832624
Image:  M31
PSNR: 48.99,
 SNR: 26.55, SSIM: 0.98
gamma_alpha:  68113.43843732672
fun(x_map).item():  131.8609161376953
tau_alpha*np.sqrt(N) + N:  67981.57752118903
M31 _gamma_alpha:  68113.43843732672
M31 -MAP_potential:  131.8609161376953
[GD] 0 out of 15000 iterations, tol = 0.371724
[GD] con