In [1]:
import numpy as np 
from PIL import Image 
import pywt

In [2]:
# Define Params for the Algorithm
lambd = 1.0 # currently L_y=5
gamma = 0.98 /(5 + 1/lambd)
beta = 0.8
sigma = 1
dim_Y = 256
d = dim_Y**2

# Get the slices
coeffs = pywt.wavedec2(np.ones((dim_Y ,dim_Y)), 'haar', level=4)
_, slices = pywt.coeffs_to_array(coeffs)

In [3]:
def get_nabla_f(sigma, y):
    """Returns the gradient of the function f w.r.t to X 
    """
    # Precalculate variables to save time
    psi_y = pywt.coeffs_to_array(pywt.wavedec2(y, 'haar', level=4))[0]
    sig_2 = sigma**2
    def nabla_f(X):
        return - (psi_y - X) / sig_2
    
    return nabla_f

In [4]:
def proximal(X, theta):
    """Computes the proximal operator parametrized by lambda and theta^T * g of X     
    
    Note:   Our prox operator of x is argmin_{\bar{x}}(\theta * |\bar{x}| + \frac{\bar{x} - x}{2\lambda})
            which is equivalent to argmin_{\bar{x}}(\theta * \lambda * |\bar{x}| + \frac{\bar{x} - x}{2})
    The code below represents the hand calculated solution we found for the argmin
    """
    X_bar = np.copy(X)
    for i in range(len(X)):
        for j in range(len(X[0])):
            sign = 1 if X_bar[i, j] > 0 else -1
            X_bar[i, j] = sign * max(0, np.absolute(X_bar[i, j]) - lambd * theta)
    return X_bar

In [5]:
def myula_kernel(X, nabla_f, theta):
    """Returns the new value of X based on the MYULA kernel
    """
    Z = np.random.normal(size=X.shape)
    return X - gamma * nabla_f(X) - gamma * (X - proximal(X, theta))/lambd + np.sqrt(2*gamma) * Z

In [6]:
def projection(a):
    return 0.000001 if a < 0 else a

In [38]:
def sapg_scalar_homogeneous(theta, X, nabla_f, N, true_X):
    """SAPG Algorithm for scalar theta and alpha positively homogeneous regularizer g"""
    thetas = [theta]
    alpha = 1/theta
    for i in range(N):
        # i**(-beta) leads to div by 0 (duh..)
        delta = alpha * (i+1)**(-beta) / dim_Y
        X = myula_kernel(X, nabla_f, thetas[i])
        #print(np.linalg.norm(X - true_X))
        print(np.linalg.norm(X, ord=1))
        if i % 50 == 0:
            alpha = 1 / theta
        theta = projection(theta + delta * (dim_Y / (alpha * theta) - np.linalg.norm(X, ord=1))) # g(x) = |x|_1
        thetas.append(theta)
    return np.mean(thetas), thetas

In [39]:
def generate_data():
    X = np.random.laplace(size=(dim_Y, dim_Y))
    coeffs = pywt.array_to_coeffs(X, slices, output_format='wavedec2')
    y = pywt.waverec2(coeffs, 'haar')
    
    return X, y

In [40]:
X, y = generate_data()
nabla_f = get_nabla_f(sigma, y)
sapg_scalar_homogeneous(1.1, np.zeros((dim_Y, dim_Y)), nabla_f, 200, X)

143.21910287691247
185.41435914693645
202.59456407093518
222.33295415648794
231.2994894989464
244.4379764592257
241.1552477384551
249.10388476513708
254.18167479927442
260.9264887857433
264.9777580917864
257.99913248633953
267.02923545794164
272.12939586442474
262.84947263139134
269.1739651141104
277.5907225263106
284.3190088410619
286.58883548542167
280.80576132680045
280.0666410534094
284.02700228075526
277.9312108339806
275.8643886046274
282.3268365893305
291.6034209964904
288.05711435453463
289.1053021281903
300.53482754813797
299.656010977787
281.8121703651435
287.6221543985869
283.37585458396575
281.2927089127204
289.5177936665635
285.7110338789344
288.0360540010495
285.4637726498747
282.50687649159556
278.03520983000976
283.4156433678135
277.92037222146115
279.0997322619291
280.42585566311044
285.64871398547535
285.497023704955
288.1920843131015
293.1240038698393
284.6985902801497
281.133672106355
284.8069877101751
283.720752206574
284.0875124376154
293.04362811318674
286.718993

(0.918256135385917,
 [1.1,
  1.500500344897328,
  1.5051024280726883,
  1.4822497784957152,
  1.4443520228201134,
  1.408748364875707,
  1.3710226285102125,
  1.3442494457891891,
  1.3175932369497736,
  1.2928198818643604,
  1.2685579269073244,
  1.2461369191502003,
  1.230560991875431,
  1.213136175558838,
  1.1959346965801183,
  1.1847915242665255,
  1.1726208889869016,
  1.1588360732366834,
  1.1443058082920508,
  1.1306653873116363,
  1.120402449338266,
  1.1114713875302782,
  1.102285291301323,
  1.0957924813992794,
  1.090517352081267,
  1.084000372412846,
  1.0756605675529165,
  1.068982583850126,
  1.0626411213128717,
  1.0541081166997421,
  1.0465101535772623,
  1.0436127040881855,
  1.0396642281345627,
  1.036953213064464,
  1.0348963595492724,
  1.0312989259283674,
  1.028742300110675,
  1.0259157552528675,
  1.0237923353398735,
  1.022380656793363,
  1.0218980067283836,
  1.020469083391492,
  1.0201175392327158,
  1.019582562149769,
  1.0188541179569064,
  1.017289536530086