In [1]:
import jax 
import jax.numpy as jnp
import math
from jax import jit, vmap
from functools import partial, lru_cache
from src.models.vsh_model import*
from jax import random
import pandas as pd
from iminuit import Minuit # to perform least square
from src.models.configuration import*
from numpyro.infer import MCMC, NUTS
import numpyro
import numpyro.distributions as dist
from src.data.data_utils import*

  from .autonotebook import tqdm as notebook_tqdm


# Plan 

1. Initial Least Square Fit $\rightarrow$ Fit the VSH model using all data.
2. Compute $X^2$ Values $\rightarrow$ For each QSO
3. Compute the Median $\rightarrow$ Compute the median of all $X$ values.
4. Reject Outliers $\rightarrow$ If $X>\kappa \times \text{ median}(X)$, mark the source as outlier.
5. Refit $\rightarrow$ Exclude outliers and re-run the fit.
6. Iterate $\rightarrow$ Repeate steps (e.g. 2-5) until convergence (i.e. outlier stops changing).

## Procedure In EDR3
1. Perform Least Square $\rightarrow$ estimate VSH coefficients
2. $\kappa$-Clipping $\rightarrow$ remove outliers
3. Perform Bootstrapping $\rightarrow$ quantify the uncertainty of results

## Our approach
Keep step 1 and 2 from EDR3, nut instrad:
1. Perform Least Square $+$ $\kappa$-Clipping $\rightarrow$ filter the data
2. Perform HMC (Bayesian inference) sampling on filtered dataset $\rightarrow$ achieving posterior samples, with VSH coefficient estimate and uncertainities.

In [2]:
# Load data
df = load_qso_dataframe()
angles, obs, error = config_data(df)


In [4]:
lmax = 1
total_params = count_vsh_coeffs(lmax) 
limits = vsh_minuit_limits(lmax=lmax, t_bound=0.01, s_bound=0.002)

# Flat vector theta: [t10, ..., t_lmaxm, s10, ..., s_lmaxm]
theta_init = jnp.zeros(total_params)

# Fix everything except theta
#bound_least_square = partial(least_square, data, obs, error, lmax=lmax, grid=False)

def least_square_wrapper(*theta_flat):
    theta = jnp.array(theta_flat)  # reconstructs the vector from scalars
    return least_square(angles, obs, error, theta, lmax=lmax, grid=False)


m = Minuit(least_square_wrapper, *theta_init)

m.errordef = Minuit.LEAST_SQUARES
for i, name in enumerate(m.parameters):
    m.limits[name] = limits[name]


m.migrad()
print(m.params)
print('Converged to miminum?', m.fmin.is_valid)

┌───┬──────┬───────────┬───────────┬────────────┬────────────┬─────────┬─────────┬───────┐
│   │ Name │   Value   │ Hesse Err │ Minos Err- │ Minos Err+ │ Limit-  │ Limit+  │ Fixed │
├───┼──────┼───────────┼───────────┼────────────┼────────────┼─────────┼─────────┼───────┤
│ 0 │ x0   │  -3.1e-3  │  0.9e-3   │            │            │  -0.01  │  0.01   │       │
│ 1 │ x1   │  -2.0e-3  │  0.1e-3   │            │            │ -0.002  │  0.002  │       │
│ 2 │ x2   │ -10.0e-3  │  0.7e-3   │            │            │  -0.01  │  0.01   │       │
│ 3 │ x3   │  5.1e-3   │  1.0e-3   │            │            │  -0.01  │  0.01   │       │
│ 4 │ x4   │  -0.0005  │  0.0010   │            │            │ -0.002  │  0.002  │       │
│ 5 │ x5   │ -2.000e-3 │ 0.034e-3  │            │            │ -0.002  │  0.002  │       │
└───┴──────┴───────────┴───────────┴────────────┴────────────┴─────────┴─────────┴───────┘
Converged to miminum? False


In [5]:


