In [1]:
import os
import numpy as np
from stable_baselines3 import DQN
from src.environment import AircraftDisruptionEnv
from scripts.visualizations import StatePlotter
from scripts.utils import load_scenario_data
from src.config import *
import re
import torch
import time
import ipywidgets as widgets
from IPython.display import display, clear_output, Image as IPImage
from io import BytesIO
import matplotlib.pyplot as plt

# Load the model and run inference
def run_inference_dqn(model_path, scenario_folder, env_type, seed):
    # Load the scenario data
    data_dict = load_scenario_data(scenario_folder)

    # Extract necessary data for the environment
    aircraft_dict = data_dict['aircraft']
    flights_dict = data_dict['flights']
    rotations_dict = data_dict['rotations']
    alt_aircraft_dict = data_dict['alt_aircraft']
    config_dict = data_dict['config']

    # Initialize the environment
    env = AircraftDisruptionEnv(
        aircraft_dict, 
        flights_dict, 
        rotations_dict, 
        alt_aircraft_dict, 
        config_dict,
        env_type=env_type
    )

    # Load the trained model and set the environment
    model = DQN.load(model_path)
    model.set_env(env)

    # Set model to evaluation mode
    model.policy.set_training_mode(False)
    model.exploration_rate = 0.0

    # Set random seed for reproducibility
    np.random.seed(seed)
    torch.manual_seed(seed)

    print(f"seed: {seed}")

    # Create StatePlotter object for visualizing the environment state
    state_plotter = StatePlotter(
        aircraft_dict=env.aircraft_dict,
        flights_dict=env.flights_dict,
        rotations_dict=env.rotations_dict,
        alt_aircraft_dict=env.alt_aircraft_dict,
        start_datetime=env.start_datetime,
        end_datetime=env.end_datetime,
        uncertain_breakdowns=env.uncertain_breakdowns,
    )

    # Reset the environment for inference
    obs, _ = env.reset()
    done_flag = False
    total_reward = 0
    step_num = 0
    max_steps = 1000  # Set a maximum number of steps to prevent infinite loops

    # List to collect images
    plots = []

    while not done_flag and step_num < max_steps:
        # Visualize the current state
        print(f"Step {step_num}:")

        # Extract necessary information from the environment for plotting
        swapped_flights = env.swapped_flights
        environment_delayed_flights = env.environment_delayed_flights
        current_datetime = env.current_datetime

        # Retrieve the updated dictionaries from the environment
        updated_flights_dict = env.flights_dict
        updated_rotations_dict = env.rotations_dict
        updated_alt_aircraft_dict = env.alt_aircraft_dict
        cancelled_flights = env.penalized_cancelled_flights

        if DEBUG_MODE_VISUALIZATION:
            print("Flights Dict:")
            print(updated_flights_dict)
            print("Alt Aircraft Dict:")
            print(updated_alt_aircraft_dict)
            print("Swapped Flights:")
            print(swapped_flights)
            print("Environment Delayed Flights:")
            print(environment_delayed_flights)
            print("Cancelled Flights:")
            print(cancelled_flights)
            print("Unavailabilities:")
            print(env.alt_aircraft_dict)
            print("Uncertain Breakdowns:")
            for key, value in env.uncertain_breakdowns.items():
                print(f"{key}: {value}")
            print("Current Breakdowns:")
            print(env.current_breakdowns)
            print("")

        # Update the StatePlotter's dictionaries with the updated ones
        state_plotter.alt_aircraft_dict = updated_alt_aircraft_dict
        state_plotter.flights_dict = updated_flights_dict
        state_plotter.rotations_dict = updated_rotations_dict

        if 'reward' not in locals():
            reward = 0
            action = 0
        # Collect the plot as an image
        fig = state_plotter.plot_state(
            updated_flights_dict, 
            swapped_flights, 
            environment_delayed_flights, 
            cancelled_flights, 
            current_datetime, 
            title_appendix=env_type,
            show_plot=False,
            reward_and_action=(reward, env.map_index_to_action(action), total_reward)
        )
        # Convert the figure to an image buffer
        buf = BytesIO()
        fig.savefig(buf, format='png')
        buf.seek(0)
        img = IPImage(data=buf.read(), format='png', embed=True)
        plots.append(img)
        plt.close(fig)  # Close the figure to prevent automatic display

        # Get the action mask from the environment
        action_mask = obs['action_mask']

        # Convert observation to float32
        obs = {key: np.array(value, dtype=np.float32) for key, value in obs.items()}

        # Get the action mask from the observation
        action_mask = obs.get('action_mask', None)
        if action_mask is None:
            raise ValueError("Action mask is missing in the observation!")

        # Get the Q-values and apply the action mask
        obs_tensor = model.policy.obs_to_tensor(obs)[0]
        q_values = model.policy.q_net(obs_tensor).detach().cpu().numpy().squeeze()

        # Mask invalid actions by setting their Q-values to -inf
        masked_q_values = q_values.copy()
        masked_q_values[action_mask == 0] = -np.inf

        # Predict the action using the masked Q-values
        action = np.argmax(masked_q_values)

        # Verify if the action is valid
        if action_mask[action] == 0:
            raise ValueError(f"Invalid action selected by the model: {action}")

        # Take action in the environment
        obs, reward, terminated, truncated, info = env.step(action)
        # Accumulate the reward
        total_reward += reward

        action_mapped = env.map_index_to_action(action)
        print("action index:")
        print(action)
        print("action mapped:")
        print(action_mapped)
        print(f"Action taken: {action_mapped}, Reward: {reward}")

        # Combine terminated and truncated flags
        done_flag = terminated or truncated

        step_num += 1

    print("================================================")
    print("Final state:")

    # Plot the final state and collect it
    fig = state_plotter.plot_state(
        updated_flights_dict, 
        swapped_flights, 
        environment_delayed_flights, 
        cancelled_flights, 
        current_datetime, 
        title_appendix=env_type,
        show_plot=False,
        reward_and_action=(reward, env.map_index_to_action(action), total_reward)
    )
    # Convert the figure to an image buffer
    buf = BytesIO()
    fig.savefig(buf, format='png')
    buf.seek(0)
    img = IPImage(data=buf.read(), format='png', embed=True)
    plots.append(img)
    plt.close(fig)  # Close the figure to prevent automatic display

    print(f"Total Reward: {total_reward}")
    print(f"Total Steps: {step_num}")

    # Create an interactive slider to display the plots
    def update_plot(index):
        with output:
            clear_output(wait=True)
            display(plots[index])

    slider = widgets.IntSlider(
        value=0, min=0, max=len(plots)-1, step=1, description='Step:'
    )
    output = widgets.Output()

    slider.observe(lambda change: update_plot(change['new']), names='value')

    # Display the initial plot
    update_plot(0)

    # Display the slider and output
    display(slider, output)

    return total_reward, step_num

