# Attempt to convert code to jax

In [18]:
# Importing required packages
import pandas as pd 
import matplotlib.pyplot as plt
import numpy as np
import numpyro
import jax.numpy as jnp
from jax import grad, jit, vmap
import iminuit 

In [19]:
# Importing data on Antikythera Mechanism (AM) as data_AM
data_AM = pd.read_csv("data/1-Fragment_C_Hole_Measurements.csv")

# Printing data
data_AM  

Unnamed: 0,Section ID,Hole,Inter-hole Distance,Mean(X),Mean(Y)
0,0,1,,15.59931,92.59653
1,1,2,,16.44167,91.50069
2,1,3,1.440694,17.30764,90.34931
3,1,4,1.456973,18.15278,89.16250
4,1,5,1.299821,18.96528,88.14792
...,...,...,...,...,...
76,7,77,1.259985,108.69444,63.42778
77,7,78,1.430105,110.05486,63.86875
78,7,79,1.179606,111.14583,64.31736
79,7,80,1.428043,112.48021,64.82604


In [20]:
sub_data = data_AM[~data_AM['Section ID'].isin([0,4])] # removing sections 0 and 4
sub_data.shape # Quick check

(79, 5)

Original code

In [21]:
def model(x, y, R, phis, xcent, ycent, phase):

    ''' 
    Compute the radial and tangential errors between measured hole position and predicted hole locations
    in the fractured ring model.

    This function models the expected hole positions assuming an originally circular ring with N regularly spaced holes. 
    It then accounts for misaligment by applying transformation (xcent, ycent), in paper refered as $(x_{0j}, y_{0j})$, and a phase
    shift. Finally, it computes the errors between the measuered and predicted hole positions, projecting them into radial and 
    tangential components.

    Parameters:

    x : np.array
        Measured x-coordinates of hole positions for a specific fractured section. 
    y : np.array
        Measured y-coordinates of hole positions for a secific fracured section. 
    R : float
        Estimated radius of the original ring
    phis : np.array
        Angles of the holes in the unbroken ring, evenly spaced between 0 and 2pi
    xcent : float
        x-coordinate of the estimated center if the fractured section
    ycent : float
        y-coordinate of the estimated center of the fractured section
    phase : float
        Phase shift applied to align the hole positions.

    Returns:
    
    rp : np.array
        Radial errors (distance from the expected radius). 
    tp : np.array 
        Tangential errors (misaligment along the circumference of the ring). 

    '''

    phi = phis + phase # Apply phase shift to the angles

    cphi = np.cos(phi) # Compute cosine values
    sphi = np.sin(phi) # Compute sine values

    # compute model points in x,y
    r_x = R*cphi # X-coordinates of prediceted holes
    r_y = R*sphi # Y-coordinates of predicted holes

    # shift data point to be around model x,y
    d_x = x - xcent 
    d_y = y - ycent

    # find error vector between data and model
    e_x = r_x - d_x
    e_y = r_y - d_y

    # project vector into radius and tangent
    rp = e_x*cphi + e_y*sphi # Radial projection
    tp = e_x*sphi - e_y*cphi # Tangential projection

    return rp, tp

In [22]:
def log_likelihood_rt(params, data, N):
    '''
    Compute log-likelihood for the isotropic Gaussian error model.

    Parameters:
    params (list): Model parameters [R, sigma_r, sigma_t, phase1, phase2, ..., xcent1, xcent2, ..., ycent1, ycent2, ...]
    data (list of tuples): Measured hole positions for each fractured section.
    N (int): Total number of holes in the original complete ring.

    Returns:
    float: Log-likelihood value.
    '''

    R, sigma_r, sigma_t = params[:3]
    phases, xcents, ycents = np.split(params[3:], 3)
    #x,y = data

    invsig_r = 1./(2*(sigma_r*sigma_r))
    invsig_t = 1./(2*(sigma_t*sigma_t))

    npoints = np.sum([len(dt) for dt in data])
    prefact = -npoints*np.log(2*np.pi*sigma_t*sigma_r)
    phis = 2*np.pi*np.arange(100)/N

    #k = np.arange(N)
    exp_likelihood = 0
    for i, sect in enumerate(data):
        x,y = sect

        # assume independent r, tangent
        rp, tp = model(x, y, R, phis[:len(x)], xcents[i], ycents[i], phases[i])

        exponent = -invsig_r*(rp**2) - invsig_t*(tp**2)

        exp_likelihood += np.sum(exponent)


    return prefact + exp_likelihood

