In [None]:
from jax import grad, jit
import jax
import jax.numpy as np
import matplotlib.pyplot as plt
import spectrum_recovery_pool as srp
import copy
import numpy
from skimage import io
from matplotlib.patches import Rectangle
import time

from lippmann import show_spectrum, show_lippmann_transform, lippmann_transform
from display_spectral_data import load_specim_data
from color_tools import upsample_hue_saturation, from_spectrum_to_xyz, from_xyz_to_rgb

# Alternative optimisation
Using autograd and jax.  
The goal of this notebook is to optimise the spectrum assuming it's smooth, but not necessairly bandlimited, see regularisation. 

I did not estimate depth or decay here.

In [None]:
name = "parrot" # prefix of the file with data and results 
    
# if the data is stored in "Cubes"     
downsampled, wavelengths = load_specim_data("Cubes/" + name, ds=25, cut=True)
print(downsampled.shape)

In [None]:
def forward_model(power_spectrum, Z, k0):
    A = srp.generate_matrix_A(omegas, Z, r=r, k0=k0)
    return np.abs(A @ power_spectrum)**2

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def data_fidelity(power_spectrum, Z, k0, measured_spectrum, mask=1):
    return np.linalg.norm((measured_spectrum - forward_model(power_spectrum, Z, k0))*mask)

def derivative_regularisation(power_spectrum):
    derivative = np.linalg.norm(jax.device_put(power_spectrum)[1:]-jax.device_put(power_spectrum)[:-1])
    return derivative

def second_derivative_regularisation(power_spectrum):
    derivative = np.linalg.norm(jax.device_put(power_spectrum)[2:]-2*jax.device_put(power_spectrum)[1:-1] + jax.device_put(power_spectrum)[:-2])
    return derivative

In [None]:
r_mercury = 0.7 * np.exp(1j * np.deg2rad(-148))
r_air = 0.2
r = r_mercury

omegas = 2 * np.pi * srp.c / wavelengths

mask = np.linspace(-2, 30, len(wavelengths))
mask = sigmoid(mask)

data_grad = jit(grad(data_fidelity, argnums=0))
regularization_grad = jit(grad(derivative_regularisation))
K_grad = jit(grad(data_fidelity, argnums=2))
Z_grad = jit(grad(data_fidelity, argnums=1))
reg_conts = 5e-1

result = []
depths = []
kzeros = []

In [None]:
for x in range(downsampled.shape[0]):
    for y in range(downsampled.shape[1]):
        print(x, y)
        rate = 0.2
        Z_rate = 1e-13
        K_rate = 1e-1
        cost = 100
        k0 = 3.7
        Z = 3e-6
        measured_spectrum = downsampled[x, y]
        power_spectrum = copy.copy(measured_spectrum)
        start = time.time()
        for i in range(500):
            power_spectrum -= rate * (data_grad(power_spectrum, Z, k0, measured_spectrum)\
                                                + reg_conts * regularization_grad(power_spectrum))
            power_spectrum = np.maximum(0, power_spectrum) # Force power_spectrum to be non-negative
            if cost < data_fidelity(power_spectrum, Z, k0, measured_spectrum):
                rate = 0.5 * rate
                Z_rate = 0.5 * Z_rate
                K_rate = 0.5 * K_rate
                print("lowering the rate")
            cost = data_fidelity(power_spectrum, Z, k0, measured_spectrum)
            if i % 10 == 0:
                k0 -= K_rate * K_grad(power_spectrum, Z, k0, measured_spectrum)
                Z -= Z_rate * Z_grad(power_spectrum, Z, k0, measured_spectrum)
        
        result.append(copy.copy(power_spectrum))
        depths.append(Z)
        kzeros.append(k0)
        print(cost)
        print(f"time: {time.time() - start:.2f}s")

In [None]:
result = np.array(result).reshape(downsampled.shape)
np.save("PNAS/gradient_descent/result0", result)

In [None]:
recorded_xyz = from_spectrum_to_xyz(wavelengths, downsampled, normalize=False)
recorded_xyz = recorded_xyz / np.min(np.sum(recorded_xyz, axis=2))
recorded_rgb = from_xyz_to_rgb(recorded_xyz)

estimated_xyz = from_spectrum_to_xyz(wavelengths, result, normalize=False)
estimated_xyz = estimated_xyz / np.min(np.sum(estimated_xyz, axis=2))
estimated_rgb = from_xyz_to_rgb(estimated_xyz)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
ax1.imshow(recorded_rgb)
ax2.imshow(estimated_rgb)
ax1.set_title("original colors")
# rect = Rectangle((pixel[1] - 0.5, pixel[0] - 0.5), 1, 1, alpha=1, color="none", ec="white", lw=2,zorder=10)
# rect2 = Rectangle((pixel[1] - 0.5, pixel[0] - 0.5), 1, 1, alpha=1, color="none", ec="white", lw=2,zorder=10)
# ax1.add_patch(rect)
# ax2.add_patch(rect2)
plt.show()

In [None]:
fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(12, 3))
pixel=(6, 6)

show_spectrum(wavelengths, downsampled[pixel], show_background=True, short_display=True, ax=ax1, visible=True)
show_spectrum(wavelengths, result[pixel], show_background=True, short_display=True, ax=ax2, visible=True)

depths = np.linspace(0,Z,200)
show_lippmann_transform(depths, lippmann_transform(wavelengths, result[pixel], depths, r=r, k0=k0)[0], ax=ax3, short_display=True)

show_spectrum(wavelengths, forward_model(result[pixel], Z, k0), show_background=True, short_display=True, ax=ax4, visible=True)
ax1.set_title("Reflected spectrum")
ax2.set_title("estimated")
ax3.set_title("Pattern")
ax4.set_title("Re-estimated reflected spectrum")
# plt.savefig(f"{results_path}_point.pdf")estimate_decay
plt.show()

In [None]:
plt.plot(wavelengths, downsampled[pixel])
plt.plot(wavelengths, result[pixel])
plt.show()