In [1]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import simpy
from LEOEnvironmentRL import initialize, load_route_from_csv  # Use RL version
import pandas as pd
import os
from stable_baselines3 import DQN
from sb3_contrib import MaskablePPO
from sb3_contrib.common.wrappers import ActionMasker
import torch
import random

# %% 
import sb3_contrib
from HandoverEnvironment import LEOEnv as LEOEnvPPO 
from HandoverEnvironment import mask_fn, predict_valid_action
from HandoverEnvironment_DQN import LEOEnv as LEOEnvDQN
from HandoverEnvironment_DQN import predict_valid_action as predict_valid_action_dqn
from HandoverEnvironment_ODT import LEOEnv as LEOEnvODT
from HandoverEnvironment_ODT import predict_valid_action_dt
from ODT import OnlineDecisionTransformer
from LEOEnvironment import LEOEnv as LEOEnvBase

['ARS', 'CrossQ', 'MaskablePPO', 'QRDQN', 'RecurrentPPO', 'TQC', 'TRPO', '__all__', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', '__version__', 'ars', 'common', 'crossq', 'file_handler', 'os', 'ppo_mask', 'ppo_recurrent', 'qrdqn', 'tqc', 'trpo', 'version_file']


In [2]:
# Create PPO environment
inputParams = pd.read_csv("input.csv")
constellation_name = inputParams['Constellation'][0]
route, route_duration = load_route_from_csv('route.csv', skip_rows=3)
ppo_env = LEOEnvPPO(constellation_name, route)
ppo_env = ActionMasker(ppo_env, mask_fn)
# Evaluation with debugging
obs, info = ppo_env.reset()
print(f"Initial mask sum: {np.sum(ppo_env.action_mask) if hasattr(ppo_env, 'action_mask') else 'No mask attr'}")

Loading flight route from csv
UTF-8 decoding error: 'utf-8' codec can't decode byte 0xa0 in position 163: invalid start byte. Trying latin1 encoding.
[{'Time (EDT)': 'Tue 10:48:57 AM', 'Latitude': '29.9569', 'Longitude': '-95.3369', 'Course': '? 133°', 'kts': '159', 'mph': '183', 'feet': '1,075', 'Rate': '1,922\xa0Climbing', 'Reporting Facility': '\xa0FlightAware ADS-B\xa0(KIAH)'}, {'Time (EDT)': 'Tue 10:49:13 AM', 'Latitude': '29.9482', 'Longitude': '-95.3254', 'Course': '? 131°', 'kts': '174', 'mph': '200', 'feet': '1,475', 'Rate': '1,313\xa0Climbing', 'Reporting Facility': '\xa0FlightAware ADS-B\xa0(KIAH)'}, {'Time (EDT)': 'Tue 10:49:29 AM', 'Latitude': '29.9394', 'Longitude': '-95.3129', 'Course': '? 128°', 'kts': '194', 'mph': '223', 'feet': '1,775', 'Rate': '984\xa0Climbing', 'Reporting Facility': '\xa0FlightAware ADS-B\xa0(KAXH)'}, {'Time (EDT)': 'Tue 10:49:45 AM', 'Latitude': '29.9303', 'Longitude': '-95.2989', 'Course': '? 126°', 'kts': '218', 'mph': '251', 'feet': '2,000', 'R

In [3]:
# Load PPO Agent 
ppo_agent = MaskablePPO("MlpPolicy", ppo_env, verbose=1)
ppo_agent.load("handover_ppo_agent")

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


<sb3_contrib.ppo_mask.ppo_mask.MaskablePPO at 0x28b638880>

In [4]:
# Create DQN Environment
inputParams = pd.read_csv("input.csv")
constellation_name = inputParams['Constellation'][0]
route, route_duration = load_route_from_csv('route.csv', skip_rows=3)
dqn_env = LEOEnvDQN(constellation_name, route)
dqn_env = ActionMasker(dqn_env, mask_fn)
# Evaluation with debugging
obs, info = dqn_env.reset()
print(f"Initial mask sum: {np.sum(dqn_env.action_mask) if hasattr(dqn_env, 'action_mask') else 'No mask attr'}")


Loading flight route from csv
UTF-8 decoding error: 'utf-8' codec can't decode byte 0xa0 in position 163: invalid start byte. Trying latin1 encoding.
[{'Time (EDT)': 'Tue 10:48:57 AM', 'Latitude': '29.9569', 'Longitude': '-95.3369', 'Course': '? 133°', 'kts': '159', 'mph': '183', 'feet': '1,075', 'Rate': '1,922\xa0Climbing', 'Reporting Facility': '\xa0FlightAware ADS-B\xa0(KIAH)'}, {'Time (EDT)': 'Tue 10:49:13 AM', 'Latitude': '29.9482', 'Longitude': '-95.3254', 'Course': '? 131°', 'kts': '174', 'mph': '200', 'feet': '1,475', 'Rate': '1,313\xa0Climbing', 'Reporting Facility': '\xa0FlightAware ADS-B\xa0(KIAH)'}, {'Time (EDT)': 'Tue 10:49:29 AM', 'Latitude': '29.9394', 'Longitude': '-95.3129', 'Course': '? 128°', 'kts': '194', 'mph': '223', 'feet': '1,775', 'Rate': '984\xa0Climbing', 'Reporting Facility': '\xa0FlightAware ADS-B\xa0(KAXH)'}, {'Time (EDT)': 'Tue 10:49:45 AM', 'Latitude': '29.9303', 'Longitude': '-95.2989', 'Course': '? 126°', 'kts': '218', 'mph': '251', 'feet': '2,000', 'R

In [5]:
# Load DQN Agent 
dqn_agent = DQN("MlpPolicy", dqn_env, verbose=1, buffer_size=100) 
dqn_agent.load("handover_dqn_agent")

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


<stable_baselines3.dqn.dqn.DQN at 0x28b8d9430>

In [6]:
# Create ODT Environment 
inputParams = pd.read_csv("input.csv")
constellation_name = inputParams['Constellation'][0]
route, route_duration = load_route_from_csv('route.csv', skip_rows=3)
odt_env = LEOEnvODT(constellation_name, route)
odt_env = ActionMasker(odt_env, mask_fn)
# Evaluation with debugging
obs, info = odt_env.reset()
print(f"Initial mask sum: {np.sum(odt_env.action_mask) if hasattr(odt_env, 'action_mask') else 'No mask attr'}")

Loading flight route from csv
UTF-8 decoding error: 'utf-8' codec can't decode byte 0xa0 in position 163: invalid start byte. Trying latin1 encoding.
[{'Time (EDT)': 'Tue 10:48:57 AM', 'Latitude': '29.9569', 'Longitude': '-95.3369', 'Course': '? 133°', 'kts': '159', 'mph': '183', 'feet': '1,075', 'Rate': '1,922\xa0Climbing', 'Reporting Facility': '\xa0FlightAware ADS-B\xa0(KIAH)'}, {'Time (EDT)': 'Tue 10:49:13 AM', 'Latitude': '29.9482', 'Longitude': '-95.3254', 'Course': '? 131°', 'kts': '174', 'mph': '200', 'feet': '1,475', 'Rate': '1,313\xa0Climbing', 'Reporting Facility': '\xa0FlightAware ADS-B\xa0(KIAH)'}, {'Time (EDT)': 'Tue 10:49:29 AM', 'Latitude': '29.9394', 'Longitude': '-95.3129', 'Course': '? 128°', 'kts': '194', 'mph': '223', 'feet': '1,775', 'Rate': '984\xa0Climbing', 'Reporting Facility': '\xa0FlightAware ADS-B\xa0(KAXH)'}, {'Time (EDT)': 'Tue 10:49:45 AM', 'Latitude': '29.9303', 'Longitude': '-95.2989', 'Course': '? 126°', 'kts': '218', 'mph': '251', 'feet': '2,000', 'R

In [7]:
# Load ODT Agent 
odt_agent = OnlineDecisionTransformer(
        state_dim=odt_env.observation_space.shape[0],
        action_dim=odt_env.action_space.n,
        max_length=20,
        embed_dim=64,
        num_layers=2,
        target_return=1.0
    )

model_path = 'decision_transformer_final.pth'
odt_agent.load(model_path)

In [None]:
# Initalize baseline environment
inputParams = pd.read_csv("input.csv")
constellation_name = inputParams['Constellation'][0]
route, route_duration = load_route_from_csv('route.csv', skip_rows=3)
base_env = LEOEnvBase(constellation_name, route)

# Evaluation with debugging
obs, info = base_env.reset()

Loading flight route from csv
UTF-8 decoding error: 'utf-8' codec can't decode byte 0xa0 in position 163: invalid start byte. Trying latin1 encoding.
[{'Time (EDT)': 'Tue 10:48:57 AM', 'Latitude': '29.9569', 'Longitude': '-95.3369', 'Course': '? 133°', 'kts': '159', 'mph': '183', 'feet': '1,075', 'Rate': '1,922\xa0Climbing', 'Reporting Facility': '\xa0FlightAware ADS-B\xa0(KIAH)'}, {'Time (EDT)': 'Tue 10:49:13 AM', 'Latitude': '29.9482', 'Longitude': '-95.3254', 'Course': '? 131°', 'kts': '174', 'mph': '200', 'feet': '1,475', 'Rate': '1,313\xa0Climbing', 'Reporting Facility': '\xa0FlightAware ADS-B\xa0(KIAH)'}, {'Time (EDT)': 'Tue 10:49:29 AM', 'Latitude': '29.9394', 'Longitude': '-95.3129', 'Course': '? 128°', 'kts': '194', 'mph': '223', 'feet': '1,775', 'Rate': '984\xa0Climbing', 'Reporting Facility': '\xa0FlightAware ADS-B\xa0(KAXH)'}, {'Time (EDT)': 'Tue 10:49:45 AM', 'Latitude': '29.9303', 'Longitude': '-95.2989', 'Course': '? 126°', 'kts': '218', 'mph': '251', 'feet': '2,000', 'R

: 

In [None]:
# set training to false to enable saving plots 
base_env.earth.Training = False
dqn_env.env.earth.Training = False
ppo_env.env.earth.Training = False
odt_env.env.earth.Training = False

done_ppo = False
done_dqn = False
done_base = False
done_odt = False
step_count = 0

obs_ppo, info_ppo = ppo_env.reset()
obs_dqn, info_dqn = dqn_env.reset()
obs_base, info_base = base_env.reset()
obs_odt, info_odt = odt_env.reset()

obs_ppo_list = []
obs_dqn_list = []
obs_base_list = []
obs_odt_list = []

while not (done_ppo or done_dqn or done_base or done_odt):
    print(f"\n--- Step {step_count} ---")
    
    # PPO Agent Step
    mask_ppo = ppo_env.env._get_action_mask()
    print(f"PPO Valid actions: {np.sum(mask_ppo)}")
    action_ppo = predict_valid_action(ppo_agent, obs_ppo, mask_ppo)
    print(f"PPO Action: {action_ppo}, Valid: {mask_ppo[action_ppo]}")
    obs_ppo, reward_ppo, done_ppo, truncated_ppo, info_ppo = ppo_env.env.step(action_ppo)
    obs_ppo_list.append(obs_ppo)

    # DQN Agent Step
    mask_dqn = dqn_env.env._get_action_mask()
    print(f"DQN Valid actions: {np.sum(mask_dqn)}")
    action_dqn = predict_valid_action_dqn(dqn_agent, obs_dqn, mask_dqn)
    print(f"DQN Action: {action_dqn}, Valid: {mask_dqn[action_dqn]}")
    obs_dqn, reward_dqn, done_dqn, truncated_dqn, info_dqn = dqn_env.step(action_dqn)
    obs_dqn_list.append(obs_dqn)


    # ODT Agent Step
    mask_odt = odt_env.env._get_action_mask()
    print(f"ODT Valid actions: {np.sum(mask_odt)}")
    action_odt = predict_valid_action_dt(odt_agent, obs_odt, mask_odt)
    print(f"ODT Action: {action_odt}, Valid: {mask_odt[action_odt]}")
    obs_odt, reward_odt, done_odt, truncated_odt, info_odt = odt_env.step(action_odt)
    odt_agent.step(obs_odt, action_odt, reward_odt, obs_odt, done_odt or truncated_odt)
    obs_odt_list.append(obs_odt)

    # Baseline Environment Step
    obs_base, reward_base, done_base, truncated_base, info_base = base_env.step()
    obs_base_list.append(obs_base)
    
    step_count += 1

Aircraft A-380 initialized at (29.96, -95.34)
Using OneWeb constellation design
Initialized Earth
total divisions in x = 1920
 total divisions in y = 906
 total cells = 1739520
 window of operation (longitudes) = (0, 1920)
 window of operation (latitudes) = (0, 906)

candidates found: 2
Aircraft A-380 initialized at (29.96, -95.34)
Using OneWeb constellation design
Initialized Earth
total divisions in x = 1920
 total divisions in y = 906
 total cells = 1739520
 window of operation (longitudes) = (0, 1920)
 window of operation (latitudes) = (0, 906)

candidates found: 2
Aircraft A-380 initialized at (29.96, -95.34)
Using OneWeb constellation design
Initialized Earth
total divisions in x = 1920
 total divisions in y = 906
 total cells = 1739520
 window of operation (longitudes) = (0, 1920)
 window of operation (latitudes) = (0, 906)
Aircraft A-380 initialized at (29.96, -95.34)
Using OneWeb constellation design
Initialized Earth
total divisions in x = 1920
 total divisions in y = 906
 to

In [None]:
obs_base_list = np.array(obs_base_list)
obs_dqn_list = np.array(obs_dqn_list)    
obs_ppo_list = np.array(obs_ppo_list)
obs_odt_list = np.array(obs_odt_list)

In [None]:
import matplotlib.pyplot as plt
plt.plot(obs_base_list[:, 6], label = 'baseline handovers')
plt.plot(obs_ppo_list[:, 6], label = 'ppo handovers')
plt.plot(obs_dqn_list[:, 6], label = 'dqn handovers')
plt.plot(obs_odt_list[:, 6], label = 'odt handovers')
plt.legend()
plt.title('Total number of handovers across runs')
plt.show()

In [None]:
plt.plot(obs_base_list[:, 7], label = 'baseline allocated bandwidth')
plt.plot(obs_ppo_list[:, 7], label = 'ppo allocated bandwidth')
plt.plot(obs_dqn_list[:, 7], label = 'dqn allocated bandwidth')
plt.plot(obs_odt_list[:, 7], label = 'odt allocated bandwidth')
plt.legend()
plt.title('Total allocated bandwidth across runs')
plt.show()

In [None]:
# Get average allocation to demand across runs 
avg_allocated_to_demand_base = []
avg_allocated_to_demand_ppo = []
avg_allocated_to_demand_dqn = []
avg_allocated_to_demand_odt = []

for i in range(len(obs_base_list)): 
    avg_allocated_to_demand_base.append(np.mean(obs_base_list[:i, 8]))
    avg_allocated_to_demand_ppo.append(np.mean(obs_ppo_list[:i, 8]))
    avg_allocated_to_demand_dqn.append(np.mean(obs_dqn_list[:i, 8]))
    avg_allocated_to_demand_odt.append(np.mean(obs_odt_list[:i, 8]))

In [None]:
plt.plot(avg_allocated_to_demand_base, label = 'baseline average allocation to demand')
plt.plot(avg_allocated_to_demand_ppo, label = 'ppo average allocation to demand')
plt.plot(avg_allocated_to_demand_dqn, label = 'dqn average allocation to demand')
plt.plot(avg_allocated_to_demand_odt, label = 'odt average allocation to demand')
plt.legend()
plt.title('Average Allocation to Demand Across Runs')
plt.show()

In [None]:
# Get average allocation to demand across runs 
allocated_to_demand_base = []
allocated_to_demand_ppo = []
allocated_to_demand_dqn = []
allocated_to_demand_odt = []

for i in range(len(obs_base_list)): 
    allocated_to_demand_base.append(obs_base_list[i, 8])
    allocated_to_demand_ppo.append(obs_ppo_list[i, 8])
    allocated_to_demand_dqn.append(obs_dqn_list[i, 8])
    allocated_to_demand_odt.append(obs_odt_list[i, 8])

In [None]:
plt.plot(allocated_to_demand_base, label = 'baseline average allocation to demand')
plt.plot(allocated_to_demand_ppo, label = 'ppo average allocation to demand')
plt.plot(allocated_to_demand_dqn, label = 'dqn average allocation to demand')
plt.plot(allocated_to_demand_odt, label = 'odt average allocation to demand')
plt.legend()
plt.title('Average Allocation to Demand Across Runs')
plt.show()