In [None]:
import sxs
import matplotlib.pyplot as plt
from pathlib import Path
import json
import numpy as np
import re
import scipy as sp
import pandas as pd

# sxs.write_config(download=True, cache=True, auto_supersede=False)
sxs.read_config()

# Functions

In [None]:
def extract_levs(strings):
    """Extracts unique Lev references like Lev1, Lev2, etc., sorted by numeric value."""
    levs = set()
    for s in strings:
        found = re.findall(r"Lev\d+", s)
        levs.update(found)
    # Sort by numeric part
    return sorted(levs, key=lambda x: int(re.search(r"\d+", x).group()))


def get_center_diff(key, lev_low, lev_high):
    high_lev = sxs.load(f"{key}/Lev{lev_high}").horizons
    low_lev = sxs.load(f"{key}/Lev{lev_low}").horizons

    t = low_lev.A.coord_center_inertial.time
    diff_A = (
        high_lev.A.coord_center_inertial.interpolate(t)
        - low_lev.A.coord_center_inertial
    )
    int_diff_A = sp.integrate.simpson(np.linalg.norm(diff_A, axis=1), diff_A.time)

    t = low_lev.B.coord_center_inertial.time
    diff_B = (
        high_lev.B.coord_center_inertial.interpolate(t)
        - low_lev.B.coord_center_inertial
    )
    int_diff_B = sp.integrate.simpson(np.linalg.norm(diff_B, axis=1), diff_B.time)

    return int_diff_A, int_diff_B


def get_mismatch(key, mis_dict):
    # First get the mismatch value
    highest_two_levs = extract_levs(mis_dict[key].keys())[-2:]

    # 4d version
    # mismatch_key = f"({highest_two_levs[0]}, {highest_two_levs[1]}) 4d"
    # mis_val = mis_dict[key][mismatch_key]["mismatch"]

    # Non 4d version?
    mismatch_key = f"({highest_two_levs[0]}, {highest_two_levs[1]})"
    mis_val = mis_dict[key][mismatch_key]

    return int(highest_two_levs[0][-1]), int(highest_two_levs[1][-1]), mis_val


def get_mismatch_lev(key, mis_dict, lev_low, lev_high):
    mismatch_key = f"({lev_low}, {lev_high})"
    return mis_dict[key][mismatch_key]


def get_mismatch_and_center_diff(key, mis_dict, min_lev=None):
    lev_low, lev_high, mis_val = get_mismatch(key, mis_dict)
    if min_lev is not None:
        if lev_low < min_lev:
            raise ValueError(
                f"Mismatch level {lev_low} is below minimum level {min_lev}."
            )
    int_diff_A, int_diff_B = get_center_diff(key, lev_low, lev_high)
    return mis_val, int_diff_A, int_diff_B


def get_mismatch_and_center_diff_between_levs(key, mis_dict, min_num_lev=3):
    levs = extract_levs(mis_dict[key].keys())
    if len(levs) < min_num_lev:
        raise ValueError(
            f"Not enough levels for key {key}. Found {len(levs)}, expected at least {min_num_lev}."
        )

    mis_int_dict = {}
    # Non 4d version
    for low, high in zip(levs[:-1], levs[1:]):
        key_levs = f"({low}, {high})"
        mis_int_dict[key_levs] = {}
        mis_int_dict[key_levs]['mismatch'] = mis_dict[key][key_levs]
        int_diff_A, int_diff_B = get_center_diff(key, low[-1], high[-1])
        mis_int_dict[key_levs]["int_diff_A"] = int_diff_A
        mis_int_dict[key_levs]["int_diff_B"] = int_diff_B

    return mis_int_dict


# Work Area

## Load mismatch dict

In [None]:
mismatch_data = Path("./data/data_mismatch.json")
if not mismatch_data.exists():
    raise FileNotFoundError(f"Data mismatch file not found: {mismatch_data}")

mis_dict = json.loads(mismatch_data.read_text())
len(mis_dict.keys())

In [None]:
base_key = list(mis_dict.keys())[10]
# base_key = 'SXS:BBH:1359'
print(base_key)
mis_dict[base_key].keys(), mis_dict[base_key]

## Plot center diff vs mismatch

In [None]:
mis_arr = []
center_diff_arr = []

for key in list(mis_dict.keys())[::-1][:25]:
    mis_val, int_diff_A, int_diff_B = get_mismatch_and_center_diff(key, mis_dict)
    print(f"{key}: {mis_val}, {int_diff_A:.3e}, {int_diff_B:.3e}")
    mis_arr.append(mis_val)
    center_diff_arr.append(int_diff_A)

In [None]:
plt.scatter( center_diff_arr,mis_arr)
plt.yscale("log")
plt.xscale("log")

## Sxs catalog

In [None]:
df = sxs.load("dataframe", tag="3.0.0")

In [None]:
fdf = df.copy()

