In [1]:
import sys
sys.path.append('..')

In [2]:
import numpy

In [32]:
def Dx(image: numpy.ndarray) -> numpy.ndarray:

    nb_rows, nb_cols = numpy.shape(image)
    image_derivated = numpy.zeros(shape=(nb_rows, nb_cols))

    image_derivated[:, 1:nb_cols] = \
        image[:, 1:nb_cols] - image[:, 0:nb_cols-1]

    image_derivated[:, 0] = image[:, 0] - image[:, nb_cols]

    return image_derivated

def Dy(image: numpy.ndarray) -> numpy.ndarray:
    
    nb_rows, nb_cols = numpy.shape(image)
    image_derivated = numpy.zeros(shape=(nb_rows, nb_cols))
    
    image_derivated[1:nb_rows, :] = \
        image[1:nb_rows, :] - image[0:nb_rows-1, :]

    image_derivated[0, :] = image[0, :] - image[nb_rows-1, :]

    return image_derivated

def Dxt(image: numpy.ndarray) -> numpy.ndarray:

    nb_rows, nb_cols = numpy.shape(image)
    image_derivated = numpy.zeros(shape=(nb_rows, nb_cols))
    
    image_derivated[:, 0:nb_cols-1] = \
        image[:, 0:nb_cols-1] - image[:, 1:nb_cols]

    image_derivated[:, nb_cols-1] = image[:, nb_cols-1] - image[:, 1]

    return image_derivated

def Dyt(image: numpy.ndarray) -> numpy.ndarray:
    
    nb_rows, nb_cols = numpy.shape(image)
    image_derivated = numpy.zeros(shape=(nb_rows, nb_cols))
    
    image_derivated[0:nb_rows-1, :] = \
        image[0:nb_rows-1, :] - image[1:nb_rows, :]

    image_derivated[nb_rows-1, :] = image[nb_rows-1, :] - image[0, :]

    return image_derivated




In [47]:
def decimation(image: numpy.ndarray, d: int) -> numpy.ndarray:
    if d <= 0:
        raise AssertionError('d <= 0')
    return numpy.copy(image[0::d, 0::d])

def decimation_transpose(image: numpy.ndarray, d: int) -> numpy.ndarray:
    nb_rows, nb_cols = image.shape
    out = numpy.zeros((d*nb_rows, d*nb_cols))
    out[0::d, 0::d] = numpy.copy(image)
    return out

# def blockproc(im: numpy.ndarray, block_sz: numpy.ndarray, func):
#     h, w = im.shape
#     m, n = block_sz
#     for x in range(0, h, m):
#         for y in range(0, w, n):
#             block = im[x:x+m, y:y+n]
#             block[:,:] = func(block)
#     return im

# blockproc(decimation_transpose(arr, d), (3, 3), print)
# def block_mm(nr, nc, nb, m, x1):
#     x1 = blockproc()

import lasp.utils

def fourier_diagonalization(kernel: numpy.ndarray, shape_out: numpy.ndarray) -> numpy.ndarray:
    nb_rows, nb_cols = kernel.shape
    kernel_padded = numpy.zeros(shape_out)
    kernel_padded[:nb_rows, :nb_cols] = numpy.copy(kernel)
    center = numpy.round(shape_out/2)
    return numpy.fft.fft2(lasp.utils.circshift(kernel_padded, 1-center))
    


In [45]:


def split_bregman(
    g: numpy.ndarray,
    h: numpy.ndarray,
    beta0: float,
    beta1: float,
    sigma: float,
    d: float,
    nb_iterations: int
) -> numpy.ndarray:
    
    # Define decipation
    ## decimation: S
    # S = lambda x: decimation(x, d)
    ## decimation transposed: S^{T}
    ST = lambda x: decimation_transpose(x, d)
    ## STg
    g_decim_transp = decimation_transpose(g, d)

    
    # Compute eigens values of H_{BCCB}
    ## H_{BCCB} = F^{T} h_bccb_diag F
    h_bccb_diag = fourier_diagonalization(h, g.shape)
    h2_bccb_diag = h_bccb_diag*numpy.conj(h_bccb_diag)

    # Create laplacian approximation Delta
    laplacian = numpy.array(
        [
            [-1, 0, -1], 
            [0, 4, 0], 
            [-1, 0, -1]
        ]
    )
    laplacian_diag = fourier_diagonalization(laplacian, h2_bccb_diag.shape)

    # laplacian_pad = numpy.zeros_like(g)
    # laplacian_pad[0:3, 0:3] = laplacian
    # center = numpy.array([numpy.round(3/2), numpy.round(3/2)], dtype=int)
    # laplacian_pad_shifted = lasp.utils.circshift(laplacian_pad, 1-center)
    # laplacian_diagonalized_fft = numpy.fft.fft2(laplacian_pad_shifted)

    f = numpy.copy(g)
    d_x = numpy.zeros_like(g)
    d_y = numpy.zeros_like(g)
    b_x = numpy.zeros_like(g)
    b_y = numpy.zeros_like(g)


    # Compute constant terms
    cst1 = (1/d) * (h2_bccb_diag + (2*beta0+sigma) * numpy.fft.fft2(laplacian_diag))

    ## H^T S^{T} y
    cst2 = numpy.fft.ifft2(h_bccb_diag * numpy.fft.fft2(ST(g)))


    for no_iter in range(0, nb_iterations):

        # Compute f^{k+1}
        a = cst1
        b = sigma*(Dxt(d_x-b_x)+Dyt(d_y-b_y)) + cst2
        f = numpy.fft.ifft2(numpy.fft.fft2(b) / a) 
        
        # Compute d_{x}^{k+1} and d_{y}^{k+1}
        s_x = Dx(f) + b_x
        s_y = Dy(f) + b_y
        s = numpy.sqrt(s_x**2 + s_y**2)
        coef = numpy.max(s - (beta1 / sigma), 0)
        d_x = coef * (s_x / s)
        d_y = coef * (s_y / s)

        # Compute b_{x}^{k+1} and b_{y}^{k+1}
        b_x = b_x + (Dx(f) - d_x)
        b_y = b_y + (Dy(f) - d_y)


    return f


    


In [46]:
y = numpy.arange(1, 17).reshape((4, 4))
print(y)

d = 2
 
S = lambda x: decimation(x, d)
ST = lambda x: decimation_transpose(x, d)

print(S(y))
print(ST(y))


[[ 1  2  3  4]
 [ 5  6  7  8]
 [ 9 10 11 12]
 [13 14 15 16]]
[[ 1  3]
 [ 9 11]]
[[ 1.  0.  2.  0.  3.  0.  4.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.]
 [ 5.  0.  6.  0.  7.  0.  8.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.]
 [ 9.  0. 10.  0. 11.  0. 12.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.]
 [13.  0. 14.  0. 15.  0. 16.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.]]


array([[ 1,  2,  1,  4],
       [ 5,  6,  7,  8],
       [ 9, 10, 11, 12],
       [13, 14, 15, 16]])