In [1]:
from jax import value_and_grad
import numpy as tnp
from katsu.katsu_math import np, set_backend_to_jax
from katsu.mueller import linear_retarder, linear_polarizer, _empty_mueller
from katsu.polarimetry import drrp_data_reduction_matrix, condition_number

In [2]:
lp_angle = 0
wp_retardance = np.sqrt(3) * 2 * np.pi #np.pi / 2
# set_backend_to_jax()

def identity_matrix(shape):

    mueller = _empty_mueller(shape)
    for i in range(4):
        for j in range(4):
            if i == j :

                if np.__name__ == "jax.numpy":
                    mueller = mueller.at[..., i, j].set(1.)
                else:
                    mueller[..., i, j] = 1.
    return mueller

def forward(x):
    """_summary_

    Parameters
    ----------
    x : ndarray len 4
        angles in radians for optimal measurement set
    """

    NMEAS = 4

    wvp = linear_retarder(x, wp_retardance, shape=[NMEAS])
    pol = linear_polarizer(lp_angle, shape=[NMEAS])
    PSA = pol @ wvp

    PSG = identity_matrix(shape=[NMEAS])
    Winv = drrp_data_reduction_matrix(PSG, PSA, invert=False)
    # Winv = PSA[...,0,:]
    cond = condition_number(Winv)

    return cond

loss_fg = value_and_grad(forward)

In [7]:
from scipy.optimize import minimize

results = minimize(forward, x0=np.random.random([4]), method='L-BFGS-B', jac=False)
# print(results)

In [8]:
print(np.degrees(results.x))

[-45.54256551  30.0149666   59.98490176   0.1316377 ]
