In [1]:
%load_ext autoreload
%autoreload 2
import os
import numpy as np
import matplotlib.pyplot as plt

from boostlets_mod import Boostlet_syst, sk_to_phys, rm_sk_index_in_horiz_cone
from mod_plotting_utilities import plot_array_images
from mod_RIRIS_func import load_DB_ZEA, rand_downsamp_RIR, ImageOps, jitter_downsamp_RIR
from mod_RIRIS_func import computePareto, ista, iffst, linear_interpolation_fft, perforMetrics


# Inputs 
Dictionary and Image

In [2]:
# Tamaño diccionario // Tamaño imagen interpolada
M, N  = 128, 128
# Imagen tamaño:
M0, N0 = 100, 100

# Dictionary
n_v_scales, n_h_scales = 10, 10  
base_v, base_h = 1/2, 1/2
n_v_thetas, n_h_thetas = 15, 15 

# Image
room = "Balder"
ratio_mics = 0.5
u = round(1/ratio_mics)
extrap_mode = "pad" # or pad

#  subimage 
Tstart = 0
Tend = Tstart+M0

# sampling 
dx=3e-2
fs=11250
cs=340
dt = 1/fs
# Pareto
beta_set = np.logspace(-2.5, -1, 50)

# ISTA
epsilon = 9.4e-6 # ISTA



## Create dict

In [None]:
# BS = Boostlet_syst(dx=dx, dt=1/fs, cs=cs,
#                  M=M, N=N, 
#                  n_v_scales=n_v_scales, n_h_scales=n_h_scales,
#                  n_v_thetas=n_v_thetas, n_h_thetas=n_h_thetas, 
#                  base_v=base_v, base_h=base_h, 
#                  )

# Sk = BS.get_boostlet_dict()
# Sk = BS.get_boostlet_dict2()

# BS.print_max_scales()
# plot_array_images(Sk, num_cols=10)

or load dictionary

In [None]:
from mod_RIRIS_func import load_sk
# folder_dict = 'saved_dicts/tan_dicts'
# file_dict = 'BS_m_128_n_128_vsc_2_hsc_2_bases_0.5_0.5_thV_3_thH_3.mat'
folder_dict = 'ss_saved_dicts/'
file_dict = 'SS_m_512_n_128.mat'

Sk = load_sk(folder=folder_dict, file=file_dict, build_dict=None)
# plot_array_images(Sk, num_cols=5)


## Load Image

In [None]:
# ---------- LOAD Image ---------------------------
folder = "./dependencies/measurementData"
file = room+"RIR.mat"
file_path = os.path.join(folder, file)
print("Image loaded:")
print(file_path)

# Load full image and select a subimage to apply decomposition
full_image = load_DB_ZEA(file_path)[0]
orig_image = full_image[Tstart:Tend, :N0]

# mask0, _ = rand_downsamp_RIR(orig_image.shape, ratio_t=1, ratio_x=ratio_mics)
mask0, _ = jitter_downsamp_RIR(orig_image.shape, ratio_t=1, ratio_x=ratio_mics)


## Extrapolation

In [None]:

# ----------------------------------------------------
# Extrapolation
# ----------------------------------------------------

extr_size = Sk.shape[:2]
imOps = ImageOps(orig_image.shape, mask=mask0, extrap_shape=extr_size, mode=extrap_mode) 
image = imOps.expand_image(orig_image)
mask = imOps.get_mask(image)

fft1 = np.fft.fftshift( np.fft.fft2(orig_image) )
fft2 = np.fft.fftshift( np.fft.fft2(image) )
fft3 = np.fft.fftshift( np.fft.fft2(mask*image) )

images = [orig_image, image, mask*image, abs(fft1), abs(fft2)+200*abs(Sk[:,:,12]), abs(fft3)]
titles = ['Original Image', 'Expanded Image', 'Masked image', r'$\mathcal{F}(orig. im)$', r'$\mathcal{F}(exp. im)$', r'$\mathcal{F}(mask. im)$']

fig, axs = plt.subplots(2,3, figsize=(6, 6))
for ax, im, titl in zip(axs.flatten(), images, titles):
    ax.pcolor(im)
    ax.set_title(titl)
    ax.axis('off')
plt.tight_layout()
plt.show()

## Remove elements from dict
Check test_dict_to_remove.ipynb

In [None]:
rm_sk_ids = rm_sk_index_in_horiz_cone(dx=dx, dt=dt, cs=cs, Sk=Sk)
print(f"Removed IDs: {rm_sk_ids}")
Sk = np.delete(Sk, rm_sk_ids, axis=2)

Dictionary in physical space

In [None]:
phys_sk = sk_to_phys(Sk)
plot_array_images(phys_sk, cmap='gray', num_cols=5)

# ISTA


In [None]:
# ----------------------------------------------------
# Pareto
# ----------------------------------------------------
beta_star, Jcurve = computePareto(image, mask, Sk, beta_set, f_plot=True)

# ----------------------------------------------------
# ISTA recovery
# ----------------------------------------------------
alpha = ista(image, mask, Sk, beta=beta_star, epsilon=epsilon, max_iterations=15, f_plot=True, f_verbose=True )


## Recover image

In [10]:
# recover inpainted image from sparse coefficients (Eq. 19)
image_recov = iffst(alpha, Sk)
final_image = imOps.recover_image(image_recov)


## Linear interpolation

In [11]:
image_linear = linear_interpolation_fft(image*mask, dx=dx, fs=fs, cs=cs)
image_lin = imOps.recover_image(image_linear)


# %% Performance Metrics
NMSE_nlin, MAC, frqMAC = perforMetrics(image=image, image_recov=image_recov, 
                                    image_under=image*mask, 
                                    fs=fs, u=u, dx=dx, room=room)


## Visual results

In [None]:

images = [orig_image[:100,:], (orig_image*mask0)[:100,:], final_image[:100,:], image_lin[:100,:]]
titles = ['Original Image', 'Masked Image', 'Final reconst image', "Linear reconst"]

fig, ax = plt.subplots(1, len(images), figsize=(18, 6))
for i in range(len(images)):
    ax[i].imshow(images[i])
    ax[i].set_title(titles[i])
    ax[i].axis('off')
plt.suptitle(f"NMSE lin = {NMSE_nlin[0]};    NMSE dict = {NMSE_nlin[1]}")
plt.tight_layout()
plt.show()