In [23]:
import numpy as np
from iminuit import Minuit

def neg_log_likelihood_rt(N, R, sigma_r, sigma_t,
                       phase_0, phase_1, phase_2, phase_3, phase_4, phase_5,
                       xcent_0, xcent_1, xcent_2, xcent_3, xcent_4, xcent_5,
                       ycent_0, ycent_1, ycent_2, ycent_3, ycent_4, ycent_5):
    """
    Compute the negative log-likelihood for 6 sections dynamically.

    Parameters:

    N, R, sigma_r, sigma_t : float
        Model parameters for ring radius and error standard deviations.
    phase_i, xcent_i, ycent_i : float
        Phase shifts and translations for each of the 6 sections.

    Returns:
    float
        Negative log-likelihood value (to be minimized by iminuit).
    """
    # Convert parameters into arrays to avoid shape issues
    phases = np.array([phase_0, phase_1, phase_2, phase_3, phase_4, phase_5])
    xcents = np.array([xcent_0, xcent_1, xcent_2, xcent_3, xcent_4, xcent_5])
    ycents = np.array([ycent_0, ycent_1, ycent_2, ycent_3, ycent_4, ycent_5])

    # Ensure data is a NumPy array before passing it
    global data
    data = np.array(data, dtype=object)  # Convert data list to a NumPy array (dtype=object keeps tuples)

    # Compute log-likelihood and negate it
    params = np.array([R, sigma_r, sigma_t, *phases, *xcents, *ycents])  # Convert to array to avoid issues
    return -log_likelihood_rt(params, data, N)  # Negate log-likelihood for minimization

# Example setup for 6 sections

num_sections = 6  # Now using 6 sections


data = [(sub_data['Mean(X)'], sub_data['Mean(Y)'])]


# Initial parameter estimates for 6 sections
N = 355
R_init = 77
sigma_r_init = 0.04
sigma_t_init = 0.1
np.random.seed(1)
phases_init = [-2.53, -2.53, -2.53, -2.54, -2.55, -2.55]
xcents_init = [79, 79, 79, 81, 81, 83]
ycents_init = [136, 135, 135, 136, 135, 136]


# Combine parameters into one list
init_params = [N, R_init, sigma_r_init, sigma_t_init, *phases_init, *xcents_init, *ycents_init]

m_rt = Minuit(neg_log_likelihood_rt, 
           N=init_params[0], R=init_params[1], sigma_r=init_params[2], sigma_t=init_params[3],
           phase_0=init_params[4], phase_1=init_params[5], phase_2=init_params[6], 
           phase_3=init_params[7], phase_4=init_params[8], phase_5=init_params[9],
           xcent_0=init_params[10], xcent_1=init_params[11], xcent_2=init_params[12], 
           xcent_3=init_params[13], xcent_4=init_params[14], xcent_5=init_params[15],
           ycent_0=init_params[16], ycent_1=init_params[17], ycent_2=init_params[18], 
           ycent_3=init_params[19], ycent_4=init_params[20], ycent_5=init_params[21])

# Define likelihood optimization settings
m_rt.errordef = Minuit.LIKELIHOOD
m_rt.limits['N'] = (345, 365)  # Set realistic bounds for N
m_rt.limits['R'] = (60, 80)
m_rt.limits['sigma_r'] = (0, None)  # sigma_r must be positive
m_rt.limits['sigma_t'] = (0, None)  # sigma_t must be positive

# Run the minimization
m_rt.migrad()

Migrad,Migrad.1
FCN = 7.307,Nfcn = 1522
EDM = 2.19e-05 (Goal: 0.0001),time = 0.2 sec
Valid Minimum,Below EDM threshold (goal x 10)
No parameters at limit,Below call limit
Hesse ok,Covariance accurate

