In [1]:
import numpy as np
import skimage as ski
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import interactive
from IPython.display import display

In [2]:
def soft_shrinkage(x, alpha):
    return np.maximum(np.abs(x)-alpha, 0)*np.sign(x)

def gradient_step(x, y, A, A_adjoint, t):
    grad = A_adjoint(A(x) - y)
    return x - 2*t*grad

def ista(y, A, A_adjoint, prox, t, lamda, iter):
    x = ski.transform.iradon(y*y.shape[-2])

    for i in range(iter):#
        e = np.linalg.norm(A(x)-y)**2 + lamda*np.linalg.norm(x, ord=1)
        print('Iteration ' + str(i) + ', Energy: ' + str(e))
        lin_up = gradient_step(x, y, A, A_adjoint, t)
        x = prox(lin_up, lamda*t)
        
    return x


In [3]:
x_opt = ski.transform.resize(ski.data.shepp_logan_phantom(), (100,100))
theta = np.linspace(0,180, endpoint= False, num=100)

def radon(x):
    return ski.transform.radon(x, theta)/x.shape[-1]

def radon_adjoint(y):
    return ski.transform.iradon(y,theta,filter_name=None) / (y.shape[-2]*np.pi/(2 * len(theta)))

y = radon(x_opt) + np.random.normal(loc = 0, scale= 0.01, size = [100,100])
 
approx = []
for lamda in range(10):
    approx.append(ista(y, radon, radon_adjoint, soft_shrinkage, t=0.0003, lamda = lamda+1, iter = 50))


Iteration 0, Energy: 27.52624193448563
Iteration 1, Energy: 27.49698874508569
Iteration 2, Energy: 27.46824311980224
Iteration 3, Energy: 27.440004265696217
Iteration 4, Energy: 27.412272148438863
Iteration 5, Energy: 27.385045923363567
Iteration 6, Energy: 27.358323648719008
Iteration 7, Energy: 27.332105730213318
Iteration 8, Energy: 27.30638988058317
Iteration 9, Energy: 27.281175956867298
Iteration 10, Energy: 27.256461627851156
Iteration 11, Energy: 27.23224911261906
Iteration 12, Energy: 27.208540296434666
Iteration 13, Energy: 27.185335191587143
Iteration 14, Energy: 27.162632158963216
Iteration 15, Energy: 27.1404280168305
Iteration 16, Energy: 27.11871823304807
Iteration 17, Energy: 27.09750586580371
Iteration 18, Energy: 27.076790186157
Iteration 19, Energy: 27.056565452724747
Iteration 20, Energy: 27.036830843939082
Iteration 21, Energy: 27.017646004347398
Iteration 22, Energy: 26.999199504451145
Iteration 23, Energy: 26.981242768474164
Iteration 24, Energy: 26.9637902936448

KeyboardInterrupt: 

In [None]:
def plot_lamda(idx):
    
    fig = plt.figure()
    plt.imshow(approx[idx])


slider = widgets.IntSlider(min = 0, max = 9, step = 1, value = 0, continuous_update = True)
interactive_plot = interactive(plot_lamda, idx = slider)
display(interactive_plot)

In [None]:
plt.imshow(ski.transform.iradon(y,theta))
plt.figure()
plt.imshow(approx)

In [None]:
alpha = 3
x = np.linspace(-10,10,num=100)
prox_x = soft_shrinkage(x, alpha)
plt.plot(x, prox_x)