fdf = fdf[fdf['reference_eccentricity'] < 1e-3]
fdf = fdf[fdf['object_types'] == "BHBH"]
fdf = fdf[fdf['common_horizon_time'] > 6000.0]
# fdf = fdf[fdf['common_horizon_time'] < 200000.0]
fdf = fdf[fdf['reference_mass_ratio'] < 5]
fdf = fdf[fdf['reference_dimensionless_spin1_mag'] < 0.4]
fdf = fdf[fdf['reference_dimensionless_spin2_mag'] < 0.4]
len(fdf['common_horizon_time'])

In [None]:
mis_arr = []
center_diff_arr_A = []
center_diff_arr_B = []

min_lev = 3  # Minimum level to consider for mismatch

for filtered_key in fdf.index.to_list():
    if filtered_key not in mis_dict:
        print(f"Key {filtered_key} not found in mismatch data, skipping.")
        continue
    try:
        mis_val, int_diff_A, int_diff_B = get_mismatch_and_center_diff(filtered_key, mis_dict, min_lev=min_lev)
    except Exception as e:
        print(f"KeyError for {filtered_key}: {e}, skipping.")
        continue
    print(f"{filtered_key}: {mis_val}, {int_diff_A:.3e}, {int_diff_B:.3e}")
    mis_arr.append(mis_val)
    center_diff_arr_A.append(int_diff_A)
    center_diff_arr_B.append(int_diff_B)

In [None]:
plt.scatter( center_diff_arr_A,mis_arr)
plt.ylabel("Mismatch")
plt.xlabel("Center Diff (A)")
plt.yscale("log")
plt.xscale("log")

In [None]:
plt.scatter( center_diff_arr_B,mis_arr)
plt.yscale("log")
plt.xscale("log")
plt.ylabel("Mismatch")
plt.xlabel("Center Diff (B)")

In [None]:
df.loc['SXS:BBH:3864'].to_dict()

### Lev trend

In [None]:
df = sxs.load("dataframe", tag="3.0.0")

In [None]:
fdf = df.copy()

fdf = fdf[fdf['reference_eccentricity'] < 1e-3]
fdf = fdf[fdf['object_types'] == "BHBH"]
fdf = fdf[fdf['common_horizon_time'] < 6000.0]
# fdf = fdf[fdf['common_horizon_time'] < 200000.0]
fdf = fdf[fdf['reference_mass_ratio'] < 2]
fdf = fdf[fdf['reference_dimensionless_spin1_mag'] < 0.4]
fdf = fdf[fdf['reference_dimensionless_spin2_mag'] < 0.4]
len(fdf['common_horizon_time'])

In [None]:
filtered_data = {}
min_num_lev = 4  # Minimum level to consider for mismatch

for filtered_key in fdf.index.to_list():
    if filtered_key not in mis_dict:
        print(f"Key {filtered_key} not found in mismatch data, skipping.")
        continue
    try:
        filtered_data[filtered_key] = get_mismatch_and_center_diff_between_levs(filtered_key, mis_dict, min_num_lev=min_num_lev)
        print(f"{filtered_key}: {filtered_data[filtered_key]}")
    except Exception as e:
        # print(f"KeyError for {filtered_key}: {e}, skipping.")
        continue


In [None]:
key = 'SXS:BBH:1132'
key = 'SXS:BBH:0198'
key = 'SXS:BBH:0310'
key = 'SXS:BBH:3864'
# key = 'SXS:BBH:4434'
mis_int_dict = get_mismatch_and_center_diff_between_levs(key, mis_dict, min_num_lev=min_num_lev)
x = [mis_int_dict[k]['int_diff_A'] for k in mis_int_dict]
x = [mis_int_dict[k]['int_diff_B'] for k in mis_int_dict]
y = [mis_int_dict[k]['mismatch'] for k in mis_int_dict]

plt.scatter( x,y)
plt.ylabel("Mismatch")
plt.xlabel("Center Diff (A)")
plt.yscale("log")
plt.xscale("log")

In [None]:
mis_int_dict

## Merging time data

