In [None]:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pyarrow.compute as pc
import pyarrow as pa
import os

from adam_core.time import Timestamp
from adam_core.orbits import Orbits
from adam_core.coordinates import KeplerianCoordinates, CartesianCoordinates, Origin, OriginCodes
from adam_core.dynamics.moid import calculate_moid
from adam_core.utils.spice import get_perturber_state


from adam_impact_study.analysis.population import compute_orbital_element_recovery_statistics, plot_orbital_element_recovery_statistics
from adam_impact_study.types import ImpactorOrbits
from adam_impact_study.analysis import summarize_impact_study_results
from adam_impact_study.analysis.utils import collect_all_window_results


RUN_DIR = "../results/results"
OUT_DIR = "../results_summary"

impactor_orbits = ImpactorOrbits.from_parquet(os.path.join(RUN_DIR, "../impactor_orbits.parquet"))
summary, window_results = summarize_impact_study_results(RUN_DIR, OUT_DIR, max_processes=60, summary_plots=False, per_object_plots=False)

In [None]:
summary_completed = summary.apply_mask(summary.complete())
completed_orbits = summary_completed.orbit.orbit_id.unique()

window_results_completed = window_results.apply_mask(pc.is_in(window_results.orbit_id, completed_orbits))
window_results_completed.to_dataframe()

In [None]:
from astropy.time import Time

def create_subplot_figure(quantity, max_cols=2, **kwargs):
    # Create a figure of subplots with one subplot per quantity
    # If there are more quantities than max_cols, create a new row
    num_quantities = len(quantity)
    num_rows = (num_quantities + max_cols - 1) // max_cols
    
    fig, ax = plt.subplots(num_rows, max_cols, **kwargs)
    # Create a figure of subplots with one subplot per quantity
    return fig, ax.flatten()

sizes = impactor_orbits.diameter.unique().sort()
# for decade in ["2025", "2035", "2045", "2055", "2065", "2075", "2085", "2095", "2105", "2115"]:
#     decade = f"_{decade}_"
#     decade_mask = pc.match_substring(summary_completed.orbit.orbit_id, decade)
#     summary_completed_decade = summary_completed.apply_mask(decade_mask)

#     if len(summary_completed_decade) == 0:
#         continue

fig, ax = create_subplot_figure(sizes, max_cols=2, dpi=200, figsize=(6, 12))
fig.subplots_adjust(
    wspace=0.2, 
    hspace=0.2,
    top=0.9,
    bottom=0.05,
    left=0.05,
    right=0.95
)

for i, size in enumerate(sizes):

    summary_completed_size = summary_completed.select("orbit.diameter", size)


    warning_time = np.nan_to_num(summary_completed_size.orbit.impact_time.mjd().to_numpy(zero_copy_only=False) - summary_completed_size.discovery_time.mjd().to_numpy(zero_copy_only=False), 0)
    impact_time = summary_completed_size.orbit.impact_time.mjd().to_numpy(zero_copy_only=False)
    time_until_impact = impact_time - Timestamp.from_astropy(Time("2025-05-05T00:00:00", scale="utc")).mjd()

    performance_metric = warning_time / time_until_impact

    ax[i].set_title(f"{size} km")
    ax[i].hist(1-performance_metric, bins=100, density=True, range=(0, 1), cumulative=True)
    ax[i].set_xlim(0, 1)
    

In [None]:
impact_mjds = impactor_orbits.impact_time.mjd().to_numpy(zero_copy_only=False)
from astropy.time import Time

fig, ax = plt.subplots(1, 1, dpi=200)
bins = ax.hist(impact_mjds, bins=100, density=False)


for year in range(2025, 2135, 10):
    y = Timestamp.from_astropy(Time([f"{year}-01-01T00:00:00"], scale="utc")).mjd()
    ax.axvline(y, color="k", lw=0.5, alpha=0.5)
