In [1]:
import os
import numpy as np
from functools import partial
import math
import time as time

import torch
M1 = False

if M1:
    device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
else:
    os.environ["CUDA_VISIBLE_DEVICES"]="2"
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if torch.cuda.is_available():
        print(torch.cuda.is_available())
        print(torch.cuda.device_count())
        print(torch.cuda.current_device())
        print(torch.cuda.get_device_name(torch.cuda.current_device()))



import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

import large_scale_UQ as luq
from large_scale_UQ.utils import to_numpy, to_tensor
from convex_reg import utils as utils_cvx_reg



True
1
0
NVIDIA A100-PCIE-40GB
Using device: cuda


In [2]:
img_name = 'W28' # 'M31'
# Input noise level
input_snr = 30.

# Benchmark wavelet-based model optimisation


In [12]:
# Optimisation options for the MAP estimation
options = {"tol": 1e-5, "iter": 15000, "update_iter": 4999, "record_iters": False}
# Save param
repo_dir = '/disk/xray0/tl3/repos/large-scale-UQ'


# Define my torch types (CRR requires torch.float32, wavelets require torch.float64)
myType = torch.float64
myComplexType = torch.complex128

# Wavelet parameters
reg_params = [5e2] # [5e2, 5e1, 1e3, 5e3, 1e4, 5e4]
wavs_list = ['db1','db2','db3','db4','db5','db6','db7','db8']
levels = 4



In [13]:
# %%
# Load image and mask
img, mat_mask = luq.helpers.load_imgs(img_name, repo_dir)

# Aliases
x = img
ground_truth = img

torch_img = torch.tensor(
    np.copy(img), dtype=myType, device=device).reshape((1,1) + img.shape
)

phi = luq.operators.MaskedFourier_torch(
    shape=img.shape, 
    ratio=0.5 ,
    mask=mat_mask,
    norm='ortho',
    device=device
)

y = phi.dir_op(torch_img).detach().cpu().squeeze().numpy()

# Define X Cai noise level
eff_sigma = luq.helpers.compute_complex_sigma_noise(y, input_snr)
sigma = eff_sigma * np.sqrt(2)

# Generate noise
rng = np.random.default_rng(seed=0)
n_re = rng.normal(0, eff_sigma, y[y!=0].shape)
n_im = rng.normal(0, eff_sigma, y[y!=0].shape)
# Add noise
y[y!=0] += (n_re + 1.j*n_im)

# Observation
torch_y = torch.tensor(np.copy(y), device=device, dtype=myComplexType).reshape((1,) + img.shape)
x_init = torch.abs(phi.adj_op(torch_y))


In [14]:
# %%
# Define the likelihood
g = luq.operators.L2Norm_torch(
    sigma=sigma,
    data=torch_y,
    Phi=phi,
)
# Lipschitz constant computed automatically by g, stored in g.beta

# Define real prox
f = luq.operators.RealProx_torch()

# %%


# Prior parameters
reg_param = 5e2

# Define the wavelet dict
# Define the l1 norm with dict psi
psi = luq.operators.DictionaryWv_torch(wavs_list, levels)
h = luq.operators.L1Norm_torch(1., psi, op_to_coeffs=True)
h.gamma = reg_param

# Compute stepsize
alpha = 0.98 / g.beta

# Effective threshold
print('Threshold: ', h.gamma * alpha)




Threshold:  0.0025999027914999657


In [6]:
# Run the optimisation
x_hat, diagnostics = luq.optim.FB_torch(
    torch.clone(x_init),
    options=options,
    g=g,
    f=f,
    h=h,
    alpha=alpha,
    tau=alpha,
    viewer=None
)

Running Base Forward Backward
[Forward Backward] 0 out of 15000 iterations, tol = 5.16e-01
[Forward Backward] converged in 144 iterations


In [14]:
%%timeit

# Run the optimisation
x_hat, diagnostics = luq.optim.FB_torch(
    torch.clone(x_init),
    options=options,
    g=g,
    f=f,
    h=h,
    alpha=alpha,
    tau=alpha,
    viewer=None
)



Running Base Forward Backward
[Forward Backward] 0 out of 15000 iterations, tol = 5.18e-01
[Forward Backward] converged in 207 iterations
Running Base Forward Backward
[Forward Backward] 0 out of 15000 iterations, tol = 5.18e-01
[Forward Backward] converged in 207 iterations
Running Base Forward Backward
[Forward Backward] 0 out of 15000 iterations, tol = 5.18e-01
[Forward Backward] converged in 207 iterations
Running Base Forward Backward
[Forward Backward] 0 out of 15000 iterations, tol = 5.18e-01
[Forward Backward] converged in 207 iterations
Running Base Forward Backward
[Forward Backward] 0 out of 15000 iterations, tol = 5.18e-01
[Forward Backward] converged in 207 iterations
Running Base Forward Backward
[Forward Backward] 0 out of 15000 iterations, tol = 5.18e-01
[Forward Backward] converged in 207 iterations
Running Base Forward Backward
[Forward Backward] 0 out of 15000 iterations, tol = 5.18e-01
[Forward Backward] converged in 207 iterations
Running Base Forward Backward
[For

In [7]:
# %%
np_x_init = to_numpy(x_init)
np_x = np.copy(x)
np_x_hat = to_numpy(x_hat)


In [8]:

print('Dirty image SNR: ', luq.utils.eval_snr(np_x, np_x_init))
print('Dirty image SNR: ', luq.utils.eval_snr(np_x, np_x_hat))


Dirty image SNR:  3.39
Dirty image SNR:  23.68
