In [None]:
import sxs
import matplotlib.pyplot as plt
from pathlib import Path
import json
import numpy as np
import re
import scipy as sp

# sxs.write_config(download=True, cache=True, auto_supersede=False)
sxs.read_config()

# Functions

In [None]:
def extract_levs(strings):
    """Extracts unique Lev references like Lev1, Lev2, etc., sorted by numeric value."""
    levs = set()
    for s in strings:
        found = re.findall(r"Lev\d+", s)
        levs.update(found)
    # Sort by numeric part
    return sorted(levs, key=lambda x: int(re.search(r"\d+", x).group()))


def get_center_diff(key, lev_low, lev_high):
    high_lev = sxs.load(f"{key}/Lev{lev_high}").horizons
    low_lev = sxs.load(f"{key}/Lev{lev_low}").horizons

    t = low_lev.A.coord_center_inertial.time
    diff_A = (
        high_lev.A.coord_center_inertial.interpolate(t)
        - low_lev.A.coord_center_inertial
    )
    int_diff_A = sp.integrate.simpson(np.linalg.norm(diff_A, axis=1), diff_A.time)

    t = low_lev.B.coord_center_inertial.time
    diff_B = (
        high_lev.B.coord_center_inertial.interpolate(t)
        - low_lev.B.coord_center_inertial
    )
    int_diff_B = sp.integrate.simpson(np.linalg.norm(diff_B, axis=1), diff_B.time)

    return int_diff_A, int_diff_B


def get_mismatch(key, mis_dict):
    # First get the mismatch value
    highest_two_levs = extract_levs(mis_dict[key].keys())[-2:]
    mismatch_key = f"({highest_two_levs[0]}, {highest_two_levs[1]}) 4d"

    mis_val = mis_dict[key][mismatch_key]["mismatch"]

    return int(highest_two_levs[0][-1]), int(highest_two_levs[1][-1]), mis_val


def get_mismatch_and_center_diff(key, mis_dict, min_lev=None):
    lev_low, lev_high, mis_val = get_mismatch(key, mis_dict)
    if min_lev is not None:
        if lev_low < min_lev:
            raise ValueError(f"Mismatch level {lev_low} is below minimum level {min_lev}.")
    int_diff_A, int_diff_B = get_center_diff(key, lev_low, lev_high)
    return mis_val, int_diff_A, int_diff_B


# Work Area

## Load mismatch dict

In [None]:
mismatch_data = Path("./data/data_mismatch.json")
if not mismatch_data.exists():
    raise FileNotFoundError(f"Data mismatch file not found: {mismatch_data}")

mis_dict = json.loads(mismatch_data.read_text())
len(mis_dict.keys())

In [None]:
base_key = list(mis_dict.keys())[10]
# base_key = 'SXS:BBH:1359'
print(base_key)
mis_dict[base_key].keys(), mis_dict[base_key]

## Plot center diff vs mismatch

In [None]:
mis_arr = []
center_diff_arr = []

for key in list(mis_dict.keys())[::-1][:25]:
    mis_val, int_diff_A, int_diff_B = get_mismatch_and_center_diff(key, mis_dict)
    print(f"{key}: {mis_val}, {int_diff_A:.3e}, {int_diff_B:.3e}")
    mis_arr.append(mis_val)
    center_diff_arr.append(int_diff_A)

In [None]:
plt.scatter( center_diff_arr,mis_arr)
plt.yscale("log")
plt.xscale("log")

## Sxs catalog

In [None]:
df = sxs.load("dataframe", tag="3.0.0")

In [None]:
fdf = df.copy()

fdf = fdf[fdf['reference_eccentricity'] < 1e-3]
fdf = fdf[fdf['object_types'] == "BHBH"]
fdf = fdf[fdf['common_horizon_time'] > 6000.0]
# fdf = fdf[fdf['common_horizon_time'] < 200000.0]
fdf = fdf[fdf['reference_mass_ratio'] < 5]
fdf = fdf[fdf['reference_dimensionless_spin1_mag'] < 0.4]
fdf = fdf[fdf['reference_dimensionless_spin2_mag'] < 0.4]
len(fdf['common_horizon_time'])

In [None]:
mis_arr = []
center_diff_arr_A = []
center_diff_arr_B = []

min_lev = 3  # Minimum level to consider for mismatch

for filtered_key in fdf.index.to_list():
    if filtered_key not in mis_dict:
        print(f"Key {filtered_key} not found in mismatch data, skipping.")
        continue
    try:
        mis_val, int_diff_A, int_diff_B = get_mismatch_and_center_diff(filtered_key, mis_dict, min_lev=min_lev)
    except Exception as e:
        print(f"KeyError for {filtered_key}: {e}, skipping.")
        continue
    print(f"{filtered_key}: {mis_val}, {int_diff_A:.3e}, {int_diff_B:.3e}")
    mis_arr.append(mis_val)
    center_diff_arr_A.append(int_diff_A)
    center_diff_arr_B.append(int_diff_B)

In [None]:
plt.scatter( center_diff_arr_A,mis_arr)
plt.ylabel("Mismatch")
plt.xlabel("Center Diff (A)")
plt.yscale("log")
plt.xscale("log")

In [None]:
plt.scatter( center_diff_arr_B,mis_arr)
plt.yscale("log")
plt.xscale("log")
plt.ylabel("Mismatch")
plt.xlabel("Center Diff (B)")

In [None]:
df.loc['SXS:BBH:4434'].to_dict()