In [1]:
import os
import numpy as np
from functools import partial
import math
import time as time

import torch

M1 = False

if M1:
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
else:
    os.environ["CUDA_VISIBLE_DEVICES"] = "2"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if torch.cuda.is_available():
        print(torch.cuda.is_available())
        print(torch.cuda.device_count())
        print(torch.cuda.current_device())
        print(torch.cuda.get_device_name(torch.cuda.current_device()))


import quantifai as qai
from quantifai.utils import to_numpy, to_tensor
from convex_reg import utils as utils_cvx_reg

True
1
0
NVIDIA A100-PCIE-40GB


In [2]:
img_name = "W28"  # 'M31'
# Input noise level
input_snr = 30.0

# Benchmark CRR-NN-based model's optimisation

In [5]:
# Optimisation options for the MAP estimation
options = {"tol": 1e-5, "iter": 15000, "update_iter": 4999, "record_iters": False}
# Save param
repo_dir = "./../.."

# Define my torch types (CRR requires torch.float32)
myType = torch.float32
myComplexType = torch.complex64

# CRR load parameters
sigma_training = 5
t_model = 5
CRR_dir_name = "./../../trained_models/"
# CRR parameters
lmbd = 5e4
mu = 20

# Build observations and operators

In [6]:
# Load image and mask
img, mat_mask = qai.helpers.load_imgs(img_name, repo_dir)

# Aliases
x = img
ground_truth = img

torch_img = torch.tensor(np.copy(img), dtype=myType, device=device).reshape(
    (1, 1) + img.shape
)

phi = qai.operators.MaskedFourier_torch(
    shape=img.shape, ratio=0.5, mask=mat_mask, norm="ortho", device=device
)

y = phi.dir_op(torch_img).detach().cpu().squeeze().numpy()

# Define X Cai noise level
eff_sigma = qai.helpers.compute_complex_sigma_noise(y, input_snr)
sigma = eff_sigma * np.sqrt(2)

# Generate noise
rng = np.random.default_rng(seed=0)
n_re = rng.normal(0, eff_sigma, y[y != 0].shape)
n_im = rng.normal(0, eff_sigma, y[y != 0].shape)
# Add noise
y[y != 0] += n_re + 1.0j * n_im

# Observation
torch_y = torch.tensor(np.copy(y), device=device, dtype=myComplexType).reshape(
    (1,) + img.shape
)
x_init = torch.abs(phi.adj_op(torch_y))

# Define likelihood and prior

In [7]:
# %%
# Define the likelihood
likelihood = qai.operators.L2Norm_torch(
    sigma=sigma,
    data=torch_y,
    Phi=phi,
)
# Lipschitz constant computed automatically by g, stored in g.beta

# Define real prox
prox_op = qai.operators.RealProx_torch()


# %%
# Load CRR model
torch.set_grad_enabled(False)
torch.set_num_threads(4)

exp_name = f"Sigma_{sigma_training}_t_{t_model}/"
if device.type == "cpu":
    CRR_model = utils_cvx_reg.load_model(
        CRR_dir_name + exp_name, "cpu", device_type="cpu"
    )
elif device.type == "cuda":
    CRR_model = utils_cvx_reg.load_model(
        CRR_dir_name + exp_name, "cuda", device_type="gpu"
    )


print(f"Numbers of parameters before prunning: {CRR_model.num_params}")
CRR_model.prune()
print(f"Numbers of parameters after prunning: {CRR_model.num_params}")


# [not required] intialize the eigen vector of dimension (size, size) associated to the largest eigen value
CRR_model.initializeEigen(size=100)
# compute bound via a power iteration which couples the activations and the convolutions
L_CRR = CRR_model.precise_lipschitz_bound(n_iter=100)
# the bound is stored in the model
# L_CRR = model.L.data.item()
print(f"Lipschitz bound {L_CRR:.3f}")

