In [1]:
# === IMPORTS AND SETUP ===
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"

import numpy as np
import jax
import jax.numpy as jnp
from pathlib import Path
from astropy import units as u
from astropy.time import Time
from astropy.coordinates import EarthLocation, get_body_barycentric_posvel, solar_system_ephemeris
import matplotlib.pyplot as plt

# PINT (for data loading and comparison)
from pint.models import get_model as pint_get_model
from pint.toa import get_TOAs as pint_get_TOAs
from pint.residuals import Residuals

jax.config.update('jax_enable_x64', True)
print(f"JAX {jax.__version__}, Float64: {jax.config.jax_enable_x64}")

JAX 0.8.1, Float64: True


In [2]:
# === CONSTANTS ===
SECS_PER_DAY = 86400.0
C_KM_S = 299792.458
T_SUN_SEC = 4.925490947e-6
AU_KM = 149597870.7
K_DM_SEC = 1.0 / 2.41e-4

# Planetary GM/c^3 (seconds)
T_PLANET = {
    'jupiter': 4.702819050227708e-09,
    'saturn':  1.408128810019423e-09,
    'uranus':  2.150589551363761e-10,
    'neptune': 2.537311999186760e-10,
    'venus':   1.205680558494223e-11,
}

In [3]:
# === DATA PATHS ===
par_file = Path('/home/mattm/projects/MPTA/partim/production/fifth_pass/tdb/J1909-3744_tdb.par')
tim_file = Path('/home/mattm/projects/MPTA/partim/production/fifth_pass/tdb/J1909-3744.tim')

# Load with PINT
pint_model = pint_get_model(str(par_file))
pint_toas = pint_get_TOAs(str(tim_file), model=pint_model)
print(f"Loaded {pint_toas.ntoas} TOAs for {pint_model.PSR.value}")

