In [None]:
import os
import numpy as np
from functools import partial
import math
from tqdm import tqdm
import time as time

import torch
M1 = False

if M1:
    device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
else:
    os.environ["CUDA_VISIBLE_DEVICES"]="0"
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if torch.cuda.is_available():
        print(torch.cuda.is_available())
        print(torch.cuda.device_count())
        print(torch.cuda.current_device())
        print(torch.cuda.get_device_name(torch.cuda.current_device()))


from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from torchmetrics.functional import structural_similarity_index_measure 
from torchmetrics.functional import peak_signal_noise_ratio 

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.ticker as tick

import scipy.io as sio
from astropy.io import fits
import skimage as ski

import large_scale_UQ as luq
from large_scale_UQ.utils import to_numpy, to_tensor
from convex_reg import utils as utils_cvx_reg



In [None]:
repo_dir = '/disk/xray0/tl3/repos/large-scale-UQ'
CRR_save_dir = '/disk/xray0/tl3/outputs/large-scale-UQ/def_UQ_results/CRR/paper_figs_new_UQ/'
wav_save_dir = '/disk/xray0/tl3/outputs/large-scale-UQ/def_UQ_results/wavelets/paper_figs_new_UQ/'
load_var_dir = '/disk/xray0/tl3/outputs/large-scale-UQ/def_UQ_results/CRR/vars/'
wav_load_var_dir = '/disk/xray0/tl3/outputs/large-scale-UQ/def_UQ_results/wavelets/vars/'

cmap = 'cubehelix'

map_vars_path_arr = [
    load_var_dir+'CYN_CRR_UQ_MAP_lmbd_5.0e+04_MAP_vars.npy',
    load_var_dir+'M31_CRR_UQ_MAP_lmbd_5.0e+04_MAP_vars.npy',
    load_var_dir+'3c288_CRR_UQ_MAP_lmbd_5.0e+04_MAP_vars.npy',
    load_var_dir+'W28_CRR_UQ_MAP_lmbd_5.0e+04_MAP_vars.npy',
]
samp_vars_path_arr = [
    load_var_dir+'CYN_SKROCK_CRR_lmbd_5.0e+04_mu_2.0e+01_nsamples_5.0e+04_thinning_1.0e+01_vars.npy',
    load_var_dir+'M31_SKROCK_CRR_lmbd_5.0e+04_mu_2.0e+01_nsamples_5.0e+04_thinning_1.0e+01_vars.npy',
    load_var_dir+'3c288_SKROCK_CRR_lmbd_5.0e+04_mu_2.0e+01_nsamples_5.0e+04_thinning_1.0e+01_vars.npy',
    load_var_dir+'W28_SKROCK_CRR_lmbd_5.0e+04_mu_2.0e+01_nsamples_5.0e+04_thinning_1.0e+01_vars.npy',    
]
wav_map_vars_path_arr = [
    wav_load_var_dir+'CYN_wavelets_UQ_MAP_reg_param_5.0e+02_MAP_vars.npy',
    wav_load_var_dir+'M31_wavelets_UQ_MAP_reg_param_5.0e+02_MAP_vars.npy',
    wav_load_var_dir+'3c288_wavelets_UQ_MAP_reg_param_5.0e+02_MAP_vars.npy',
    wav_load_var_dir+'W28_wavelets_UQ_MAP_reg_param_5.0e+02_MAP_vars.npy',
]
wav_samp_vars_path_arr = [
    wav_load_var_dir+'CYN_SKROCK_wavelets_reg_param_5.0e+02_nsamples_5.0e+04_thinning_1.0e+01_vars.npy',
    wav_load_var_dir+'M31_SKROCK_wavelets_reg_param_5.0e+02_nsamples_5.0e+04_thinning_1.0e+01_vars.npy',
    wav_load_var_dir+'3c288_SKROCK_wavelets_reg_param_5.0e+02_nsamples_5.0e+04_thinning_1.0e+01_vars.npy',
    wav_load_var_dir+'W28_SKROCK_wavelets_reg_param_5.0e+02_nsamples_5.0e+04_thinning_1.0e+01_vars.npy',    
]

img_name_arr = [
    'CYN',
    'M31',
    '3c288',
    'W28',
]
vmin_log_arr = [
    -3.,
    -2.,
    -2.,
    -2.,
]



