- Mumford Shah pour la super resolution
- Notebook basÃ© sur la v4 de CMST_...

In [1]:
import sys
sys.path.append('../../..')

## Path

In [None]:
import pathlib

PATH_DATASET = pathlib.Path('dataset_2')
PATH_DATASET_NPY = PATH_DATASET / 'npy'
PATH_DATASET_PNG = PATH_DATASET / 'png'


PATH_RESULTS = PATH_DATASET / pathlib.Path('results')
PATH_RESULTS_PNG = PATH_RESULTS / 'png'
PATH_RESULTS_NPY = PATH_RESULTS / 'npy'
PATH_RESULTS_CSV = PATH_RESULTS / 'csv'

if not(PATH_RESULTS.exists()):
    PATH_RESULTS.mkdir()

if not(PATH_RESULTS_PNG.exists()):
    PATH_RESULTS_PNG.mkdir()

if not(PATH_RESULTS_NPY.exists()):
    PATH_RESULTS_NPY.mkdir()

if not(PATH_RESULTS_PNG.exists()):
    PATH_RESULTS_PNG.mkdir()

if not(PATH_RESULTS_CSV.exists()):
    PATH_RESULTS_CSV.mkdir()

## Configuration matplotlib

In [1]:
import matplotlib
matplotlib.rcParams['figure.facecolor'] = '303030'
matplotlib.rcParams['axes.facecolor'] = '303030'
matplotlib.rcParams['text.color'] = 'FFFFFF'
matplotlib.rcParams['xtick.color'] = 'FFFFFF'
matplotlib.rcParams['ytick.color'] = 'FFFFFF'
matplotlib.rcParams['axes.labelcolor'] = 'FFFFFF'

##  Mumford-Shah for Super-Resolution

In [2]:
import numpy
import lasp.differential
import lasp.utils
import lasp.filters.linear
import lasp.thresholding

import PIL.Image

def decimation(image: numpy.ndarray, d: int) -> numpy.ndarray:
    if d <= 0:
        raise AssertionError('d <= 0')
    return numpy.copy(image[0::d, 0::d])

def mumford_shah_sr(
    img: numpy.ndarray, 
    h: numpy.ndarray, 
    alpha: float,
    beta0: float,
    beta1: float,
    sigma: float,
    d: int,
    nb_iterations: int,
    tolerance: float,
    error_history: list[float] = None
) -> numpy.ndarray:

    bicubic_resized = img
    
    Dx = lasp.differential.dx
    Dy = lasp.differential.dy
    Dxt = lasp.differential.dxT
    Dyt = lasp.differential.dyT

    # Build kernel

    laplacian = lasp.filters.linear.laplacian()
    lap_diag = lasp.utils.fourier_diagonalization(
        kernel = laplacian,
        shape_out = img.shape 
    )
   
    h_diag = lasp.utils.fourier_diagonalization(
        kernel = h,
        shape_out = img.shape
    )

    h2_diag = numpy.abs(h_diag)**2


    # uker = alpha * h2_diag + (beta+sigma) * lap_diag
    # uker = alpha * h2_diag + (1+sigma) * lap_diag

    uker = (1/d) * alpha * h2_diag + (2*beta0+sigma) * lap_diag


    # print('Bicubic Resized : \t {}'.format(bicubic_resized.shape))
    # print('H Diag : \t {}'.format(h_diag.shape))
    # rhs1fft = alpha * numpy.conj(h_diag) * numpy.fft.fft2(img)
    rhs1fft = alpha * h_diag * numpy.fft.fft2(bicubic_resized)

    # Initialization
    u = numpy.copy(img) 
    d_x=numpy.zeros_like(img)
    d_y=numpy.zeros_like(img)
    b_x=numpy.zeros_like(img)
    b_y=numpy.zeros_like(img)

    for _ in range(0, nb_iterations):

        rhs2 = sigma*Dxt(d_x-b_x)+sigma*Dyt(d_y-b_y)
        rhsfft = rhs1fft + numpy.fft.fft2(rhs2)

        u0=numpy.copy(u)
        
        u = numpy.real(numpy.fft.ifft2(rhsfft / uker))    

        err = numpy.linalg.norm(u-u0, 'fro') / numpy.linalg.norm(u, 'fro')
        
        if not(error_history is None):
            error_history.append(err)

        if err < tolerance:
            break
        
        # d_x, d_y = lasp.thresholding.multidimensional_soft(
        #     d = numpy.array([ Dx(u)+b_x, Dy(u)+b_y ]),
        #     epsilon = 1/sigma
        # )

        d_x, d_y = lasp.thresholding.multidimensional_soft(
            d = numpy.array([ Dx(u)+b_x, Dy(u)+b_y ]),
            epsilon = beta1/sigma
        )

        b_x=b_x+Dx(u)-d_x
        b_y=b_y+Dy(u)-d_y

    u_normalized = lasp.utils.normalize(u)

    return u_normalized

In [None]:
import lasp.filters.linear
import lasp.noise
import lasp.io


import pathlib

import numpy

import scipy.signal
import scipy.io

import matplotlib.pyplot