--- loading checkpoint from epoch 10 ---
---------------------
Building a CRR-NN model with 
 - [1, 8, 32] channels 
 - linear_spline activation functions
  (LinearSpline(mode=conv, num_activations=32, init=zero, size=21, grid=0.010, monotonic_constraint=True.))
---------------------
Numbers of parameters before prunning: 13610
---------------------
 PRUNNING 
 Found 22 filters with non-vanishing potential functions
---------------------
Numbers of parameters after prunning: 4183
Lipschitz bound 0.780


# Run optimisation algorithm anc compute the MAP reconstruction

In [8]:
# Compute stepsize
alpha = 0.98 / (likelihood.beta + mu * lmbd * L_CRR)

x_hat = qai.optim.FISTA_CRR_torch(
    x_init=x_init,
    options=options,
    likelihood=likelihood,
    prox_op=prox_op,
    CRR_model=CRR_model,
    alpha=alpha,
    lmbd=lmbd,
    mu=mu,
)

[GD] 0 out of 15000 iterations, tol = 0.102651
[GD] converged in 544 iterations


In [9]:
%%timeit

x_hat = qai.optim.FISTA_CRR_torch(
    x_init=x_init,
    options=options,
    likelihood=likelihood,
    prox_op=prox_op,
    CRR_model=CRR_model,
    alpha=alpha,
    lmbd=lmbd,
    mu=mu,
)



[GD] 0 out of 15000 iterations, tol = 0.102651
[GD] converged in 544 iterations
[GD] 0 out of 15000 iterations, tol = 0.102651
[GD] converged in 544 iterations
[GD] 0 out of 15000 iterations, tol = 0.102651
[GD] converged in 544 iterations
[GD] 0 out of 15000 iterations, tol = 0.102651
[GD] converged in 544 iterations
[GD] 0 out of 15000 iterations, tol = 0.102651
[GD] converged in 544 iterations
[GD] 0 out of 15000 iterations, tol = 0.102651
[GD] converged in 544 iterations
[GD] 0 out of 15000 iterations, tol = 0.102651
[GD] converged in 544 iterations
[GD] 0 out of 15000 iterations, tol = 0.102651
[GD] converged in 544 iterations
634 ms ± 574 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [10]:
# %%
np_x_init = to_numpy(x_init)
np_x = np.copy(x)
np_x_hat = to_numpy(x_hat)

In [11]:
print("Dirty image SNR: ", qai.utils.eval_snr(np_x, np_x_init))
print("Dirty image SNR: ", qai.utils.eval_snr(np_x, np_x_hat))

Dirty image SNR:  3.39
Dirty image SNR:  26.85


# Benchmark wavelet-based model optimisation


In [25]:
# Optimisation options for the MAP estimation
options = {"tol": 1e-5, "iter": 15000, "update_iter": 4999, "record_iters": False}
# Save param
repo_dir = "./../.."


# Define my torch types (CRR requires torch.float32, wavelets require torch.float64)
myType = torch.float64
myComplexType = torch.complex128

# Wavelet parameters
reg_params = [1e2]  # [5e2] # [5e2, 5e1, 1e3, 5e3, 1e4, 5e4]
wavs_list = ["db1", "db2", "db3", "db4", "db5", "db6", "db7", "db8", "self"] # ["db8"]
levels = 4

In [26]:
# %%
# Load image and mask
img, mat_mask = qai.helpers.load_imgs(img_name, repo_dir)

# Aliases
x = img
ground_truth = img

torch_img = torch.tensor(np.copy(img), dtype=myType, device=device).reshape(
    (1, 1) + img.shape
)

phi = qai.operators.MaskedFourier_torch(
    shape=img.shape, ratio=0.5, mask=mat_mask, norm="ortho", device=device
)

y = phi.dir_op(torch_img).detach().cpu().squeeze().numpy()

# Define X Cai noise level
eff_sigma = qai.helpers.compute_complex_sigma_noise(y, input_snr)
sigma = eff_sigma * np.sqrt(2)

