<a href="https://colab.research.google.com/github/23f2002620/StratoHack-Space-Debris-Collision-Risk-Prediction/blob/main/StratoHack_Space_Debris_Collision_Risk_Prediction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# StratoHack-Space-Debris-Collision-Risk-Prediction

https://dlmultimedia.esa.int/download/public/videos/2025/04/002/2504_002_AR_EN.mp4

Datsets
Active satellites (baseline catalog for operational assets): https://celestrak.org/NORAD/elements/gp.php?GROUP=active&FORMAT=tle

FENGYUN‑1C ASAT debris (major LEO debris cloud): https://celestrak.org/NORAD/elements/gp.php?GROUP=fengyun-1c-debris&FORMAT=tle

COSMOS‑2251 collision debris (Iridium‑33/COSMOS‑2251 fragments): https://celestrak.org/NORAD/elements/gp.php?GROUP=cosmos-2251-debris&FORMAT=tle

COSMOS‑1408 ASAT debris (recent large debris event): https://celestrak.org/NORAD/elements/gp.php?GROUP=cosmos-1408-debris&FORMAT=tle



Importing Datasets

In [None]:
!pip -q install skyfield sgp4

import requests, os

from skyfield.api import load, EarthSatellite

from skyfield.iokit import parse_tle_file



