# Setup and Installations

In [None]:
import utils
import visualizations_and_metrics as vm
import env_manager as env_manager
import algo_trainer as algo_trainer
from typing import SupportsIndex

In [None]:
def chain_training(manager: env_manager, generator: env_manager.EnvManager.env_generator, algo_agent, running_result_grid: list):
    if len(running_result_grid) != 0: 
        # take the best config from the previous training 
        best = running_result_grid[-1].get_best_result(metric="env_runners/episode_reward_mean", mode="max", scope="all")
        
        # Initialize the environment manager with new route file
        rou, csv = next(generator)
        manager.initialize_env(rou, csv)
        
        # continue the training with the best config
        algo_agent.config = algo_agent.from_dict(best.config)
        algo_agent.build_config(flag=True)
        
    result = algo_agent.train()
    
    return result

def training(num_intersection: int, experiment_type: str, algo_config: str, env_config: str, num_training: SupportsIndex):
    final_running_result = []
    sumo_type = "SingleAgent"
    algo_type = experiment_type.split("_")
     
    if experiment_type.__contains__("Multi"):
        sumo_type = "MultiAgent"
    
    # Initialize the environment manager
    manager = env_manager.EnvManager(f"{sumo_type}Environment", env_config, intersection_id=f"intersection_{num_intersection}")
    generator = manager.env_generator(f"Nets/intersection_{num_intersection}/route_xml_path_intersection_{num_intersection}.txt", algo_name=algo_type[0])
    
    # Initialize the environment manager with new route file
    rou, csv = next(generator)
    manager.initialize_env(rou, csv)
    
    algo_agent = algo_trainer.ALGOTrainer(config_path=algo_config, env_manager=manager, experiment_type=experiment_type)
    algo_agent.build_config()
    for i in range(num_training):
        chain_result = chain_training(manager=manager, generator=generator, algo_agent=algo_agent, running_result_grid=final_running_result)
        if chain_result is not None:
            final_running_result.append(chain_result)

    print(f"Finished training for intersection: {num_intersection} with {num_training} training rounds")
    
    return final_running_result

In [None]:
# Choose which intersection you want to train
num_intersection_to_train = 1

# Choose the experiment_type:
# PPO_SingleAgent | DQN_SingleAgent | DDQN_SingleAgent | PPO_MultiAgent | DQN_MultiAgent | DDQN_MultiAgent
experiment_type = "DQN_SingleAgent"  

# Choose how many training cycles you want to run
num_training_cycles = 2

env_config_file_path = "env_config.json"

ppo_config_file_path = "ppo_config.json"

dqn_config_file_path = "dqn_config.json"

In [None]:
training(num_intersection=num_intersection_to_train, experiment_type=experiment_type, algo_config=dqn_config_file_path, env_config=env_config_file_path, num_training=num_training_cycles)

print(f"Finished training for intersection: {num_intersection_to_train} with {num_training_cycles} training rounds")

In [None]:
# Notify when done
experiment_date = env_manager.datetime.now().strftime("%m.%d-%H:%M:%S")
message = f'Training for intersection {num_intersection_to_train} with {experiment_type} and {num_training_cycles} from {experiment_date} is done!'  # Replace with your message
recipient = 'eviatar109@icloud.com'  # Replace with your iCloud email
utils.send_imessage(message, recipient)

In [None]:
result = results[-1]
from ray.rllib.algorithms.algorithm import Algorithm

best_result = result.get_best_result("env_runners/episode_reward_max", "max")
checkpoint_path = best_result.checkpoint.path
print(f'Best checkpoint path: {checkpoint_path}')

# Load the Algorithm from the checkpoint
algo = Algorithm.from_checkpoint(checkpoint_path)

# Retrieve the current configuration
new_config = algo.config.copy()
new_config["evaluation_duration"] = 2 # Define as many evaluation episodes as you want


# Re-create the algorithm instance with the updated configuration
algo = Algorithm.from_checkpoint(checkpoint_path)
algo.config = new_config

# Evaluate the Algorithm
eval_results = algo.evaluate()
print(eval_results)

In [None]:
utils.extract_and_write_all_params()