In [2]:
import sys
import os
import time
from functools import partial
import numpy as np
from astropy.io import fits
import scipy.io as sio

import torch
M1 = False

if M1:
    device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
else:
    os.environ["CUDA_VISIBLE_DEVICES"]="2"
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if torch.cuda.is_available():
        print(torch.cuda.is_available())
        print(torch.cuda.device_count())
        print(torch.cuda.current_device())
        print(torch.cuda.get_device_name(torch.cuda.current_device()))


import large_scale_UQ as luq

from tqdm import tqdm

from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from large_scale_UQ.utils import to_numpy, to_tensor

from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable


True
1
0
NVIDIA A100-PCIE-40GB


In [3]:
options = {"tol": 1e-5, "iter": 5000, "update_iter": 50, "record_iters": False}


In [6]:

repo_dir = '/disk/xray0/tl3/repos/large-scale-UQ'
save_dir = repo_dir + '/notebooks/SAPG/output/'
savefig_dir = repo_dir + '/notebooks/SAPG/figs/'

# optimization settings
wavs =  ["db8"]# ["db1", "db4"]                                     # Wavelet dictionaries to combine
levels = 4 # 3                                               # Wavelet levels to consider [1-6]
reg_param = 2.e-3
img_name = 'M31'

# Saving names
save_name = '{:s}_256_wavelet_SAPG-{:s}_{:d}_reg_{:.1f}'.format(
    img_name, wavs[0], levels, reg_param
)


# Load img
img_path = repo_dir + '/data/imgs/{:s}.fits'.format(img_name)
img_data = fits.open(img_path, memmap=False)

# Loading the image and cast it to float
img = np.copy(img_data[0].data)[0,:,:].astype(np.float64)
# Flipping data
img = np.flipud(img)

# Aliases
x = img
ground_truth = img
 
 
# Load op from X Cai
op_mask = sio.loadmat(repo_dir + '/data/operators_masks/fourier_mask.mat')['Ma']

# Matlab's reshape works with 'F'-like ordering
mat_mask = np.reshape(np.sum(op_mask, axis=0), (256,256), order='F').astype(bool)

# Define my torch types
myType = torch.float64
myComplexType = torch.complex128

torch_img = torch.tensor(np.copy(img), dtype=myType, device=device).reshape((1,1) + img.shape)
dim = img.shape[0]

# A mock radio imaging forward model with half of the Fourier coefficients masked
# Use X. Cai's Fourier mask
phi = luq.operators.MaskedFourier_torch(
    dim=dim, 
    ratio=0.5 ,
    mask=mat_mask,
    norm='ortho',
    device=device
)

# Define X Cai noise level
sigma = 0.0024
sigma2 = sigma**2

sigma_GT = np.copy(sigma)
sigma2_GT = np.copy(sigma2)

y = phi.dir_op(torch_img).detach().cpu().squeeze().numpy()

# Generate noise
rng = np.random.default_rng(seed=0)
n = rng.normal(0, sigma, y[y!=0].shape)
# Add noise
y[y!=0] += n

# Observation
torch_y = torch.tensor(np.copy(y), device=device, dtype=myComplexType).reshape((1,) + img.shape)
x_init = torch.abs(phi.adj_op(torch_y))



INSTRUME                                                                         [astropy.io.fits.card]


In [10]:
sigma2

5.759999999999999e-06

In [12]:

# Define the likelihood
g = luq.operators.L2Norm_torch(
    sigma=sigma,
    data=torch_y,
    Phi=phi,
)
# g.beta = 1.0 / sigma ** 2

# Define real prox
f = luq.operators.RealProx_torch()


# Define the wavelet dict
# Define the l1 norm with dict psi
# gamma = torch.max(torch.abs(psi.dir_op(y_torch))) * reg_param
psi = luq.operators.DictionaryWv_torch(wavs, levels)

h = luq.operators.L1Norm_torch(1., psi, op_to_coeffs=True)
gamma = h._get_max_abs_coeffs(h.dir_op(torch.clone(x_init))) * reg_param
h.gamma = gamma
h.beta = 1.0


In [13]:

# Define sigma bounds
min_sigma2 = torch.tensor(1e-10, device=device, dtype=myType)
max_sigma2 = torch.tensor(1e1, device=device, dtype=myType)
sigma2_init = torch.tensor(1e-2, device=device, dtype=myType)



In [14]:
# Negative log-likelihood -logp(t|x,sigma^2)
f = lambda _x, sigma2: g.fun(_x, sigma2=sigma2)

# --- Gradient w.r.t. sigma^2
dimx = x.size
alpha_homogenious = 1
df_wrt_sigma = lambda _x, sigma2: g.grad_sigma2(_x, sigma2=sigma2) - dimx / (alpha_homogenious * sigma2)
# Note: The second part corresponds to the normalisation constant Z of the posterior

# --- Gradient w.r.t. x
df_wrt_x = lambda _x, sigma2: g.grad(_x, sigma2=sigma2)


In [15]:
# Define prior
fun_prior = lambda _x : h._fun_coeffs(h.dir_op(torch.clone(_x)))
sub_op = lambda _x1, _x2 : _x1 - _x2

# proximity operator
prox_prior_cai = lambda _x, lmbd : torch.clone(_x) + h.adj_op(h._op_to_two_coeffs(
    h.prox(h.dir_op(torch.clone(_x)), lmbd),
    h.dir_op(torch.clone(_x)), sub_op
))
prox_prior = lambda _x, lmbd : h.adj_op(h.prox(h.dir_op(torch.clone(_x)), lmbd))

# gradient of the prior
gradg = lambda x, lam, lambda_prox: (x - prox_prior_cai(x,lam)) / lambda_prox    


In [None]:
# Log of posterior distribution
logPi = lambda _x, sigma2, theta: (- f(_x, sigma2) - theta * fun_prior(_x))


In [None]:

## Lipschitz Constants

# --- Maximum eigenvalue of operator A. Norm of blurring operator.
AAt_norm = max_eigenval(A, AT, nx, 1e-4, int(1e4), 0) 

# Lipshcitz constant of f.
Lp_fun = lambda sigma2: AAt_norm**2 / sigma2  
L_f =  min(Lp_fun(min_sigma2), Lp_fun(max_sigma2))  

# --- regularization parameter of proximity operator (\lambda).
lambdaMax = 2
lambda_prox = min((5/L_f), lambdaMax)   
# --- end

# --- Lipshcitz constant of g.
L_g =  1/lambda_prox 

# --- Lipshcitz constant of g + f
L =  L_f + L_g
# --- end

# --- Stepsize of MCMC algorithm.
gamma = 0.98*1/L
# --- end


In [16]:

g._compute_lip_constant()