In [None]:
def get_data_dict(data, num_points_to_skip=10, skip_T_before_merger=1000):
    horizon_data = data.A

    common_horizon_time = horizon_data.time[-1]  # The merger time
    interp_t = horizon_data.time[::num_points_to_skip]
    interp_t = interp_t[interp_t < common_horizon_time - skip_T_before_merger]
    time_to_merger = common_horizon_time - interp_t

    areal_mass_change = np.array(
        horizon_data.areal_mass.interpolate(interp_t) - horizon_data.areal_mass[0]
    )
    christodoulou_mass_change = np.array(
        (
            horizon_data.christodoulou_mass.interpolate(interp_t)
            - horizon_data.christodoulou_mass[0]
        )
    )
    chi_intertial = horizon_data.chi_inertial.interpolate(interp_t)
    coord_center_inertial = horizon_data.coord_center_inertial.interpolate(interp_t)
    dimensionful_inertial_spin = horizon_data.dimensionful_inertial_spin.interpolate(
        interp_t
    )

    data_dict_A = {
        "A_areal_mass": np.array(horizon_data.areal_mass.interpolate(interp_t)),
        "A_christodoulou_mass": np.array(
            horizon_data.christodoulou_mass.interpolate(interp_t)
        ),
        "A_areal_mass_change": areal_mass_change,
        "A_christodoulou_mass_change": christodoulou_mass_change,
        "A_chi_inertial_x": np.array(chi_intertial[:, 0]),
        "A_chi_inertial_y": np.array(chi_intertial[:, 1]),
        "A_chi_inertial_z": np.array(chi_intertial[:, 2]),
        "A_chi_inertial_mag": np.linalg.norm(chi_intertial, axis=1),
        "A_coord_center_inertial_x": np.array(coord_center_inertial[:, 0]),
        "A_coord_center_inertial_y": np.array(coord_center_inertial[:, 1]),
        "A_coord_center_inertial_z": np.array(coord_center_inertial[:, 2]),
        "A_dimensionful_inertial_spin_x": np.array(dimensionful_inertial_spin[:, 0]),
        "A_dimensionful_inertial_spin_y": np.array(dimensionful_inertial_spin[:, 1]),
        "A_dimensionful_inertial_spin_z": np.array(dimensionful_inertial_spin[:, 2]),
        "A_dimensionful_inertial_spin_mag": np.linalg.norm(
            dimensionful_inertial_spin, axis=1
        ),
    }

    horizon_data = data.B

    chi_intertial = horizon_data.chi_inertial.interpolate(interp_t)
    coord_center_inertial = horizon_data.coord_center_inertial.interpolate(interp_t)
    dimensionful_inertial_spin = horizon_data.dimensionful_inertial_spin.interpolate(
        interp_t
    )

    areal_mass_change = np.array(
        horizon_data.areal_mass.interpolate(interp_t) - horizon_data.areal_mass[0]
    )
    christodoulou_mass_change = np.array(
        (
            horizon_data.christodoulou_mass.interpolate(interp_t)
            - horizon_data.christodoulou_mass[0]
        )
    )

    data_dict_B = {
        "B_areal_mass": np.array(horizon_data.areal_mass.interpolate(interp_t)),
        "B_christodoulou_mass": np.array(
            horizon_data.christodoulou_mass.interpolate(interp_t)
        ),
        "B_areal_mass_change": areal_mass_change,
        "B_christodoulou_mass_change": christodoulou_mass_change,
        "B_chi_inertial_x": np.array(chi_intertial[:, 0]),
        "B_chi_inertial_y": np.array(chi_intertial[:, 1]),
        "B_chi_inertial_z": np.array(chi_intertial[:, 2]),
        "B_chi_inertial_mag": np.linalg.norm(chi_intertial, axis=1),
        "B_coord_center_inertial_x": np.array(coord_center_inertial[:, 0]),
        "B_coord_center_inertial_y": np.array(coord_center_inertial[:, 1]),
        "B_coord_center_inertial_z": np.array(coord_center_inertial[:, 2]),
        "B_dimensionful_inertial_spin_x": np.array(dimensionful_inertial_spin[:, 0]),
        "B_dimensionful_inertial_spin_y": np.array(dimensionful_inertial_spin[:, 1]),
        "B_dimensionful_inertial_spin_z": np.array(dimensionful_inertial_spin[:, 2]),
        "B_dimensionful_inertial_spin_mag": np.linalg.norm(
            dimensionful_inertial_spin, axis=1
        ),
    }

    return {**data_dict_A, **data_dict_B, "time_to_merger": time_to_merger}


In [None]:
df = sxs.load("dataframe", tag="3.0.0")

In [None]:
fdf = df.copy()

fdf = fdf[fdf['reference_eccentricity'] < 1e-3]
fdf = fdf[fdf['object_types'] == "BHBH"]
fdf = fdf[fdf['common_horizon_time'] > 6000.0]
fdf = fdf[fdf['common_horizon_time'] < 20000.0]
fdf = fdf[fdf['reference_mass_ratio'] < 5]
fdf = fdf[fdf['reference_dimensionless_spin1_mag'] < 0.1]
fdf = fdf[fdf['reference_dimensionless_spin2_mag'] < 0.1]
len(fdf['common_horizon_time'])

In [None]:
filtered_runs = fdf.index.to_list()
data_dict = {}
for i in filtered_runs:
    try:
        data_dict[i] = get_data_dict(sxs.load(i).horizons)
        print(f"Loaded data for {i}, len: {len(data_dict[i]['time_to_merger'])}")
    except Exception as e:
        print(f"Error loading data for {i}")
        continue


In [None]:
# Join all the data into a single DataFrame

rows = []
for run_name, subdict in data_dict.items():
    df_sub = pd.DataFrame(subdict)            # shape = (M, K)
    df_sub["run"] = run_name                  # add a column called “run”
    rows.append(df_sub)

big_df = pd.concat(rows, ignore_index=True)

In [None]:
big_df