# Time-based residuals (JAX)

Minimal notebook that computes Tempo2/PINT-style time-domain residuals using JAX/JIT for the J1909-3744 par/tim pair.

In [1]:
import sys
from pathlib import Path

import jax
import jax.numpy as jnp
import numpy as np

# Use local PINT checkout to parse par/tim
sys.path.insert(0, str(Path('../PINT/src').resolve()))
from pint import models
from pint.toa import get_TOAs
from pint import residuals as pint_residuals

SECS_PER_DAY = 86400.0

par_path = Path('/home/mattm/projects/MPTA/partim/production/fifth_pass/J1909-3744.par')
tim_path = Path('/home/mattm/projects/MPTA/partim/production/fifth_pass/J1909-3744.tim')

model_pint = models.get_model(str(par_path), allow_tcb=True)
toas = get_TOAs(str(tim_path), model=model_pint)
print(f"Loaded {len(toas)} TOAs")

# PINT residuals for comparison
pint_res = pint_residuals.Residuals(toas, model_pint)
pint_time_us = pint_res.time_resids.to('us').value
print(f"PINT RMS (us): {np.sqrt(np.mean(pint_time_us**2)):.6f}")
print(f"PINT first 5 (us): {pint_time_us[:5]}")

  warn("PINT does not support 'DILATEFREQ Y'")
  warn("PINT only supports 'TIMEEPH FB90'")
[32m2025-11-27 02:48:42.135[0m | [34m[1mDEBUG   [0m | [36mpint.toa[0m:[36mget_TOAs[0m:[36m195[0m - [34m[1mUsing EPHEM = DE440 from the given model[0m
[32m2025-11-27 02:48:42.135[0m | [34m[1mDEBUG   [0m | [36mpint.toa[0m:[36mget_TOAs[0m:[36m211[0m - [34m[1mUsing CLOCK = BIPM2024 from the given model[0m
[32m2025-11-27 02:48:42.135[0m | [34m[1mDEBUG   [0m | [36mpint.toa[0m:[36mget_TOAs[0m:[36m232[0m - [34m[1mUsing PLANET_SHAPIRO = True from the given model[0m
[32m2025-11-27 02:48:43.135[0m | [34m[1mDEBUG   [0m | [36mpint.toa[0m:[36m__init__[0m:[36m1377[0m - [34m[1mNo pulse number flags found in the TOAs[0m
[32m2025-11-27 02:48:43.145[0m | [34m[1mDEBUG   [0m | [36mpint.toa[0m:[36mapply_clock_corrections[0m:[36m2232[0m - [34m[1mApplying clock corrections (include_bipm = True)[0m
[32m2025-11-27 02:48:43.332[0m | [1mINFO    [0m | 

Loaded 10408 TOAs


[32m2025-11-27 02:48:45.952[0m | [34m[1mDEBUG   [0m | [36mpint.models.absolute_phase[0m:[36mget_TZR_toa[0m:[36m101[0m - [34m[1mCreating and dealing with the single TZR_toa for absolute phase[0m
[32m2025-11-27 02:48:45.953[0m | [34m[1mDEBUG   [0m | [36mpint.toa[0m:[36m__init__[0m:[36m1377[0m - [34m[1mNo pulse number flags found in the TOAs[0m
[32m2025-11-27 02:48:45.953[0m | [34m[1mDEBUG   [0m | [36mpint.toa[0m:[36mapply_clock_corrections[0m:[36m2232[0m - [34m[1mApplying clock corrections (include_bipm = True)[0m
[32m2025-11-27 02:48:45.954[0m | [1mINFO    [0m | [36mpint.observatory[0m:[36mgps_correction[0m:[36m230[0m - [1mApplying GPS to UTC clock correction (~few nanoseconds)[0m
[32m2025-11-27 02:48:45.955[0m | [1mINFO    [0m | [36mpint.observatory[0m:[36mbipm_correction[0m:[36m245[0m - [1mApplying TT(TAI) to TT(BIPM2023) clock correction (~27 us)[0m
[32m2025-11-27 02:48:45.955[0m | [1mINFO    [0m | [36mpint.observ

PINT RMS (us): 0.820289
PINT first 5 (us): [-1.95642748 -0.93402091 -0.98108717 -1.19099971 -0.12538321]


In [2]:
# Extract barycentric TDB times and a simple spin model
bary_mjd = np.array([t.value for t in toas.get_mjds(high_precision=True)], dtype=np.float64)
freq_mhz = np.array(toas.table['freq'].to('MHz').value, dtype=np.float64)

class SpinModel:
    def __init__(self, f0, f1, tref_mjd):
        self.f0 = float(f0)
        self.f1 = float(f1)
        self.tref_mjd = float(tref_mjd)

spin_model = SpinModel(
    f0=model_pint.F0.value,
    f1=getattr(model_pint, 'F1', 0.0).value if hasattr(model_pint, 'F1') else 0.0,
    tref_mjd=model_pint.PEPOCH.value if hasattr(model_pint, 'PEPOCH') else model_pint.POSEPOCH.value,
)
spin_model.__dict__

{'f0': 339.31569191904083,
 'f1': -1.6147500686892461e-15,
 'tref_mjd': 59017.9997538705}

In [3]:
@jax.jit
def residuals_time_domain(t_bary_mjd, model: SpinModel):
    """Tempo2/PINT-style time residuals: frac(phase)/F0 at barycentric time."""
    dt_sec = (t_bary_mjd - model.tref_mjd) * SECS_PER_DAY
    phase = model.f0 * dt_sec + 0.5 * model.f1 * dt_sec**2
    frac_phase = jnp.mod(phase + 0.5, 1.0) - 0.5
    return frac_phase / model.f0

t_bary_jax = jnp.array(bary_mjd)
res_jax_s = residuals_time_domain(t_bary_jax, spin_model)
res_jax_us = np.array(res_jax_s) * 1e6

rms_jax = np.sqrt(np.mean(res_jax_us**2))
corr = np.corrcoef(res_jax_us, pint_time_us)[0, 1]

print(f"JAX time-based RMS (us): {rms_jax:.6f}")
print(f"JAX first 5 (us): {res_jax_us[:5]}")
print(f"Correlation vs PINT: {corr:.6f}")

W1127 02:48:46.311410  869611 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
W1127 02:48:46.312621  869421 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.


TypeError: Error interpreting argument to <function residuals_time_domain at 0x7e3bf3e379c0> as an abstract array. The problematic value is of type <class '__main__.SpinModel'> and was passed to the function at path model.
This typically means that a jit-wrapped function was called with a non-array argument, and this argument was not marked as static using the static_argnums or static_argnames parameters of jax.jit.