# Mumford Shah Fast Super-Resolution : Exploration

## Google Colab Integration

In [230]:
# !git clone https://github.com/Stage-SuperResolution/MumfordShah.git
# !mv MumfordShah/* .
# !rm -rf MumfordShah
# !pip install git+https://github.com/akhaten/lasp.git@a22e3634d8326742d27d7ea7fa216b59b46a573a

## Import

In [231]:
# Lasp module

import sys
sys.path.append('../../..')

import lasp.io
import lasp.filters.linear
import lasp.noise
import lasp.convert
import lasp.algorithm.experimental
import lasp.thresholding
import lasp.differential
import lasp.utils

# Other

import scipy.signal
import scipy.io.matlab
import numpy
import pandas
import tqdm
import pathlib
import matplotlib.pyplot
import typing
import yaml

# PATH = pathlib.Path('./5-MumfordShahFSR')
# if not(PATH.exists()):
#     PATH.mkdir()

## Mumford-Shah FSR

In [232]:
def mumford_shah_fsr(
    y: numpy.ndarray, 
    h: numpy.ndarray, 
    alpha: float,
    beta0: float,
    beta1: float,
    sigma: float,
    d: int,
    nb_iterations: int,
    tolerance: float,
    gamma: float = 0.,
    error_history: list[float] = None
) -> numpy.ndarray:


    """Mumford Shah
    # TODO: make test
    Solve $$argmin_{x} { (alpha/2) || y - Hx ||^2 + (beta0/2) || nabla y ||^2 + beta1 || nabla y ||_1$$

    Params:
        - y: low resolution
        - h: deblur kernel
        - alpha: hyper parameter of data fidelity
        - beta0: hyper parameter of dirichlet energy
        - beta1: hyper parameter of total variation
        - sigma: split-bregman hyper parameter
        - d: decimation
        - nb_iterations: number of iteration
        - tolerance: tolerance
        - error_history: save errors of each iteration
    
    Returns:
        - high resolution of y
    """

    def block_mm(nr, nc, Nb, x1, order: str) -> numpy.ndarray:

        block_shape = numpy.array([nr, nc])

        x1 = lasp.utils.blockproc_reshape(x1, block_shape, order)
        x1 = numpy.reshape(x1, newshape=(nr*nc, Nb), order=order)
        x1 = numpy.sum(x1, axis=1)
        x = numpy.reshape(x1, newshape=(nr, nc), order=order)

        return x


    # print(alpha, beta0, beta1)

    y_rows, y_cols = y.shape

    kernel2d_id = numpy.pad(numpy.array([[1]]), pad_width=1)

    dx_oper = lasp.differential.dx(kernel2d_id)
    dx_diag = lasp.utils.fourier_diagonalization(dx_oper, (d*y_rows, d*y_cols))
    dxT_diag = numpy.conj(dx_diag)

    dy_oper = lasp.differential.dy(kernel2d_id)
    dy_diag = lasp.utils.fourier_diagonalization(dy_oper,(d*y_rows, d*y_cols))
    dyT_diag = numpy.conj(dy_diag)

    # Build kernel

    lap_diag = dxT_diag * dx_diag + dyT_diag * dy_diag + gamma
   
    h_diag = lasp.utils.fourier_diagonalization(
        kernel = h,
        shape_out = numpy.array([d*y_rows, d*y_cols])
    )

    h_diag_transp = numpy.conj(h_diag)

    h2_diag = numpy.abs(h_diag)**2
 
    STy = numpy.zeros(shape=(d*y_rows, d*y_cols))
    STy[0::d, 0::d] = numpy.copy(y)
    rhs1fft = alpha * h_diag_transp * numpy.fft.fft2(STy)


    # Initialization
    import PIL.Image
    u = numpy.array(
        PIL.Image.Image.resize(
            PIL.Image.fromarray(y),
            (y_rows*d, y_cols*d),
            PIL.Image.Resampling.BICUBIC
        )
    )
    # u = numpy.copy(y) 
    d_x=numpy.zeros_like(u)
    d_y=numpy.zeros_like(u)
    b_x=numpy.zeros_like(u)
    b_y=numpy.zeros_like(u)

    for _ in range(0, nb_iterations):

        rhs2fft = sigma*dxT_diag*numpy.fft.fft2(d_x-b_x) \
            + sigma*dyT_diag*numpy.fft.fft2(d_y-b_y)
        rhsfft = rhs1fft + rhs2fft

        u_prev = numpy.copy(u)

        # Inverse
        #u_fft = rhsfft / uker
        ## Parameters
        # fr = rhsfft
        # fb = h_diag
        # fbc = numpy.conj(h_diag)
        # f2b = h2_diag
        # nr, nc = y.shape
        # m = nr * nc
        # f2d = lap_diag
        # nb = d*d
        ##
        # x1 = h_diag*rhsfft / lap_diag
        x1 = h_diag*rhsfft / lap_diag
        fbr = block_mm(y_rows, y_cols, d*d, x1, order='F')
        # invW = block_mm(y.shape[0], y.shape[1], d*d, h2_diag / lap_diag, order='F')
        invW = block_mm(y_rows, y_cols, d*d, h2_diag / lap_diag, order='F')
        # invWBR = fbr / (invW + beta1*d*d)
        invWBR = fbr / ( invW + (beta0+sigma) * (d*d / alpha) )
        fun = lambda block : block*invWBR
        FCBinvWBR = lasp.utils.blockproc(numpy.copy(h_diag_transp), numpy.array([y_rows, y_cols]), fun)
        ## Returns
        u_fft = (rhsfft - FCBinvWBR) / lap_diag
        u_fft /= (beta0 + sigma)
        # u_fft /= beta1
        ##########
        
        # u_fft = rhsfft / uker

        # Compute errors
        u = numpy.real(numpy.fft.ifft2(u_fft))    

        err = numpy.linalg.norm(u-u_prev, 'fro') / numpy.linalg.norm(u, 'fro')
        
        if not(error_history is None):
            error_history.append(err)

        if err < tolerance:
            break
        
        
        u_dx = numpy.real(numpy.fft.ifft2(dx_diag * u_fft))
        u_dy = numpy.real(numpy.fft.ifft2(dy_diag * u_fft))


        d_x, d_y = lasp.thresholding.multidimensional_soft(
            d = numpy.array(
                [ 
                    u_dx + b_x, 
                    u_dy + b_y 
                ]
            ),
            epsilon = beta1 / sigma
        )

        b_x += (u_dx - d_x)
        b_y += (u_dy - d_y)

    u_normalized = lasp.utils.normalize(u)

    return u_normalized