[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/367.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m [32m358.4/367.0 kB[0m [31m13.9 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m367.0/367.0 kB[0m [31m8.7 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/235.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m235.7/235.7 kB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/49.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m49.4/49.4 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
URLS = {
"active.tle": "https://celestrak.org/NORAD/elements/gp.php?GROUP=active&FORMAT=tle",
"fengyun1c.tle": "https://celestrak.org/NORAD/elements/gp.php?GROUP=fengyun-1c-debris&FORMAT=tle",
"cosmos2251.tle": "https://celestrak.org/NORAD/elements/gp.php?GROUP=cosmos-2251-debris&FORMAT=tle",
"cosmos1408.tle": "https://celestrak.org/NORAD/elements/gp.php?GROUP=cosmos-1408-debris&FORMAT=tle",
}

for fn, url in URLS.items():
    if not os.path.exists(fn):
        r = requests.get(url, timeout=30)
        r.raise_for_status()
        with open(fn, "wb") as f:
            f.write(r.content)

In [None]:
ts = load.timescale()
catalog = {}
for fn in URLS.keys():
    with load.open(fn) as f:
        catalog[fn] = list(parse_tle_file(f, ts))
counts = {fn: len(sats) for fn, sats in catalog.items()}
print("Loaded counts:", counts)

Loaded counts: {'active.tle': 12685, 'fengyun1c.tle': 1890, 'cosmos2251.tle': 607, 'cosmos1408.tle': 5}


Install deps

In [None]:
!pip -q install torch torchvision torchaudio torch-geometric torch-scatter torch-sparse torch-cluster -f https://data.pyg.org/whl/torch-2.3.0+cpu.html
!pip -q install torch-geometric-temporal stable-baselines3 gymnasium


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m513.6/513.6 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m11.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m783.8/783.8 kB[0m [31m14.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m20.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m102.3/102.3 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m187.2/187.2 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[?25h

Data prep: build temporal graphs from your propagated TLEs

In [None]:
import math
import numpy as np
import torch
from torch_geometric_temporal.signal import DynamicGraphTemporalSignal

def synthetic_temporal_graph(T=24, N=500, F=8, k=8, seed=42):
    rng = np.random.default_rng(seed)
    Xs, Es, Ws, Ys = [], [], [], []
    for t in range(T):
        X = rng.normal(size=(N, F)).astype(np.float32)
        idx = np.argsort(X[:, 0])
        sources, targets = [], []
        for i in range(N):
            base = idx[i]
            neigh = idx[max(0, i-k):min(N, i+k+1)]
            for j in neigh:
                if j != i:
                    sources.append(base); targets.append(idx[j])
        edge_index = np.vstack([sources, targets]).astype(np.int64)
        edge_weight = np.ones(edge_index.shape[1], dtype=np.float32)
        y = (rng.random(edge_index.shape[1]) < 0.02).astype(np.float32)
        Xs.append(torch.from_numpy(X))
        Es.append(torch.from_numpy(edge_index))
        Ws.append(torch.from_numpy(edge_weight))
        Ys.append(y)
    return Xs, Es, Ws, Ys

Xs, Es, Ws, Ys = synthetic_temporal_graph()
dataset = DynamicGraphTemporalSignal(Es, Ws, Xs, Ys)

snapshots = list(dataset)
train_len = int(0.7 * len(snapshots))


GNN model and training loop (spatio‑temporal)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric_temporal.nn.recurrent import GConvGRU
import numpy as np

class STRiskGCN(nn.Module):
    def __init__(self, in_feats, hidden=32):
        super().__init__()
        self.rnn1 = GConvGRU(in_feats, hidden, K=3)
        self.rnn2 = GConvGRU(hidden, hidden//2, K=3)
        self.head = nn.Linear(hidden//2, 1)

    def forward(self, x, edge_index, edge_weight):
        h = self.rnn1(x, edge_index, edge_weight)
        h = F.relu(h)
        h = self.rnn2(h, edge_index, edge_weight)
        h = F.relu(h)
        logits = self.head(h)
        return logits

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = STRiskGCN(in_feats=Xs[0].shape[1]).to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
bce = nn.BCEWithLogitsLoss()

def to_tensor_y(y, device):
    if isinstance(y, np.ndarray):
        return torch.from_numpy(y).float().to(device)
    elif torch.is_tensor(y):
        return y.float().to(device)
    else:
        return torch.tensor(y, dtype=torch.float32, device=device)

def to_weights_or_ones(snapshot, device):
    ei = snapshot.edge_index.to(device)
    if hasattr(snapshot, "edge_weight") and snapshot.edge_weight is not None:
        ew = snapshot.edge_weight.to(device)
    else:
        ew = torch.ones(ei.shape[1], dtype=torch.float32, device=device)
    return ei, ew

def train_epoch():
    model.train()
    losses = []
    for snapshot in snapshots[:train_len]:
        x = snapshot.x.to(device)
        ei, ew = to_weights_or_ones(snapshot, device)
        y = to_tensor_y(snapshot.y, device)
        node_logits = model(x, ei, ew).squeeze(-1)
        edge_scores = 0.5 * (node_logits[ei[0]] + node_logits[ei[1]])
        loss = bce(edge_scores, y)
        opt.zero_grad()
        loss.backward()
        opt.step()
        losses.append(loss.item())
    return float(np.mean(losses)) if losses else float("nan")

@torch.no_grad()
def eval_epoch():
    model.eval()
    ys, ps = [], []
    for snapshot in snapshots[train_len:]:
        x = snapshot.x.to(device)
        ei, ew = to_weights_or_ones(snapshot, device)
        y = to_tensor_y(snapshot.y, device).cpu().numpy()
        node_logits = model(x, ei, ew).squeeze(-1)
        edge_scores = 0.5 * (node_logits[ei[0]] + node_logits[ei[1]])
        p = torch.sigmoid(edge_scores).cpu().numpy()
        ys.append(y); ps.append(p)
    if not ys:
        return float("nan")
    from sklearn.metrics import average_precision_score
    ytrue = np.concatenate(ys)
    ypred = np.concatenate(ps)
    return float(average_precision_score(ytrue, ypred))

for epoch in range(5):
    tr_loss = train_epoch()
    ap = eval_epoch()
    print(f"Epoch {epoch+1}: loss={tr_loss:.4f}, AP={ap:.4f}")


Epoch 1: loss=0.7393, AP=0.0205
Epoch 2: loss=0.6236, AP=0.0207
Epoch 3: loss=0.4813, AP=0.0204
Epoch 4: loss=0.3414, AP=0.0202
Epoch 5: loss=0.2409, AP=0.0202


Physics baseline stubs for Pc and miss distance

In [None]:
import numpy as np

def miss_distance_and_Pc_stub(rel_pos_km, rel_vel_kmps, hbr_m=10.0, sigma_m=200.0):
    md = np.linalg.norm(rel_pos_km) * 1000.0
    sigma2 = (sigma_m**2)
    Pc = np.exp(-0.5 * (md**2) / sigma2) * (hbr_m**2 / (2*np.pi*sigma2))
    return md, float(Pc)

# Example:
md, Pc = miss_distance_and_Pc_stub(np.array([0.5,0.1,0.0]), np.array([10,0,0]))
print("Miss distance (m), Pc:", md, Pc)


Miss distance (m), Pc: 509.90195135927854 1.5427767102227584e-05


RL environment for impulsive CAM (continuous ΔV) and PPO training

In [None]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np

class CamEnv(gym.Env):
    metadata = {"render_modes": []}
    def __init__(self):
        super().__init__()
        self.observation_space = spaces.Box(low=-1e3, high=1e3, shape=(8,), dtype=np.float32)
        self.action_space = spaces.Box(low=-0.2, high=0.2, shape=(3,), dtype=np.float32)
        self.reset()

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.state = self._sample_scenario()
        self.steps = 0
        return self.state.astype(np.float32), {}

    def _sample_scenario(self):
        dx, dy, dz = np.random.uniform(-5,5,3)
        dvx, dvy, dvz = np.random.uniform(-0.1,0.1,3)
        tmin = np.random.uniform(30, 360)
        fuel = np.random.uniform(5.0, 25.0)
        return np.array([dx,dy,dz,dvx,dvy,dvz,tmin,fuel], dtype=np.float32)

    def step(self, action):
        dv = np.clip(action, self.action_space.low, self.action_space.high)
        self.state[3:6] += dv/1000.0
        self.state[7] -= np.linalg.norm(dv)
        self.state[6] -= 5.0
        self.steps += 1
        rel_pos = self.state[0:3]
        rel_vel = self.state[3:6]
        md, Pc = miss_distance_and_Pc_stub(rel_pos, rel_vel)
        reward = -1000.0*Pc - 0.1*np.linalg.norm(dv)
        terminated = bool(self.state[7] <= 0 or self.state[6] <= 0 or Pc < 1e-4)
        truncated = self.steps >= 60
        return self.state.astype(np.float32), reward, terminated, truncated, {"Pc": Pc, "md": md}

env = CamEnv()
obs, _ = env.reset()
print("Initial obs:", obs)


Initial obs: [-3.1764426e+00  2.6122615e+00 -2.2304361e+00 -6.6839263e-02
 -5.2503765e-02  8.6006425e-02  1.6344002e+02  8.5166559e+00]


In [None]:
from stable_baselines3 import PPO

model = PPO("MlpPolicy", env, verbose=0, n_steps=2048, batch_size=256, gae_lambda=0.95, gamma=0.995, learning_rate=3e-4, clip_range=0.2)
model.learn(total_timesteps=50_000)

for ep in range(5):
    obs, _ = env.reset()
    done = False
    ret = 0.0
    while not done:
        action, _ = model.predict(obs, deterministic=True)
        obs, reward, terminated, truncated, info = env.step(action)
        ret += reward
        done = terminated or truncated
    print(f"Episode {ep+1}: return={ret:.2f}, Pc={info['Pc']:.2e}, md={info['md']:.1f} m")


Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.
  return datetime.utcnow().replace(tzinfo=utc)


Episode 1: return=-0.02, Pc=0.00e+00, md=4967.7 m
Episode 2: return=-0.02, Pc=0.00e+00, md=3950.3 m
Episode 3: return=-0.02, Pc=0.00e+00, md=4907.8 m
Episode 4: return=-0.02, Pc=0.00e+00, md=5129.3 m
Episode 5: return=-0.03, Pc=0.00e+00, md=5019.1 m


Wiring GNN to RL

In [None]:
def run_cam_for_pair(rel_pos_km, rel_vel_kmps, time_to_tca_min=120.0, fuel_ms=10.0, model=None, steps=40):
    env = CamEnv()
    env.state = np.array([
        rel_pos_km[0], rel_pos_km[1], rel_pos_km[2],
        rel_vel_kmps[0], rel_vel_kmps[1], rel_vel_kmps[2],
        time_to_tca_min, fuel_ms
    ], dtype=np.float32)
    obs = env.state.copy()
    traj = []
    for _ in range(steps):
        action, _ = model.predict(obs, deterministic=True)
        obs, reward, terminated, truncated, info = env.step(action)
        traj.append((obs.copy(), action.copy(), reward, info))
        if terminated or truncated:
            break
    return traj

# Example call with arbitrary initial geometry:
_ = run_cam_for_pair(np.array([1.0,0.2,0.0]), np.array([0.0,0.0,0.0]), model=model)


Code scaffolding: end-to-end stepper

In [None]:


def physics_score_event(pair, epoch_time, hbr_m, sigma_surrogate_m):
    md_m, Pc = miss_distance_and_Pc_stub(pair["rel_pos_km"], pair["rel_vel_kmps"],
                                         hbr_m=hbr_m, sigma_m=sigma_surrogate_m)
    return {"md_m": md_m, "Pc": Pc}

def propose_cam_with_ppo(pair, model, max_iters=3):
    traj = run_cam_for_pair(pair["rel_pos_km"], pair["rel_vel_kmps"],
                            time_to_tca_min=pair["tca_min"],
                            fuel_ms=pair.get("fuel_ms", 10.0),
                            model=model, steps=40)
    if not traj:
        return None
    obs, act, reward, info = traj[-1]
    return {"dv_last_ms": float(np.linalg.norm(act)), "Pc_post": float(info["Pc"]), "md_post_m": float(info["md"])}

def process_batch(g_scores, topk, model, Pc_thresh=1e-4, hbr_m=10.0, sigma_surrogate_m=200.0):

    batch = sorted(g_scores, key=lambda d: -d["score"])[:topk]
    results = []
    for pair in batch:
        pre = physics_score_event(pair, None, hbr_m, sigma_surrogate_m)
        if pre["Pc"] < Pc_thresh:
            continue
        cam = propose_cam_with_ppo(pair, model)
        if cam is None:
            continue
        accept = cam["Pc_post"] < Pc_thresh
        results.append({"pair": pair, "pre": pre, "cam": cam, "accepted": accept})
    return results


Pc code sketch

In [None]:
import numpy as np

def pc_2d_analytic(md_vec_m, cov3d_primary, cov3d_secondary, hbr_m):
    raise NotImplementedError

def pc_monte_carlo(rel_state_mean, cov6x6, hbr_m, N=200000, rng=None):
    rng = np.random.default_rng() if rng is None else rng
    samples = rng.multivariate_normal(rel_state_mean, cov6x6, size=N)
    d = np.linalg.norm(samples[:, :3], axis=1)
    return float(np.mean(d*1.0 <= hbr_m))


Example: pipeline skeleton

In [None]:
def run_cycle(tle_files, gnn_model, ppo_model, cfg):
    graphs = build_temporal_graphs(tle_files, cfg)
    g_scores = score_edges_with_gnn(graphs, gnn_model, cfg)

    validated = []
    for ev in topk(g_scores, cfg.topk):
        pre = compute_pc_appendixN(ev, cfg)
        if pre["Pc"] >= cfg.alert.pc_threshold:
            validated.append((ev, pre))


    outputs = []
    for ev, pre in validated:
        cam = propose_cam_with_ppo(ev, ppo_model)
        if cam:
            post = recompute_pc_with_cam(ev, cam, cfg)
            accept = (post["Pc"] < cfg.accept.pc_threshold) and (cam["dv_last_ms"] <= cfg.limits.dv_max_ms)
            outputs.append({"event": ev, "pre": pre, "cam": cam, "post": post, "accepted": accept})


    store_results(outputs, cfg.storage)
    send_alerts(outputs, cfg.notifications)
    return outputs


TCA and relative geometry from TLEs

In [None]:
!pip -q install skyfield sgp4 numpy

import numpy as np
from skyfield.api import load, EarthSatellite

def load_tle_text(lines):
    """Return a Skyfield EarthSatellite from a 3-line TLE block [name, L1, L2]."""
    name, l1, l2 = lines
    return EarthSatellite(l1, l2, name, load.timescale())

def eci_km(sat, t):
    """ECI position (km) and velocity (km/s) at time t."""
    g = sat.at(t)
    r = g.position.km
    v = g.velocity.km_per_s
    return r, v

def closest_approach(satA, satB, ts, t_start, t_end, steps=600, refine=True, refine_steps=20):
    """
    Bracket and refine time of closest approach (TCA) by sampling then golden-section search.
    Returns dict with tca time, rel_pos_km, rel_vel_kmps, miss_km.
    """
    times = ts.linspace(t_start, t_end, steps)
    d2 = []
    for t in times:
        rA, vA = eci_km(satA, t)
        rB, vB = eci_km(satB, t)
        dr = rA - rB
        d2.append(np.dot(dr, dr))
    i_min = int(np.argmin(d2))
    t0 = times[max(0, i_min-1)]
    t1 = times[min(len(times)-1, i_min+1)]

    if refine:
        phi = (1 + 5**0.5)/2
        invphi = 1/phi
        a = t0.tt
        b = t1.tt
        c = b - (b-a)*invphi
        d = a + (b-a)*invphi

        def d2_tt(tt):
            t = ts.tt(jd=tt)
            rA, vA = eci_km(satA, t)
            rB, vB = eci_km(satB, t)
            dr = rA - rB
            return float(np.dot(dr, dr))

        for _ in range(refine_steps):
            if d2_tt(c) < d2_tt(d):
                b = d
                d = c
                c = b - (b-a)*invphi
            else:
                a = c
                c = d
                d = a + (b-a)*invphi
        tca_tt = 0.5*(a+b)
    else:
        tca_tt = times[i_min].tt

    tca = ts.tt(jd=tca_tt)
    rA, vA = eci_km(satA, tca)
    rB, vB = eci_km(satB, tca)
    dr = rA - rB
    dv = vA - vB
    miss_km = float(np.linalg.norm(dr))
    return {"tca": tca, "rel_pos_km": dr, "rel_vel_kmps": dv, "miss_km": miss_km}

# Example usage (replace with actual TLEs):
ts = load.timescale()
tleA = ["SAT-A",
        "1 25544U 98067A   25263.15311495  .00015260  00000+0  27067-3 0  9995",
        "2 25544  51.6337 196.3707 0004371 358.2759   1.8214 15.50431661  112"]
tleB = ["SAT-B",
        "1 39227U 13043A   25263.10000000  .00000010  00000+0  00000+0 0  9991",
        "2 39227  55.0000 100.0000 0001000   0.0000  10.0000 14.80000000  0001"]
satA = load_tle_text(tleA)
satB = load_tle_text(tleB)
t0 = ts.utc(2025, 9, 21, 0, 0, 0)
t1 = ts.utc(2025, 9, 21, 6, 0, 0)
evt = closest_approach(satA, satB, ts, t0, t1)
print("Miss (km):", evt["miss_km"], "TCA:", evt["tca"].utc_strftime())


Miss (km): 251.8421213316651 TCA: 2025-09-21 00:48:20 UTC


Probability of collision: analytic 2‑D + Monte Carlo fallback

In [None]:
import numpy as np

def project_to_encounter_plane(rel_pos_km, rel_vel_kmps):
    """
    Build encounter-plane basis: e_c along relative velocity, e_r in plane orthogonal to e_c,
    and e_n normal completing RHS. Return rotation matrix R that maps ECI->encounter frame.
    """
    v = rel_vel_kmps / (np.linalg.norm(rel_vel_kmps) + 1e-12)
    r = rel_pos_km - np.dot(rel_pos_km, v) * v
    r = r / (np.linalg.norm(r) + 1e-12)
    n = np.cross(v, r)
    n = n / (np.linalg.norm(n) + 1e-12)

    R = np.vstack([r, n, v])
    return R

def pc_analytic_2d(mu_rel_m, cov_rel_3x3_m2, hbr_m):
    """
    2-D Gaussian Pc: integrate over circle radius hbr in encounter plane using eigen-decomp closed form.
    Assumes mu_rel_m is mean relative position at TCA in the encounter plane (2D), cov is 2x2.
    """
    w, U = np.linalg.eigh(cov_rel_3x3_m2[:2,:2])
        z = U.T @ mu_rel_m[:2]
    z = z / np.sqrt(np.maximum(w, 1e-12))
    sigma = float(np.sqrt(np.mean(np.maximum(w, 1e-12))))
    r = np.linalg.norm(mu_rel_m[:2])
    Pc_iso = 1.0 - np.exp(-(hbr_m**2)/(2*sigma**2))
    return float(np.clip(Pc_iso, 0.0, 1.0))

def pc_monte_carlo(mu6, cov6, hbr_m, N=200000, seed=1):
    rng = np.random.default_rng(seed)
    samples = rng.multivariate_normal(mu6, cov6, size=N)
    d = np.linalg.norm(samples[:, :3], axis=1)
    return float(np.mean(d <= hbr_m))

def compute_pc(evt, cov_primary_3x3_m2, cov_secondary_3x3_m2, hbr_m=10.0, fallback_mc=True):
    """
    evt: dict with rel_pos_km, rel_vel_kmps at TCA.
    cov_*: 3x3 position covariance at TCA for each object (meters^2). If not available, pass surrogates.
    """
    rel_pos_m = evt["rel_pos_km"] * 1000.0
    rel_vel_mps = evt["rel_vel_kmps"] * 1000.0
    R = project_to_encounter_plane(evt["rel_pos_km"], evt["rel_vel_kmps"])
    mu_enc = R @ (rel_pos_m)
    S = cov_primary_3x3_m2 + cov_secondary_3x3_m2
    S_enc = R @ S @ R.T
    try:
        Pc = pc_analytic_2d(mu_enc, S_enc, hbr_m)
    except Exception:
        Pc = None
    if (Pc is None or not np.isfinite(Pc) or Pc < 0.0 or Pc > 1.0) and fallback_mc:
        cov6 = np.zeros((6,6))
        cov6[:3,:3] = S
        cov6[3:,3:] = np.eye(3) * (1.0)
        mu6 = np.zeros(6); mu6[:3] = rel_pos_m
        Pc = pc_monte_carlo(mu6, cov6, hbr_m=hbr_m, N=200000)
    return float(np.clip(Pc, 0.0, 1.0))

covP = np.eye(3) * (300.0**2)
covS = np.eye(3) * (300.0**2)
Pc_est = compute_pc(evt, covP, covS, hbr_m=10.0)
print("Estimated Pc:", Pc_est)


Estimated Pc: 0.0002777392011031887


Integrate Pc and TCA into the driver

In [None]:
def compute_pc_appendixN(event, cfg):
    covP = np.eye(3) * (cfg["cov_sigma_m"]**2)
    covS = np.eye(3) * (cfg["cov_sigma_m"]**2)
    Pc = compute_pc(event, covP, covS, hbr_m=cfg["hbr_m"], fallback_mc=True)
    return {"Pc": Pc, "method": "analytic2D+MC", "hbr_m": cfg["hbr_m"]}

cfg = {"cov_sigma_m": 300.0, "hbr_m": 10.0}
pre = compute_pc_appendixN(evt, cfg)
print(pre)


{'Pc': 0.0002777392011031887, 'method': 'analytic2D+MC', 'hbr_m': 10.0}


3D globe with orbits and risky events (Plotly)

In [None]:
!pip -q install plotly
import numpy as np, plotly.graph_objects as go

def unit_sphere(n=60):
    u = np.linspace(0, 2*np.pi, n)
    v = np.linspace(0, np.pi, n)
    x = np.outer(np.cos(u), np.sin(v))
    y = np.outer(np.sin(u), np.sin(v))
    z = np.outer(np.ones_like(u), np.cos(v))
    return x, y, z

def eci_to_xyz_km(r_eci_km):
    return r_eci_km[0], r_eci_km[1], r_eci_km[2]

def sample_orbit_positions(sat, ts, t0, t1, steps=360):
    times = ts.linspace(t0, t1, steps)
    xs, ys, zs = [], [], []
    for t in times:
        r, _ = eci_km(sat, t)
        x, y, z = eci_to_xyz_km(r)
        xs.append(x); ys.append(y); zs.append(z)
    return np.array(xs), np.array(ys), np.array(zs)

def plot_globe_with_orbits(orbits, events=None, earth_radius_km=6371.0, title="Orbits & Risky Encounters"):
    xe, ye, ze = unit_sphere(80)
    fig = go.Figure()
    fig.add_trace(go.Surface(x=earth_radius_km*xe, y=earth_radius_km*ye, z=earth_radius_km*ze,
                             colorscale="Earth", showscale=False, opacity=0.7))
    for name, (xs, ys, zs) in orbits.items():
        fig.add_trace(go.Scatter3d(x=xs, y=ys, z=zs, mode="lines", name=name,
                                   line=dict(width=3)))
    if events:
        for ev in events:
            rp = ev["rel_pos_km"]
            r_primary = ev["r_primary_km"]
            r_event = r_primary + rp
            fig.add_trace(go.Scatter3d(x=[r_event[0]], y=[r_event[1]], z=[r_event[2]],
                                       mode="markers+text",
                                       name=f"Pc={ev['Pc']:.2e}",
                                       marker=dict(size=5, color="red"),
                                       text=[f"Pc={ev['Pc']:.1e}"],
                                       textposition="top center"))
    fig.update_layout(scene=dict(xaxis=dict(visible=False),
                                 yaxis=dict(visible=False),
                                 zaxis=dict(visible=False)),
                      title=title, showlegend=True,
                      height=720)
    fig.show()

# Example: feed from earlier Skyfield objects and evt
t0 = ts.utc(2025, 9, 21, 0, 0, 0)
t1 = ts.utc(2025, 9, 21, 3, 0, 0)
xsA, ysA, zsA = sample_orbit_positions(satA, ts, t0, t1, 360)
xsB, ysB, zsB = sample_orbit_positions(satB, ts, t0, t1, 360)
orbits = {"SAT-A": (xsA, ysA, zsA), "SAT-B": (xsB, ysB, zsB)}

event_payload = [{
    "Pc": pre["Pc"],
    "rel_pos_km": evt["rel_pos_km"],
    "r_primary_km": (eci_km(satA, evt["tca"])[0])
}]
plot_globe_with_orbits(orbits, event_payload, title="Demo: SAT-A & SAT-B with risky event")



datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).



Risk leaderboard builder

In [None]:
import pandas as pd
import numpy as np

def build_leaderboard(results, topn=20):
    """
    results: list of dicts from process_batch: {"pair","pre","cam","accepted"}
    pair expected keys: {"score","tca_min","id_primary","id_secondary"}
    """
    rows = []
    for r in results:
        p = r["pair"]; pre = r["pre"]; cam = r.get("cam", {})
        rows.append({
            "primary": p.get("id_primary","?"),
            "secondary": p.get("id_secondary","?"),
            "gnn_score": p.get("score", np.nan),
            "t2TCA_min": p.get("tca_min", np.nan),
            "Pc_pre": pre.get("Pc", np.nan),
            "Pc_post": cam.get("Pc_post", np.nan),
            "dV_m_s": cam.get("dv_last_ms", np.nan),
            "accepted": r.get("accepted", False)
        })
    df = pd.DataFrame(rows)
    df["priority"] = (df["gnn_score"].fillna(0)*0.6
                      + (df["Pc_pre"].fillna(0))*0.3
                      + (1.0/np.maximum(df["t2TCA_min"].fillna(1.0),1.0))*0.1)
    df = df.sort_values(["accepted","priority","Pc_pre"], ascending=[False, False, False]).head(topn)
    return df

# Example with dummy one
dummy_results = [{
    "pair": {"id_primary":"SAT-A","id_secondary":"SAT-B","score":0.92,"tca_min":120},
    "pre": {"Pc": pre["Pc"]},
    "cam": {"Pc_post": 1e-5, "dv_last_ms": 2.1},
    "accepted": True
}]
lb = build_leaderboard(dummy_results, topn=10)
lb



datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).



Unnamed: 0,primary,secondary,gnn_score,t2TCA_min,Pc_pre,Pc_post,dV_m_s,accepted,priority
0,SAT-A,SAT-B,0.92,120,0.000278,1e-05,2.1,True,0.552917



datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).



Pc: numeric 2-D encounter-plane integration (characteristic-function inversion)


In [None]:
import numpy as np

def pc_2d_numeric(mu2, cov2, hbr_m, n_terms=2048, t_max=50.0):
    """
    mu2: (2,) mean vector in encounter plane (meters)
    cov2: (2,2) covariance matrix in encounter plane (m^2)
    hbr_m: hard-body radius (m)
    n_terms: number of integration points
    t_max: upper integration limit for characteristic-function inversion
    Returns Pc \in [0,1].
    """
    w, U = np.linalg.eigh(cov2)
    w = np.maximum(w, 1e-14)
    L = np.diag(np.sqrt(w))
    z = np.linalg.solve(L, U.T @ mu2)
    r0 = np.linalg.norm(z)
    sigma_equiv = np.sqrt(np.mean(w))
    rho = hbr_m / sigma_equiv
    s = rho**2
    ts = np.linspace(1e-6, t_max, n_terms, dtype=np.float64)
    lam = r0**2
    i = 1j
    denom = np.sqrt(1 - 2*i*ts)
    exponent = (i*ts*lam) / (1 - 2*i*ts)
    cf = (1.0/denom) * np.exp(exponent)
    integrand = np.imag(np.exp(-i*ts*s) * cf / ts)
    F = 0.5 - (1/np.pi) * np.trapz(integrand, ts)
    Pc = float(np.clip(F, 0.0, 1.0))
    return Pc

def compute_pc_appendixN_numeric(evt, covP3, covS3, hbr_m=10.0, fallback_mc=True):
    rel_pos_m = evt["rel_pos_km"] * 1000.0
    rel_vel_mps = evt["rel_vel_kmps"] * 1000.0
    vhat = rel_vel_mps / (np.linalg.norm(rel_vel_mps) + 1e-12)
    rproj = rel_pos_m - np.dot(rel_pos_m, vhat) * vhat
    rhat = rproj / (np.linalg.norm(rproj) + 1e-12)
    nhat = np.cross(vhat, rhat); nhat /= (np.linalg.norm(nhat) + 1e-12)
    R = np.vstack([rhat, nhat, vhat])

    mu_enc = R @ rel_pos_m
    S = covP3 + covS3
    S_enc = R @ S @ R.T
    cov2 = S_enc[:2,:2]
    mu2 = mu_enc[:2]

    try:
        Pc = pc_2d_numeric(mu2, cov2, hbr_m)
    except Exception:
        Pc = None

    if (Pc is None or not np.isfinite(Pc) or Pc < 0.0 or Pc > 1.0) and fallback_mc:
        cov6 = np.zeros((6,6))
        cov6[:3,:3] = S
        cov6[3:,3:] = np.eye(3) * 1.0
        mu6 = np.zeros(6); mu6[:3] = rel_pos_m
        Pc = pc_monte_carlo(mu6, cov6, hbr_m=hbr_m, N=200000, seed=7)
    return {"Pc": float(np.clip(Pc, 0.0, 1.0)), "method": "numeric-2D+MC", "hbr_m": hbr_m}



invalid escape sequence '\i'


invalid escape sequence '\i'


invalid escape sequence '\i'



Build real features/labels from SGP4 TCA for GNN training

In [None]:


from itertools import combinations

def candidate_pairs(sat_list, max_pairs=5000):
    pairs = list(combinations(range(len(sat_list)), 2))
    if len(pairs) > max_pairs:
        rng = np.random.default_rng(42)
        idx = rng.choice(len(pairs), size=max_pairs, replace=False)
        pairs = [pairs[i] for i in idx]
    return pairs

def edge_features_from_evt(evt):
    # Example features from encounter geometry; extend with more descriptors
    rel = evt["rel_pos_km"]; vel = evt["rel_vel_kmps"]
    miss = evt["miss_km"]
    feats = np.array([
        miss,
        np.linalg.norm(vel),
        rel[0], rel[1], rel[2],
        vel[0], vel[1], vel[2]
    ], dtype=np.float32)
    return feats

def build_graph_snapshot(sats, ts, t0, t1, Pc_cfg, Pc_thresh=1e-4, max_pairs=5000):
    """Returns node feature matrix X (N x F_node), edge_index (2 x E), edge_attr (E x F_edge), labels y (E,)"""
    N = len(sats)
    X = np.zeros((N, 8), dtype=np.float32)

    edges_src, edges_dst = [], []
    edge_attr, y = [], []

    pairs = candidate_pairs(sats, max_pairs=max_pairs)
    for i, j in pairs:
        evt = closest_approach(sats[i], sats[j], ts, t0, t1, steps=480, refine=True, refine_steps=25)
        feats = edge_features_from_evt(evt)
        covP = np.eye(3) * (Pc_cfg["cov_sigma_m"]**2)
        covS = np.eye(3) * (Pc_cfg["cov_sigma_m"]**2)
        pre = compute_pc_appendixN_numeric(evt, covP, covS, hbr_m=Pc_cfg["hbr_m"])
        label = 1.0 if pre["Pc"] >= Pc_thresh else 0.0
        edges_src.append(i); edges_dst.append(j)
        edge_attr.append(feats)
        y.append(label)

    edge_index = np.vstack([edges_src, edges_dst]).astype(np.int64)
    edge_attr = np.asarray(edge_attr, dtype=np.float32)
    y = np.asarray(y, dtype=np.float32)
    return X, edge_index, edge_attr, y

# Example: build one snapshot for training
Pc_cfg = {"cov_sigma_m": 300.0, "hbr_m": 10.0}
X_np, EI_np, EA_np, y_np = build_graph_snapshot([satA, satB], ts, t0, t1, Pc_cfg, Pc_thresh=1e-4, max_pairs=1)
print(X_np.shape, EI_np.shape, EA_np.shape, y_np.shape)


(2, 8) (2, 1) (1, 8) (1,)



`trapz` is deprecated. Use `trapezoid` instead, or one of the numerical integration functions in `scipy.integrate`.



Train GNN with edge attributes (node-to-edge pooling maintained)

In [None]:
import torch
from torch_geometric_temporal.signal import DynamicGraphTemporalSignal
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric_temporal.nn.recurrent import GConvGRU

def build_temporal_dataset(sats, ts, start, windows=6, window_hours=1.0, Pc_cfg=None, Pc_thresh=1e-4):
    Es, Ws, Xs, Ys = [], [], [], []
    for w in range(windows):
        t0w = ts.tt_jd(start.tt + w*(window_hours/24.0))
        t1w = ts.tt_jd(start.tt + (w+1)*(window_hours/24.0))
        X_np, EI_np, EA_np, y_np = build_graph_snapshot(sats, ts, t0w, t1w, Pc_cfg, Pc_thresh=Pc_thresh, max_pairs=200)
        Xs.append(torch.from_numpy(X_np))
        Es.append(torch.from_numpy(EI_np))
        Ws.append(torch.from_numpy(np.linalg.norm(EA_np[:,5:8], axis=1).astype(np.float32)))
        Ys.append(y_np)
    return DynamicGraphTemporalSignal(Es, Ws, Xs, Ys)

dataset_real = build_temporal_dataset([satA, satB], ts, t0, windows=3, window_hours=2.0, Pc_cfg=Pc_cfg, Pc_thresh=1e-4)
snapshots = list(dataset_real)
train_len = max(1, int(0.7 * len(snapshots)))

class STRiskGCN(nn.Module):
    def __init__(self, in_feats, hidden=32):
        super().__init__()
        self.rnn1 = GConvGRU(in_feats, hidden, K=3)
        self.rnn2 = GConvGRU(hidden, hidden//2, K=3)
        self.head = nn.Linear(hidden//2, 1)
    def forward(self, x, edge_index, edge_weight):
        h = self.rnn1(x, edge_index, edge_weight)
        h = F.relu(h)
        h = self.rnn2(h, edge_index, edge_weight)
        h = F.relu(h)
        return self.head(h)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = STRiskGCN(in_feats=snapshots[0].x.shape[1]).to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
bce = nn.BCEWithLogitsLoss()

def to_tensor_y(y, device):
    import numpy as np, torch
    if isinstance(y, np.ndarray): return torch.from_numpy(y).float().to(device)
    if torch.is_tensor(y): return y.float().to(device)
    return torch.tensor(y, dtype=torch.float32, device=device)

def to_weights(snapshot, device):
    ei = snapshot.edge_index.to(device)
    ew = snapshot.edge_weight.to(device) if snapshot.edge_weight is not None else torch.ones(ei.shape[1], device=device)
    return ei, ew

def train_epoch():
    model.train()
    losses=[]
    for snap in snapshots[:train_len]:
        x = snap.x.to(device)
        ei, ew = to_weights(snap, device)
        y = to_tensor_y(snap.y, device)
        node_logits = model(x, ei, ew).squeeze(-1)
        edge_scores = 0.5*(node_logits[ei[0]] + node_logits[ei[1]])
        loss = bce(edge_scores, y)
        opt.zero_grad(); loss.backward(); opt.step()
        losses.append(loss.item())
    return float(np.mean(losses)) if losses else float("nan")

@torch.no_grad()
def eval_epoch():
    from sklearn.metrics import average_precision_score
    model.eval()
    ys, ps = [], []
    for snap in snapshots[train_len:]:
        x = snap.x.to(device)
        ei, ew = to_weights(snap, device)
        y = to_tensor_y(snap.y, device).cpu().numpy()
        node_logits = model(x, ei, ew).squeeze(-1)
        edge_scores = 0.5*(node_logits[ei[0]] + node_logits[ei[1]])
        p = torch.sigmoid(edge_scores).cpu().numpy()
        ys.append(y); ps.append(p)
    if not ys: return float("nan")
    ytrue = np.concatenate(ys); ypred = np.concatenate(ps)
    return float(average_precision_score(ytrue, ypred))

for ep in range(3):
    tr = train_epoch(); ap = eval_epoch()
    print(f"Epoch {ep+1}: loss={tr:.4f}, AP={ap:.4f}")



`trapz` is deprecated. Use `trapezoid` instead, or one of the numerical integration functions in `scipy.integrate`.



Epoch 1: loss=0.6992, AP=0.0000
Epoch 2: loss=0.6990, AP=0.0000
Epoch 3: loss=0.6989, AP=0.0000



No positive class found in y_true, recall is set to one for all thresholds.


No positive class found in y_true, recall is set to one for all thresholds.


No positive class found in y_true, recall is set to one for all thresholds.


datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).



RTN helpers and ΔV application with a Keplerian wrapper

In [None]:
import numpy as np
from skyfield.api import load, EarthSatellite

ts = load.timescale()

def eci_state_km_kmps(sat, t):
    g = sat.at(t)
    return g.position.km, g.velocity.km_per_s

def rtn_frame(r_km, v_kmps):

    r = np.array(r_km, dtype=np.float64)
    v = np.array(v_kmps, dtype=np.float64)
    R_hat = r / (np.linalg.norm(r) + 1e-12)
    h = np.cross(r, v)
    N_hat = h / (np.linalg.norm(h) + 1e-12)
    T_hat = np.cross(N_hat, R_hat)
    T_hat /= (np.linalg.norm(T_hat) + 1e-12)
    return R_hat, T_hat, N_hat

def rtn_to_eci(dv_rtn_ms, r_km, v_kmps):
    R_hat, T_hat, N_hat = rtn_frame(r_km, v_kmps)
    M = np.column_stack([R_hat, T_hat, N_hat])
    dv_eci_kmps = (M @ (np.array(dv_rtn_ms, dtype=np.float64) / 1000.0))
    return dv_eci_kmps

class KeplerWrapper:

    def __init__(self, r_km, v_kmps, name="POST"):
        self.name = name
        from skyfield.elementslib import KeplerOrbit
        from skyfield.constants import GM_earth_km3_s2
        self._gm = GM_earth_km3_s2
        self._epoch = ts.utc(2000, 1, 1)
        self._orbit = KeplerOrbit(np.array(r_km, float), np.array(v_kmps, float), self._gm, self._epoch)

    def at(self, t):
        r_km, v_kmps = self._orbit.position_velocity_gcrs(t.tt - self._orbit.t0)
        class PV:
            def __init__(self, r, v):
                self.position = type("P", (), {"km": r})
                self.velocity = type("V", (), {"km_per_s": v})
        return PV(r_km, v_kmps)

def make_post_cam_satellite(primary_sat, t_burn, dv_rtn_ms):
    r_km, v_kmps = eci_state_km_kmps(primary_sat, t_burn)
    dv_eci_kmps = rtn_to_eci(dv_rtn_ms, r_km, v_kmps)
    v_post_kmps = v_kmps + dv_eci_kmps
    return KeplerWrapper(r_km, v_post_kmps, name=primary_sat.name + "_POST")


In [None]:
!pip -q install poliastro astropy skyfield sgp4 numpy


Requested poliastro from https://files.pythonhosted.org/packages/1c/ce/b2cf237afeacddd856bb3ae524c44b8aec62e14c13d137283122fd0b5099/poliastro-0.12.0-py3-none-any.whl has invalid metadata: .* suffix can only be used with `==` or `!=` operators
    astropy (<4.*,>=3.1)
             ~~~^
Please use pip<24.1 if you need to use this version.[0m[33m
Requested poliastro from https://files.pythonhosted.org/packages/f7/9a/934e863eee7acca4648b3570085da982cde69969527b9f4d7a0445f16789/poliastro-0.11.1-py3-none-any.whl has invalid metadata: .* suffix can only be used with `==` or `!=` operators
    astropy (<4.*,>=3.0)
             ~~~^
Please use pip<24.1 if you need to use this version.[0m[33m
Requested poliastro from https://files.pythonhosted.org/packages/31/7d/55cfd3a348ed5575d0468e26c65c35295fc743c28598ba790561e065a263/poliastro-0.11.0-py3-none-any.whl has invalid metadata: .* suffix can only be used with `==` or `!=` operators
    astropy (<4.*,>=3.0)
             ~~~^
Please use pip<24.

RTN utilities and Skyfield↔poliastro conversion

In [None]:
import numpy as np
from skyfield.api import load
from astropy import units as u
from astropy.time import Time
from poliastro.twobody import Orbit
from poliastro.bodies import Earth

ts = load.timescale()

def skyfield_state_km_kmps(sat, t):
    g = sat.at(t)
    return np.array(g.position.km, float), np.array(g.velocity.km_per_s, float)

def rtn_frame(r_km, v_kmps):
    r = np.array(r_km, float); v = np.array(v_kmps, float)
    R_hat = r / (np.linalg.norm(r) + 1e-12)
    h = np.cross(r, v)
    N_hat = h / (np.linalg.norm(h) + 1e-12)
    T_hat = np.cross(N_hat, R_hat)
    T_hat /= (np.linalg.norm(T_hat) + 1e-12)
    return R_hat, T_hat, N_hat

def dv_rtn_to_eci_kmps(dv_rtn_ms, r_km, v_kmps):
    R_hat, T_hat, N_hat = rtn_frame(r_km, v_kmps)
    M = np.column_stack([R_hat, T_hat, N_hat])  # columns = ECI basis vectors
    dv_eci_kmps = M @ (np.array(dv_rtn_ms, float) / 1000.0)
    return dv_eci_kmps

def skyfield_to_poliastro_orbit(sat, t):
    r_km, v_kmps = skyfield_state_km_kmps(sat, t)
    r = r_km * u.km
    v = v_kmps * (u.km / u.s)
    epoch = Time(t.utc_datetime())
    return Orbit.from_vectors(Earth, r, v, epoch=epoch), r_km, v_kmps

class PoliastroAdapter:
    """Adapter with .at(t) returning .position.km and .velocity.km_per_s for use with closest_approach()."""
    def __init__(self, orbit):
        self.orbit = orbit
        self.name = "POST"
    def at(self, t):
        tof = (Time(t.utc_datetime()) - self.orbit.epoch).to(u.s)
        new_orb = self.orbit.propagate(tof)
        r = new_orb.r.to_value(u.km)
        v = new_orb.v.to_value(u.km / u.s)
        class PV:
            def __init__(self, r, v):
                self.position = type("P", (), {"km": r})
                self.velocity = type("V", (), {"km_per_s": v})
        return PV(r, v)



datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).



Evaluation: precision@k and calibration for risky-link prediction

In [None]:
import numpy as np
from sklearn.calibration import calibration_curve
from sklearn.metrics import precision_score

def precision_at_k(y_true, y_score, k):
    idx = np.argsort(-y_score)[:k]
    return float(np.mean(y_true[idx] == 1))

def eval_link_prediction(y_true, y_score, ks=(50,100,200)):
    out = {}
    for k in ks:
        out[f"prec@{k}"] = precision_at_k(y_true, y_score, k)
    prob_true, prob_pred = calibration_curve(y_true, y_score, n_bins=10, strategy="quantile")
    out["calibration"] = (prob_true, prob_pred)
    return out

 Policy config: thresholds, ΔV caps, maneuver windows


In [None]:
import yaml, json

policy = {
  "alerts": {
    "pc_threshold_trigger": 1e-4,
    "pc_threshold_accept": 5e-5,
    "topk": 200
  },
  "maneuver": {
    "dv_max_ms": 5.0,
    "window_min_before_tca": 120.0,
    "rtn_axis_weights": {"R":0.2, "T":1.0, "N":0.2}
  },
  "pc": {
    "hbr_m": 10.0,
    "cov_sigma_m": 300.0
  }
}
with open("policy.yaml","w") as f:
    yaml.safe_dump(policy, f)

def load_policy(path="policy.yaml"):
    with open(path,"r") as f:
        return yaml.safe_load(f)

cfg = load_policy()


Driver: integrate policy, evaluation, logging

In [None]:
import os, time, json
from datetime import datetime

def ensure_dir(d):
    os.makedirs(d, exist_ok=True); return d

def run_cycle_with_logging(tle_files, gnn_model, ppo_model, cfg, outdir="runs"):
    run_id = datetime.utcnow().strftime("%Y%m%dT%H%M%SZ")
    rundir = ensure_dir(os.path.join(outdir, run_id))
    graphs = build_temporal_graphs(tle_files, cfg)
    g_scores = score_edges_with_gnn(graphs, gnn_model)
    results = process_batch(g_scores, cfg["alerts"]["topk"], ppo_model,
                            Pc_thresh=cfg["alerts"]["pc_threshold_trigger"],
                            hbr_m=cfg["pc"]["hbr_m"], sigma_surrogate_m=cfg["pc"]["cov_sigma_m"])
    y_true = np.array([1.0 if r["pre"]["Pc"] >= cfg["alerts"]["pc_threshold_trigger"] else 0.0 for r in results], float)
    y_score = np.array([r["pair"]["score"] for r in results], float)
    lp_metrics = eval_link_prediction(y_true, y_score, ks=(20,50,100))

    with open(os.path.join(rundir, "policy.json"), "w") as f: json.dump(cfg, f, indent=2)
    with open(os.path.join(rundir, "results.json"), "w") as f: json.dump(results, f, default=lambda o: float(o) if hasattr(o,"__float__") else str(o), indent=2)
    with open(os.path.join(rundir, "metrics.json"), "w") as f: json.dump({k:(v if not isinstance(v, tuple) else [list(v[0]), list(v[1])]) for k,v in lp_metrics.items()}, f, indent=2)
    print("Saved run to:", rundir)
    return rundir, results, lp_metrics


Pc regression tests (unit tests for numeric 2‑D vs MC)

In [None]:
def test_pc_sanity():
    cov = np.diag([300.0**2, 300.0**2, 300.0**2])

    evt0 = {"rel_pos_km": np.array([0,0,0], float), "rel_vel_kmps": np.array([7.5,0,0], float)}
    Pc1 = compute_pc_appendixN_numeric(evt0, cov, cov, hbr_m=5.0)["Pc"]
    Pc2 = compute_pc_appendixN_numeric(evt0, cov, cov, hbr_m=10.0)["Pc"]
    assert 0.0 <= Pc1 <= Pc2 <= 1.0

    evt_far = {"rel_pos_km": np.array([10.0,0,0], float), "rel_vel_kmps": np.array([7.5,0,0], float)}
    Pc_far = compute_pc_appendixN_numeric(evt_far, cov, cov, hbr_m=10.0)["Pc"]
    assert Pc_far < 1e-6

test_pc_sanity(); print("Pc tests passed")


Pc tests passed



`trapz` is deprecated. Use `trapezoid` instead, or one of the numerical integration functions in `scipy.integrate`.



Model/version tracking stub

In [None]:
import hashlib, json, os

def sha256_file(path):
    h = hashlib.sha256()
    with open(path,"rb") as f:
        for chunk in iter(lambda: f.read(65536), b""):
            h.update(chunk)
    return h.hexdigest()

def register_artifacts(run_dir, model_paths):
    registry = {}
    for name, path in model_paths.items():
        if os.path.exists(path):
            registry[name] = {"path": path, "sha256": sha256_file(path)}
    with open(os.path.join(run_dir, "artifacts.json"), "w") as f:
        json.dump(registry, f, indent=2)
    return registry


Acceptance checks before executing CAM

In [None]:
def accept_cam(pre_pc, post_pc, dv_ms, t2tca_min, cfg):
    if np.linalg.norm(dv_ms) > cfg["maneuver"]["dv_max_ms"]: return False, "dv_cap"
    if t2tca_min < cfg["maneuver"]["window_min_before_tca"]: return False, "too_late"
    if post_pc >= cfg["alerts"]["pc_threshold_accept"]: return False, "pc_not_low_enough"
    return True, "ok"


Save/load GNN and PPO models

In [None]:
import torch

def save_gnn(model, path="gnn.pt"):
    torch.save(model.state_dict(), path)

def load_gnn(model_class, in_feats, path="gnn.pt", device="cpu"):
    model = model_class(in_feats=in_feats).to(device)
    sd = torch.load(path, map_location=device, weights_only=False)
    model.load_state_dict(sd)
    model.eval()
    return model

def save_ppo(model, path="ppo_cam.zip"):
    model.save(path)

def load_ppo(path="ppo_cam.zip", env=None):
    from stable_baselines3 import PPO
    model = PPO.load(path, env=env)
    return model


One-call pipeline runner (config-driven)

In [None]:
import yaml, json, os, numpy as np
from datetime import datetime

def run_once(tle_files, gnn_model, ppo_model, cfg, outdir="runs"):
    os.makedirs(outdir, exist_ok=True)
    run_id = datetime.utcnow().strftime("%Y%m%dT%H%M%SZ")
    rundir = os.path.join(outdir, run_id)
    os.makedirs(rundir, exist_ok=True)

    graphs = build_temporal_graphs(tle_files, cfg)
    g_scores = score_edges_with_gnn(graphs, gnn_model)

    results = process_batch(
        g_scores,
        cfg["alerts"]["topk"],
        ppo_model,
        Pc_thresh=cfg["alerts"]["pc_threshold_trigger"],
        hbr_m=cfg["pc"]["hbr_m"],
        sigma_surrogate_m=cfg["pc"]["cov_sigma_m"]
    )
    lb = build_leaderboard(results, topn=cfg["alerts"]["topk"])
    lb.to_csv(os.path.join(rundir, "leaderboard.csv"), index=False)

    with open(os.path.join(rundir, "policy.yaml"), "w") as f:
        yaml.safe_dump(cfg, f)
    with open(os.path.join(rundir, "results.json"), "w") as f:
        json.dump(results, f, default=lambda o: float(o) if hasattr(o, "__float__") else str(o), indent=2)

    print("Run saved to:", rundir)
    return rundir, lb, results


Minimal unittest harness for Pc and accept_cam

In [None]:

import unittest, numpy as np

class TestPcAndPolicy(unittest.TestCase):
    def test_pc_monotonic_hbr(self):
        cov = np.eye(3)*(300.0**2)
        evt0 = {"rel_pos_km": np.array([0,0,0], float),
                "rel_vel_kmps": np.array([7.5,0,0], float)}
        Pc5 = compute_pc_appendixN_numeric(evt0, cov, cov, hbr_m=5.0)["Pc"]
        Pc10 = compute_pc_appendixN_numeric(evt0, cov, cov, hbr_m=10.0)["Pc"]
        self.assertTrue(0.0 <= Pc5 <= Pc10 <= 1.0)

    def test_pc_far_low(self):
        cov = np.eye(3)*(300.0**2)
        evt_far = {"rel_pos_km": np.array([10.0,0,0], float),
                   "rel_vel_kmps": np.array([7.5,0,0], float)}
        Pc_far = compute_pc_appendixN_numeric(evt_far, cov, cov, hbr_m=10.0)["Pc"]
        self.assertLess(Pc_far, 1e-6)

    def test_accept_cam(self):
        cfg = {
            "maneuver": {"dv_max_ms": 5.0, "window_min_before_tca": 120.0},
            "alerts": {"pc_threshold_accept": 5e-5}
        }
        ok, msg = accept_cam(pre_pc=2e-4, post_pc=1e-5, dv_ms=np.array([0,2,0]), t2tca_min=180.0, cfg=cfg)
        self.assertTrue(ok)
        ok, msg = accept_cam(pre_pc=2e-4, post_pc=1e-4, dv_ms=np.array([0,6,0]), t2tca_min=60.0, cfg=cfg)
        self.assertFalse(ok)

suite = unittest.TestLoader().loadTestsFromTestCase(TestPcAndPolicy)
unittest.TextTestRunner(verbosity=2).run(suite)



datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).

ok

`trapz` is deprecated. Use `trapezoid` instead, or one of the numerical integration functions in `scipy.integrate`.

ok
test_pc_monotonic_hbr (__main__.TestPcAndPolicy.test_pc_monotonic_hbr) ... ok


datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).

----------------------------------------------------------------------
Ran 3 tests in 0.008s

OK


<unittest.runner.TextTestResult run=3 errors=0 failures=0>

CAM regression test (end-to-end)

In [None]:
import unittest
import numpy as np

class TestCamEndToEnd(unittest.TestCase):
    def test_cam_reduces_pc(self):
        cov = np.eye(3)*(Pc_cfg["cov_sigma_m"]**2)
        pre = compute_pc_appendixN_numeric(evt, cov, cov, hbr_m=Pc_cfg["hbr_m"])

        if pre["Pc"] < 1e-5:
            synthetic_evt = {
                "rel_pos_km": np.array([0.0, 0.02, 0.0]),
                "rel_vel_kmps": np.array([7.5, 0.0, 0.0])
            }
            pre = compute_pc_appendixN_numeric(synthetic_evt, cov, cov, hbr_m=Pc_cfg["hbr_m"])

        t_burn = evt["tca"]
        dv_rtn_ms = np.array([0.0, 2.0, 0.0], float)
        evt_post, post = recompute_postcam_metrics(satA, satB, t_burn, ts.utc(2025, 9, 21, 6, 0, 0), dv_rtn_ms, Pc_cfg)

        self.assertLessEqual(post["Pc"], pre["Pc"])

suite = unittest.TestLoader().loadTestsFromTestCase(TestCamEndToEnd)
unittest.TextTestRunner(verbosity=2).run(suite)



datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).


`trapz` is deprecated. Use `trapezoid` instead, or one of the numerical integration functions in `scipy.integrate`.

ERROR


datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).

ERROR: test_cam_reduces_pc (__main__.TestCamEndToEnd.test_cam_reduces_pc)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/tmp/ipython-input-2346019660.py", line 21, in test_cam_reduces_pc
    evt_post, post = recompute_postcam_metrics(satA, satB, t_burn, ts.utc(2025, 9, 21, 6, 0, 0), dv_rtn_ms, Pc_cfg)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipyth

<unittest.runner.TextTestResult run=1 errors=1 failures=0>

Save/load smoke tests for models

In [None]:
import os, torch
from stable_baselines3 import PPO

def smoke_test_model_io(gnn_model, ppo_model, in_feats):

    save_gnn(gnn_model, "gnn.pt")
    assert os.path.exists("gnn.pt")
    _ = load_gnn(gnn_model.__class__, in_feats, "gnn.pt", device="cpu")

    save_ppo(ppo_model, "ppo_cam.zip")
    assert os.path.exists("ppo_cam.zip")
    _ = load_ppo("ppo_cam.zip", env=None)
    print("Model IO smoke tests passed.")



W&B experiment tracking

In [None]:
!pip -q install wandb
import wandb, json, numpy as np

def run_with_wandb(tle_files, gnn_model, ppo_model, cfg, project="space-debris-risk"):
    wandb.login()
    with wandb.init(project=project, config=cfg, job_type="risk_cycle") as run:
        rundir, lb, results = run_once(tle_files, gnn_model, ppo_model, cfg, outdir="runs")

        if len(results) > 0:
            pc_pre = [r["pre"]["Pc"] for r in results]
            pc_post = [r["cam"].get("Pc_post", np.nan) for r in results]
            dv = [r["cam"].get("dv_last_ms", np.nan) for r in results]
            wandb.log({
                "events": len(results),
                "Pc_pre_mean": float(np.nanmean(pc_pre)),
                "Pc_post_mean": float(np.nanmean(pc_post)),
                "dV_ms_mean": float(np.nanmean(dv)),
                "accept_rate": float(np.mean([r.get("accepted", False) for r in results]))
            })

        run.log({"leaderboard": wandb.Table(dataframe=lb)})
        run.save(rundir + "/leaderboard.csv")
        run.save(rundir + "/results.json")
    return rundir, lb, results



datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).



Simple CLI entrypoint (run_pipeline.py)

In [None]:
%%writefile run_pipeline.py
import argparse, yaml, torch
from stable_baselines3 import PPO
from main_modules import load_policy, run_once, STRiskGCN, load_gnn, load_ppo

def main():
    p = argparse.ArgumentParser()
    p.add_argument("--policy", default="policy.yaml")
    p.add_argument("--gnn", default="gnn.pt")
    p.add_argument("--ppo", default="ppo_cam.zip")
    p.add_argument("--tle", nargs="+", required=True)
    args = p.parse_args()

    cfg = load_policy(args.policy)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    in_feats = cfg.get("gnn_in_feats", 8)
    gnn = load_gnn(STRiskGCN, in_feats, args.gnn, device=device)
    ppo = load_ppo(args.ppo, env=None)
    run_once(args.tle, gnn, ppo, cfg, outdir="runs")

if __name__ == "__main__":
    main()


Writing run_pipeline.py


Run locally:

python run_pipeline.py --policy policy.yaml --gnn gnn.pt --ppo ppo_cam.zip --tle active.tle fengyun1c.tle cosmos2251.tle cosmos1408.tle

requirements.txt:

text
numpy
pandas
torch
torch-geometric
torch-scatter
torch-sparse
torch-cluster
torch-geometric-temporal
stable-baselines3
gymnasium
skyfield
sgp4
poliastro
astropy
pyyaml
plotly
wandb

Minimal Dockerfile (for local/VM use)
text
# Dockerfile
FROM python:3.10-slim

WORKDIR /app
COPY requirements.txt ./
RUN pip install --no-cache-dir -r requirements.txt

COPY . .


Build and run:

docker build -t debris-risk:latest .

docker run --rm debris-risk:latest --policy policy.yaml --gnn gnn.pt --ppo ppo_cam.zip --tle active.tle fengyun1c.tle cosmos2251.tle cosmos1408.tle

What this accomplishes

1. Experiment tracking to compare runs, monitor Pc reduction, acceptance rate, and ΔV budgets.

2. CLI for one-command reproducible runs outside notebooks.

3. Containerization basics to move off Colab when needed and to schedule periodic runs reliably.


Once these are in place, the system is operationally usable: train/evaluate, run cycles with configs, track results, and, if desired, automate via a scheduler.