[32m2025-11-28 14:56:11.147[0m | [34m[1mDEBUG   [0m | [36mpint.toa[0m:[36mget_TOAs[0m:[36m195[0m - [34m[1mUsing EPHEM = DE440 from the given model[0m
[32m2025-11-28 14:56:11.148[0m | [34m[1mDEBUG   [0m | [36mpint.toa[0m:[36mget_TOAs[0m:[36m211[0m - [34m[1mUsing CLOCK = BIPM2024 from the given model[0m
[32m2025-11-28 14:56:11.149[0m | [34m[1mDEBUG   [0m | [36mpint.toa[0m:[36mget_TOAs[0m:[36m232[0m - [34m[1mUsing PLANET_SHAPIRO = True from the given model[0m
[32m2025-11-28 14:56:11.148[0m | [34m[1mDEBUG   [0m | [36mpint.toa[0m:[36mget_TOAs[0m:[36m211[0m - [34m[1mUsing CLOCK = BIPM2024 from the given model[0m
[32m2025-11-28 14:56:11.149[0m | [34m[1mDEBUG   [0m | [36mpint.toa[0m:[36mget_TOAs[0m:[36m232[0m - [34m[1mUsing PLANET_SHAPIRO = True from the given model[0m


[32m2025-11-28 14:56:12.246[0m | [34m[1mDEBUG   [0m | [36mpint.toa[0m:[36m__init__[0m:[36m1377[0m - [34m[1mNo pulse number flags found in the TOAs[0m
[32m2025-11-28 14:56:12.255[0m | [34m[1mDEBUG   [0m | [36mpint.toa[0m:[36mapply_clock_corrections[0m:[36m2232[0m - [34m[1mApplying clock corrections (include_bipm = True)[0m
[32m2025-11-28 14:56:12.255[0m | [34m[1mDEBUG   [0m | [36mpint.toa[0m:[36mapply_clock_corrections[0m:[36m2232[0m - [34m[1mApplying clock corrections (include_bipm = True)[0m
[32m2025-11-28 14:56:12.468[0m | [1mINFO    [0m | [36mpint.observatory[0m:[36mgps_correction[0m:[36m230[0m - [1mApplying GPS to UTC clock correction (~few nanoseconds)[0m
[32m2025-11-28 14:56:12.468[0m | [34m[1mDEBUG   [0m | [36mpint.observatory[0m:[36m_load_gps_clock[0m:[36m108[0m - [34m[1mLoading global GPS clock file[0m
[32m2025-11-28 14:56:12.470[0m | [34m[1mDEBUG   [0m | [36mpint.observatory.clock_file[0m:[36m__init__

Loaded 10408 TOAs for J1909-3744


In [28]:
# === HELPER FUNCTIONS ===

# Parameters that need high precision (longdouble) - these affect phase calculation
HIGH_PRECISION_PARAMS = {'F0', 'F1', 'F2', 'F3', 'PEPOCH', 'TZRMJD', 'POSEPOCH', 'DMEPOCH'}

def parse_par_file(path):
    """Parse tempo2-style .par file with high precision for timing-critical parameters.
    
    F0, F1, PEPOCH, TZRMJD etc. are stored as strings to preserve full precision,
    which can then be converted to np.longdouble when needed.
    Other parameters use standard float64.
    """
    params = {}
    params_str = {}  # Store raw strings for high-precision parameters
    
    with open(path) as f:
        for line in f:
            line = line.strip()
            if not line or line.startswith('#'):
                continue
            parts = line.split()
            if len(parts) >= 2:
                key = parts[0].upper()
                value_str = parts[1]
                
                if key in HIGH_PRECISION_PARAMS:
                    # Store as string for later high-precision conversion
                    params_str[key] = value_str
                    # Also store as float for convenience (when precision isn't critical)
                    try:
                        params[key] = float(value_str)
                    except ValueError:
                        params[key] = value_str
                else:
                    try:
                        params[key] = float(value_str)
                    except ValueError:
                        params[key] = value_str
    
    # Add _str suffix versions for high-precision access
    params['_high_precision'] = params_str
    return params

def get_longdouble(params, key, default=None):
    """Get a parameter as np.longdouble with full precision."""
    hp = params.get('_high_precision', {})
    if key in hp:
        return np.longdouble(hp[key])
    elif key in params:
        return np.longdouble(params[key])
    elif default is not None:
        return np.longdouble(default)
    else:
        raise KeyError(f"Parameter {key} not found")

def parse_ra(ra_str):
    parts = ra_str.split(':')
    h, m, s = float(parts[0]), float(parts[1]), float(parts[2])
    return (h + m/60 + s/3600) * 15 * np.pi / 180

def parse_dec(dec_str):
    parts = dec_str.split(':')
    sign = -1 if parts[0].startswith('-') else 1
    d, m, s = abs(float(parts[0])), float(parts[1]), float(parts[2])
    return sign * (d + m/60 + s/3600) * np.pi / 180

par_params = parse_par_file(par_file)

# Verify high-precision parsing
print("High-precision parameters parsed:")
for k, v in par_params['_high_precision'].items():
    ld = np.longdouble(v)
    print(f"  {k}: {v} -> {float(ld):.20g}")

High-precision parameters parsed:
  F0: 339.3156919190406855 -> 339.31569191904065974
  F1: -1.6147400369092967286e-15 -> -1.6147400369092966925e-15
  PEPOCH: 59017.999753870497848 -> 59017.999753870499262
  POSEPOCH: 59017.999753870497848 -> 59017.999753870499262
  DMEPOCH: 58999.999754149591407 -> 58999.999754149590444
  TZRMJD: 59679.248061951184916 -> 59679.248061951184354


In [5]:
# === SSB POSITION AND DELAY FUNCTIONS (Using Astropy/DE440) ===

def compute_ssb_obs_pos_vel(tdb_mjd, obs_itrf_km):
    """Compute observatory position and velocity relative to SSB using Astropy."""
    times = Time(tdb_mjd, format='mjd', scale='tdb')
    
    with solar_system_ephemeris.set('de440'):
        earth_pv = get_body_barycentric_posvel('earth', times)
        ssb_geo_pos = earth_pv[0].xyz.to(u.km).value.T
        ssb_geo_vel = earth_pv[1].xyz.to(u.km/u.s).value.T
    
    obs_itrf = EarthLocation.from_geocentric(
        obs_itrf_km[0] * u.km, obs_itrf_km[1] * u.km, obs_itrf_km[2] * u.km
    )
    
    obs_gcrs = obs_itrf.get_gcrs(obstime=times)
    geo_obs_pos = np.column_stack([
        obs_gcrs.cartesian.x.to(u.km).value,
        obs_gcrs.cartesian.y.to(u.km).value,
        obs_gcrs.cartesian.z.to(u.km).value
    ])
    
    # Observatory velocity from numerical derivative
    dt_sec = 1.0
    times_plus = Time(tdb_mjd + dt_sec/SECS_PER_DAY, format='mjd', scale='tdb')
    obs_gcrs_plus = obs_itrf.get_gcrs(obstime=times_plus)
    geo_obs_pos_plus = np.column_stack([
        obs_gcrs_plus.cartesian.x.to(u.km).value,
        obs_gcrs_plus.cartesian.y.to(u.km).value,
        obs_gcrs_plus.cartesian.z.to(u.km).value
    ])
    geo_obs_vel = (geo_obs_pos_plus - geo_obs_pos) / dt_sec
    
    return ssb_geo_pos + geo_obs_pos, ssb_geo_vel + geo_obs_vel


def compute_pulsar_direction(ra_rad, dec_rad, pmra_rad_day, pmdec_rad_day, posepoch, t_mjd):
    """Compute pulsar direction unit vector with proper motion."""
    dt = t_mjd - posepoch
    cos_dec0 = np.cos(dec_rad)
    ra = ra_rad + pmra_rad_day * dt / cos_dec0
    dec = dec_rad + pmdec_rad_day * dt
    cos_dec = np.cos(dec)
    return np.column_stack([cos_dec * np.cos(ra), cos_dec * np.sin(ra), np.sin(dec)])


def compute_roemer_delay(ssb_obs_pos_km, L_hat, parallax_mas=0.0):
    """Compute Roemer delay including parallax correction."""
    re_dot_L = np.sum(ssb_obs_pos_km * L_hat, axis=1)
    roemer_sec = -re_dot_L / C_KM_S
    
    if parallax_mas != 0.0:
        distance_kpc = 1.0 / parallax_mas
        L_km = distance_kpc * 3.085677581e16
        re_sqr = np.sum(ssb_obs_pos_km**2, axis=1)
        parallax_sec = 0.5 * (re_sqr / L_km) * (1.0 - re_dot_L**2 / re_sqr) / C_KM_S
        roemer_sec = roemer_sec + parallax_sec
    return roemer_sec


def compute_shapiro_delay(obs_body_pos_km, L_hat, T_body):
    """Compute Shapiro delay for a body."""
    r = np.sqrt(np.sum(obs_body_pos_km**2, axis=1))
    rcostheta = np.sum(obs_body_pos_km * L_hat, axis=1)
    return -2.0 * T_body * np.log((r - rcostheta) / AU_KM)


def compute_barycentric_freq(freq_topo_mhz, ssb_obs_vel_km_s, L_hat):
    """Compute barycentric frequency."""
    v_radial = np.sum(ssb_obs_vel_km_s * L_hat, axis=1)
    return freq_topo_mhz * (1.0 - v_radial / C_KM_S)

In [6]:
# === DM, SOLAR WIND, FD DELAYS (JAX-compiled) ===

@jax.jit
def dm_delay(t_mjd, freq_mhz, dm_coeffs, dm_factorials, dm_epoch):
    dt_years = (t_mjd - dm_epoch) / 365.25
    powers = jnp.arange(len(dm_coeffs))
    dt_powers = dt_years[:, jnp.newaxis] ** powers[jnp.newaxis, :]
    dm_eff = jnp.sum(dm_coeffs * dt_powers / dm_factorials, axis=1)
    return K_DM_SEC * dm_eff / (freq_mhz ** 2)


@jax.jit
def solar_wind_delay(obs_sun_pos_km, L_hat, freq_mhz, ne_sw):
    AU_KM_local = 1.495978707e8
    AU_PC = 4.84813681e-6
    r_km = jnp.sqrt(jnp.sum(obs_sun_pos_km**2, axis=1))
    r_au = r_km / AU_KM_local
    sun_dir = obs_sun_pos_km / r_km[:, jnp.newaxis]
    cos_elong = jnp.sum(sun_dir * L_hat, axis=1)
    elong = jnp.arccos(jnp.clip(cos_elong, -1.0, 1.0))
    rho = jnp.pi - elong
    sin_rho = jnp.maximum(jnp.sin(rho), 1e-10)
    geometry_pc = AU_PC * rho / (r_au * sin_rho)
    dm_sw = ne_sw * geometry_pc
    return K_DM_SEC * dm_sw / (freq_mhz ** 2)


@jax.jit
def fd_delay(freq_mhz, fd_coeffs):
    log_freq = jnp.log(freq_mhz / 1000.0)
    poly_coeffs = jnp.concatenate([fd_coeffs[::-1], jnp.array([0.0])])
    return jnp.polyval(poly_coeffs, log_freq)

In [7]:
# === ELL1 BINARY DELAY WITH ABERRATION CORRECTION ===

@jax.jit
def ell1_binary_delay(t_tdb, pb, a1, tasc, eps1, eps2, pbdot, xdot, gamma, r_shap, s_shap):
    """ELL1 binary delay with third-order eccentricity and aberration correction."""
    
    def compute_roemer_terms(t_binary):
        dt_days = t_binary - tasc
        dt_sec = dt_days * SECS_PER_DAY
        n0 = 2.0 * jnp.pi / (pb * SECS_PER_DAY)
        Phi = n0 * dt_sec * (1.0 - pbdot / 2.0 / pb * dt_days)
        
        sin_Phi, cos_Phi = jnp.sin(Phi), jnp.cos(Phi)
        sin_2Phi, cos_2Phi = jnp.sin(2*Phi), jnp.cos(2*Phi)
        sin_3Phi, cos_3Phi = jnp.sin(3*Phi), jnp.cos(3*Phi)
        sin_4Phi, cos_4Phi = jnp.sin(4*Phi), jnp.cos(4*Phi)
        
        a1_eff = jnp.where(xdot != 0.0, a1 + xdot * dt_sec, a1)
        
        eps1_sq, eps2_sq = eps1**2, eps2**2
        eps1_cu, eps2_cu = eps1**3, eps2**3
        
        # Dre / a1 (third-order Roemer)
        Dre_a1 = (
            sin_Phi + 0.5 * (eps2 * sin_2Phi - eps1 * cos_2Phi)
            - (1.0/8.0) * (5*eps2_sq*sin_Phi - 3*eps2_sq*sin_3Phi - 2*eps2*eps1*cos_Phi
                          + 6*eps2*eps1*cos_3Phi + 3*eps1_sq*sin_Phi + 3*eps1_sq*sin_3Phi)
            - (1.0/12.0) * (5*eps2_cu*sin_2Phi + 3*eps1_sq*eps2*sin_2Phi
                           - 6*eps1*eps2_sq*cos_2Phi - 4*eps1_cu*cos_2Phi
                           - 4*eps2_cu*sin_4Phi + 12*eps1_sq*eps2*sin_4Phi
                           + 12*eps1*eps2_sq*cos_4Phi - 4*eps1_cu*cos_4Phi)
        )
        
        # Drep / a1 (derivative w.r.t. Phi)
        Drep_a1 = (
            cos_Phi + eps1 * sin_2Phi + eps2 * cos_2Phi
            - (1.0/8.0) * (5*eps2_sq*cos_Phi - 9*eps2_sq*cos_3Phi + 2*eps1*eps2*sin_Phi
                          - 18*eps1*eps2*sin_3Phi + 3*eps1_sq*cos_Phi + 9*eps1_sq*cos_3Phi)
            - (1.0/12.0) * (10*eps2_cu*cos_2Phi + 6*eps1_sq*eps2*cos_2Phi
                           + 12*eps1*eps2_sq*sin_2Phi + 8*eps1_cu*sin_2Phi
                           - 16*eps2_cu*cos_4Phi + 48*eps1_sq*eps2*cos_4Phi
                           - 48*eps1*eps2_sq*sin_4Phi + 16*eps1_cu*sin_4Phi)
        )
        
        # Drepp / a1 (second derivative)
        Drepp_a1 = (
            -sin_Phi + 2*eps1*cos_2Phi - 2*eps2*sin_2Phi
            - (1.0/8.0) * (-5*eps2_sq*sin_Phi + 27*eps2_sq*sin_3Phi + 2*eps1*eps2*cos_Phi
                          - 54*eps1*eps2*cos_3Phi - 3*eps1_sq*sin_Phi - 27*eps1_sq*sin_3Phi)
            - (1.0/12.0) * (-20*eps2_cu*sin_2Phi - 12*eps1_sq*eps2*sin_2Phi
                           + 24*eps1*eps2_sq*cos_2Phi + 16*eps1_cu*cos_2Phi
                           + 64*eps2_cu*sin_4Phi - 192*eps1_sq*eps2*sin_4Phi
                           - 192*eps1*eps2_sq*cos_4Phi + 64*eps1_cu*cos_4Phi)
        )
        
        Dre = a1_eff * Dre_a1
        Drep = a1_eff * Drep_a1
        Drepp = a1_eff * Drepp_a1
        
        # Aberration correction (Damour & Deruelle 1986)
        nhat = n0
        delay_roemer = Dre * (1 - nhat*Drep + (nhat*Drep)**2 + 0.5*nhat**2*Dre*Drepp)
        
        delay_einstein = gamma * sin_Phi
        delay_shapiro = jnp.where(
            (r_shap != 0.0) & (s_shap != 0.0),
            -2.0 * r_shap * jnp.log(1.0 - s_shap * sin_Phi),
            0.0
        )
        
        return delay_roemer + delay_einstein + delay_shapiro
    
    # Single iteration (aberration correction handles the rest)
    return compute_roemer_terms(t_tdb)

In [29]:
# === MAIN COMPUTATION ===
import math
print("="*70)
print("JUG MK4 - RESIDUAL COMPUTATION (with proper TZR handling)")
print("="*70)

# Extract data
tdbld = np.array(pint_toas.table['tdbld'].value, dtype=np.float64)
freq_mhz = np.array(pint_toas.table['freq'].value, dtype=np.float64)
mjd_float = np.array(pint_toas.table['mjd_float'].value, dtype=np.float64)

# Astrometric parameters
ra_rad = parse_ra(par_params['RAJ'])
dec_rad = parse_dec(par_params['DECJ'])
pmra_rad_day = par_params.get('PMRA', 0.0) * (np.pi / 180 / 3600000) / 365.25
pmdec_rad_day = par_params.get('PMDEC', 0.0) * (np.pi / 180 / 3600000) / 365.25
posepoch = par_params.get('POSEPOCH', par_params['PEPOCH'])
parallax_mas = par_params.get('PX', 0.0)

# Observatory (MeerKAT)
obs_itrf_km = np.array([5109360.133, 2006852.586, -3238948.127]) / 1000

# Compute SSB positions
print("Computing SSB positions (using Astropy/DE440)...")
ssb_obs_pos_km, ssb_obs_vel_km_s = compute_ssb_obs_pos_vel(tdbld, obs_itrf_km)
L_hat = compute_pulsar_direction(ra_rad, dec_rad, pmra_rad_day, pmdec_rad_day, posepoch, tdbld)

# Barycentric frequency
freq_bary_mhz = compute_barycentric_freq(freq_mhz, ssb_obs_vel_km_s, L_hat)

# Roemer delay
jug_roemer_sec = compute_roemer_delay(ssb_obs_pos_km, L_hat, parallax_mas)
print(f"  Roemer delay: {jug_roemer_sec.min():.3f} to {jug_roemer_sec.max():.3f} s")

# Sun position and Shapiro delay
times = Time(tdbld, format='mjd', scale='tdb')
with solar_system_ephemeris.set('de440'):
    sun_pos = get_body_barycentric_posvel('sun', times)[0].xyz.to(u.km).value.T
obs_sun_pos_km = sun_pos - ssb_obs_pos_km
jug_sun_shapiro_sec = compute_shapiro_delay(obs_sun_pos_km, L_hat, T_SUN_SEC)

# Planetary Shapiro delays
planet_shapiro_enabled = str(par_params.get('PLANET_SHAPIRO', 'N')).upper() in ('Y', 'YES', 'TRUE', '1')
if planet_shapiro_enabled:
    print("Computing planetary Shapiro delays...")
    jug_planet_shapiro_sec = np.zeros(len(tdbld))
    with solar_system_ephemeris.set('de440'):
        for planet in ['jupiter', 'saturn', 'uranus', 'neptune', 'venus']:
            planet_pos = get_body_barycentric_posvel(planet, times)[0].xyz.to(u.km).value.T
            obs_planet_km = planet_pos - ssb_obs_pos_km
            shapiro = compute_shapiro_delay(obs_planet_km, L_hat, T_PLANET[planet])
            jug_planet_shapiro_sec += shapiro
            print(f"    {planet:8s}: {np.mean(shapiro)*1e9:+7.3f} ± {np.std(shapiro)*1e9:.3f} ns")
    jug_shapiro_sec = jug_sun_shapiro_sec + jug_planet_shapiro_sec
else:
    jug_shapiro_sec = jug_sun_shapiro_sec

roemer_shapiro_sec = jug_roemer_sec + jug_shapiro_sec

# DM delay
dm_coeffs = []
k = 0
while True:
    key = 'DM' if k == 0 else f'DM{k}'
    if key in par_params:
        dm_coeffs.append(float(par_params[key]))
        k += 1
    else:
        break
dm_coeffs = np.array(dm_coeffs if dm_coeffs else [0.0])
dm_factorials = np.array([float(math.factorial(i)) for i in range(len(dm_coeffs))])
dm_epoch = float(par_params.get('DMEPOCH', par_params['PEPOCH']))

jug_dm_sec = np.array(dm_delay(
    jnp.array(tdbld), jnp.array(freq_bary_mhz),
    jnp.array(dm_coeffs), jnp.array(dm_factorials), dm_epoch
))
print(f"  DM delay: {jug_dm_sec.min()*1e3:.3f} to {jug_dm_sec.max()*1e3:.3f} ms")

# Solar wind delay
NE_SW = float(par_params.get('NE_SW', 0.0))
if NE_SW > 0:
    jug_sw_sec = np.array(solar_wind_delay(
        jnp.array(obs_sun_pos_km), jnp.array(L_hat), jnp.array(freq_bary_mhz), NE_SW
    ))
    print(f"  Solar wind: {jug_sw_sec.min()*1e6:.3f} to {jug_sw_sec.max()*1e6:.3f} µs")
else:
    jug_sw_sec = np.zeros_like(jug_dm_sec)

# FD delay
fd_coeffs = []
fd_idx = 1
while f'FD{fd_idx}' in par_params:
    fd_coeffs.append(float(par_params[f'FD{fd_idx}']))
    fd_idx += 1
if fd_coeffs:
    jug_fd_sec = np.array(fd_delay(jnp.array(freq_bary_mhz), jnp.array(fd_coeffs)))
    print(f"  FD delay: {jug_fd_sec.min()*1e6:.3f} to {jug_fd_sec.max()*1e6:.3f} µs")
else:
    jug_fd_sec = np.zeros_like(jug_dm_sec)

# Binary parameters
PB = float(par_params['PB'])
A1 = float(par_params['A1'])
TASC = float(par_params['TASC'])
EPS1 = float(par_params.get('EPS1', 0.0))
EPS2 = float(par_params.get('EPS2', 0.0))
PBDOT = float(par_params.get('PBDOT', 0.0))
XDOT = float(par_params.get('XDOT', 0.0))
GAMMA = float(par_params.get('GAMMA', 0.0))

M2 = float(par_params.get('M2', 0.0))
SINI = float(par_params.get('SINI', 0.0))
H3 = float(par_params.get('H3', 0.0))
STIG = float(par_params.get('STIG', 0.0))

if H3 != 0.0 or STIG != 0.0:
    r_shapiro, s_shapiro = H3, STIG
elif M2 != 0.0 and SINI != 0.0:
    r_shapiro, s_shapiro = T_SUN_SEC * M2, SINI
else:
    r_shapiro, s_shapiro = 0.0, 0.0

# Binary delay
t_topo_tdb = tdbld - (roemer_shapiro_sec + jug_dm_sec + jug_sw_sec + jug_fd_sec) / SECS_PER_DAY
jug_binary_sec = np.array(ell1_binary_delay(
    jnp.array(t_topo_tdb), PB, A1, TASC, EPS1, EPS2, PBDOT, XDOT, GAMMA, r_shapiro, s_shapiro
))
print(f"  Binary delay: {jug_binary_sec.min():.6f} to {jug_binary_sec.max():.6f} s")

# Total delay
jug_total_delay_sec = roemer_shapiro_sec + jug_dm_sec + jug_sw_sec + jug_fd_sec + jug_binary_sec

# Phase and residuals (longdouble precision) - USE HIGH-PRECISION PARSER
PEPOCH = get_longdouble(par_params, 'PEPOCH')
F0 = get_longdouble(par_params, 'F0')
F1 = get_longdouble(par_params, 'F1', 0.0)
F2 = get_longdouble(par_params, 'F2', 0.0)

tdbld_ld = np.array(pint_toas.table['tdbld'].value, dtype=np.longdouble)
jug_delay_ld = np.array(jug_total_delay_sec, dtype=np.longdouble)
dt_sec = (tdbld_ld - PEPOCH) * np.longdouble(SECS_PER_DAY) - jug_delay_ld

# Phase polynomial (Taylor series)
phase_jug = F0 * dt_sec + np.longdouble(0.5) * F1 * dt_sec**2 + np.longdouble(1.0/6.0) * F2 * dt_sec**3

# === USE PINT's TZR TOA DIRECTLY ===
# PINT creates a TZR TOA that includes clock corrections (UTC->TDB)
# To match PINT exactly, we should use PINT's TZR phase computation
print("\nUsing PINT's TZR reference (for exact match)...")

# Get PINT's TZR TOA and compute delays on it
pint_tzr_toa = pint_model.get_TZR_toa(pint_toas)
TZRMJD_TDB = np.longdouble(pint_tzr_toa.table['tdbld'][0])
TZRFRQ = float(pint_tzr_toa.table['freq'][0])

print(f"  PINT TZR MJD (TDB): {float(TZRMJD_TDB):.10f}")
print(f"  PINT TZR FREQ: {TZRFRQ:.3f} MHz")

# Get PINT's delay at TZR (this is the correct reference delay)
pint_tzr_delay = float(pint_model.delay(pint_tzr_toa).to('s').value[0])
print(f"  PINT TZR delay: {pint_tzr_delay:.6f} s")

# Phase at TZR using PINT's TZR time and delay
tzr_dt_sec = (TZRMJD_TDB - PEPOCH) * np.longdouble(SECS_PER_DAY) - np.longdouble(pint_tzr_delay)
phase_tzr = F0 * tzr_dt_sec + np.longdouble(0.5) * F1 * tzr_dt_sec**2 + np.longdouble(1.0/6.0) * F2 * tzr_dt_sec**3

print(f"  TZR phase: {float(phase_tzr):.6f} cycles")

# Subtract TZR phase to get fractional phase
frac_phase_jug = phase_jug - phase_tzr
frac_phase_jug = np.mod(frac_phase_jug + 0.5, 1.0) - 0.5

jug_residuals_us = np.array(frac_phase_jug / F0 * 1e6, dtype=np.float64)
print(f"\nJUG residuals RMS: {np.std(jug_residuals_us):.3f} µs")

JUG MK4 - RESIDUAL COMPUTATION (with proper TZR handling)
Computing SSB positions (using Astropy/DE440)...
  Roemer delay: -490.177 to 477.484 s
Computing planetary Shapiro delays...
    jupiter :  -3.040 ± 13.685 ns
    saturn  :  -0.702 ± 2.864 ns
    uranus  :  -1.444 ± 0.035 ns
    neptune :  -1.487 ± 0.056 ns
    venus   :  +0.006 ± 0.030 ns
  DM delay: 15.654 to 52.331 ms
  Solar wind: 0.052 to 1.964 µs
  FD delay: -0.244 to -0.002 µs
  Binary delay: -1.897879 to 1.897996 s

Using PINT's TZR reference (for exact match)...
  PINT TZR MJD (TDB): 59679.2488627115
  PINT TZR FREQ: 1029.026 MHz
  PINT TZR delay: -45.574211 s
  TZR phase: 19385773446.129742 cycles

JUG residuals RMS: 0.817 µs
  Roemer delay: -490.177 to 477.484 s
Computing planetary Shapiro delays...
    jupiter :  -3.040 ± 13.685 ns
    saturn  :  -0.702 ± 2.864 ns
    uranus  :  -1.444 ± 0.035 ns
    neptune :  -1.487 ± 0.056 ns
    venus   :  +0.006 ± 0.030 ns
  DM delay: 15.654 to 52.331 ms
  Solar wind: 0.052 to 1

In [30]:
# === COMPARISON WITH PINT ===
print("="*70)
print("JUG vs PINT COMPARISON")
print("="*70)

pint_residuals_obj = Residuals(pint_toas, pint_model)
pint_residuals_us = pint_residuals_obj.time_resids.to(u.us).value

jug_centered = jug_residuals_us - np.mean(jug_residuals_us)
pint_centered = pint_residuals_us - np.mean(pint_residuals_us)
diff_ns = (jug_centered - pint_centered) * 1000

print(f"\nJUG RMS:  {np.std(jug_centered):.3f} µs")
print(f"PINT RMS: {np.std(pint_centered):.3f} µs")
print(f"\nDifference (JUG - PINT):")
print(f"  Mean: {np.mean(diff_ns):+.3f} ns")
print(f"  RMS:  {np.std(diff_ns):.3f} ns")
print(f"  Max:  {np.max(np.abs(diff_ns)):.2f} ns")

# Trend
t_years = np.array((mjd_float - mjd_float[0]) / 365.25, dtype=np.float64)
diff_ns_f64 = np.array(diff_ns, dtype=np.float64)
slope, intercept = np.polyfit(t_years, diff_ns_f64, 1)
print(f"  Drift: {slope:.3f} ns/yr")

if np.std(diff_ns) < 10:
    print(f"\n✅ SUCCESS: JUG matches PINT within {np.std(diff_ns):.1f} ns RMS")

JUG vs PINT COMPARISON

JUG RMS:  0.817 µs
PINT RMS: 0.817 µs

Difference (JUG - PINT):
  Mean: -0.000 ns
  RMS:  2.548 ns
  Max:  12.30 ns
  Drift: -0.347 ns/yr

✅ SUCCESS: JUG matches PINT within 2.5 ns RMS

JUG RMS:  0.817 µs
PINT RMS: 0.817 µs

Difference (JUG - PINT):
  Mean: -0.000 ns
  RMS:  2.548 ns
  Max:  12.30 ns
  Drift: -0.347 ns/yr

✅ SUCCESS: JUG matches PINT within 2.5 ns RMS


In [24]:
# === TEST: USE PINT's DELAYS WITH JUG's PHASE CALC ===
print("="*70)
print("TEST: JUG phase calculation with PINT delays")
print("="*70)

# Use PINT's total delay
pint_total_delay = pint_model.delay(pint_toas).to('s').value

# Compute phase using PINT's delays but JUG's phase polynomial
pint_delay_ld = np.array(pint_total_delay, dtype=np.longdouble)
dt_sec_pint = (tdbld_ld - PEPOCH) * np.longdouble(SECS_PER_DAY) - pint_delay_ld

phase_with_pint_delay = F0 * dt_sec_pint + np.longdouble(0.5) * F1 * dt_sec_pint**2 + np.longdouble(1.0/6.0) * F2 * dt_sec_pint**3

# Use PINT's TZR
frac_phase_test = phase_with_pint_delay - phase_tzr
frac_phase_test = np.mod(frac_phase_test + 0.5, 1.0) - 0.5

test_residuals_us = np.array(frac_phase_test / F0 * 1e6, dtype=np.float64)

# Compare with PINT
pint_residuals_obj = Residuals(pint_toas, pint_model)
pint_residuals_us = pint_residuals_obj.time_resids.to(u.us).value

test_centered = test_residuals_us - np.mean(test_residuals_us)
pint_centered = pint_residuals_us - np.mean(pint_residuals_us)
test_diff_ns = (test_centered - pint_centered) * 1000

print(f"\nWith PINT delays + JUG phase calc:")
print(f"  RMS diff: {np.std(test_diff_ns):.4f} ns")

# Check drift
t_years = np.array((mjd_float - mjd_float[0]) / 365.25, dtype=np.float64)
slope_test, _ = np.polyfit(t_years, np.array(test_diff_ns, dtype=np.float64), 1)
print(f"  Drift: {slope_test:.4f} ns/yr")

if np.std(test_diff_ns) < 1.0:
    print("  => The issue is in JUG's DELAY computation")
else:
    print("  => The issue is in JUG's PHASE calculation")
    
# Also check if the drift is in the delay difference
jug_pint_delay_diff_ns = (jug_total_delay_sec - pint_total_delay) * 1e9
slope_delay, _ = np.polyfit(t_years, np.array(jug_pint_delay_diff_ns, dtype=np.float64), 1)
print(f"\nDelay difference drift: {slope_delay:.4f} ns/yr")

# The residual drift is -2.74 ns/yr but delay drift is +0.35 ns/yr
# So there's a -3.1 ns/yr difference coming from the phase calculation itself

# Let's check if PINT uses F1 differently
# PINT's phase = F0*dt + F1*dt²/2 + F2*dt³/6
# JUG's phase = F0*dt + 0.5*F1*dt² + (1/6)*F2*dt³
# These should be identical...

# Check if the issue is in how dt is computed
# PINT: dt = (tdbld - PEPOCH) * day - delay
# JUG: dt_sec = (tdbld_ld - PEPOCH) * SECS_PER_DAY - delay

# Let's print actual dt values
print(f"\ndt comparison (first TOA):")
print(f"  JUG dt:  {float(dt_sec[0]):.6f} s")
print(f"  PINT-delay dt: {float(dt_sec_pint[0]):.6f} s")
print(f"  Difference: {float(dt_sec[0] - dt_sec_pint[0])*1e9:.3f} ns")

TEST: JUG phase calculation with PINT delays

With PINT delays + JUG phase calc:
  RMS diff: 4.6607 ns
  Drift: -2.3955 ns/yr
  => The issue is in JUG's PHASE calculation

Delay difference drift: 0.3469 ns/yr

dt comparison (first TOA):
  JUG dt:  -42490604.566126 s
  PINT-delay dt: -42490604.566126 s
  Difference: 8.717 ns

With PINT delays + JUG phase calc:
  RMS diff: 4.6607 ns
  Drift: -2.3955 ns/yr
  => The issue is in JUG's PHASE calculation

Delay difference drift: 0.3469 ns/yr

dt comparison (first TOA):
  JUG dt:  -42490604.566126 s
  PINT-delay dt: -42490604.566126 s
  Difference: 8.717 ns


In [25]:
# === DEEP DIVE: PINT's PHASE vs JUG's PHASE ===
print("="*70)
print("PINT vs JUG PHASE CALCULATION COMPARISON")
print("="*70)

# Get PINT's raw phase (before TZR subtraction)
pint_phase_raw = pint_model.phase(pint_toas, abs_phase=False)  # raw phase, no TZR
pint_phase_raw_total = np.array(pint_phase_raw.int + pint_phase_raw.frac, dtype=np.float64)

# Get JUG's raw phase (before TZR subtraction)
jug_phase_raw = np.array(phase_jug, dtype=np.float64)

print(f"Raw phase comparison (first 5 TOAs):")
print(f"  PINT: {pint_phase_raw_total[:5]}")
print(f"  JUG:  {jug_phase_raw[:5]}")
print(f"  Diff: {(jug_phase_raw - pint_phase_raw_total)[:5]}")

# The raw phases will differ by a large constant (different reference)
# But the DERIVATIVE (rate of change) should match

# Compute phase difference
phase_diff_raw = jug_phase_raw - pint_phase_raw_total
phase_diff_raw_centered = phase_diff_raw - np.mean(phase_diff_raw)

print(f"\nRaw phase difference (JUG - PINT):")
print(f"  Mean: {np.mean(phase_diff_raw):.6f} cycles")
print(f"  RMS (centered): {np.std(phase_diff_raw_centered):.9f} cycles")
print(f"  RMS in ns: {np.std(phase_diff_raw_centered) / float(F0) * 1e9:.4f} ns")

# Check drift in phase difference
slope_phase_raw, _ = np.polyfit(t_years, phase_diff_raw_centered, 1)
print(f"  Drift: {slope_phase_raw / float(F0) * 1e9:.4f} ns/yr")

# Now let's check what PINT does in its phase calculation
# PINT uses taylor_horner which should be equivalent to our polynomial
# But let's verify the dt values match

# Get PINT's dt (time from PEPOCH after delay correction)
# We can compute this from PINT's phase and F0
# phase = F0*dt + F1*dt²/2 + ..., so for small F1 contribution: dt ≈ phase/F0

# Actually, let's compute dt the same way PINT does
pint_pepoch_ld = pint_model.PEPOCH.quantity.tdb.mjd_long

# PINT's dt calculation (from spindown.py)
pint_dt_days = tdbld_ld - pint_pepoch_ld
pint_dt_sec_before_delay = pint_dt_days * np.longdouble(86400.0)

# After subtracting delay
pint_delay_ld = np.array(pint_total_delay, dtype=np.longdouble)
pint_dt_sec = pint_dt_sec_before_delay - pint_delay_ld

print(f"\ndt comparison:")
print(f"  JUG PEPOCH: {float(PEPOCH):.15f}")
print(f"  PINT PEPOCH: {float(pint_pepoch_ld):.15f}")
print(f"  PEPOCH diff: {float(PEPOCH - pint_pepoch_ld) * 86400 * 1e9:.3f} ns")

# The phase polynomial evaluation
jug_phase_from_pint_dt = F0 * pint_dt_sec + np.longdouble(0.5) * F1 * pint_dt_sec**2 + np.longdouble(1.0/6.0) * F2 * pint_dt_sec**3

# Compare with PINT's phase
phase_eval_diff = np.array(jug_phase_from_pint_dt, dtype=np.float64) - pint_phase_raw_total
phase_eval_diff_centered = phase_eval_diff - np.mean(phase_eval_diff)

print(f"\nPhase evaluation comparison (using same dt as PINT):")
print(f"  RMS diff: {np.std(phase_eval_diff_centered) / float(F0) * 1e9:.6f} ns")

slope_eval, _ = np.polyfit(t_years, phase_eval_diff_centered, 1)
print(f"  Drift: {slope_eval / float(F0) * 1e9:.6f} ns/yr")

# If this is still not zero, the difference is in F0/F1/F2 precision
# or in PINT's taylor_horner implementation

# Let's check F values with more precision
print(f"\nSpindown parameter precision:")
pint_F0 = pint_model.F0.quantity.value
pint_F1 = pint_model.F1.quantity.value
pint_F2 = pint_model.F2.quantity.value if hasattr(pint_model, 'F2') and pint_model.F2.quantity is not None else 0.0

print(f"  PINT F0: {pint_F0:.25e}")
print(f"  JUG  F0: {float(F0):.25e}")
print(f"  Diff:    {(pint_F0 - float(F0)):.6e} Hz")

print(f"\n  PINT F1: {pint_F1:.25e}")
print(f"  JUG  F1: {float(F1):.25e}")
print(f"  Diff:    {(pint_F1 - float(F1)):.6e} Hz/s")

# The F0 difference of ~26 fHz could cause timing errors
# Over T years: phase error = dF0 * T * 365.25 * 86400
# Time error = phase_error / F0
T_span = t_years.max() - t_years.min()
f0_phase_error = (pint_F0 - float(F0)) * T_span * 365.25 * 86400  # cycles
f0_time_error_ns = f0_phase_error / pint_F0 * 1e9
print(f"\nF0 difference effect over {T_span:.1f} year span: {f0_time_error_ns:.3f} ns")

PINT vs JUG PHASE CALCULATION COMPARISON
Raw phase comparison (first 5 TOAs):
  PINT: [-1.44177289e+10 -1.44177289e+10 -1.44177289e+10 -1.44177289e+10
 -1.44177289e+10]
  JUG:  [-1.44177289e+10 -1.44177289e+10 -1.44177289e+10 -1.44177289e+10
 -1.44177289e+10]
  Diff: [-3.81469727e-05 -3.81469727e-05 -3.62396240e-05 -3.81469727e-05
 -3.62396240e-05]

Raw phase difference (JUG - PINT):
  Mean: -0.000040 cycles
  RMS (centered): 0.000002833 cycles
  RMS in ns: 8.3477 ns
  Drift: -2.7372 ns/yr

dt comparison:
  JUG PEPOCH: 59017.999753870499262
  PINT PEPOCH: 59017.999753870499262
  PEPOCH diff: 122.168 ns

Phase evaluation comparison (using same dt as PINT):
  RMS diff: 7.662033 ns
  Drift: -2.384454 ns/yr

Spindown parameter precision:
  PINT F0: 3.3931569191904065974085825e+02
  JUG  F0: 3.3931569191904065974085825e+02
  Diff:    2.575717e-14 Hz

  PINT F1: -1.6147400369092966924824765e-15
  JUG  F1: -1.6147400369092966924824765e-15
  Diff:    -3.611119e-32 Hz/s

F0 difference effect ov

In [27]:
# === FIX: USE PINT's HIGH-PRECISION PARAMETERS ===
print("="*70)
print("TEST: Using PINT's high-precision parameters")
print("="*70)

# Get PINT's parameters with full precision
PEPOCH_PINT = pint_model.PEPOCH.quantity.tdb.mjd_long
F0_PINT = np.longdouble(pint_model.F0.quantity.value)
F1_PINT = np.longdouble(pint_model.F1.quantity.value)

# Check if F2 exists
try:
    F2_PINT = np.longdouble(pint_model.F2.quantity.value) if pint_model.F2.quantity is not None else np.longdouble(0.0)
except AttributeError:
    F2_PINT = np.longdouble(0.0)

print(f"Using PINT's parameters:")
print(f"  PEPOCH: {float(PEPOCH_PINT):.15f}")
print(f"  F0: {float(F0_PINT):.20f}")
print(f"  F1: {float(F1_PINT):.25e}")

# Recompute phase with PINT's parameters and delays
dt_sec_fixed = (tdbld_ld - PEPOCH_PINT) * np.longdouble(SECS_PER_DAY) - pint_delay_ld
phase_fixed = F0_PINT * dt_sec_fixed + np.longdouble(0.5) * F1_PINT * dt_sec_fixed**2 + np.longdouble(1.0/6.0) * F2_PINT * dt_sec_fixed**3

# TZR phase with PINT's parameters
tzr_dt_sec_fixed = (TZRMJD_TDB - PEPOCH_PINT) * np.longdouble(SECS_PER_DAY) - np.longdouble(pint_tzr_delay)
phase_tzr_fixed = F0_PINT * tzr_dt_sec_fixed + np.longdouble(0.5) * F1_PINT * tzr_dt_sec_fixed**2 + np.longdouble(1.0/6.0) * F2_PINT * tzr_dt_sec_fixed**3

# Fractional phase
frac_phase_fixed = phase_fixed - phase_tzr_fixed
frac_phase_fixed = np.mod(frac_phase_fixed + 0.5, 1.0) - 0.5

# Convert to residuals
fixed_residuals_us = np.array(frac_phase_fixed / F0_PINT * 1e6, dtype=np.float64)

# Compare with PINT
fixed_centered = fixed_residuals_us - np.mean(fixed_residuals_us)
fixed_diff_ns = (fixed_centered - pint_centered) * 1000

print(f"\nWith PINT's parameters + PINT's delays:")
print(f"  RMS diff: {np.std(fixed_diff_ns):.4f} ns")

slope_fixed, _ = np.polyfit(t_years, np.array(fixed_diff_ns, dtype=np.float64), 1)
print(f"  Drift: {slope_fixed:.4f} ns/yr")

# Now use JUG's delays with PINT's parameters
jug_delay_ld_for_test = np.array(jug_total_delay_sec, dtype=np.longdouble)
dt_sec_jug_delay = (tdbld_ld - PEPOCH_PINT) * np.longdouble(SECS_PER_DAY) - jug_delay_ld_for_test
phase_jug_delay = F0_PINT * dt_sec_jug_delay + np.longdouble(0.5) * F1_PINT * dt_sec_jug_delay**2 + np.longdouble(1.0/6.0) * F2_PINT * dt_sec_jug_delay**3

frac_phase_jug_delay = phase_jug_delay - phase_tzr_fixed
frac_phase_jug_delay = np.mod(frac_phase_jug_delay + 0.5, 1.0) - 0.5

jug_delay_residuals_us = np.array(frac_phase_jug_delay / F0_PINT * 1e6, dtype=np.float64)
jug_delay_centered = jug_delay_residuals_us - np.mean(jug_delay_residuals_us)
jug_delay_diff_ns = (jug_delay_centered - pint_centered) * 1000

print(f"\nWith PINT's parameters + JUG's delays:")
print(f"  RMS diff: {np.std(jug_delay_diff_ns):.4f} ns")

slope_jug_delay, _ = np.polyfit(t_years, np.array(jug_delay_diff_ns, dtype=np.float64), 1)
print(f"  Drift: {slope_jug_delay:.4f} ns/yr")

if np.std(fixed_diff_ns) < 0.5:
    print("\n✅ With PINT's parameters, phase calculation matches perfectly!")
    print("   The issue was parameter precision, not the algorithm.")
    
# Summary
print("\n" + "="*70)
print("SUMMARY: ROOT CAUSE OF 5.9 ns DIFFERENCE")
print("="*70)
print(f"""
The ~6 ns RMS difference and -2.7 ns/yr drift come from:

1. PEPOCH precision: JUG's float64 parsing loses ~122 ns precision
2. F0 precision: JUG's float64 parsing loses ~2.6e-14 Hz
   → This causes a ~15 ns drift over the 6.3-year data span

3. Delay differences: JUG's delays differ by ~8.6 ns (mean), 2.5 ns RMS
   → This is from minor ephemeris/computation differences

SOLUTION: Use PINT's parameter values directly (or parse with higher precision)
or accept this level of agreement as sufficient for most purposes.

Current status:
  - With PINT params + PINT delays: {np.std(fixed_diff_ns):.2f} ns RMS, {slope_fixed:.2f} ns/yr drift
  - With PINT params + JUG delays:  {np.std(jug_delay_diff_ns):.2f} ns RMS, {slope_jug_delay:.2f} ns/yr drift  
  - With JUG params + JUG delays:   5.9 ns RMS, -2.7 ns/yr drift
""")

TEST: Using PINT's high-precision parameters
Using PINT's parameters:
  PEPOCH: 59017.999753870499262
  F0: 339.31569191904065974086
  F1: -1.6147400369092966924824765e-15

With PINT's parameters + PINT's delays:
  RMS diff: 0.0051 ns
  Drift: 0.0000 ns/yr

With PINT's parameters + JUG's delays:
  RMS diff: 2.5483 ns
  Drift: -0.3469 ns/yr

✅ With PINT's parameters, phase calculation matches perfectly!
   The issue was parameter precision, not the algorithm.

SUMMARY: ROOT CAUSE OF 5.9 ns DIFFERENCE

The ~6 ns RMS difference and -2.7 ns/yr drift come from:

1. PEPOCH precision: JUG's float64 parsing loses ~122 ns precision
2. F0 precision: JUG's float64 parsing loses ~2.6e-14 Hz
   → This causes a ~15 ns drift over the 6.3-year data span

3. Delay differences: JUG's delays differ by ~8.6 ns (mean), 2.5 ns RMS
   → This is from minor ephemeris/computation differences

SOLUTION: Use PINT's parameter values directly (or parse with higher precision)
or accept this level of agreement as suf