## Dataset utils

## Make images/inputs

In [233]:
import dataset
IMAGES_PATH = pathlib.Path('../0-Images')
GENERATION_PATH = pathlib.Path('./images-generated')
if not(GENERATION_PATH.exists()):
    GENERATION_PATH.mkdir()
IMG_GEN_PKL = GENERATION_PATH / 'df_imgs.pkl'

In [234]:
df_imgs = pandas.DataFrame(columns=[ 'original', 'blur', 'decimation', 'noise' ])
noise = lasp.convert.snrdb_to_snr(60)
df_imgs = dataset.add_image(df_imgs, IMAGES_PATH / 'Baboon.bmp', (15, 5), 2, noise)
df_imgs = dataset.add_image(df_imgs, IMAGES_PATH / 'Baboon.bmp', (7, 3), 2, noise)
# df_imgs.attrs['imgs_path'] = IMAGES_PATH
df_imgs.attrs['imgs_gen_path'] = str(GENERATION_PATH)
print(df_imgs.attrs)
df_imgs

{'imgs_gen_path': 'images-generated'}


Unnamed: 0,original,blur,decimation,noise
0,../0-Images/Baboon.bmp,"(15, 5)",2,1000.0
1,../0-Images/Baboon.bmp,"(7, 3)",2,1000.0


In [235]:
# if not(GENERATION_PATH.exists()):
#     GENERATION_PATH.mkdir()
# dataset.make_images(df_imgs, GENERATION_PATH)

In [236]:
# with open(GENERATION_PATH / 'metadatas.yml', 'w') as file:
#     yaml.dump(df_imgs.attrs, file)
# pandas.to_pickle(df_imgs, IMG_GEN_PKL)

## Make Mumford-Shah parameters

In [237]:
OUTPUT_PATH = pathlib.Path('./outputs_1')
if not(OUTPUT_PATH.exists()):
    OUTPUT_PATH.mkdir()
    
PARAMS_PKL = OUTPUT_PATH / 'df_params.pkl'

In [238]:
df_params = pandas.DataFrame(columns=['df_imgs_index', 'deblur_kernel', 'alpha', 'beta0', 'beta1', 'sigma', 'iterations', 'tol'])

alpha, sigma = 1, 2
iterations = 100

beta0_inf, beta0_sup = 0, 1
beta0_n = 5
beta0_step = (beta0_sup-beta0_inf) / beta0_n

