In [None]:
import os
import pandas as pd
from typing import List, Tuple
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import json
import torch
from t_test_cpd import TTestCPDetector
from agent import DoubleDQNAgent
from world_cup_env import WorldCupEnv
from agent import DoubleDQNAgent
from train_utils import OffPolicyTrainer

font = {'size': 24}

matplotlib.rc('font', **font)

In [None]:
DATASET_PATH = "../dataset"
INDEX_FIELD = "timestamp"
DATA_FIELD = "num_request"
CPD_CANDIDATE_ROOT = "../05_binseg_series/binseg_cpd_candidate"
AGENT_STATE_DICT = "../06_rl_warmup/state_dict"
N_LOOKBACK = 4
N_PREDICT = 2

CPD_RESULT_ROOT="ocpd_result"
CPD_IMG_ROOT="ocpd_img"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
def get_data_file_list(dataset_path: str) -> List[str]:
    return os.listdir(dataset_path)

In [None]:
def read_dataset(csv_path: str,index_field:str,data_field:str) -> Tuple[np.ndarray, np.ndarray]:
    df = pd.read_csv(csv_path)
    return df[index_field].to_numpy(), df[data_field].to_numpy()

In [None]:
def read_candidate_cpds(path: str) -> List[int]:
    candidate_cpds = None
    with open(path, "r") as f:
        candidate_cpds = json.load(f)
    return candidate_cpds

In [None]:
def ocpd(workload: np.ndarray, is_cpd_near_idx: List):
    change_idx = []
    t_max_warmup = 10
    t_init_warmup = 1
    detector = TTestCPDetector(t_max_warmup, 0.10)
    detector.set_t_warmup(t_init_warmup)
    is_freezed = False # mark if the warmup window is freezed or not
    for idx, x in enumerate(workload):
        if not idx in is_cpd_near_idx:
            if not is_freezed:
                current_t_warmup = min(t_max_warmup, detector.get_t_warmup()+1)
                detector.set_t_warmup(current_t_warmup)
        else:
            is_freezed = True
        _, is_change = detector.predict_next(x)
        if is_change:
            is_freezed = False
            detector.set_t_warmup(t_init_warmup)
            change_idx.append(idx)
    return change_idx

In [None]:
workload_to_skip_list = ["workload_1998-06-13", "workload_1998-06-14", "workload_1998-06-20", "workload_1998-06-21", "workload_1998-06-27", "workload_1998-06-28","workload_1998-07-04"]

In [None]:
if not os.path.exists(CPD_RESULT_ROOT):
    os.makedirs(CPD_RESULT_ROOT)
if not os.path.exists(CPD_IMG_ROOT):
    os.makedirs(CPD_IMG_ROOT)
data_file_list = get_data_file_list(DATASET_PATH)
for file_name in data_file_list:
    workload_name = file_name.split(".")[0]
    if workload_name in workload_to_skip_list:
        continue
    print("read %s" % (file_name))
    np_index, np_data = read_dataset(os.path.join(DATASET_PATH, file_name), INDEX_FIELD, DATA_FIELD)
    candidate_cpds = read_candidate_cpds(os.path.join(CPD_CANDIDATE_ROOT, workload_name+".json"))
    agent = DoubleDQNAgent(2)
    agent.load(os.path.join(AGENT_STATE_DICT, workload_name+".pth"))
    np_data = np_data/20000.0
    workload_diff = np.diff(np_data).reshape((-1, 1))
    env = WorldCupEnv(workload_diff, candidate_cpds, N_LOOKBACK, N_PREDICT)
    trainer = OffPolicyTrainer(env, agent, num_episodes=150, replay_buffer_size=128, batch_size=32, discount_factor=0.9, epsilon_start=0.5, epsilon_end=0.1, epsilon_step=20, learning_rate_start=1e-3, learning_rate_end=1e-4, learning_rate_step=100, tau=0.05)
    trainer.set_env(env)
    state_list, action_list = trainer.eval()
    np_data = np_data*20000.0
    is_cpd_near_idx = np.where(np.array(action_list) == 1)[0]+N_LOOKBACK
    ocpd_result = ocpd(np_data, is_cpd_near_idx)
    with open(os.path.join(CPD_RESULT_ROOT, workload_name+".json"), "w") as f:
        json.dump(ocpd_result, f, indent=4)

    fig, ax = plt.subplots()
    fig.set_size_inches(14, 7)
    ax.plot(np_data/10, color="#3F51B5")
    for cp in ocpd_result:
        ax.axvline(x=cp, color="#FF5722", linestyle='--', linewidth=1)
    ax.set_xlabel('time (min)')
    ax.set_ylabel('request number (x10 requests)')
    ax.set_title(workload_name.replace("_", " "))
    ax.grid(True, linestyle="--")
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    fig.savefig(os.path.join(CPD_IMG_ROOT, workload_name+".pdf"))
    plt.close()