In [None]:
from jax import grad, jit
import jax
import jax.numpy as np
import matplotlib.pyplot as plt
import spectrum_recovery_pool as srp
import copy
import numpy
from skimage import io
from matplotlib.patches import Rectangle
import time
from IPython.display import display, clear_output

from lippmann import show_spectrum, show_lippmann_transform, lippmann_transform
from display_spectral_data import load_specim_data

# Alternative optimisation
Using autograd and jax.  
The goal of this notebook is to optimise the spectrum assuming it's smooth, but not necessairly bandlimited, see regularisation. 

I did not estimate depth or decay here.

In [None]:
name = "color_checker" # prefix of the file with data and results 
    
data_folder = "Cubes_ours/"
    
# if the data is stored in "Cubes"     
downsampled, wavelengths = load_specim_data(data_folder + name, ds=1, cut=True)

image = io.imread(data_folder + name + ".png")
image = np.swapaxes(image[:, ::-1, :3], 1, 0)
cut_idx = numpy.loadtxt(data_folder + name + "_cut.txt").astype(int)
image = image[cut_idx[0, 0]:cut_idx[0, 1], cut_idx[1, 0]:cut_idx[1, 1]]

pixel = (80, 180)

plt.figure(figsize=(5, 5))
rect = Rectangle((pixel[1] - 2, pixel[0] - 2), 4, 4, alpha=1, color="none", ec="white", lw=1,zorder=10)
plt.gca().add_patch(rect)
plt.imshow(image)
plt.show()

In [None]:
omegas = 2 * np.pi * srp.c / wavelengths

def forward_model(power_spectrum, Z, k0):
    A = srp.generate_matrix_A(omegas, Z, r=r, k0=k0)
    return np.abs(A @ power_spectrum)**2

measured_spectrum = downsampled[pixel]

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

mask = np.linspace(-3, 50, len(wavelengths))
mask = sigmoid(mask)

def data_fidelity(power_spectrum, Z, k0, measured_spectrum, mask=1):
    return np.linalg.norm((measured_spectrum - forward_model(power_spectrum, Z, k0))*mask)

### Regularisation
Just means square penalty on the first derivative seems to be working well, but we could use second derivative too

In [None]:
def derivative_regularisation(power_spectrum):
    derivative = np.linalg.norm(jax.device_put(power_spectrum)[1:]-jax.device_put(power_spectrum)[:-1])
    return derivative

def second_derivative_regularisation(power_spectrum):
    derivative = np.linalg.norm(jax.device_put(power_spectrum)[2:]-2*jax.device_put(power_spectrum)[1:-1] + jax.device_put(power_spectrum)[:-2])
    return derivative

### Initialise with the recorded spectrum

In [None]:
show_spectrum(wavelengths, measured_spectrum, show_background=True)
plt.show()

In [None]:
power_spectrum = copy.copy(measured_spectrum)
show_spectrum(wavelengths, power_spectrum, show_background=True)
plt.show()

In [None]:
show_spectrum(wavelengths, mask, show_background=True)

In [None]:
Z = 8.2e-6
r_mercury = 0.7 * np.exp(-1j * np.deg2rad(180 - 148))
r_air = 0.2
r = r_air
k0 = 1.0


data_grad = jit(grad(data_fidelity, argnums=0))
regularization_grad = jit(grad(second_derivative_regularisation))
K_grad = jit(grad(data_fidelity, argnums=2))
Z_grad = jit(grad(data_fidelity, argnums=1))
reg_conts = 1e-1
plt.plot(wavelengths, reg_conts * regularization_grad(power_spectrum), c="g", label="regularisation")
plt.plot(wavelengths, data_grad(power_spectrum, Z, k0, measured_spectrum, mask), label="with mask")
plt.plot(wavelengths, data_grad(power_spectrum, Z, k0, measured_spectrum), c="r", label= "no mask")
plt.legend()
plt.show()


In [None]:
rate = 0.2
Z_rate = 5e-12
K_rate = 1e-2
cost = 100

In [None]:
start = time.time()
for i in range(1000):
    power_spectrum -= rate * (data_grad(power_spectrum, Z, k0, measured_spectrum, mask)\
                                        + reg_conts * regularization_grad(power_spectrum))
    power_spectrum = np.maximum(0, power_spectrum) # Force power_spectrum to be non-negative
    k0 -= K_rate * K_grad(power_spectrum, Z, k0, measured_spectrum)
    Z -= Z_rate * Z_grad(power_spectrum, Z, k0, measured_spectrum)
    k0 = max(0.00001, k0)
    if cost < data_fidelity(power_spectrum, Z, k0, measured_spectrum):
        rate = 0.5 * rate
        Z_rate = 0.5 * Z_rate
        K_rate = 0.5 * K_rate
        display("lowering the rate")
    cost = data_fidelity(power_spectrum, Z, k0, measured_spectrum)
    if i % 10 == 0:
        clear_output(wait=True)
        display(f'Iteration {i} Cost: {cost/np.linalg.norm(measured_spectrum)}')

print(f"time: {time.time() - start:.2f}s")

In [None]:
fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(16, 4))

show_spectrum(wavelengths, measured_spectrum, show_background=True, ax=ax1)
show_spectrum(wavelengths, power_spectrum, show_background=True, ax=ax2)

depths = np.linspace(0,Z,200)
show_lippmann_transform(depths, lippmann_transform(wavelengths, measured_spectrum, depths, r=r, k0=k0)[0], ax=ax3, short_display=True)

show_spectrum(wavelengths, forward_model(power_spectrum, k0=k0, Z=Z), show_background=True, ax=ax4)
ax1.set_title("Reflected spectrum")
ax2.set_title("Curent estimate")
ax3.set_title("Pattern")
ax4.set_title("Re-estimated reflected spectrum")
plt.show()

In [None]:
print(Z)
print(k0)

In [None]:
plt.figure(figsize=(5,5))
plt.plot(wavelengths, measured_spectrum, label="measured")
plt.plot(wavelengths, forward_model(power_spectrum, k0=k0, Z=Z), "r", label="re-estimated")
plt.legend()
plt.show()