In [3]:
import os
import numpy as np
from stable_baselines3 import DQN
from src.environment import AircraftDisruptionEnv
from scripts.visualizations import StatePlotter
from scripts.utils import load_scenario_data
from src.config import *
import re
import torch
import time
import ipywidgets as widgets
from IPython.display import display, clear_output, Image as IPImage
from io import BytesIO
import matplotlib.pyplot as plt

from scripts.utils import NumpyEncoder
from scripts.logger import *

from scripts.logger import create_new_id, log_inference_metadata, get_config_variables, find_corresponding_training_id, convert_to_serializable
import src.config as config

from scripts.utils import load_json, get_training_metadata, check_conflicts_between_training_and_current_config

# Load the model and run inference
def run_inference_dqn_single(model_path, scenario_folder, env_type, seed):
    """
    Runs inference on a single scenario and logs detailed results.

    Args:
        model_path (str): Path to the trained model.
        scenario_folder (str): Path to the scenario folder.
        env_type (str): Type of environment ("myopic" or "proactive").
        seed (int): Seed for reproducibility.
        inference_id (str): Unique ID for the inference session.
    """

    # Load the scenario data
    data_dict = load_scenario_data(scenario_folder)
    aircraft_dict = data_dict['aircraft']
    flights_dict = data_dict['flights']
    rotations_dict = data_dict['rotations']
    alt_aircraft_dict = data_dict['alt_aircraft']
    config_dict = data_dict['config']

    # Initialize the environment
    env = AircraftDisruptionEnv(
        aircraft_dict, 
        flights_dict, 
        rotations_dict, 
        alt_aircraft_dict, 
        config_dict,
        env_type=env_type
    )

    # Load the trained model and configure
    model = DQN.load(model_path)
    model.set_env(env)
    model.policy.set_training_mode(False)
    model.exploration_rate = 0.0

    # Set random seed for reproducibility
    np.random.seed(seed)
    torch.manual_seed(seed)

    print(f"seed: {seed}")

    # Create StatePlotter object for visualizing the environment state
    state_plotter = StatePlotter(
        aircraft_dict=env.aircraft_dict,
        flights_dict=env.flights_dict,
        rotations_dict=env.rotations_dict,
        alt_aircraft_dict=env.alt_aircraft_dict,
        start_datetime=env.start_datetime,
        end_datetime=env.end_datetime,
        uncertain_breakdowns=env.uncertain_breakdowns,
    )

    # Reset the environment for inference
    obs, _ = env.reset()
    done_flag = False
    total_reward = 0
    step_num = 0
    max_steps = 1000  # Set a maximum number of steps to prevent infinite loops

    # List to collect images
    plots = []

    # Scenario-level data
    scenario_log = {
        "scenario_folder": scenario_folder,
        "env_type": env_type,
        "seed": seed,
        "total_reward": 0,
        "steps": [],
    }


    while not done_flag and step_num < max_steps:
        # Visualize the current state
        print(f"Step {step_num}:")

        # Extract necessary information from the environment for plotting
        swapped_flights = env.swapped_flights
        environment_delayed_flights = env.environment_delayed_flights
        current_datetime = env.current_datetime

        # Retrieve the updated dictionaries from the environment
        updated_flights_dict = env.flights_dict
        updated_rotations_dict = env.rotations_dict
        updated_alt_aircraft_dict = env.alt_aircraft_dict
        cancelled_flights = env.penalized_cancelled_flights

        if DEBUG_MODE_VISUALIZATION:
            print("Flights Dict:")
            print(updated_flights_dict)
            print("Alt Aircraft Dict:")
            print(updated_alt_aircraft_dict)
            print("Swapped Flights:")
            print(swapped_flights)
            print("Environment Delayed Flights:")
            print(environment_delayed_flights)
            print("Cancelled Flights:")
            print(cancelled_flights)
            print("Unavailabilities:")
            print(env.alt_aircraft_dict)
            print("Uncertain Breakdowns:")
            for key, value in env.uncertain_breakdowns.items():
                print(f"{key}: {value}")
            print("Current Breakdowns:")
            print(env.current_breakdowns)
            print("")

        # Update the StatePlotter's dictionaries with the updated ones
        state_plotter.alt_aircraft_dict = updated_alt_aircraft_dict
        state_plotter.flights_dict = updated_flights_dict
        state_plotter.rotations_dict = updated_rotations_dict

        if 'reward' not in locals():
            reward = 0
            action = 0
        # Collect the plot as an image
        fig = state_plotter.plot_state(
            updated_flights_dict, 
            swapped_flights, 
            environment_delayed_flights, 
            cancelled_flights, 
            current_datetime, 
            title_appendix=env_type,
            show_plot=False,
            reward_and_action=(reward, env.map_index_to_action(action), total_reward)
        )
        # Convert the figure to an image buffer
        buf = BytesIO()
        fig.savefig(buf, format='png')
        buf.seek(0)
        img = IPImage(data=buf.read(), format='png', embed=True)
        plots.append(img)
        plt.close(fig)  # Close the figure to prevent automatic display

        # Get the action mask from the environment
        action_mask = obs['action_mask']

        # Convert observation to float32
        obs = {key: np.array(value, dtype=np.float32) for key, value in obs.items()}

        # Get the action mask from the observation
        action_mask = obs.get('action_mask', None)
        if action_mask is None:
            raise ValueError("Action mask is missing in the observation!")

        # Get the Q-values and apply the action mask
        obs_tensor = model.policy.obs_to_tensor(obs)[0]
        q_values = model.policy.q_net(obs_tensor).detach().cpu().numpy().squeeze()

        # Mask invalid actions by setting their Q-values to -inf
        masked_q_values = q_values.copy()
        masked_q_values[action_mask == 0] = -np.inf

        # Predict the action using the masked Q-values
        action = np.argmax(masked_q_values)

        # Verify if the action is valid
        if action_mask[action] == 0:
            raise ValueError(f"Invalid action selected by the model: {action}")

        # Take action in the environment
        obs, reward, terminated, truncated, info = env.step(action)
        # Accumulate the reward
        total_reward += reward

        action_mapped = env.map_index_to_action(action)
        print("action index:")
        print(action)
        print("action mapped:")
        print(action_mapped)
        print(f"Action taken: {action_mapped}, Reward: {reward}")

        # Combine terminated and truncated flags
        done_flag = terminated or truncated

        # Log step data
        step_log = {
            "step_num": step_num,
            "action": action,
            "flight_action": action_mapped[0],
            "aircraft_action": action_mapped[1],
            "reward": reward,
            "total_reward": total_reward,
            "q_values": q_values.tolist(),
            "masked_q_values": masked_q_values.tolist(),
            "action_mask": action_mask.tolist(),
            "done_flag": done_flag,
            "info_after_step": convert_to_serializable(env.info_after_step),  # Capturing all detailed reward components
        }


        scenario_log["steps"].append(step_log)

        step_num += 1

    print("================================================")
    print("Final state:")

    # Plot the final state and collect it
    fig = state_plotter.plot_state(
        updated_flights_dict, 
        swapped_flights, 
        environment_delayed_flights, 
        cancelled_flights, 
        current_datetime, 
        title_appendix=env_type,
        show_plot=False,
        reward_and_action=(reward, env.map_index_to_action(action), total_reward)
    )
    # Convert the figure to an image buffer
    buf = BytesIO()
    fig.savefig(buf, format='png')
    buf.seek(0)
    img = IPImage(data=buf.read(), format='png', embed=True)
    plots.append(img)
    plt.close(fig)  # Close the figure to prevent automatic display

    print(f"Total Reward: {total_reward}")
    print(f"Total Steps: {step_num}")

    # Create an interactive slider to display the plots
    def update_plot(index):
        with output:
            clear_output(wait=True)
            display(plots[index])

    slider = widgets.IntSlider(
        value=0, min=0, max=len(plots)-1, step=1, description='Step:'
    )
    output = widgets.Output()

    slider.observe(lambda change: update_plot(change['new']), names='value')

    # Display the initial plot
    update_plot(0)

    # Display the slider and output
    display(slider, output)
    scenario_log["total_reward"] = total_reward


    return total_reward, step_num, scenario_log


