In [None]:
import os
import pandas as pd
from typing import List, Tuple
import numpy as np
import matplotlib
import json
from train_utils import OffPolicyTrainer
from world_cup_env import WorldCupEnv
from agent import DoubleDQNAgent

font = {'size': 16}

matplotlib.rc('font', **font)

In [None]:
DATASET_PATH = "signal"
N_LOOKBACK = 4
N_PREDICT = 2

REWARD_SAVE_ROOT="rl_reward"
WEIGHT_SAVE_ROOT="rl_weight"

In [None]:
def get_data_file_list(dataset_path: str) -> List[str]:
    return os.listdir(dataset_path)

In [None]:
def read_dataset(csv_path: str) -> Tuple[np.ndarray, np.ndarray]:
    df = pd.read_csv(csv_path)
    return df["signal"].to_numpy(), df["is_change_point"].to_numpy(dtype=np.int32)

In [None]:
if not os.path.exists(REWARD_SAVE_ROOT):
        os.makedirs(REWARD_SAVE_ROOT)
if not os.path.exists(WEIGHT_SAVE_ROOT):
    os.makedirs(WEIGHT_SAVE_ROOT)
state_list = None
action_list = None
env = None
agent = DoubleDQNAgent(2)
trainer = OffPolicyTrainer(env, agent, num_episodes=150, replay_buffer_size=128, batch_size=32, discount_factor=0.9, epsilon_start=0.5, epsilon_end=0.1, epsilon_step=20, learning_rate_start=1e-3, learning_rate_end=1e-4, learning_rate_step=100, tau=0.05)
for file_name in get_data_file_list(DATASET_PATH):
    signal, is_change_point = read_dataset(os.path.join(DATASET_PATH, file_name))
    # use part of the data to train rl agent, you can also use the whole dataset.
    signal=signal[:int(0.7*len(signal))]
    is_change_point=is_change_point[:int(0.7*len(is_change_point))]
    candidate_cpds = np.where(is_change_point == 1)[0]
    signal = signal/10.0
    workload_diff = np.diff(signal).reshape((-1, 1))
    env = WorldCupEnv(workload_diff, candidate_cpds, N_LOOKBACK, N_PREDICT)
    trainer.set_env(env)
    _, reward_per_episode = trainer.train()
    state_list, action_list = trainer.eval()
    with open(os.path.join(REWARD_SAVE_ROOT, file_name.split(".")[0]+".json"), "w") as f:
        json.dump(reward_per_episode, f, indent=4)
    agent.save(os.path.join(WEIGHT_SAVE_ROOT, file_name.split(".")[0]+".pth"))