In [None]:
# === CONFIGURATION ===
PULSAR_NAME = "J1909-3744"
DATA_DIR = "/home/mattm/projects/MPTA/partim/production/fifth_pass/tdb"
CLOCK_DIR = "data/clock"  # Directory containing clock files

In [None]:
# === IMPORTS ===
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import numpy as np
import jax
import jax.numpy as jnp
from pathlib import Path
from dataclasses import dataclass
from typing import Dict, Any, List, Tuple
from bisect import bisect_left

from astropy import units as u
from astropy.time import Time, TimeDelta
from astropy.coordinates import EarthLocation

jax.config.update('jax_enable_x64', True)
print(f"JAX {jax.__version__} ready (float64 enabled)")

In [None]:
# === CONSTANTS ===
SECS_PER_DAY = 86400.0
MJD_JD_OFFSET = 2400000.5
K_DM_SEC = 1.0 / 2.41e-4

OBSERVATORIES = {
    'meerkat': np.array([5109360.133, 2006852.586, -3238948.127]) / 1000,
}

HIGH_PRECISION_PARAMS = {'F0', 'F1', 'F2', 'PEPOCH', 'TZRMJD'}

In [None]:
# === FILE PARSING ===

def parse_par_file(path: Path) -> Dict[str, Any]:
    """Parse par file with high precision for critical parameters."""
    params = {}
    with open(path) as f:
        for line in f:
            line = line.strip()
            if not line or line.startswith('#'):
                continue
            parts = line.split()
            if len(parts) >= 2:
                key = parts[0].upper()
                try:
                    params[key] = float(parts[1])
                except ValueError:
                    params[key] = parts[1]
    return params


def parse_ra(ra_str: str) -> float:
    """Parse RA string (HH:MM:SS.sss) to radians."""
    parts = ra_str.split(':')
    h, m, s = float(parts[0]), float(parts[1]), float(parts[2])
    return (h + m/60 + s/3600) * 15 * np.pi / 180


def parse_dec(dec_str: str) -> float:
    """Parse DEC string (DD:MM:SS.sss) to radians."""
    parts = dec_str.split(':')
    sign = -1 if parts[0].startswith('-') else 1
    d, m, s = abs(float(parts[0])), float(parts[1]), float(parts[2])
    return sign * (d + m/60 + s/3600) * np.pi / 180


@dataclass
class SimpleTOA:
    mjd_int: int
    mjd_frac: float
    freq_mhz: float


def parse_tim_file(path: Path) -> List[SimpleTOA]:
    """Parse TIM file to extract MJD and frequency."""
    toas = []
    with open(path) as f:
        for line in f:
            line = line.strip()
            if not line or line.startswith(('#', 'FORMAT', 'C ', 'JUMP', 'PHASE', 'MODE', 'INCLUDE')):
                continue
            parts = line.split()
            if len(parts) < 5:
                continue
            
            mjd_str = parts[2]
            if '.' in mjd_str:
                int_str, frac_str = mjd_str.split('.')
                mjd_int = int(int_str)
                mjd_frac = float('0.' + frac_str)
            else:
                mjd_int = int(mjd_str)
                mjd_frac = 0.0
            
            freq_mhz = float(parts[3])
            toas.append(SimpleTOA(mjd_int, mjd_frac, freq_mhz))
    
    return toas


def parse_clock_file(path: Path) -> Dict[str, np.ndarray]:
    """Parse tempo2-style clock file."""
    mjds, offsets = [], []
    with open(path) as f:
        for line in f:
            line = line.strip()
            if not line or line.startswith('#'):
                continue
            parts = line.split()
            if len(parts) >= 2:
                try:
                    mjds.append(float(parts[0]))
                    offsets.append(float(parts[1]))
                except ValueError:
                    continue
    return {'mjd': np.array(mjds), 'offset': np.array(offsets)}


print("File parsing functions defined")

In [None]:
# === TDB CALCULATION ===

def interpolate_clock(clock_data: Dict, mjd: float) -> float:
    """Linear interpolation of clock correction."""
    mjds = clock_data['mjd']
    offsets = clock_data['offset']
    
    if len(mjds) == 0:
        return 0.0
    if mjd <= mjds[0]:
        return offsets[0]
    if mjd >= mjds[-1]:
        return offsets[-1]
    
    idx = bisect_left(mjds, mjd)
    if idx == 0:
        return offsets[0]
    
    mjd0, mjd1 = mjds[idx-1], mjds[idx]
    off0, off1 = offsets[idx-1], offsets[idx]
    frac = (mjd - mjd0) / (mjd1 - mjd0)
    return off0 + frac * (off1 - off0)


