# Setup and Installations

In [None]:
import utils
import visualizations_and_metrics as vm
import env_manager as env_manager
import algo_trainer as algo_trainer
from typing import SupportsIndex

In [None]:
def chain_training(manager: env_manager, generator: env_manager.EnvManager.env_generator, algo_agent, running_result: list):
    if len(running_result) != 0: 
        # take the best config from the previous training 
        best = running_result[-1].get_best_result("env_runners/episode_reward_max", "max")
        
        # Initialize the environment manager with new route file
        rou, csv = next(generator)
        manager.initialize_env(rou, csv)
        
        # continue the training with the best config
        algo_agent.config = best.config
        algo_agent.build_config()
    
    result = algo_agent.train()
    
    return result

def training(num_intersection: int, experiment_type: str, algo_config: str, env_config: str, num_training: SupportsIndex):
    running_result = []
    sumo_type = "SingleAgent"
    algo_type = experiment_type.split("_")
     
    if experiment_type.__contains__("Multi"):
        sumo_type = "MultiAgent"
    
    # Initialize the environment manager
    manager = env_manager.EnvManager(f"{sumo_type}Environment", env_config, intersection_id=f"intersection_{num_intersection}")
    generator = manager.env_generator(f"Nets/intersection_{num_intersection}/route_xml_path_intersection_{num_intersection}.txt", algo_name=algo_type[0])
    
    # Initialize the environment manager with new route file
    rou, csv = next(generator)
    manager.initialize_env(rou, csv)
    
    algo_agent = algo_trainer.ALGOTrainer(config_path=algo_config, env_manager=manager, experiment_type=experiment_type)
    algo_agent.build_config()

    for i in range(num_training):
        chain_result = chain_training(manager=manager, generator=generator, algo_agent=algo_agent, running_result=running_result)
        if chain_result is not None:
            running_result.append(chain_result)
    
    return running_result

In [None]:
num_intersection_to_train = 4  # Choose which intersection you want to train

# Choose the experiment_type:
# PPO_SingleAgent | DQN_SingleAgent | DDQN_SingleAgent | PPO_MultiAgent | DQN_MultiAgent | DDQN_MultiAgent
experiment_type = "DQN_SingleAgent"  

num_training_cycles = 1

env_config_file_path = "env_config.json"

ppo_config_file_path = "ppo_config.json"

dqn_config_file_path = "dqn_config.json"

In [None]:
results = training(num_intersection=num_intersection_to_train, experiment_type=experiment_type, algo_config=dqn_config_file_path, env_config=env_config_file_path, num_training=num_training_cycles)

In [None]:
# Notify when done
message = 'if you got this message, the training is done! send whatsapp to matan'
recipient = 'eviatar109@icloud.com'  # Replace with your iCloud email
utils.send_imessage(message, recipient)

In [None]:
# Save the results to a CSV file
utils.save_custom_metrics_to_csv(results, num_intersection_to_train, experiment_type, cycle_index=2)