In [None]:
model_prefix = '-WAV' # '-WAV' # '-CRR'

for it in range(4):
    # Set paths
    if model_prefix == '-CRR':
        img_name = img_name_arr[it]
        map_vars_path = map_vars_path_arr[it]
        samp_vars_path = samp_vars_path_arr[it]
        vmin_log = vmin_log_arr[it]
        save_dir = CRR_save_dir
    elif model_prefix == '-WAV':
        img_name = img_name_arr[it]
        map_vars_path = wav_map_vars_path_arr[it]
        samp_vars_path = wav_samp_vars_path_arr[it]
        vmin_log = vmin_log_arr[it]
        save_dir = wav_save_dir

    # Load variables
    map_vars = np.load(map_vars_path, allow_pickle=True)[()]
    samp_vars = np.load(samp_vars_path, allow_pickle=True)[()]

    # Extract variables
    x_gt = samp_vars['X_ground_truth']
    x_dirty = samp_vars['X_dirty']
    x_map = samp_vars['X_MAP']
    x_mmse = samp_vars['X_MMSE']
    post_meanvar = samp_vars['post_meanvar']
    x_sampling_var = post_meanvar.get_var().detach().cpu().squeeze()
    # print(samp_vars.keys())

    # Extract UQ variables
    sampling_st_dev_arr = samp_vars['st_dev_down']
    sampling_quantiles_arr = samp_vars['quantiles']
    # gt_mean_img_arr = map_vars['gt_mean_img_arr']
    mean_img_arr = map_vars['mean_img_arr']
    error_p_arr = map_vars['error_p_arr']
    error_m_arr = map_vars['error_m_arr']
    superpix_sizes = map_vars['superpix_sizes']
    # clip_low_val = map_vars['LCI_params']['clip_low_val']
    # clip_high_val = map_vars['LCI_params']['clip_high_val']
    clip_low_val = 0.
    clip_high_val = 1.
    print(map_vars.keys())

    img, mat_mask = luq.helpers.load_imgs(img_name, repo_dir)
    gt_mean_img_arr = []
    for superpix_size in superpix_sizes:
        mean_image = ski.measure.block_reduce(
            np.copy(img), block_size=(superpix_size, superpix_size), func=np.mean
        )
        gt_mean_img_arr.append(mean_image)


    # Print point estimates
    # Need to replace zero values with veery small numbers for the log plots
    plot_x_gt = np.copy(x_gt)
    plot_x_gt[plot_x_gt==0] = np.random.rand(np.sum(plot_x_gt==0)) * 1e-7
    fig = plt.figure(figsize=(5,5), dpi=200)
    axs = plt.gca()
    plt_im = axs.imshow(np.log10(abs(plot_x_gt)), cmap=cmap, vmin=vmin_log, vmax=0)
    divider = make_axes_locatable(axs)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    cbar = fig.colorbar(plt_im, cax=cax)
    cbar.ax.yaxis.set_major_formatter(tick.FormatStrFormatter('%.1f'))
    axs.set_yticks([]);axs.set_xticks([])
    plt.tight_layout()
    plt.savefig(
        '{:s}{:s}{:s}{:s}'.format(save_dir, img_name, model_prefix, '-GroundTruth_image.pdf'),
        bbox_inches='tight',
        dpi=200
    )
    plt.show()

    fig = plt.figure(figsize=(5,5), dpi=200)
    axs = plt.gca()
    plt_im = axs.imshow(x_gt, cmap=cmap, vmin=0, vmax=1)
    divider = make_axes_locatable(axs)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    cbar = fig.colorbar(plt_im, cax=cax)
    cbar.ax.yaxis.set_major_formatter(tick.FormatStrFormatter('%.1f'))
    axs.set_yticks([]);axs.set_xticks([])
    plt.tight_layout()
    plt.savefig(
        '{:s}{:s}{:s}{:s}'.format(save_dir, img_name, model_prefix, '-GroundTruth_image_normalScale.pdf'),
        bbox_inches='tight',
        dpi=200
    )
    plt.show()

    # Compute SNR
    dirty_snr = luq.utils.eval_snr(x_gt, x_dirty)
    # Plot dirty reconstruction
    fig = plt.figure(figsize=(5,5), dpi=200)
    axs = plt.gca()
    plt_im = axs.imshow(np.log10(abs(x_dirty)), cmap=cmap, vmin=vmin_log, vmax=0)
    divider = make_axes_locatable(axs)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    cbar = fig.colorbar(plt_im, cax=cax)
    cbar.ax.yaxis.set_major_formatter(tick.FormatStrFormatter('%.1f'))
    axs.set_yticks([]);axs.set_xticks([])
    textstr = r'$\mathrm{SNR}=%.2f$ dB'%(np.mean(dirty_snr))
    props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
    axs.text(
        0.05, 0.95, textstr, transform=axs.transAxes,
        fontsize=12, verticalalignment='top', bbox=props
    )
    plt.tight_layout()
    plt.savefig(
        '{:s}{:s}{:s}{:s}'.format(save_dir, img_name, model_prefix, '-dirty_image.pdf'),
        bbox_inches='tight',
        dpi=200
    )
    plt.show()

    # Dirty reconstruction error
    fig = plt.figure(figsize=(5,5), dpi=200)
    axs = plt.gca()
    plt_im = axs.imshow(np.log10(np.abs(x_gt - x_dirty)), cmap=cmap, vmax=0, vmin=vmin_log-1)
    divider = make_axes_locatable(axs)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    cbar = fig.colorbar(plt_im, cax=cax)
    cbar.ax.yaxis.set_major_formatter(tick.FormatStrFormatter('%.1f'))
    axs.set_yticks([]);axs.set_xticks([])
    plt.tight_layout()
    plt.savefig(
        '{:s}{:s}{:s}{:s}'.format(save_dir, img_name, model_prefix, '-dirty-recon_error_image.pdf'),
        bbox_inches='tight',
        dpi=200
    )
    plt.show()

    # MAP error
    fig = plt.figure(figsize=(5,5), dpi=200)
    axs = plt.gca()
    plt_im = axs.imshow(np.log10(np.abs(x_gt - x_map)), cmap=cmap, vmax=0, vmin=vmin_log-1)
    divider = make_axes_locatable(axs)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    cbar = fig.colorbar(plt_im, cax=cax)
    cbar.ax.yaxis.set_major_formatter(tick.FormatStrFormatter('%.1f'))
    axs.set_yticks([]);axs.set_xticks([])
    plt.tight_layout()
    plt.savefig(
        '{:s}{:s}{:s}{:s}'.format(save_dir, img_name, model_prefix, '-error_image.pdf'),
        bbox_inches='tight',
        dpi=200
    )
    plt.show()

    # Compute SNR
    map_snr = luq.utils.eval_snr(x_gt, x_map)
    # Plot MAP
    fig = plt.figure(figsize=(5,5), dpi=200)
    axs = plt.gca()
    plt_im = axs.imshow(np.log10(np.abs(x_map)), cmap=cmap, vmin=vmin_log, vmax=0)
    divider = make_axes_locatable(axs)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    cbar = fig.colorbar(plt_im, cax=cax)
    cbar.ax.yaxis.set_major_formatter(tick.FormatStrFormatter('%.1f'))
    axs.set_yticks([]);axs.set_xticks([])
    textstr = r'$\mathrm{SNR}=%.2f$ dB'%(np.mean(map_snr))
    props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
    axs.text(
        0.05, 0.95, textstr, transform=axs.transAxes,
        fontsize=12, verticalalignment='top', bbox=props
    )
    plt.tight_layout()
    plt.savefig(
        '{:s}{:s}{:s}{:s}'.format(save_dir, img_name, model_prefix, '-MAP_image.pdf'),
        bbox_inches='tight',
        dpi=200
    )
    plt.show()

    # Compute SNR
    mmse_snr = luq.utils.eval_snr(x_gt, x_mmse)
    # Plot MMSE
    fig = plt.figure(figsize=(5,5), dpi=200)
    axs = plt.gca()
    plt_im = axs.imshow(np.log10(np.abs(x_mmse)), cmap=cmap, vmin=vmin_log, vmax=0)
    divider = make_axes_locatable(axs)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    cbar = fig.colorbar(plt_im, cax=cax)
    cbar.ax.yaxis.set_major_formatter(tick.FormatStrFormatter('%.1f'))
    axs.set_yticks([]);axs.set_xticks([])
    textstr = r'$\mathrm{SNR}=%.2f$ dB'%(np.mean(mmse_snr))
    props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
    axs.text(
        0.05, 0.95, textstr, transform=axs.transAxes,
        fontsize=12, verticalalignment='top', bbox=props
    )
    plt.tight_layout()
    plt.savefig(
        '{:s}{:s}{:s}{:s}'.format(save_dir, img_name, model_prefix, '-MMSE_image.pdf'),
        bbox_inches='tight',
        dpi=200
    )
    plt.show()

    fig = plt.figure(figsize=(5,5), dpi=200)
    axs = plt.gca()
    plt_im = axs.imshow(np.log10(np.abs(np.sqrt(x_sampling_var))), cmap=cmap)#, vmin=-3,vmax=0)
    divider = make_axes_locatable(axs)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    cbar = fig.colorbar(plt_im, cax=cax)
    cbar.ax.yaxis.set_major_formatter(tick.FormatStrFormatter('%.1f'))
    axs.set_yticks([]);axs.set_xticks([])
    plt.tight_layout()
    plt.savefig(
        '{:s}{:s}{:s}{:s}'.format(save_dir, img_name, model_prefix, '-sampling_StDev_image.pdf'),
        bbox_inches='tight',
        dpi=200
    )
    plt.show()

    # Select vmin and vmax for the plots
    vmin_log_LCI_length_withoutMean_arr = []
    vmax_log_LCI_length_withoutMean_arr = []
    vmin_log_stDev_arr = []
    vmax_log_stDev_arr = []
    for samp_st_dev_im, sampling_quantiles, mean_im, gt_mean_im, error_p, error_m, pix_size in zip(
        sampling_st_dev_arr, sampling_quantiles_arr, mean_img_arr, gt_mean_img_arr, error_p_arr, error_m_arr, superpix_sizes
    ):
        if pix_size <= 8:
            # Clip LCI values
            LCI_length = luq.utils.clip_matrix(
                np.copy(error_p), clip_low_val, clip_high_val
            ) - luq.utils.clip_matrix(
                np.copy(error_m), clip_low_val, clip_high_val
            )

            vmin_log_LCI_length_withoutMean_arr.append(np.min(np.log10(np.abs(LCI_length - np.mean(LCI_length)))))
            vmax_log_LCI_length_withoutMean_arr.append(np.max(np.log10(np.abs(LCI_length - np.mean(LCI_length)))))
            vmin_log_stDev_arr.append(np.min(np.log10(np.abs(samp_st_dev_im))))
            vmax_log_stDev_arr.append(np.max(np.log10(np.abs(samp_st_dev_im))))

    vmin_vmax_LCI = [np.min(vmin_log_LCI_length_withoutMean_arr), np.max(vmax_log_LCI_length_withoutMean_arr)]
    vmin_vmax_StDev = [np.min(vmin_log_stDev_arr), np.max(vmax_log_stDev_arr)]



    # Print UQ maps
    for samp_st_dev_im, sampling_quantiles, mean_im, gt_mean_im, error_p, error_m, pix_size in zip(
        sampling_st_dev_arr, sampling_quantiles_arr, mean_img_arr, gt_mean_img_arr, error_p_arr, error_m_arr, superpix_sizes
    ):
        if pix_size <= 8:

            # Clip LCI values
            LCI_length = luq.utils.clip_matrix(
                np.copy(error_p), clip_low_val, clip_high_val
            ) - luq.utils.clip_matrix(
                np.copy(error_m), clip_low_val, clip_high_val
            )

            sampling_LCI = sampling_quantiles[1,:,:] - sampling_quantiles[0,:,:]

            vmax_mean = np.max((np.log10(np.abs(mean_im)), np.log10(np.abs(LCI_length))))
            vmin_mean = np.min((np.log10(np.abs(mean_im)), np.log10(np.abs(LCI_length))))

            fig = plt.figure(figsize=(5,5), dpi=200)
            axs = plt.gca()
            plt_im = axs.imshow(np.log10(np.abs(mean_im)), cmap=cmap)
            divider = make_axes_locatable(axs)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            cbar = fig.colorbar(plt_im, cax=cax)
            cbar.ax.yaxis.set_major_formatter(tick.FormatStrFormatter('%.2f'))
            axs.set_yticks([]);axs.set_xticks([])
            plt.tight_layout()
            plt.savefig(
                '{:s}{:s}{:s}{:s}{:d}{:s}'.format(save_dir, img_name, model_prefix, '-LCI_pixSize_', pix_size,'_meanMAP.pdf'),
                bbox_inches='tight',
                dpi=200
            )
            plt.show()

            fig = plt.figure(figsize=(5,5), dpi=200)
            if vmin_vmax_LCI[0] > -4:
                vmin_LCI = vmin_vmax_LCI[0]
            else:
                vmin_LCI = -4
            axs = plt.gca()
            plt_im = axs.imshow(np.log10(np.abs(LCI_length - np.mean(LCI_length))), cmap=cmap, vmin=vmin_LCI, vmax=vmin_vmax_LCI[1])
            divider = make_axes_locatable(axs)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            cbar = fig.colorbar(plt_im, cax=cax)
            cbar.ax.yaxis.set_major_formatter(tick.FormatStrFormatter('%.2f'))
            axs.set_yticks([]);axs.set_xticks([])
            textstr = r'$\mathrm{mean}=%.4f$'%(np.mean(LCI_length))
            props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
            axs.text(
                0.05, 0.95, textstr, transform=axs.transAxes,
                fontsize=12, verticalalignment='top', bbox=props
            )
            plt.tight_layout()
            plt.savefig(
                '{:s}{:s}{:s}{:s}{:d}{:s}'.format(save_dir, img_name, model_prefix, '-LCI_pixSize_', pix_size,'_LCI-meanLCI.pdf'),
                bbox_inches='tight',
                dpi=200
            )
            plt.show()

            fig = plt.figure(figsize=(5,5), dpi=200)
            axs = plt.gca()
            im_error = np.log10(np.abs(gt_mean_im - mean_im))
            plt_im = axs.imshow(im_error, cmap=cmap)
            divider = make_axes_locatable(axs)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            cbar = fig.colorbar(plt_im, cax=cax)
            cbar.ax.yaxis.set_major_formatter(tick.FormatStrFormatter('%.2f'))
            axs.set_yticks([]);axs.set_xticks([])
            plt.tight_layout()
            plt.savefig(
                '{:s}{:s}{:s}{:s}{:d}{:s}'.format(save_dir, img_name, model_prefix, '-LCI_pixSize_', pix_size,'_error.pdf'),
                bbox_inches='tight',
                dpi=200
            )
            plt.show()

            fig = plt.figure(figsize=(5,5), dpi=200)
            axs = plt.gca()
            plt_im = axs.imshow(np.log10(np.abs(sampling_LCI)), cmap=cmap)
            divider = make_axes_locatable(axs)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            cbar = fig.colorbar(plt_im, cax=cax)
            cbar.ax.yaxis.set_major_formatter(tick.FormatStrFormatter('%.2f'))
            axs.set_yticks([]);axs.set_xticks([])
            plt.tight_layout()
            plt.savefig(
                '{:s}{:s}{:s}{:s}{:d}{:s}'.format(save_dir, img_name, model_prefix, '-LCI_pixSize_', pix_size,'_sampling_LCI.pdf'),
                bbox_inches='tight',
                dpi=200
            )
            plt.show()

            fig = plt.figure(figsize=(5,5), dpi=200)
            axs = plt.gca()
            plt_im = axs.imshow(np.log10(np.abs(samp_st_dev_im)), cmap=cmap, vmin=vmin_vmax_StDev[0], vmax=vmin_vmax_StDev[1])
            divider = make_axes_locatable(axs)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            cbar = fig.colorbar(plt_im, cax=cax)
            cbar.ax.yaxis.set_major_formatter(tick.FormatStrFormatter('%.2f'))
            axs.set_yticks([]);axs.set_xticks([])
            plt.tight_layout()
            plt.savefig(
                '{:s}{:s}{:s}{:s}{:d}{:s}'.format(save_dir, img_name, model_prefix, '-LCI_pixSize_', pix_size,'_sampling_StDev.pdf'),
                bbox_inches='tight',
                dpi=200
            )
            plt.show()