latest = True
env_type = "myopic"
# env_type = "proactive"

if latest:
    MODEL_PATH = f"../trained_models/dqn/{env_type}_3ac-{max(int(model.split('-')[1].split('.')[0]) for model in os.listdir('../trained_models/dqn') if model.startswith(f'{env_type}_3ac-'))}.zip"
else:
    MODEL_PATH = f"../trained_models/dqn/_perfect_{env_type}_3ac-2.zip"

print(f"Model Path: {MODEL_PATH}")

# seed = 42
seed = int(time.time())


# PROACTIVE EXAMPLE
SCENARIO_FOLDER = "../data/Locked/alpha/Scenario_01"

# NOT WORKING
# MODEL_PATH = "../trained_models/dqn/myopic_3ac-16.zip"

# ACTUALLY WORKING
# MODEL_PATH = "../trained_models/dqn/myopic_3ac-1.zip"

# Extract the env_type using regex
match = re.search(r'/(myopic|proactive)_', MODEL_PATH)
# env_type = match.group(1) if match else None

print(f"Environment Type: {env_type}")

# Verify folder and model exist
if not os.path.exists(SCENARIO_FOLDER):
    raise FileNotFoundError(f"Scenario folder not found: {SCENARIO_FOLDER}")

if not os.path.exists(MODEL_PATH):
    raise FileNotFoundError(f"Model file not found: {MODEL_PATH}")

# Run the fixed inference loop
run_inference_dqn(MODEL_PATH, SCENARIO_FOLDER, env_type, seed)


Model Path: ../trained_models/dqn/myopic_3ac-21.zip
Environment Type: myopic
seed: 1732805446
Step 0:

Reward for action: flight 6, aircraft 0
  +0 for resolving 0 conflicts (excluding cancellations): set()
  -0.0 for delays (0 minutes)
  -1000 penalty for cancelled flights
  -0 for inaction with conflicts
  +60.5 bonus for proactive action (605.0 minutes ahead)
  -60.0 penalty for time passed
_______________
-999.5 total reward for action: flight 6, aircraft 0
action index:
24
action mapped:
(6, 0)
Action taken: (6, 0), Reward: -999.5
Step 1:

Reward for action: flight 2, aircraft 0
  +0 for resolving 0 conflicts (excluding cancellations): set()
  -0.0 for delays (0 minutes)
  -2000 penalty for cancelled flights
  -0 for inaction with conflicts
  +19.700000000000003 bonus for proactive action (197.0 minutes ahead)
  -120.0 penalty for time passed
_______________
-2100.3 total reward for action: flight 2, aircraft 0
action index:
8
action mapped:
(2, 0)
Action taken: (2, 0), Reward: -2

IntSlider(value=0, description='Step:', max=3)

Output()

(6739.8, 3)