0,1,2,3,4,5,6,7,8
,Name,Value,Hesse Error,Minos Error-,Minos Error+,Limit-,Limit+,Fixed
0.0,N,357.0,3.4,,,345,365,
1.0,R,79.1,0.4,,,60,80,
2.0,sigma_r,0.79,0.25,,,0,,
3.0,sigma_t,2.4,0.7,,,0,,
4.0,phase_0,-2.522,0.008,,,,,
5.0,phase_1,-2.5,1.0,,,,,
6.0,phase_2,-2.5,1.0,,,,,
7.0,phase_3,-2.5,1.0,,,,,
8.0,phase_4,-2.6,1.0,,,,,

0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22
,N,R,sigma_r,sigma_t,phase_0,phase_1,phase_2,phase_3,phase_4,phase_5,xcent_0,xcent_1,xcent_2,xcent_3,xcent_4,xcent_5,ycent_0,ycent_1,ycent_2,ycent_3,ycent_4,ycent_5
N,12,0.94 (0.681),0.63 (0.732),1.7 (0.744),22.55e-3 (0.849),0,0,0,0,0,0.358 (0.617),0,0,0,0,0,0.99 (0.668),0,0,0,0,0
R,0.94 (0.681),0.158,0.10 (0.995),0.27 (0.993),1.77e-3 (0.580),0.00,0.00,0.00,0.00,0.00,0.046 (0.687),0.00,0.00,0.00,0.00,0.00,0.17 (0.980),0.00,0.00,0.00,0.00,0.00
sigma_r,0.63 (0.732),0.10 (0.995),0.0619,0.17 (1.000),1.22e-3 (0.641),0.00,0.00,0.00,0.00,0.00,0.029 (0.697),0.00,0.00,0.00,0.00,0.00,0.11 (0.987),0.00,0.00,0.00,0.00,0.00
sigma_t,1.7 (0.744),0.27 (0.993),0.17 (1.000),0.458,3.42e-3 (0.658),0.0,0.0,0.0,0.0,0.0,0.079 (0.692),0.0,0.0,0.0,0.0,0.0,0.29 (0.986),0.0,0.0,0.0,0.0,0.0
phase_0,22.55e-3 (0.849),1.77e-3 (0.580),1.22e-3 (0.641),3.42e-3 (0.658),5.91e-05,0,0,0,0,0,0.35e-3 (0.273),0,0,0,0,0,1.94e-3 (0.589),0,0,0,0,0
phase_1,0,0.00,0.00,0.0,0,1,0,0,0,0,0.000,0,0,0,0,0,0.00,0,0,0,0,0
phase_2,0,0.00,0.00,0.0,0,0,1,0,0,0,0.000,0,0,0,0,0,0.00,0,0,0,0,0
phase_3,0,0.00,0.00,0.0,0,0,0,1,0,0,0.000,0,0,0,0,0,0.00,0,0,0,0,0
phase_4,0,0.00,0.00,0.0,0,0,0,0,1,0,0.000,0,0,0,0,0,0.00,0,0,0,0,0


Trying to convert to jax