# Generate noise
rng = np.random.default_rng(seed=0)
n_re = rng.normal(0, eff_sigma, y[y != 0].shape)
n_im = rng.normal(0, eff_sigma, y[y != 0].shape)
# Add noise
y[y != 0] += n_re + 1.0j * n_im

# Observation
torch_y = torch.tensor(np.copy(y), device=device, dtype=myComplexType).reshape(
    (1,) + img.shape
)
x_init = torch.abs(phi.adj_op(torch_y))

In [27]:
# %%
# Define the likelihood
likelihood = qai.operators.L2Norm_torch(
    sigma=sigma,
    data=torch_y,
    Phi=phi,
)
# Lipschitz constant computed automatically by g, stored in g.beta

# Define real prox
cvx_set_prox_op = qai.operators.RealProx_torch()

# %%


# Prior parameters
reg_param = 1e2

# Define the wavelet dict
# Define the l1 norm with dict psi
psi = qai.operators.DictionaryWv_torch(wavs_list, levels, shape=x_init.shape)
reg_prox_op = qai.operators.L1Norm_torch(1.0, psi, op_to_coeffs=True)
reg_prox_op.gamma = reg_param

# Compute stepsize
alpha = 0.98 / likelihood.beta

# Effective threshold
print("Threshold: ", reg_prox_op.gamma * alpha)

Threshold:  0.0005199805582999931


In [28]:
# Run the optimisation
x_hat_FISTA, diagnostics = qai.optim.FISTA_torch(
    torch.clone(x_init),
    options=options,
    likelihood=likelihood,
    cvx_set_prox_op=cvx_set_prox_op,
    reg_prox_op=reg_prox_op,
    alpha=alpha,
    tau=alpha,
    viewer=None,
)

Running FISTA algorithm
[Forward Backward] 0 out of 15000 iterations, tol = 5.24e-01
[Forward Backward] converged in 230 iterations


In [29]:
%%timeit

# Run the optimisation
x_hat_FISTA, diagnostics = qai.optim.FISTA_torch(
    torch.clone(x_init),
    options=options,
    likelihood=likelihood,
    cvx_set_prox_op=cvx_set_prox_op,
    reg_prox_op=reg_prox_op,
    alpha=alpha,
    tau=alpha,
    viewer=None,
)

Running FISTA algorithm
[Forward Backward] 0 out of 15000 iterations, tol = 5.24e-01
[Forward Backward] converged in 230 iterations
Running FISTA algorithm
[Forward Backward] 0 out of 15000 iterations, tol = 5.24e-01
[Forward Backward] converged in 230 iterations
Running FISTA algorithm
[Forward Backward] 0 out of 15000 iterations, tol = 5.24e-01
[Forward Backward] converged in 230 iterations
Running FISTA algorithm
[Forward Backward] 0 out of 15000 iterations, tol = 5.24e-01
[Forward Backward] converged in 230 iterations
Running FISTA algorithm
[Forward Backward] 0 out of 15000 iterations, tol = 5.24e-01
[Forward Backward] converged in 230 iterations
Running FISTA algorithm
[Forward Backward] 0 out of 15000 iterations, tol = 5.24e-01
[Forward Backward] converged in 230 iterations
Running FISTA algorithm
[Forward Backward] 0 out of 15000 iterations, tol = 5.24e-01
[Forward Backward] converged in 230 iterations
Running FISTA algorithm
[Forward Backward] 0 out of 15000 iterations, tol = 

In [30]:
# %%
np_x_init = to_numpy(x_init)
np_x = np.copy(x)
np_x_hat = to_numpy(x_hat)

In [31]:
print("Dirty image SNR: ", qai.utils.eval_snr(np_x, np_x_init))
print("Estimation image SNR: ", qai.utils.eval_snr(np_x, np_x_hat))

print("FISTA image SNR: ", qai.utils.eval_snr(np_x, to_numpy(x_hat_FISTA)))

Dirty image SNR:  3.39
Estimation image SNR:  26.85
FISTA image SNR:  29.27
