In [1]:
import time
from datetime import datetime
import os
import sys
import logging # 日志记录用
import gymnasium as gym
import selfhealing_env
from typing import Optional, Tuple, Union 
from stable_baselines3 import DQN
import torch

In [2]:
logging.DEBUG

10

In [3]:
def logger_obj(logger_name, level=logging.DEBUG, verbose=0):
    # 就是一个日志记录器.可以放到utils里
    """
    Method to return a custom logger with the given name and level
    """
    logger = logging.getLogger(logger_name)
    logger.setLevel(level)
    format_string = ("%(asctime)s - %(levelname)s - %(funcName)s (%(lineno)d):  %(message)s")
    datefmt = '%Y-%m-%d %I:%M:%S %p'
    log_format = logging.Formatter(format_string, datefmt)

    # Creating and adding the file handler
    file_handler = logging.FileHandler(logger_name, mode='a')
    file_handler.setFormatter(log_format)
    logger.addHandler(file_handler)

    if verbose == 1:
        # Creating and adding the console handler
        console_handler = logging.StreamHandler(sys.stdout)
        console_handler.setFormatter(log_format)
        logger.addHandler(console_handler)

    return logger

In [None]:
class TrainManager():
    
    def __init__(self,   
                 env:gym.Env,
                 log_output_path:Optional[str],
                 log_level:int = logging.DEBUG,
                 timer_enable: bool = False,
                 episode_num:int = 1000,
                 learning_rate: float = 0.001,
                 buffer_size: int = 100000,
                 learning_starts:int = 200,
                 batch_size:int = 50,
                 tau:float = 0.1,
                 gamma:float = 0.95,
                 exploration_final_eps:float = 0.05,
                 verbose:int = 1,
                 device:torch.device = torch.device("cpu"),
                 seed:Optional[int] = None
                 ):
        
        self.env = env
        self.device = device
        self.seed = seed
        self.episode_num = episode_num
        self.time_steps = self.env.system_data.NT
        self.timer_enable = timer_enable
        
        self.log_output_path = log_output_path
        self.log_level = log_level
        
        self.Agent = DQN(env = self.env,
                             learning_rate = learning_rate,
                             buffer_size = buffer_size,
                             learning_starts = learning_starts,
                             batch_size = batch_size,
                             tau = tau,
                             gamma = gamma,
                             exploration_final_eps = exploration_final_eps,
                             verbose = verbose,
                             device = self.device,
                             seed = self.seed
                             )
        
        self.set_loggers()
                 
        pass
    
    def set_loggers(self):
        
        now = datetime.now()
        dt_string = now.strftime("__%Y_%m_%d_%H_%M")
        self.dt_string = dt_string
        # check if the dir is given            
        if self.log_output_path is None:
            # if dir is not given, save results at root dir
            output_path = os.getcwd()
            log_output_path_name = output_path + '/' + "log" + dt_string + '.log'
        else:
            # if given, check if the saving directory exists
            # if not given, create dir
            if not os.path.isdir(self.log_output_path):
                os.makedirs(self.log_output_path)
            log_output_path_name = self.log_output_path + '/' + "log" + dt_string + '.log'
            
        self.logger = logger_obj(logger_name=log_output_path_name, level=self.log_level)  # log for debuging
        # self.success_logger = SuccessLogger(ENV_NAME_2, output_path, title='Behavior Cloning')  # log for performance evaluation
        
    def test_agent(self):
        test_options = {"Specific_Disturbance":None}
        s0, _ = self.env.reset(options=test_options,seed=self.seed)
        if self.timer_enable:
            print("Begin solving benchmark")
            stime = time.time()
            
        # ================ calculate Benchmark value to normalize the restoration as ratio ==================
        self.logger.info("-------------------Run_test begin--------------------")
        self.logger.info("The testing disturbance is {}".format(self.env.disturbance))
        # as a training benchmark, we only use tieline, which is the same action space with this agent
        
        
        pass
        
    
    def train(self):
        
        flag_convergence = False
        tic = time.perf_counter()  # start clock
        for idx_episode in range(self.episode_num):
            if idx_episode % 10 == 0: # 控制episode打印频率
                toc = time.perf_counter()
                print("===================================================")
                print(f"Training time: {toc - tic:0.4f} seconds; Mission {idx_episode:d} of {self.episode_num:d}")
                print("===================================================")
            self.logger.info(f"=============== Mission {idx_episode:d} of {self.episode_num:d} =================")
            # executes the expert policy and perform Deep Q learning
            self.agent.learn(total_timesteps=self.time_steps) # 这里5就是ENV中的NT
            # test: execute learned policy on the environment
            self.test_agent()
            