def compute_tdb(mjd_ints: np.ndarray, mjd_fracs: np.ndarray,
                mk_clock: Dict, gps_clock: Dict, bipm_clock: Dict,
                location: EarthLocation) -> np.ndarray:
    """Compute TDB from UTC using standalone clock chain."""
    mjd_vals = np.array(mjd_ints, dtype=np.float64) + np.array(mjd_fracs, dtype=np.float64)
    
    # Clock corrections
    mk_corrs = np.array([interpolate_clock(mk_clock, m) for m in mjd_vals])
    gps_corrs = np.array([interpolate_clock(gps_clock, m) for m in mjd_vals])
    bipm_corrs = np.interp(mjd_vals, bipm_clock['mjd'], bipm_clock['offset']) - 32.184
    total_corrs = mk_corrs + gps_corrs + bipm_corrs
    
    # Apply corrections and convert to TDB
    time_utc = Time(val=np.array(mjd_ints, dtype=np.float64),
                    val2=np.array(mjd_fracs, dtype=np.float64),
                    format='pulsar_mjd', scale='utc',
                    location=location, precision=9)
    time_utc = time_utc + TimeDelta(total_corrs, format='sec')
    
    return time_utc.tdb.mjd


print("TDB calculation functions defined")

In [None]:
# === JAX DELAY FUNCTIONS ===

@jax.jit
def compute_delays(tdbld, freq_bary, dm_coeffs, dm_factorials, dm_epoch):
    """Compute DM delay (simplified - add more delays as needed)."""
    # DM delay
    dt_years = (tdbld - dm_epoch) / 365.25
    powers = jnp.arange(len(dm_coeffs))
    dt_powers = dt_years[:, jnp.newaxis] ** powers[jnp.newaxis, :]
    dm_eff = jnp.sum(dm_coeffs * dt_powers / dm_factorials, axis=1)
    dm_sec = K_DM_SEC * dm_eff / (freq_bary ** 2)
    
    return dm_sec


print("JAX delay functions defined")

In [None]:
# === RESIDUAL CALCULATOR ===

class JUGResidualCalculator:
    """Fast, independent pulsar timing residual calculator."""
    
    def __init__(self, par_params: Dict[str, Any], utc_mjd: np.ndarray,
                 freq_mhz: np.ndarray, location: EarthLocation,
                 clock_data: Tuple[Dict, Dict, Dict]):
        
        # Timing parameters
        self.F0 = np.longdouble(par_params['F0'])
        self.F1 = np.longdouble(par_params.get('F1', 0.0))
        self.F2 = np.longdouble(par_params.get('F2', 0.0))
        pepoch_mjd = np.longdouble(par_params['PEPOCH'])
        self.PEPOCH_sec = pepoch_mjd * np.longdouble(SECS_PER_DAY)
        
        # Compute TDB
        utc_mjd_int = np.floor(utc_mjd).astype(int)
        utc_mjd_frac = utc_mjd - utc_mjd_int
        mk_clock, gps_clock, bipm_clock = clock_data
        tdb_mjd = compute_tdb(utc_mjd_int, utc_mjd_frac, mk_clock, gps_clock, bipm_clock, location)
        
        # Store as longdouble MJD (preserves precision)
        self.tdbld_mjd = np.array(tdb_mjd, dtype=np.longdouble)
        
        # Pulsar direction
        ra_rad = parse_ra(par_params['RAJ'])
        dec_rad = parse_dec(par_params['DECJ'])
        self.l_hat = np.array([
            np.cos(dec_rad) * np.cos(ra_rad),
            np.cos(dec_rad) * np.sin(ra_rad),
            np.sin(dec_rad)
        ])
        
        # DM parameters
        dm = par_params.get('DM', 0.0)
        self.dm_coeffs = jnp.array([dm])
        self.dm_factorials = jnp.array([1.0])
        self.dm_epoch = jnp.array(par_params.get('DMEPOCH', pepoch_mjd))
        
        # Convert to JAX arrays
        self.tdbld_jax = jnp.array(self.tdbld_mjd, dtype=jnp.float64)
        self.freq_jax = jnp.array(freq_mhz, dtype=jnp.float64)
        
        # TZR calculation
        if 'TZRMJD' in par_params:
            TZRMJD_TDB = np.longdouble(par_params['TZRMJD'])
            tzr_delay = float(compute_delays(
                jnp.array([TZRMJD_TDB]),
                jnp.array([freq_mhz[0]]),
                self.dm_coeffs, self.dm_factorials, self.dm_epoch
            )[0])
        else:
            raise ValueError("TZRMJD parameter required in par file")
        
        tzr_dt_sec = TZRMJD_TDB * np.longdouble(SECS_PER_DAY) - self.PEPOCH_sec - np.longdouble(tzr_delay)
        self.tzr_phase = self.F0 * tzr_dt_sec + (self.F1/2) * tzr_dt_sec**2 + (self.F2/6) * tzr_dt_sec**3
        
        # Warmup JAX
        _ = compute_delays(self.tdbld_jax, self.freq_jax,
                          self.dm_coeffs, self.dm_factorials, self.dm_epoch).block_until_ready()
    
    def compute_residuals(self) -> np.ndarray:
        """Compute timing residuals in microseconds."""
        # Compute delays
        delay_jax = compute_delays(self.tdbld_jax, self.freq_jax,
                                   self.dm_coeffs, self.dm_factorials, self.dm_epoch)
        delay_ld = np.asarray(delay_jax, dtype=np.longdouble)
        
        # Convert TDB to seconds (preserves precision)
        tdbld_sec = self.tdbld_mjd * np.longdouble(SECS_PER_DAY)
        dt_sec = tdbld_sec - self.PEPOCH_sec - delay_ld
        
        # Compute phase
        phase = self.F0 * dt_sec + (self.F1/2) * dt_sec**2 + (self.F2/6) * dt_sec**3
        
        # Wrap and convert to microseconds
        frac_phase = np.mod(phase - self.tzr_phase + 0.5, 1.0) - 0.5
        residuals_us = np.asarray(frac_phase * 1e6 / self.F0, dtype=np.float64)
        
        return residuals_us