def run_inference_dqn_folder(model_path, scenario_folder, env_type, seed):

    inference_config_variables = get_config_variables(config)

    # Generate unique ID for training
    inference_id = create_new_id("inference")
    print(f"Inference ID: {inference_id}")
    runtime_start_in_seconds = time.time()

    # Get the config variables during training based on the model path and the logs/training/training_{id}.json file
    # first find the corresponding training id from the logs/ids.json file
    print(f"Checking for corresponding training ID for model: {model_path} and environment type: {env_type}")
    training_id = find_corresponding_training_id(model_path, env_type)
    print(f"Training ID: {training_id}")

    training_logs_path = f"logs/training/training_{training_id}.json"
    matching_variables, conflicting_variables = check_conflicts_between_training_and_current_config(training_logs_path, 
                                                                                                    env_type, 
                                                                                                    inference_config_variables)

    if len(conflicting_variables) > 0:
        print(f"Conflicting Variables: {conflicting_variables}")
        raise ValueError("Conflicting variables found between training and current config!")
    elif len(matching_variables) == 0:
        raise ValueError("No matching variables found between training and current config!")
    else:
        print("No conflicting variables found between training and current config!")

    inference_metadata = {
        "inference_id": inference_id,
        "runtime_start_in_seconds": runtime_start_in_seconds,
        "model_path": model_path,
        "scenario_folder": scenario_folder,
        "env_type": env_type,
        "seed": seed,
        **matching_variables,

    }

    # Log the inference metadata
    log_inference_metadata(inference_id, inference_metadata)

    complete_inference_log = {}
    for scenario in os.listdir(scenario_folder):
        scenario_path = os.path.join(scenario_folder, scenario)
        # Verify folder and model exist
        if not os.path.exists(scenario_path):
            raise FileNotFoundError(f"Scenario folder not found: {scenario_path}")

        if not os.path.exists(model_path):
            raise FileNotFoundError(f"Model file not found: {model_path}")

        total_reward, step_num, scenario_log = run_inference_dqn_single(model_path, scenario_path, env_type, seed)

        complete_inference_log[scenario] = scenario_log

    log_scenario_data(inference_id, complete_inference_log)