def compute_X2(alpha, delta, mu_a_obs, mu_d_obs, s_mu_a, s_mu_d, rho, theta, lmax):
    """Compute X^2 residuals for each source."""

    def per_point(alpha_i, delta_i, mu_a_i, mu_d_i, s_a, s_d, r):
        e_a, e_d = basis_vectors(alpha_i, delta_i)
        A = jnp.array([
            [s_a**2, r * s_a * s_d],
            [r * s_a * s_d, s_d**2]
        ])
        V = model_vsh(alpha_i, delta_i, theta, lmax, grid=False)
        V_alpha = jnp.vdot(V, e_a).real
        V_delta = jnp.vdot(V, e_d).real

        D = jnp.array([mu_a_i - V_alpha, mu_d_i - V_delta])
        return D.T @ jnp.linalg.inv(A) @ D

    batched_fn = jnp.vectorize(per_point, signature='(),(),(),(),(),(),()->()')
    return batched_fn(alpha, delta, mu_a_obs, mu_d_obs, s_mu_a, s_mu_d, rho)


In [7]:
ra, dec = angles[0], angles[1]
pmra_obs, pmdec_obs = obs[0], obs[1]
s_pmra, s_pmdec, rho = error[0], error[1], error[2]

coeff = jnp.array([m.values[k] for k in m.parameters])

In [8]:
X2 = compute_X2(ra, dec, pmra_obs, pmdec_obs, s_pmra, s_pmdec, rho, coeff, lmax = 1, )

In [13]:
print(len(X2))
print(jnp.median(X2))
print(jnp.max(X2))
print(jnp.min(X2))

1215942
1.5344973
25.034359
7.8481844e-07


In [16]:
def robust_least_squares_fit(angles, obs, error, theta_init, lmax, t_bound, s_bound, kappa=3.0, max_iter=10):

    limits = vsh_minuit_limits(lmax=lmax, t_bound=t_bound, s_bound=s_bound)

    alpha, delta = angles
    mu_a_obs, mu_d_obs = obs
    s_mu_a, s_mu_d, rho = error

    keep = jnp.ones_like(alpha, dtype=bool)
    theta = theta_init

    prev_outliers = None

    for iteration in range(max_iter):
        alpha_k, delta_k = alpha[keep], delta[keep]
        obs_k = (mu_a_obs[keep], mu_d_obs[keep])
        err_k = (s_mu_a[keep], s_mu_d[keep], rho[keep])
        angles_k = (alpha_k, delta_k)

        def least_square_wrapper(*theta_flat):
            theta_arr = jnp.array(theta_flat)
            return least_square(angles_k, obs_k, err_k, theta_arr, lmax=lmax, grid=False)

        m = Minuit(least_square_wrapper, *theta)
        m.errordef = Minuit.LEAST_SQUARES

        for j, name in enumerate(m.parameters):
            m.limits[name] = limits[name]

        m.migrad()

        theta = jnp.array([m.values[name] for name in m.parameters])

        # Compute X^2 over full dataset (not just kept subset)
        X2 = compute_X2(alpha, delta, mu_a_obs, mu_d_obs, s_mu_a, s_mu_d, rho, theta, lmax)
        median_X = jnp.median(X2)
        keep = X2 < (kappa**2) * median_X

        if prev_outliers is not None and jnp.array_equal(keep, prev_outliers):
            print(f"Converged after {iteration+1} iterations.")
            break
        prev_outliers = keep

    return theta, keep


In [30]:
lmax = 3
total_params = count_vsh_coeffs(lmax)
theta_init = jnp.zeros(total_params)

theta, keep = robust_least_squares_fit(angles, obs, error, theta_init, lmax, 0.05, 0.01)

Converged after 3 iterations.


In [31]:
len(keep)

1215942

In [32]:
df_clean = df.loc[np.array(keep)]

In [33]:
angles_clean, obs_clean, error_clean = config_data(df_clean)

In [34]:
print(len(angles[0]) - len(angles_clean[0]))

3792