In [39]:
def model(x, y, R, phis, xcent, ycent, phase):

    ''' 
    Compute the radial and tangential errors between measured hole position and predicted hole locations
    in the fractured ring model.

    This function models the expected hole positions assuming an originally circular ring with N regularly spaced holes. 
    It then accounts for misaligment by applying transformation (xcent, ycent), in paper refered as $(x_{0j}, y_{0j})$, and a phase
    shift. Finally, it computes the errors between the measuered and predicted hole positions, projecting them into radial and 
    tangential components.

    Parameters:

    x : np.array
        Measured x-coordinates of hole positions for a specific fractured section. 
    y : np.array
        Measured y-coordinates of hole positions for a secific fracured section. 
    R : float
        Estimated radius of the original ring
    phis : np.array
        Angles of the holes in the unbroken ring, evenly spaced between 0 and 2pi
    xcent : float
        x-coordinate of the estimated center if the fractured section
    ycent : float
        y-coordinate of the estimated center of the fractured section
    phase : float
        Phase shift applied to align the hole positions.

    Returns:
    
    rp : np.array
        Radial errors (distance from the expected radius). 
    tp : np.array 
        Tangential errors (misaligment along the circumference of the ring). 

    '''

    x, y = jnp.asarray(x), jnp.asarray(y)
    phase = jnp.asarray(phase)

    phi = phis + phase # Apply phase shift to the angles

    cphi = jnp.cos(phi) # Compute cosine values
    sphi = jnp.sin(phi) # Compute sine values

    # compute model points in x,y
    r_x = R*cphi # X-coordinates of prediceted holes
    r_y = R*sphi # Y-coordinates of predicted holes

    # shift data point to be around model x,y
    d_x = x - xcent 
    d_y = y - ycent

    # find error vector between data and model
    e_x = r_x - d_x
    e_y = r_y - d_y

    # project vector into radius and tangent
    rp = e_x*cphi + e_y*sphi # Radial projection
    tp = e_x*sphi - e_y*cphi # Tangential projection

    return rp, tp

In [None]:
import jax
import jax.numpy as jnp

def log_likelihood_rt_jax(params, data, N):
    '''
    Compute log-likelihood for the isotropic Gaussian error model using JAX.
    
    Parameters:
    params (array): Model parameters [R, sigma_r, sigma_t, phase1, phase2, ..., xcent1, xcent2, ..., ycent1, ycent2, ...]
    data (list of tuples): Measured hole positions for each fractured section.
    N (int): Total number of holes in the original complete ring.
    
    Returns:
    float: Log-likelihood value.
    '''
    
    R, sigma_r, sigma_t = params[:3]
    phases, xcents, ycents = jnp.split(params[3:], 3)
    
    invsig_r = 1./(2*(sigma_r*sigma_r))
    invsig_t = 1./(2*(sigma_t*sigma_t))

    # Compute the total number of points for log-likelihood normalization
    npoints = sum([len(sect[0]) for sect in data])
    prefact = -npoints * jnp.log(2 * jnp.pi * sigma_t * sigma_r)
    
    phis = 2 * jnp.pi * jnp.arange(100) / N  # Compute angles

    # Vectorized computation for all sections
    def section_log_likelihood(sect, i):
        x, y = sect  # Extract x, y measurements for the section
        rp, tp = model(x, y, R, phis[:len(x)], xcents[i], ycents[i], phases[i])  # Vectorized model call
        return jnp.sum(-invsig_r * (rp ** 2) - invsig_t * (tp ** 2))

    # Apply across all sections using `vmap`
    indices = jnp.arange(len(data))
    exp_likelihoods = jax.vmap(section_log_likelihood, in_axes=(0, 0))(data, indices)
    
    return jnp.sum(exp_likelihoods) + prefact


In [50]:
import jax
import jax.numpy as jnp

def log_likelihood_rt(params, data, N):
    """
    Compute log-likelihood for the isotropic Gaussian error model using JAX.

    Args:
        params (array): Model parameters [R, sigma_r, sigma_t, phase1, ..., xcent1, ..., ycent1, ...]
        data (list of tuples): Measured hole positions (x, y) for each fractured section.
        N (int): Total number of holes in the original complete ring.

    Returns:
        float: Log-likelihood value.
    """
    
    # Extract model parameters
    R, sigma_r, sigma_t = params[:3]
    phases, xcents, ycents = jnp.split(params[3:], 3)

    # Precompute inverse variances for efficiency
    invsig_r = 1. / (2 * (sigma_r ** 2))
    invsig_t = 1. / (2 * (sigma_t ** 2))

    # Compute total number of observed holes for log-likelihood prefactor
    npoints = sum([len(sect[0]) for sect in data])
    prefact = -npoints * jnp.log(2 * jnp.pi * sigma_t * sigma_r)

    # Compute expected hole positions using vectorized angles
    phis = 2 * jnp.pi * jnp.arange(100) / N  # Angle positions for 100 holes

    # Define function for section-wise log-likelihood computation
    def section_log_likelihood(sect, xc, yc, phase):
        x, y = sect  # Extract hole positions for this section
        rp, tp = model(x, y, R, phis[:len(x)], xc, yc, phase)  # Expected positions

        # Compute likelihood exponent in a vectorized way
        exponent = -invsig_r * (rp**2) - invsig_t * (tp**2)
        return jnp.sum(exponent)  # Sum across section

    # Apply to all sections using `vmap`
    indices = jnp.arange(len(data))  # Indices for mapping
    exp_likelihoods = jax.vmap(section_log_likelihood, in_axes=(0, 0, 0, 0))(
        data, xcents, ycents, phases
    )

    # Return total log-likelihood
    return jnp.sum(exp_likelihoods) + prefact


