# $\phi(\rho)$

This step is defining the "get_phi" function, so that we don't need to import the "enterprise" package

In [5]:
import numpy as np
from dataclasses import dataclass
from typing import List, Dict, Optional
import numpy as np
from dataclasses import dataclass
from typing import List, Dict, Optional
import scipy.constants as sc

###### physical time constants: seconds per Julian year/day; FYR = 1/yr (in Hz)
yr = sc.Julian_year
day = sc.day
FYR= 1.0 / yr

def _unitvec_from_model_entry(entry):
    """
    entry: pta_model[name] dict that contains at least 'phi' and 'theta' (in radians).
    Convention used here is the standard spherical one:
      x = sin(theta) * cos(phi), y = sin(theta) * sin(phi), z = cos(theta).
    """
    phi   = float(entry['phi'])
    theta = float(entry['theta'])

    th = theta
    ph = phi
    c, s = np.cos, np.sin
    return np.array([s(th)*c(ph), s(th)*s(ph), c(th)], dtype=float)

# ---------------------- Hellings–Downs ORF ----------------------
def hellings_downs_orf(unitvecs: List[np.ndarray]) -> np.ndarray:
    """
    Build Γ using the Hellings–Downs curve. Off-diagonals: Γ_ab = (3/2)x ln x - x/4 + 1/2,
    where x = (1 - cos ζ_ab)/2 and cos ζ_ab = u_a · u_b. Diagonal is 1 here.
    """
    P = len(unitvecs)
    G = np.eye(P, dtype=float)
    for a in range(P):
        for b in range(a+1, P):
            ua, ub = unitvecs[a], unitvecs[b]
            cosang = float(np.clip(np.dot(ua, ub), -1.0, 1.0))
            x = 0.5 * (1.0 - cosang)
            val = 0.5 if x <= 0 else 1.5*x*np.log(x) - 0.25*x + 0.5
            G[a, b] = G[b, a] = val
    return G

# ---------------------- PSD → discrete prior variances ----------------------
def _rho_flat_tail(log10_A, gamma, f_Hz, df_Hz, log10_kappa=None):
    """
    Numerically equivalent to enterprise.powerlaw_flat_tail but implemented
    with explicit SI units (no astropy objects):
      PSD(f) = (A^2 / (12π^2)) * FYR^(-3) * (fc/FYR)^(-γ) * [1 + (f/fc)^2]^(-γ/2)   [s^3]
      rho    = max(PSD, flat_psd) * df_Hz                                              [s^2]
    If log10_kappa is None, no flat floor is applied. 'f_Hz' and 'df_Hz' are
    length-2K arrays where each Fourier frequency is duplicated for cosine/sine.
    """
    A      = 10.0**log10_A
    GAM    = float(gamma)
    FYR_Hz = FYR                                       # 1/yr in Hz (s^-1)
    logfc  = -1.879181246047625
    fc_Hz  = (10.0**logfc) * FYR_Hz                    # turnover frequency in Hz

    # PSD(f) in seconds^3
    norm_s3 = (A*A) / (12.0*np.pi**2) * (FYR_Hz**(-3.0)) * ((10.0**logfc)**(-GAM))
    shape   = (1.0 + (f_Hz/fc_Hz)**2.0)**(-0.5*GAM)
    PSD_s3  = norm_s3 * shape

    # Optional flat floor at the PSD level: flat_psd ~ 10^(2 kappa) * yr^3 in s^3
    if log10_kappa is not None:
        flat_psd_s3 = (10.0**(2.0*log10_kappa)) * (FYR_Hz**(-3.0))
        PSD_s3 = np.maximum(PSD_s3, flat_psd_s3)

    # Discretization: rho = PSD * (Δf in Hz), giving s^2 for each Fourier coefficient
    rho_s2 = PSD_s3 * df_Hz
    return rho_s2.astype(float)


