In [1]:
from blurring import blurring
from besov_prior import besov_prior
from inverse_problem import inverse_problem
from rto_mh import rto_mh
import numpy as np
import cuqi
import time
from scipy.special import gamma
import arviz as az

In [2]:
def Test_Signal(x):
    Signal = np.zeros(len(x))
    interval_1 = np.logical_and(0.1 <= x, x <= 0.35)
    interval_2 = np.logical_and(0.35 < x, x <= 0.45)
    interval_3 = np.logical_and(0.45 < x, x <= 0.60)
    interval_4 = np.logical_and(0.60 < x, x < 0.90)
    Signal[interval_1] = np.exp(-(x[interval_1]-0.35)**2*150)
    Signal[interval_2] = 1.0
    Signal[interval_3] = 0.8
    Signal[interval_4] = 0.40
    return Signal

In [3]:
np.random.seed(5)
N = 4096
x = np.linspace(0, 1, N, endpoint=False)
signal = Test_Signal(x)
likelihood = blurring(x,sigma_kernel=0.02)
likelihood.set_data(signal,noise_level=2.0)
lam = likelihood.lam
Data = likelihood.data[0::128]
m = 32

In [None]:
np.random.seed(5)
n_range = [32, 64, 128, 256, 512, 1024, 2048]
delt = 1
wavelets = ['db1','db8']
level = 0 
nsamp = 10000
for wavelet in wavelets:
    J = 5
    for n in n_range:
            x_n = np.linspace(0, 1, n, endpoint=False)
            x0 = np.ones(len(x_n))
            likelihood = blurring(x_n,sigma_kernel=0.02, Factor= int(n/m))
            likelihood.data = Data
            likelihood.lam = lam
            prior = besov_prior(J=J, delt=delt, level=level,s=1.0,p=1.5,wavelet=wavelet)
            jac_const = likelihood.jac_const(n) @ prior.jac_const(n)
            problem = inverse_problem(likelihood,prior,jac_const)
            Nrand = m+n
            rto_sampler=rto_mh(x0,Nrand,samp=nsamp)
            z_Map=rto_sampler.initialize_Q(problem)
            rto_sampler.x0 = z_Map
            xchain, acc_rate, index_accept, log_c_chain = rto_sampler.sample(problem)
            np.save("Discrete_Invariant_Samples_"+wavelet + "_n=" + str(n) + "_s=1.0_p=1.5" + ".npz",xchain)
            np.save("Discrete_Invariant_index_accept_"+wavelet+"_n=" + str(n)+ "_s=1.0_p=1.5" +  ".npz",index_accept)
            print(acc_rate)
            J += 1

`ftol` termination condition is satisfied.
Function evaluations 6, initial cost 3.5286e+04, final cost 5.5639e+01, first-order optimality 1.01e-05.
1.4723540938156667e-05
1.0
`ftol` termination condition is satisfied.
Function evaluations 7, initial cost 3.5898e+04, final cost 9.8962e+01, first-order optimality 5.87e-05.
9.032807731016026e-05
1.0
`ftol` termination condition is satisfied.
Function evaluations 8, initial cost 3.6873e+04, final cost 1.6796e+02, first-order optimality 9.30e-05.
0.00015292166437982203
1.0
`ftol` termination condition is satisfied.
Function evaluations 8, initial cost 3.7812e+04, final cost 2.7563e+02, first-order optimality 1.91e-04.
0.0003200217782789977
1.0
`ftol` termination condition is satisfied.
Function evaluations 8, initial cost 3.8649e+04, final cost 4.4113e+02, first-order optimality 2.54e-04.
0.00041430824377798614
1.0
`ftol` termination condition is satisfied.
Function evaluations 10, initial cost 3.9449e+04, final cost 6.9532e+02, first-order

In [16]:
np.random.seed(5)
J = 9
x=np.linspace(0, 1, 2**J, endpoint=False)
signal = Test_Signal(x)
likelihood = blurring(x,sigma_kernel=0.02)
likelihood.set_data(signal,noise_level=2.0)
n = len(x)
m = len(likelihood.data)

