In [None]:
import jax.numpy as jnp
from jax import jit


def calculate(input_x, v, n, , S, L, f_obs,ff=0.008):
    # Unpack input parameters
    alpha, beta, gamma, X, Y, Z = input_x
    n1, n2, n3 = n

    # Rotation matrices
    Rx = jnp.array([[1, 0, 0], [0, jnp.cos(alpha), -jnp.sin(alpha)], [0, jnp.sin(alpha), jnp.cos(alpha)]])
    Ry = jnp.array([[jnp.cos(beta), 0, jnp.sin(beta)], [0, 1, 0], [-jnp.sin(beta), 0, jnp.cos(beta)]])
    Rz = jnp.array([[jnp.cos(gamma), -jnp.sin(gamma), 0], [jnp.sin(gamma), jnp.cos(gamma), 0], [0, 0, 1]])
    R = jnp.dot(jnp.dot(Rx, Ry), Rz)

    # Vector v and calculations for x and o
    x = input_x[-3:] - L * jnp.dot(R, v)
    o = jnp.dot(R, v)

    # Calculate 'a' using dot products
    a = jnp.sqrt(((-jnp.dot(R[0], v) * Z) / jnp.dot(R[2], v))**2 +
                 (-(jnp.dot(R[1], v) * Z) / jnp.dot(R[2], v))**2 + Z * Z)

    # Calculate Q
    Q = a * o + jnp.array([X, Y, Z])

    # Project Sip onto the line defined by o and Q
    results = []
    for S_x, S_y, S_z in S:
        Sip = jnp.array([S_x, S_y, S_z])
        Sip_proj = Sip + (jnp.dot(o, Q) - jnp.dot(o, Sip)) * (x - Sip) / (jnp.dot(o, x) - jnp.dot(o, Sip))

        # s_2 vector and s_1 as the cross product of o and s_2
        s_2 = jnp.dot(R, jnp.array([n1, n2, n3]))
        s_1 = jnp.cross(o, s_2)

        # Calculate k_1 and k_2
        k_1 = jnp.dot(Sip_proj - Q, s_1) / jnp.linalg.norm(s_1)**2
        k_2 = jnp.dot(Sip_proj - Q, s_2) / jnp.linalg.norm(s_2)**2

        # Calculate image plane coordinates f_1 and f_2
        f_1 = k_1 * ff / (a + L - ff)
        f_2 = k_2 * ff / (a + L - ff)

        results.extend([f_1, f_2])

    # Convert results to JAX array
    results = jnp.array(results) * 1e3  # convert to mm

    # MinSquare functional
    diff = results - jnp.array(f_obs)
    F_obs = jnp.sum(diff * diff)

    return F_obs


# JIT compile for efficiency
calculate_jit = jit(calculate)

In [3]:
import jax.numpy as jnp
alpha = 1
beta = 2
gamma = 3


In [12]:
%%timeit
Rx = jnp.array([[1, 0, 0], [0, jnp.cos(alpha), -jnp.sin(alpha)], [0, jnp.sin(alpha), jnp.cos(alpha)]])
Ry = jnp.array([[jnp.cos(beta), 0, jnp.sin(beta)], [0, 1, 0], [-jnp.sin(beta), 0, jnp.cos(beta)]])
Rz = jnp.array([[jnp.cos(gamma), -jnp.sin(gamma), 0], [jnp.sin(gamma), jnp.cos(gamma), 0], [0, 0, 1]])
R = jnp.dot(jnp.dot(Rx, Ry), Rz)
R

1.52 ms ± 24.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [10]:
%%timeit
# Preallocate identity matrices
Rx = jnp.eye(3)
Ry = jnp.eye(3)
Rz = jnp.eye(3)

# Compute sine and cosine of all angles
cos_alpha, sin_alpha = jnp.cos(alpha), jnp.sin(alpha)
cos_beta, sin_beta = jnp.cos(beta), jnp.sin(beta)
cos_gamma, sin_gamma = jnp.cos(gamma), jnp.sin(gamma)

# Input values into the matrices at given positions
Rx = Rx.at[1, 1].set(cos_alpha).at[1, 2].set(-sin_alpha).at[2, 1].set(sin_alpha).at[2, 2].set(cos_alpha)
Ry = Ry.at[0, 0].set(cos_beta).at[0, 2].set(sin_beta).at[2, 0].set(-sin_beta).at[2, 2].set(cos_beta)
Rz = Rz.at[0, 0].set(cos_gamma).at[0, 1].set(-sin_gamma).at[1, 0].set(sin_gamma).at[1, 1].set(cos_gamma)

R = Rx @ Ry @ Rz
R

3.2 ms ± 193 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [14]:
import numpy as np
a = np.arange(10)

In [18]:
a = jnp.array([1,2])

In [19]:
kk, ll = a

In [1]:
from value_F import value_F