# ------------ PTA_Lite reproduces enterprise's φ structure at H0/H1 ------------
@dataclass
class PTA_Lite:
    """
    Matches enterprise's Fourier construction:
      - Per pulsar: 2K Fourier coefficients (cos/sin pairing for K positive freqs).
      - RN (intrinsic red noise): uses all K freqs → 2K coeffs.
      - GW (common red process): only first G freqs are active → 2G coeffs, then zero-pad to 2K.
      - H0 (mode='curn'): returns a list of per-pulsar diagonal vectors = ρ_RN + ρ_GW_padded.
      - H1 (mode='hd')  : returns the full matrix Φ_C = I⊗diag(ρ_RN) + Γ⊗diag(ρ_GW_padded).
    """
    pta_model: Dict[str, Dict]

    components: int = 10         # K (number of positive Fourier freqs per pulsar)
    gw_components: int = 5       # G (how many GW freqs are nonzero)
    rn_name: str = "red_noise"
    mode: str = "curn"           # 'curn' (no cross-pulsar corr) or 'hd' (HD-ORF)

    def __post_init__(self):
        ###### Fix pulsar order and compute a unified Tspan (seconds)
        names = list(self.pta_model.keys())               # e.g. ['J0613-0200', 'J1713+0747', ...]
        self.pulsars = names
        self.Tspan = float(max(d['Tspan'] for d in self.pta_model.values())) * sc.Julian_year

        ###### Ensure all pulsars share the same number of Fourier freqs K
        K = int(next(iter(self.pta_model.values()))['nfrequencies'])
        if not all(int(d['nfrequencies']) == K for d in self.pta_model.values()):
            raise ValueError("All pulsars must share the same nfrequencies (K).")
        self.components = K

        ###### Build the frequency grid and step sizes in Hz, then duplicate each for cos/sin → length 2K
        base    = np.arange(1, self.components+1, dtype=float) / self.Tspan   # ν_k = k / Tspan
        df_base = np.diff(np.concatenate(([0.0], base)))                       # Δν_k = ν_k - ν_{k-1}
        self.f_rep  = np.repeat(base,   2)   # [f1,f1, f2,f2, ..., fK,fK]
        self.df_rep = np.repeat(df_base, 2)  # [Δf1,Δf1, ..., ΔfK,ΔfK]

        ###### Select the first G freqs (2G coeffs) for the common GW process, zero beyond 2G
        g = int(self.gw_components)
        self.f_rep_gw  = self.f_rep[:2*g]
        self.df_rep_gw = self.df_rep[:2*g]

        ###### Build the ORF Γ for 'hd' mode from sky positions; otherwise we keep Γ as identity (unused)
        unitvecs = [ _unitvec_from_model_entry(self.pta_model[name]) for name in self.pulsars ]
        if self.mode == "hd" and all(v is not None for v in unitvecs):
            self.Gamma = hellings_downs_orf(unitvecs) if self.mode == "hd" else np.eye(len(self.pulsars))
        else:
            if self.mode == "hd":
                print("[PTA_Lite] Warning: missing/ambiguous sky positions; Γ=I.")
            self.Gamma = np.eye(len(self.pta_model.keys()), dtype=float)

    def get_phi(self, par_dict: Dict[str, float]):
        """
        Return φ_N (H0, list of per-psr diagonals) when mode='curn';
        Return φ_C (H1, full matrix) when mode='hd'.
        Naming convention for per-psr RN parameters in 'par_dict':
          <psrname>_<rn_name>_log10_A,  <psrname>_<rn_name>_gamma,  <psrname>_<rn_name>_log10_kappa
        Global GW parameters: 'gw_log10_A', 'gw_gamma'.
        """
        P = len(self.pulsars); K = self.components; G = self.gw_components

        ###### Common GW spectrum (shared A, γ), with NO flat floor (log10_kappa=None)
        try:
            gw_log10_A = float(par_dict["gw_log10_A"])
            gw_gamma   = float(par_dict["gw_gamma"])
        except KeyError as e:
            raise KeyError(f"missing {e} (need 'gw_log10_A' & 'gw_gamma')")
        gw_kappa_fixed = None
        rho_gw_head = _rho_flat_tail(gw_log10_A, gw_gamma, self.f_rep_gw, self.df_rep_gw, gw_kappa_fixed)  # (2G,)
        pad_len = 2*K - 2*G
        rho_gw_padded = np.concatenate([rho_gw_head, np.zeros(pad_len, dtype=float)])  # (2K,)

        ###### Per-pulsar RN: each has its own (A, γ, κ); build both the H0 diagonal and H1 RN blocks
        per_psr_diag = []    # for H0: per-psr vector of length 2K = ρ_RN + ρ_GW_padded
        rn_blocks = []       # for H1: list of diag(ρ_RN) to blkdiag later
        for name in self.pulsars:
            kA = f"{name}_{self.rn_name}_log10_A"
            kg = f"{name}_{self.rn_name}_gamma"
            kk = f"{name}_{self.rn_name}_log10_kappa"
            if kA not in par_dict or kg not in par_dict or kk not in par_dict:
                raise KeyError(f"missing RN keys for {name}: '{kA}', '{kg}' (optional '{name}_{self.rn_name}_log10_kappa')")

            rn_log10_A     = float(par_dict[kA])
            rn_gamma       = float(par_dict[kg])
            rn_log10_kappa = float(par_dict[kk])

            rho_rn = _rho_flat_tail(rn_log10_A, rn_gamma, self.f_rep, self.df_rep, rn_log10_kappa)  # (2K,)
            per_psr_diag.append(rho_rn + rho_gw_padded)
            rn_blocks.append(np.diag(rho_rn))

        ###### H0 (curn): return a list (length P) of per-psr diagonals (each length 2K)
        if self.mode == "curn":
            return per_psr_diag

        ###### H1 (hd): Φ_C = I⊗diag(ρ_RN) + Γ⊗diag(ρ_GW_padded)
        phi = _block_diag(*rn_blocks)                             # I ⊗ diag(ρ_RN)
        phi += np.kron(self.Gamma, np.diag(rho_gw_padded))        # Γ ⊗ diag(ρ_GW_padded)
        return phi

def _block_diag(*mats):
    """
    Manually assemble a block-diagonal matrix from a list of square matrices.
    Equivalent to scipy.linalg.block_diag but without the dependency.
    """
    if not mats: return np.zeros((0,0), float)
    r = sum(M.shape[0] for M in mats); c = sum(M.shape[1] for M in mats)
    out = np.zeros((r,c), float); i=j=0
    for M in mats:
        rr,cc = M.shape; out[i:i+rr, j:j+cc] = M; i+=rr; j+=cc
    return out

In [6]:
K = next(iter(pta_model.values()))['nfrequencies']
P = len(pta_model)
par_dict= big_array[-1]
pta_h0 = PTA_Lite(pta_model, mode='curn', components=K, gw_components=5, rn_name='red_noise')
pta_h1 = PTA_Lite(pta_model, mode='hd',   components=K, gw_components=5, rn_name='red_noise')

phiN_list = pta_h0.get_phi(par_dict)     # List length P, each (2K,)
phiC      = pta_h1.get_phi(par_dict)     # (2K*P, 2K*P)