beta1_inf, beta1_sup = 0, 1
beta1_n = 5
beta1_step = (beta1_sup-beta1_inf) / beta1_n

range_beta0 = numpy.arange(beta0_inf, beta0_sup, beta0_step)
range_beta1 = numpy.arange(beta1_inf, beta1_sup, beta1_step)

for beta0 in range_beta0:
    for beta1 in  range_beta1:
        df_params = dataset.add_params(df_params, 0, (15, 5), alpha, beta0,  beta1, sigma, iterations, 0)


begin = 0

df_params.attrs['output_path'] = str(OUTPUT_PATH)
print(df_params.attrs)
df_params

{'output_path': 'outputs_1'}


Unnamed: 0,df_imgs_index,deblur_kernel,alpha,beta0,beta1,sigma,iterations,tol
0,0,"(15, 5)",1,0.0,0.0,2,100,0
1,0,"(15, 5)",1,0.0,0.2,2,100,0
2,0,"(15, 5)",1,0.0,0.4,2,100,0
3,0,"(15, 5)",1,0.0,0.6,2,100,0
4,0,"(15, 5)",1,0.0,0.8,2,100,0
5,0,"(15, 5)",1,0.2,0.0,2,100,0
6,0,"(15, 5)",1,0.2,0.2,2,100,0
7,0,"(15, 5)",1,0.2,0.4,2,100,0
8,0,"(15, 5)",1,0.2,0.6,2,100,0
9,0,"(15, 5)",1,0.2,0.8,2,100,0


In [239]:
with open(OUTPUT_PATH / 'metadatas.yml', 'w') as file:
    yaml.dump(df_params.attrs, file)
pandas.to_pickle(df_params, PARAMS_PKL)

## Run Mumford-Shah

In [240]:
def run(
    df_to_process: pandas.DataFrame, 
    index: int
) -> None:

    imgs_gen_path = pathlib.Path(df_to_process.attrs['imgs_gen_path'])
    params = df_to_process.loc[index]
    # print(params.keys())

    df_imgs_index = params['index']

    input_normalized = lasp.io.read(
        imgs_gen_path / str(df_imgs_index) / 'input_normalized.npy'
    )

    
    deblur_kernel = lasp.filters.linear.gaussian_filter(
        params['deblur_kernel'][0],
        params['deblur_kernel'][1]
    )

    alpha = params['alpha']
    beta0 = params['beta0']
    beta1 = params['beta1']
    sigma = params['sigma']
    decim = df_imgs.loc[df_imgs_index]['decimation']
    iterations = params['iterations']
    tol = params['tol']

    error_history = []

    output = mumford_shah_fsr(
        input_normalized, deblur_kernel,
        alpha, beta0, beta1, sigma, decim,
        iterations, tol, gamma=1e-16,
        error_history=error_history
    )

    return output, error_history
    

In [241]:
# df_params = pandas.read_pickle(pathlib.Path('./outputs_3/df_params.pkl'))
df_params = pandas.read_pickle(OUTPUT_PATH / 'df_params.pkl')
with open(OUTPUT_PATH / 'metadatas.yml', 'r') as file:
    df_params.attrs = yaml.safe_load(file)
print(df_params.attrs)
df_params

{'output_path': 'outputs_1'}


Unnamed: 0,df_imgs_index,deblur_kernel,alpha,beta0,beta1,sigma,iterations,tol
0,0,"(15, 5)",1,0.0,0.0,2,100,0
1,0,"(15, 5)",1,0.0,0.2,2,100,0
2,0,"(15, 5)",1,0.0,0.4,2,100,0
3,0,"(15, 5)",1,0.0,0.6,2,100,0
4,0,"(15, 5)",1,0.0,0.8,2,100,0
5,0,"(15, 5)",1,0.2,0.0,2,100,0
6,0,"(15, 5)",1,0.2,0.2,2,100,0
7,0,"(15, 5)",1,0.2,0.4,2,100,0
8,0,"(15, 5)",1,0.2,0.6,2,100,0
9,0,"(15, 5)",1,0.2,0.8,2,100,0


In [242]:
df_imgs = pandas.read_pickle(pathlib.Path(GENERATION_PATH / 'df_imgs.pkl'))
with open(GENERATION_PATH / 'metadatas.yml', 'r') as file:
    df_imgs.attrs = yaml.safe_load(file)
print(df_imgs.attrs)
df_imgs

{'imgs_gen_path': 'images-generated'}


