In [None]:
import numpy

import eaglecore.utils
import eaglecore.differential
import eaglecore.thresholding
import eaglecore.signal.processing

def tv(
    y: numpy.ndarray, 
    h: numpy.ndarray, 
    lamda: float, 
    sigma: float, 
    nb_iterations: int,
    tol: float = 0
) -> numpy.ndarray:
    """Total Variation Regularization with Split Bregman
    """
    
    lap_diag = eaglecore.utils.fourier_diagonalization(
        kernel = lasp.filters.linear.laplacian(),
        shape_out = y.shape
    )

    h_diag = eaglecore.utils.fourier_diagonalization(
        kernel=h,
        shape_out=y.shape
    )

    h2_diag = numpy.abs(h_diag)**2


    cst1 = h2_diag + sigma * lap_diag
    cst2 = numpy.conj(h_diag) * numpy.fft.fft2(y)
   

    # INitialization
    u = numpy.copy(y) 
    d_x=numpy.zeros_like(y)
    d_y=numpy.zeros_like(y)
    b_x=numpy.zeros_like(y)
    b_y=numpy.zeros_like(y)

    for i in range(1, nb_iterations+1):

        a = sigma * (
            eaglecore.differential.dxT(d_x-b_x)
            + eaglecore.differential.dyT(d_y-b_y)
        )

        b = numpy.fft.fft2(a) + cst2

        u0 = numpy.copy(u)
        
        u = numpy.real(numpy.fft.ifft2(b / cst1))

        err = numpy.linalg.norm(u-u0, 'fro') \
            / numpy.linalg.norm(u, 'fro')

        if i%10 == 0:
            print('Iterations: {} ! \t error is: {}'.format(i, err))

        if err <= tol:
            break

        u_dx = eaglecore.differential.dx(u)
        u_dy = eaglecore.differential.dy(u)

        d_x, d_y = eaglecore.thresholding.multidimensional_soft(
            numpy.array([u_dx+b_x, u_dy+b_y]),
            lamda/sigma
        )

        b_x += (u_dx-d_x)
        b_y += (u_dy-d_y)

    u = eaglecore.signal.processing.normalize(
        signal = u,
        new_min = 0.0,
        new_max = 1.0
    )

    return u

In [None]:
import numpy
import numpy.fft

def total_variation_step(func_dx, func_dy, func_dxT, func_dyT, lamda, sigma, cst1, cst2):
    a = sigma * ( func_dxT(d_x-b_x) + func_dyT(d_y-b_y) )

    b = numpy.fft.fft2(a) + cst2

    u0 = numpy.copy(u)
    
    u = numpy.real(numpy.fft.ifft2(b / cst1))

    err = numpy.linalg.norm(u-u0, 'fro') \
        / numpy.linalg.norm(u, 'fro')

    if i%10 == 0:
        print('Iterations: {} ! \t error is: {}'.format(i, err))

    if err <= tol:
        break

    u_dx = func_dx(u)
    u_dy = func_dy(u)

    d_x, d_y = lasp.thresholding.multidimensional_soft(
        numpy.array([u_dx+b_x, u_dy+b_y]),
        lamda/sigma
    )

    b_x += (u_dx-d_x)
    b_y += (u_dy-d_y)


# class ADMMTotalVariation:
    
#     def __init__(self) -> None:
#         self.h_diag = None
#         self.h2_diag
#         self.d = None
#         self.b = None
        
#     def initialize()
        