In [17]:
p = 1.5
s = 1.0
wavelet = 'db1'
level = 0
delt = 1.0
prior = besov_prior(J = J,delt=delt, level=level,s=s,p=p,wavelet=wavelet)
jac_const = likelihood.jac_const(n) @ prior.jac_const(n)
problem = inverse_problem(likelihood,prior,jac_const)
Nrand = n+m
x0 = np.ones(n)
nsamp = 1400
rto_sampler=rto_mh(x0,Nrand,samp=nsamp)
zMap=rto_sampler.initialize_Q(problem)
MAP = prior.transform(zMap)

`ftol` termination condition is satisfied.
Function evaluations 7, initial cost 6.0895e+05, final cost 7.7611e+02, first-order optimality 3.28e-04.
0.0006521909648807057


In [18]:
np.random.seed(5)
logpdf = lambda x:-delt/((np.sqrt(gamma(1/p)/gamma(3/p)))**p)*np.linalg.norm(prior.wavelet_weigth(x),ord=p)**p
gradient = lambda x:-p*delt/((np.sqrt(gamma(1/p)/gamma(3/p)))**p)*prior.wavelet_weight_adjoint(np.sign(prior.wavelet_weigth(x))*np.abs(prior.wavelet_weigth(x))**(p-1))
xx = cuqi.distribution.UserDefinedDistribution(dim=n,logpdf_func=logpdf,gradient_func=gradient)
model = cuqi.model.LinearModel(likelihood.jac_const(n))
y = cuqi.distribution.Gaussian(model(xx),likelihood.lam**2)
joint = cuqi.distribution.JointDistribution(y,xx)
posterior = joint(y=likelihood.data)
sampler = cuqi.sampler.NUTS(posterior,MAP,adapt_step_size=True,opt_acc_rate=0.8)
t0 = time.process_time()
chain_NUTS = sampler.sample(1000,400)
t1 = time.process_time()
total_time_NUTS = t1-t0
print(total_time_NUTS)
np.save("Comparison_samples_deconvolution_NUTS.npy",chain_NUTS.samples)
np.save("Time_NUTS.npy",total_time_NUTS)


Sample 1400 / 1400
634.859375


In [19]:
np.random.seed(5)
rto_sampler.x0=zMap
t0 = time.process_time()
chain_RTO, acc_rate, index_accept ,log_c_chain  = rto_sampler.sample(problem)
t1 = time.process_time()
total_time_RTO = t1-t0
print(acc_rate)
print(total_time_RTO)
chain_RTO_accept = chain_RTO[:,index_accept]
np.save("Comparison_samples_deconvolution_RTO.npy",chain_RTO_accept[:,0:1000])
np.save("Time_RTO.npy",total_time_RTO)

0.7057142857142857
757.59375


In [20]:
chain_RTO = np.load("Comparison_samples_deconvolution_RTO.npy")
time_RTO = np.load("Time_RTO.npy")
RTO_samples = chain_RTO
ESS_RTO = np.zeros(n)
for i in range(n):
    ESS_RTO[i]=az.ess(RTO_samples[i,:])
print(ESS_RTO.min(),np.median(ESS_RTO),ESS_RTO.max())
print(ESS_RTO.min()/time_RTO,np.median(ESS_RTO)/time_RTO,ESS_RTO.max()/time_RTO)
chain_NUTS=np.load("Comparison_samples_deconvolution_NUTS.npy")
time_NUTS = np.load("Time_NUTS.npy")
ESS_NUTS = np.zeros(n)
for i in range(n):
    ESS_NUTS[i]=az.ess(chain_NUTS[i,:])
print(ESS_NUTS.min(),np.median(ESS_NUTS),ESS_NUTS.max())
print(ESS_NUTS.min()/time_NUTS,np.median(ESS_NUTS)/time_NUTS,ESS_NUTS.max()/time_NUTS)

580.653631570516 954.1456042915211 1178.2184998559271
0.7664445906140541 1.2594422859105174 1.5552114835370898
240.20501208935383 677.5247660702041 1064.8625096460817
0.3783593998109484 1.0672044751173504 1.6773202878922306