Unnamed: 0,original,blur,decimation,noise
0,../0-Images/Baboon.bmp,"(15, 5)",2,1000.0
1,../0-Images/Baboon.bmp,"(7, 3)",2,1000.0


In [243]:
df_to_process = df_imgs.join(df_params.set_index('df_imgs_index'))

# df_to_process = df_params_filtered.set_index('df_imgs_index').join(df_imgs)
df_to_process = df_to_process.dropna(how='any')
df_to_process = df_to_process.reset_index()
df_to_process.attrs.update(df_imgs.attrs)
df_to_process.attrs.update(df_params.attrs)
output_path = pathlib.Path(df_to_process.attrs['output_path'])
print(df_to_process.attrs)
df_to_process

{'imgs_gen_path': 'images-generated', 'output_path': 'outputs_1'}


  return Index(sequences[0], name=names)


Unnamed: 0,index,original,blur,decimation,noise,deblur_kernel,alpha,beta0,beta1,sigma,iterations,tol
0,0,../0-Images/Baboon.bmp,"(15, 5)",2,1000.0,"(15, 5)",1,0.0,0.0,2,100,0
1,0,../0-Images/Baboon.bmp,"(15, 5)",2,1000.0,"(15, 5)",1,0.0,0.2,2,100,0
2,0,../0-Images/Baboon.bmp,"(15, 5)",2,1000.0,"(15, 5)",1,0.0,0.4,2,100,0
3,0,../0-Images/Baboon.bmp,"(15, 5)",2,1000.0,"(15, 5)",1,0.0,0.6,2,100,0
4,0,../0-Images/Baboon.bmp,"(15, 5)",2,1000.0,"(15, 5)",1,0.0,0.8,2,100,0
5,0,../0-Images/Baboon.bmp,"(15, 5)",2,1000.0,"(15, 5)",1,0.2,0.0,2,100,0
6,0,../0-Images/Baboon.bmp,"(15, 5)",2,1000.0,"(15, 5)",1,0.2,0.2,2,100,0
7,0,../0-Images/Baboon.bmp,"(15, 5)",2,1000.0,"(15, 5)",1,0.2,0.4,2,100,0
8,0,../0-Images/Baboon.bmp,"(15, 5)",2,1000.0,"(15, 5)",1,0.2,0.6,2,100,0
9,0,../0-Images/Baboon.bmp,"(15, 5)",2,1000.0,"(15, 5)",1,0.2,0.8,2,100,0


In [244]:
# output1, error_history1 = run(df_to_process, 0)

In [245]:
# output2, error_history2 = run(df_to_process, 15)

In [246]:
# numpy.max(numpy.abs(output1-output2))

In [247]:
# lasp.metrics.PSNR(output1, lasp.io.read(pathlib.Path('../0-Images/Baboon.bmp')))

In [248]:
# lasp.metrics.PSNR(output2, lasp.io.read(pathlib.Path('../0-Images/Baboon.bmp')))

In [249]:
prev_output = None
for index in tqdm.tqdm(df_to_process[begin:].index):

    CURRENT = output_path / str(index)
    if not(CURRENT.exists()):
        CURRENT.mkdir()

    output, error_history = run(df_to_process, index)

    lasp.io.save(output, CURRENT / 'output.npy')
    lasp.io.save(output, CURRENT / 'output.png')
    lasp.io.save(output, CURRENT / 'error_history.npy')

100%|██████████| 25/25 [05:35<00:00, 13.43s/it]


## Graph 3D

In [250]:
# import lasp.metrics


# def compute_metrics(
#     df_to_process: pandas.DataFrame, 
#     index: int
# ) -> None:

#     imgs_gen_path = pathlib.Path(df_to_process.attrs['imgs_gen_path'])
#     output_path = pathlib.Path(df_to_process.attrs['output_path'])
#     params = df_to_process.loc[index]
#     # print(params.keys())

#     df_imgs_index = params['index']

#     original = lasp.io.read(
#         imgs_gen_path / str(df_imgs_index) / 'original.npy'
#     )

#     output = lasp.io.read(
#         output_path / str(df_imgs_index) / 'output.npy'
#     )

#     return lasp.metrics.PSNR(original, output, intensity_max=255)


# psnr_metrics: list[float] = []

# for index in tqdm.tqdm(df_to_process.index):
#    psnr_metrics.append(compute_metrics(df_to_process, index))

In [251]:
# value = numpy.reshape(numpy.array(psnr_metrics), (9, 9), order='F')
# print(value.shape)
# axe_values = numpy.arange(0, 9, 1)
# print(axe_values.shape)
# lines, cols = numpy.meshgrid(axe_values, axe_values)