img_original = numpy.array(
    matplotlib.pyplot.imread('Boats.bmp'),
    dtype=numpy.double
)

h = lasp.filters.linear.gaussian_filter(size=7, sigma=3)

img_blurred = numpy.array(
    scipy.signal.convolve2d(img_original, h, mode='same'),
    dtype=numpy.double
)

img_decim = decimation(img_blurred, d=2)
 
img_decim_resized = numpy.array(
    PIL.Image.Image.resize(
        PIL.Image.fromarray(img_blurred), 
        (img_blurred.shape[1], img_blurred.shape[0]), 
        PIL.Image.Resampling.BICUBIC
    )
)

img_noised = lasp.noise.awgn(img_decim_resized, snr=30)
# img_noised = img_blurred


img = numpy.copy(img_noised)
#Img = lasp.noise.awgn(Img_blurred, snr=20)

In [None]:
path = pathlib.Path('outputs')
lasp.io.save(img, path / 'input.png')
lasp.io.save(h, path / 'kernel.png')

# Save img input and kernel as array for matlab
datas_matlab = {}
datas_matlab['input'] = img
datas_matlab['kernel'] = h
scipy.io.savemat(path / 'datas_for_matlab.mat', datas_matlab)

In [None]:
figure = matplotlib.pyplot.figure(figsize=(20, 20))

matplotlib.pyplot.subplot(2, 2, 1)
matplotlib.pyplot.axis('off')
matplotlib.pyplot.title('Original')
_ = matplotlib.pyplot.imshow(img_original, cmap='gray')

matplotlib.pyplot.subplot(2, 2, 2)
matplotlib.pyplot.axis('off')
matplotlib.pyplot.title('Blurred')
_ = matplotlib.pyplot.imshow(img_blurred, cmap='gray')

matplotlib.pyplot.subplot(2, 2, 3)
matplotlib.pyplot.axis('off')
matplotlib.pyplot.title('Decimate and Resized')
_ = matplotlib.pyplot.imshow(img_decim_resized, cmap='gray')

matplotlib.pyplot.subplot(2, 2, 4)
matplotlib.pyplot.axis('off')
matplotlib.pyplot.title('Noised')
_ = matplotlib.pyplot.imshow(img_noised, cmap='gray')

In [None]:
# img_decim = decimation(Img, 5)
 
# bicubic_resized = numpy.array(
#     PIL.Image.Image.resize(
#         PIL.Image.fromarray(img_decim), 
#         (Img.shape[1], Img.shape[0]), 
#         PIL.Image.Resampling.BICUBIC
#     )
# )
# matplotlib.pyplot.imshow(bicubic_resized, cmap='gray')

In [None]:
import lasp.utils

import numpy
import numpy.linalg


img_normalized = lasp.utils.normalize(img)
errors = []
res = mumford_shah_sr(
    img=img_normalized,
    h=h,
    alpha=100,
    beta0=1/2,
    beta1=1,
    sigma=2,
    d=2,
    nb_iterations=300, 
    tolerance=10**(-6),
    error_history=errors
)

In [None]:
# nb_iter = len(errors)
# for i in range(0, nb_iter):
#     if i%10 == 0:
#         print(
#             'Iteration {}'.format(i), 
#             ':', 
#             '\t\t {}'.format(errors[i])
#         )
# print('Nb iterations : {}'.format(nb_iter))

In [None]:
nb_iter = len(errors)
x = numpy.arange(0, nb_iter)
y = errors
matplotlib.pyplot.xlabel(xlabel='Nb iteration')
matplotlib.pyplot.ylabel(ylabel='Error')
matplotlib.pyplot.title('Errors history')
_ = matplotlib.pyplot.plot(x[0:20], y[0:20])
_ = matplotlib.pyplot.plot(x[0:20], y[0:20], 'or')

In [None]:
import lasp.io

figure = matplotlib.pyplot.figure(figsize=(20, 20))

matplotlib.pyplot.subplot(1, 2, 1)
matplotlib.pyplot.imshow(img, cmap='gray')

matplotlib.pyplot.subplot(1, 2, 2)
matplotlib.pyplot.imshow(numpy.real(res), cmap='gray')

lasp.io.save(res, path / 'output.png')

In [None]:
# import scipy.io.matlab
# import numpy
# a = scipy.io.matlab.loadmat('uu.mat')
# mat_res = numpy.array(numpy.array(a['uu']), dtype=numpy.double) 
# matplotlib.pyplot.imshow(mat_res, cmap='gray')
# print(numpy.max(numpy.abs(mat_res-res)))

In [None]:
# diff = numpy.abs(mat_res-res)
# print(numpy.max(diff))
# diff[1e-6 < diff].shape

In [None]:
# _ = matplotlib.pyplot.imshow(numpy.abs(mat_res-res), cmap='gray')

Unnamed: 0,image,alpha,beta,sigma,tol,iterations,blur,noise
0,dataset_1,100,1,2,0.0001,300,"(3, 3)",


3


Unnamed: 0,image,alpha,beta,sigma,tol,iterations,blur,noise
0,dataset_1,100,1,2,0.0001,300,"(3, 3)",


100