print("JUG residual calculator defined")

In [None]:
# === LOAD DATA ===

print("Loading data...")

# Parse files
par_file = Path(DATA_DIR) / f"{PULSAR_NAME}_tdb.par"
tim_file = Path(DATA_DIR) / f"{PULSAR_NAME}.tim"

par_params = parse_par_file(par_file)
toas = parse_tim_file(tim_file)

utc_mjd = np.array([toa.mjd_int + toa.mjd_frac for toa in toas], dtype=np.float64)
freq_mhz = np.array([toa.freq_mhz for toa in toas], dtype=np.float64)

print(f"Loaded {len(toas)} TOAs for {par_params.get('PSRJ', PULSAR_NAME)}")

# Load clock files
clock_files_dir = Path('clock_files')
clock_files_dir.mkdir(exist_ok=True)

bipm_path = clock_files_dir / 'tai2tt_bipm2024.clk'
if not bipm_path.exists():
    import urllib.request
    url = 'https://raw.githubusercontent.com/ipta/pulsar-clock-corrections/main/T2runtime/clock/tai2tt_bipm2024.clk'
    print(f"Downloading BIPM2024 clock file...")
    urllib.request.urlretrieve(url, bipm_path)

bipm_clock = parse_clock_file(bipm_path)
mk_clock = parse_clock_file(Path(CLOCK_DIR) / 'mk2utc.clk')
gps_clock = parse_clock_file(Path(CLOCK_DIR) / 'gps2utc.clk')

# Observatory location
obs_itrf = OBSERVATORIES['meerkat']
location = EarthLocation.from_geocentric(obs_itrf[0]*u.km, obs_itrf[1]*u.km, obs_itrf[2]*u.km)

print("Data loaded successfully")

In [None]:
# === COMPUTE RESIDUALS ===

import time

print("Initializing calculator...")
t_start = time.perf_counter()

calc = JUGResidualCalculator(
    par_params=par_params,
    utc_mjd=utc_mjd,
    freq_mhz=freq_mhz,
    location=location,
    clock_data=(mk_clock, gps_clock, bipm_clock)
)

print(f"Initialization: {time.perf_counter() - t_start:.3f} s")

print("\nComputing residuals...")
t_start = time.perf_counter()
residuals = calc.compute_residuals()
print(f"Computation: {(time.perf_counter() - t_start)*1000:.3f} ms")

print(f"\nResiduals RMS: {np.std(residuals):.3f} µs")
print(f"Residuals range: [{np.min(residuals):.3f}, {np.max(residuals):.3f}] µs")

In [None]:
# === PLOT RESIDUALS ===

import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.plot(residuals, 'b.', alpha=0.5, markersize=3)
ax1.set_xlabel('TOA index')
ax1.set_ylabel('Residual (µs)')
ax1.set_title(f'Timing Residuals (RMS: {np.std(residuals):.3f} µs)')
ax1.grid(True, alpha=0.3)

ax2.hist(residuals, bins=50, alpha=0.7, edgecolor='black')
ax2.set_xlabel('Residual (µs)')
ax2.set_ylabel('Count')
ax2.set_title('Residual Distribution')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()