In [38]:
import sys
import os
# Add the parent directory to sys.path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))

In [39]:
from env.extended_states_generator import ExtendedStatesGenerator
import copy

In [41]:
from utils.config import load_yaml_config, merge_configs, dict_to_namespace
import random
import numpy as np
import torch
from env.cades_env import CadesEnv
from stable_baselines3.common.env_checker import check_env

default_yaml = "../utils/configs/default.yaml"
problem_3_yaml = "../utils/configs/problem_3.yaml"
trnc_c_yaml = "../utils/configs/experiment_trnc_c.yaml"

custom_config_files = [default_yaml, problem_3_yaml, trnc_c_yaml] # Add more config files if needed
configs = [load_yaml_config(config_file) for config_file in custom_config_files]
yaml = merge_configs(*configs)

## Here you can modify the yaml object to change the configuration
## Example: yaml.seed = 123, yaml.max_num_tasks = 10

config = dict_to_namespace(yaml)
# Set random seeds for reproducibility
random.seed(config.seed)
np.random.seed(config.seed)

torch.manual_seed(config.seed)

# Initialize and check the environment
env = CadesEnv(config)
check_env(env)

In [42]:
from models.ppo import PPOModel # You can also load recurrent model here
# loading model of p3_trnc_c early termination
model = PPOModel.load("../mlruns/295592906390268527/34f0c5cb22ce42bb8e74802b3a2bb8ae/artifacts/models/best_model.zip", env, config)

Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [79]:
def change_num_tasks_available(num_tasks):

    # Set the number of critical tasks and replicas
    critical_tasks = min(num_tasks, 3)
    # min_replicas = 0 if (num_tasks - critical_tasks) <= critical_tasks else 1
    max_replicas = (num_tasks - critical_tasks) // critical_tasks
    replicas = max_replicas

    # Change the states generator to generate states with the new number of tasks
    env.states_generator.min_num_tasks = num_tasks
    env.states_generator.max_num_tasks = num_tasks
    env.states_generator.num_critical_tasks = critical_tasks
    env.states_generator.num_replicas = replicas
    

In [83]:
num_evaluations = 3
for i in range(num_evaluations):
    print(f"Evaluation {i}:")

    ## Change the number of tasks and nodes
    change_num_tasks_available(random.randint(4, env.config.max_num_tasks))
    ## Generate states
    states = env.generate_states(training=False)
    
    if(env.config.max_num_tasks != env.states_generator.max_num_tasks): # If the state generator did not generate the maximum number of tasks
        ## Pad the states to the maximum number of tasks
        indices_to_pad = env.config.max_num_tasks - env.states_generator.max_num_tasks
        padded_tasks = np.pad(states['tasks'], (0, indices_to_pad), mode='constant')
        padded_critical_mask = np.pad(states['critical_mask'], (0, indices_to_pad), mode='constant')
        padded_communications = np.pad(states['communications'], ((0, indices_to_pad), (0, indices_to_pad)), mode='constant')
        states['tasks'] = padded_tasks
        states['critical_mask'] = padded_critical_mask
        states['communications'] = padded_communications
    
    ## Evaluate the model
    print("Before:")
    print(states)
    result = model.evaluate(states)
    print("After:")
    print(result)

Evaluation 0:
Before:
{'tasks': array([5, 0, 0, 0, 0, 0, 0, 0, 0, 0]), 'num_tasks': 1, 'critical_mask': array([1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), 'nodes': array([11, 10, 10, 10, 12, 10, 10, 12]), 'num_nodes': 8, 'communications': array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=uint8), 'num_communications': 0}
After:
{'obs': {'tasks': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), 'critical_mask': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), 'nodes': array([0.91666667, 0.83333333, 0.83333333, 0.83333333, 0.58333333,
       0.83333333, 0.83333333, 1.        ]), 'communications': array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0