ax.set_xlabel("Impact Time (MJD)")
ax.set_ylabel("Density")
fig.savefig("impact_time_histogram.png", dpi=300, bbox_inches="tight")

In [None]:
sizes = impactor_orbits.diameter.unique().sort().to_pylist()


fig, ax = plt.subplots(5, 2, dpi=200, figsize=(8, 16))
ax = ax.flatten()

for decade in ["2025", "2035", "2045", "2055", "2065", "2075", "2085", "2095", "2105", "2115"]:
    decade = f"_{decade}_"
    decade_mask = pc.match_substring(window_results_completed.orbit_id, decade)
    window_results_decade = window_results_completed.apply_mask(decade_mask)

    # Use viridis colormap - it's perceptually uniform, colorblind-friendly,
    # and has high contrast between values
    colors = plt.cm.tab10(np.linspace(0, 1, len(sizes)))
    
    for i, (size, color) in enumerate(zip(sizes, colors)):
       
        size_mask = pc.equal(impactor_orbits.diameter, size)
        impactor_orbits_size = impactor_orbits.apply_mask(size_mask)
        
        window_results_size = window_results_decade.apply_mask(pc.is_in(window_results_decade.orbit_id, impactor_orbits_size.orbit_id))

        for orbit_id in window_results_size.orbit_id.unique().to_pylist():
            summary_orbit = summary.select("orbit.orbit_id", orbit_id)
            orbit_mask = pc.equal(window_results_size.orbit_id, orbit_id)

            window_results_size_orbit = window_results_size.apply_mask(orbit_mask).sort_by("observation_end")
            impact_date = impactor_orbits.select("orbit_id", orbit_id).impact_time.rescale("utc").mjd().to_numpy(zero_copy_only=False)

            ax[i].plot(
                window_results_size_orbit.observation_end.mjd().to_numpy(zero_copy_only=False) - window_results_size_orbit.observation_end.mjd().to_numpy(zero_copy_only=False)[0], 
                window_results_size_orbit.impact_probability.to_numpy(zero_copy_only=False), 
                c=color,
                alpha=0.5,
                lw=0.5,
            )
            ax[i].axhline(window_results_size_orbit.impact_probability.to_numpy(zero_copy_only=False).max(), color=color, lw=0.5, alpha=0.5)
            #ax[i].set_xlim(0, 200)

In [None]:
import plotly.graph_objects as go
import plotly.express as px

sizes = impactor_orbits.diameter.unique().sort().to_pylist()

for decade in ["2045"]:
    decade = f"_{decade}_"
    decade_mask = pc.match_substring(window_results_completed.orbit_id, decade)
    window_results_decade = window_results_completed.apply_mask(decade_mask)

    fig = go.Figure()

    for size in sizes[-1:]:
        size_mask = pc.equal(impactor_orbits.diameter, size)
        impactor_orbits_size = impactor_orbits.apply_mask(size_mask)
        
        window_results_size = window_results_decade.apply_mask(
            pc.is_in(window_results_decade.orbit_id, impactor_orbits_size.orbit_id)
        )

        for orbit_id in window_results_size.orbit_id.unique().to_pylist():
            orbit_mask = pc.equal(window_results_size.orbit_id, orbit_id)
            window_results_size_orbit = window_results_size.apply_mask(orbit_mask)

            fig.add_trace(
                go.Scatter(
                    x=window_results_size_orbit.observation_end.mjd().to_numpy(zero_copy_only=False),
                    y=window_results_size_orbit.impact_probability.to_numpy(zero_copy_only=False),
                    mode='lines',
                    line=dict(color='black', dash='dash', width=0.5),
                    name=orbit_id,
                    showlegend=True  # Hide individual orbit traces from legend
                )
            )
    
    fig.update_layout(
        title=f"Impact Probability vs Time for {decade}",
        xaxis_title="Time (MJD)",
        yaxis_title="Impact Probability",
        template="plotly_white",  # Clean white background
        hovermode='closest'
    )
    
    fig.show()