latest = True
# env_type = "myopic"
env_type = "proactive"

if latest:
    MODEL_PATH = f"../trained_models/dqn/{env_type}_3ac-{max(int(model.split('-')[1].split('.')[0]) for model in os.listdir('../trained_models/dqn') if model.startswith(f'{env_type}_3ac-'))}.zip"
else:
    MODEL_PATH = f"../trained_models/dqn/_perfect_{env_type}_3ac-2.zip"

print(f"Model Path: {MODEL_PATH}")

# seed = 42
seed = int(time.time())


# PROACTIVE EXAMPLE
SCENARIO_FOLDER = "../data/Locked/alpha/"

# NOT WORKING
# MODEL_PATH = "../trained_models/dqn/myopic_3ac-16.zip"

# ACTUALLY WORKING
# MODEL_PATH = "../trained_models/dqn/myopic_3ac-1.zip"

# Extract the env_type using regex
match = re.search(r'/(myopic|proactive)_', MODEL_PATH)
# env_type = match.group(1) if match else None

print(f"Environment Type: {env_type}")

# Run the fixed inference loop
run_inference_dqn_folder(MODEL_PATH, SCENARIO_FOLDER, env_type, seed)


Model Path: ../trained_models/dqn/proactive_3ac-50.zip
Environment Type: proactive
Inference ID: 0209
Checking for corresponding training ID for model: ../trained_models/dqn/proactive_3ac-50.zip and environment type: proactive
Training ID: 0202
Training Config Variables: {'myopic_or_proactive': 'proactive', 'model_type': 'dqn', 'training_id': '0202', 'MODEL_SAVE_PATH': '../trained_models/dqn/proactive_3ac-50.zip', 'N_EPISODES': 10, 'num_scenarios_training': 10, 'results_dir': '../results/dqn/20241129-11-32', 'CROSS_VAL_FLAG': 1, 'CROSS_VAL_INTERVAL': 0.2, 'np': "<module 'numpy' from '/Users/pieterbecking/Desktop/Boeing-ADM-DRL-Github/.venv/lib/python3.10/site-packages/numpy/__init__.py'>", 'MAX_AIRCRAFT': 3, 'MAX_FLIGHTS_PER_AIRCRAFT': 12, 'ROWS_STATE_SPACE': 4, 'COLUMNS_STATE_SPACE': 39, 'ACTION_SPACE_SIZE': 52, 'DEPARTURE_AFTER_END_RECOVERY': 1, 'BREAKDOWN_PROBABILITY': 0.9, 'BREAKDOWN_DURATION': 197.65962950781957, 'TIMESTEP_HOURS': 1, 'DUMMY_VALUE': -999, 'RESOLVED_CONFLICT_REWARD'

IntSlider(value=0, description='Step:', max=3)

Output()

NameError: name 'log_scenario_data' is not defined