In [252]:
# import matplotlib.pyplot
# from mpl_toolkits.mplot3d import Axes3D

In [253]:
# %matplotlib widget


# # %matplotlib notebook 
# # %matplotlib inline



# import numpy

# # import numpy as np



# # Z_mae_plt: numpy.ndarray = mae

# # coord_best: tuple[int, int] = i_mae, j_mae

# # Set up a figure twice as tall as it is wide
# fig = matplotlib.pyplot.figure(figsize=(10, 10), dpi=100)


# # Graph 3D MSE

# # ax = Axes3D(fig)
# ax = fig.add_subplot(projection='3d')
# # ax = Axes3D(fig)

# x = cols.ravel()
# y = lines.ravel()

# ## Display surface and points
# # _ = ax.plot_surface(x, y, value, color='blue', alpha=0.1)
# # _ = ax.scatter(x, y, value, color='green', marker='x')

# ## Display best point
# # s_color, s_spatial, mae = X_plt[coord_best], Y_plt[coord_best], Z_mae_plt[coord_best]
# # _ = ax.scatter(s_color, s_spatial, mae, color='red', marker='o')
# # ax.text(s_color, s_spatial, mae, '({:.3f}, {}, {:.3f})'.format(s_color, s_spatial, mae), color='red')

# # setting title and labels
# ax.set_title('PSNR', fontsize=30)
# ax.set_xlabel('$\\beta_{0}$')
# ax.set_ylabel('$\\beta_{1}$')
# ax.set_zlabel('psnr')

# # matplotlib.pyplot.plot(x, y, value)

# # matplotlib.pyplot.savefig(output_path / 'graph3D.png')
# matplotlib.pyplot.show()


In [254]:
# %matplotlib widget


# # Z_mse_plt: numpy.ndarray = mse


# # coord_best: tuple[int, int] = i_mse, j_mse


# # Set up a figure twice as tall as it is wide
# fig = matplotlib.pyplot.figure(figsize=(10, 10), dpi=100)


# # Graph 3D MSE

# # ax = Axes3D(fig)
# ax = fig.add_subplot(projection='3d')

# ## Display surface and points
# # _ = ax.plot_surface(X_plt, Y_plt, Z_mse_plt, color='blue', alpha=0.1)
# # _ = ax.scatter(X_plt, Y_plt, Z_mse_plt, color='green', marker='x')

# ## Display best point
# # s_color, s_spatial, mse = X_plt[coord_best], Y_plt[coord_best], Z_mse_plt[coord_best]
# # _ = ax.scatter(s_color, s_spatial, mse, color='red', marker='o')
# # ax.text(s_color, s_spatial, mse, '({:.3f}, {}, {:.3f})'.format(s_color, s_spatial, mse), color='red')


# # setting title and labels
# ax.set_title('MSE', fontsize=10)
# ax.set_xlabel('$\sigma_{spatial}$')
# ax.set_ylabel('$\sigma_{color}$')
# ax.set_zlabel('mse')

# matplotlib.pyplot.show()

In [255]:
# import matplotlib.pyplot
# import numpy

# # import numpy as np
# from mpl_toolkits.mplot3d import Axes3D

# %matplotlib widget


# # Z_mse_plt: numpy.ndarray = mse


# # coord_best: tuple[int, int] = i_mse, j_mse


# # Set up a figure twice as tall as it is wide
# fig = matplotlib.pyplot.figure(figsize=(10, 10), dpi=100)


# # Graph 3D MSE

# # ax = Axes3D(fig)
# ax = fig.add_subplot(projection='3d')

# ## Display surface and points
# # _ = ax.plot_surface(X_plt, Y_plt, Z_mse_plt, color='blue', alpha=0.1)
# # _ = ax.scatter(X_plt, Y_plt, Z_mse_plt, color='green', marker='x')

# ## Display best point
# # s_color, s_spatial, mse = X_plt[coord_best], Y_plt[coord_best], Z_mse_plt[coord_best]
# # _ = ax.scatter(s_color, s_spatial, mse, color='red', marker='o')
# # ax.text(s_color, s_spatial, mse, '({:.3f}, {}, {:.3f})'.format(s_color, s_spatial, mse), color='red')


# # setting title and labels
# ax.set_title('MSE', fontsize=10)
# ax.set_xlabel('$\sigma_{spatial}$')
# ax.set_ylabel('$\sigma_{color}$')
# ax.set_zlabel('mse')

# matplotlib.pyplot.show()