In [None]:
import os
import pandas as pd
from typing import List, Tuple
import numpy as np
import matplotlib
import json
from train_utils import OffPolicyTrainer
from world_cup_env import WorldCupEnv
from agent import DoubleDQNAgent

font = {'size': 16}

matplotlib.rc('font', **font)

In [None]:
DATASET_PATH = "../dataset"
INDEX_FIELD = "timestamp"
DATA_FIELD = "num_request"
CPD_CANDIDATE_ROOT = "../05_binseg_series/binseg_cpd_candidate"
N_LOOKBACK = 4
N_PREDICT = 2

SAVED_REWARD_ROOT = "saved_reward"
STATE_DICT_ROOT = "state_dict"

In [None]:
def get_data_file_list(dataset_path: str) -> List[str]:
    return os.listdir(dataset_path)

In [None]:
def read_dataset(csv_path: str,index_field:str,data_field:str) -> Tuple[np.ndarray, np.ndarray]:
    df = pd.read_csv(csv_path)
    return df[index_field].to_numpy(), df[data_field].to_numpy()

In [None]:
def read_candidate_cpds(path: str) -> List[int]:
    candidate_cpds = None
    with open(path, "r") as f:
        candidate_cpds = json.load(f)
    return candidate_cpds

In [None]:
workload_to_skip_list = ["workload_1998-06-13", "workload_1998-06-14", "workload_1998-06-20", "workload_1998-06-21", "workload_1998-06-27", "workload_1998-06-28","workload_1998-07-04"]

In [None]:
if not os.path.exists(SAVED_REWARD_ROOT):
    os.makedirs(SAVED_REWARD_ROOT)
if not os.path.exists(STATE_DICT_ROOT):
    os.makedirs(STATE_DICT_ROOT)
data_file_list = get_data_file_list(DATASET_PATH)
state_list = None
action_list = None
env = None
agent = DoubleDQNAgent(2)
trainer = OffPolicyTrainer(env, agent, num_episodes=150, replay_buffer_size=128, batch_size=32, discount_factor=0.9, epsilon_start=0.5, epsilon_end=0.1, epsilon_step=20, learning_rate_start=1e-3, learning_rate_end=1e-4, learning_rate_step=100, tau=0.05)
for file_name in data_file_list:
    workload_name = file_name.split(".")[0]
    if workload_name in workload_to_skip_list:
        continue
    print("read %s" % (file_name))
    np_index, np_data = read_dataset(os.path.join(DATASET_PATH, file_name), INDEX_FIELD, DATA_FIELD)
    np_data = np_data/20000.0
    workload_diff = np.diff(np_data).reshape((-1, 1))
    candidate_cpds = read_candidate_cpds(os.path.join(CPD_CANDIDATE_ROOT, workload_name+".json"))
    env = WorldCupEnv(workload_diff, candidate_cpds, N_LOOKBACK, N_PREDICT)
    trainer.set_env(env)
    _, reward_per_episode = trainer.train()
    state_list, action_list = trainer.eval()
    with open(os.path.join(SAVED_REWARD_ROOT, workload_name+".json"), "w") as f:
        json.dump(reward_per_episode, f, indent=4)
    agent.save(os.path.join(STATE_DICT_ROOT, workload_name+".pth"))