## Amazon SageMaker RL Result Evaluation

This notebook is to evaluate training job that has been completed.

In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
import sagemaker
from misc import get_execution_role, wait_for_s3_object
import ray
from ray.rllib.agents import dqn
import gym
from battery_env_sm import SimpleBattery

from nbutils import  plot_analysis

In [None]:
sage_session = sagemaker.session.Session()
s3_bucket = sage_session.default_bucket()  
s3_output_path = "s3://{}/".format(s3_bucket)
print("S3 bucket path: {}".format(s3_output_path))

## Get Model Checkpoint

In [None]:
job_name = "rl-battery-2021-05-10-09-49-21-387"

In [None]:
print("Job name: {}".format(job_name))

s3_url = "s3://{}/{}".format(s3_bucket,job_name)

intermediate_folder_key = "{}/output/intermediate/".format(job_name)
intermediate_url = "s3://{}/{}".format(s3_bucket, intermediate_folder_key)

print("S3 job path: {}".format(s3_url))
print("Intermediate folder path: {}".format(intermediate_url))
    
tmp_dir = "/tmp/{}".format(job_name)
os.system("mkdir {}".format(tmp_dir))
print("Create local folder {}".format(tmp_dir))

Download model checkpoint from s3 into `/tmp`

In [None]:
model_tar_key = "{}/output/model.tar.gz".format(job_name)
    
local_checkpoint_dir = "{}/model".format(tmp_dir)

wait_for_s3_object(s3_bucket, model_tar_key, tmp_dir, training_job_name=job_name)  

if not os.path.isfile("{}/model.tar.gz".format(tmp_dir)):
    raise FileNotFoundError("File model.tar.gz not found")
    
os.system("mkdir -p {}".format(local_checkpoint_dir))
os.system("tar -xvzf {}/model.tar.gz -C {}".format(tmp_dir, local_checkpoint_dir))

print("Checkpoint directory {}".format(local_checkpoint_dir))

checkpoint_path = f"{local_checkpoint_dir}/checkpoint"
print("checkpoint_path",checkpoint_path)

## Evaluation

In [None]:
import ray
from ray import tune
from ray.rllib.agents import dqn
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List
import pandas as pd
import warnings

def get_agent(checkpoint_path):
    def register_env_creator(env_name):
        tune.register_env(env_name, lambda env_config: SimpleBattery(env_config))

    # Alternatively to register custom env and pass to trainer, DQNTrainer(config=config, env=env_class)
    # env_class = "battery"
    # register_env_creator(env_class)

    config = dqn.DEFAULT_CONFIG.copy()
    config["num_workers"] = 1
    config["explore"] = False
    config["evaluation_config"] = {"explore": False}

    ray.shutdown()
    ray.init(local_mode=True)

    # Instantiate agent. Agent need env to be registered as it will be using tune behind the scene.
    # env: can pass in MyEnv(gym), or a registered environment (e.g. env_class)
    # region: HAHA begin HACK
    warnings.warn('HAHA: Hack to pass custom env_config to DQNTrainer')
    env_config = {"MAX_STEPS_PER_EPISODE": 168, "LOCAL": True, "FILEPATH": "../refdata/PRICE_AND_DEMAND_2020FULL_NSW1.csv"}
    class _SimpleBattery(SimpleBattery):
        def __init__(self, *args, **kwargs):
            super().__init__(env_config=env_config)
    agent = dqn.DQNTrainer(config=config, env=_SimpleBattery)
    # endregion: HAHA end HACK

    # Load trained model
    agent.restore(checkpoint_path)

    return agent


def evaluate_episode(agent):
    """
    Run evaluation over a single episode.

    Input:
        agent: trained agent.
    """
    evaluation_list: List = []
    done = False
    env_config = {"MAX_STEPS_PER_EPISODE": 168, "LOCAL": True, "FILEPATH": "../refdata/PRICE_AND_DEMAND_2020FULL_NSW1.csv"}
    env = SimpleBattery(env_config)
    state = env.reset()
    print(f"Index: {env.index}")
    total_rewards = 0

    while not done:
        action = agent.compute_action(state)
        next_state, reward, done, info = env.step(action)
        total_rewards += reward
        evaluation_list.append([reward] + [total_rewards] + [action] + state)
        state = next_state

    df_cols = [
        "reward",
        "total_reward",
        "action",
        "energy",
        "average_energy_cost",
        "market_electric_price",
        "price_t1",
        "price_t2",
        "price_t3",
        "price_t4",
        "price_t5",
    ]
    df_eval = pd.DataFrame(evaluation_list, columns=df_cols)
    return df_eval



## Analyze

In [None]:
agent = get_agent(checkpoint_path)

In [None]:
import numpy as np
np.random.seed(2)
df_eval = evaluate_episode(agent)
fig = plot_analysis(df_eval)

In [None]:
df_eval.to_csv("result_dqn.csv", index=False)