In [51]:
def neg_log_likelihood_rt(N, R, sigma_r, sigma_t,
                       phase_0, phase_1, phase_2, phase_3, phase_4, phase_5,
                       xcent_0, xcent_1, xcent_2, xcent_3, xcent_4, xcent_5,
                       ycent_0, ycent_1, ycent_2, ycent_3, ycent_4, ycent_5):
    """
    Compute the negative log-likelihood for 6 sections dynamically.
    """

    # Convert parameters into JAX arrays
    phases = jnp.array([phase_0, phase_1, phase_2, phase_3, phase_4, phase_5])
    xcents = jnp.array([xcent_0, xcent_1, xcent_2, xcent_3, xcent_4, xcent_5])
    ycents = jnp.array([ycent_0, ycent_1, ycent_2, ycent_3, ycent_4, ycent_5])

    # Fix: Ensure data is JAX-compatible (without dtype=object issue)
    global data
    data = [(jnp.asarray(x), jnp.asarray(y)) for x, y in data]  # Fix object dtype issue

    # Convert params into a JAX array
    params = jnp.array([R, sigma_r, sigma_t, *phases, *xcents, *ycents])

    # Compute and negate log-likelihood
    return -log_likelihood_rt(params, data, N)

In [52]:
import numpy as np
from iminuit import Minuit


num_sections = 6  # Now using 6 sections


data = [(sub_data['Mean(X)'], sub_data['Mean(Y)'])]


# Initial parameter estimates for 6 sections
N = 355
R_init = 77
sigma_r_init = 0.005
sigma_t_init = 0.13

phases_init = [-2.53, -2.53, -2.53, -2.54, -2.55, -2.55]
xcents_init = [79, 79, 79, 81, 81, 83]
ycents_init = [136, 135, 135, 136, 135, 136]


# Combine parameters into one list
init_params = [N, R_init, sigma_r_init, sigma_t_init, *phases_init, *xcents_init, *ycents_init]

m_rt = Minuit(neg_log_likelihood_rt, 
           N=init_params[0], R=init_params[1], sigma_r=init_params[2], sigma_t=init_params[3],
           phase_0=init_params[4], phase_1=init_params[5], phase_2=init_params[6], 
           phase_3=init_params[7], phase_4=init_params[8], phase_5=init_params[9],
           xcent_0=init_params[10], xcent_1=init_params[11], xcent_2=init_params[12], 
           xcent_3=init_params[13], xcent_4=init_params[14], xcent_5=init_params[15],
           ycent_0=init_params[16], ycent_1=init_params[17], ycent_2=init_params[18], 
           ycent_3=init_params[19], ycent_4=init_params[20], ycent_5=init_params[21])

# Define likelihood optimization settings
m_rt.errordef = Minuit.LIKELIHOOD
m_rt.limits['N'] = (345, 365)  # Set realistic bounds for N
m_rt.limits['R'] = (60, 80)
m_rt.limits['sigma_r'] = (0, None)  # sigma_r must be positive
m_rt.limits['sigma_t'] = (0, None)  # sigma_t must be positive

# Run the minimization
m_rt.migrad()

ValueError: vmap got inconsistent sizes for array axes to be mapped:
  * most axes (3 of them) had size 6, e.g. axis 0 of argument xc of type float32[6];
  * some axes (2 of them) had size 79, e.g. axis 0 of argument sect[0][0] of type float32[79]