# Simulator

In [8]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import my_nb_path  # isort: skip
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import whatif as wi
from stable_baselines3.common.evaluation import evaluate_policy
from whatif.nbtools import pprint, print  # Enable color outputs when rich is installed.

# from whatif.utils import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [9]:
%%time

################################################################################
# Load dataset
################################################################################
wi_df = wi.read_csv_dataset(wi.sample_dataset_path("chiller"))
wi_df.add_value()

# Speed up training for demo purpose
wi_df = wi_df.iloc[:1000]
tokenizer = wi.AutoTokenizer(wi_df, block_size_row=2)
# display(
#     tokenizer.df.head(2),
#     tokenizer.df_tokenized.head(2),
# )


################################################################################
# Train simulator
################################################################################
# Default hyperparam is located at src/whatif/config.yaml. Alternative you can
# (1) specify your own configuration file using config_dir and config_name, or
# (2) passing in the configuration as parameter config.
# Refer to GPTWrapper for more info.

from habana_frameworks.torch.utils.library_loader import load_habana_module

load_habana_module()

model_dir = "./"
model_name = "model_chiller.pt"
wrapper = wi.GPTWrapper(model_dir, model_name, tokenizer)
wrapper.fit()

[32m2022-05-19 08:29:10.964[0m | [1mINFO    [0m | [36mwhatif.simulator[0m:[36mfit[0m:[36m417[0m - [1m{'sequences': 40, 'epochs': 5, 'batch_size': 512, 'embedding_dim': 512, 'gpt_n_layer': 1, 'gpt_n_head': 1, 'learning_rate': '6e-4', 'num_workers': 4, 'lr_decay': True}[0m


Loading Habana modules from /usr/local/lib64/python3.8/site-packages/habana_frameworks/torch/lib


  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

[32m2022-05-19 08:29:38.339[0m | [1mINFO    [0m | [36mwhatif.simulator[0m:[36mfit[0m:[36m445[0m - [1mTraining time in mins: 0.46[0m


CPU times: user 1min 2s, sys: 14.4 s, total: 1min 17s
Wall time: 30.8 s



[1;35mGPT[0m[1m([0m
  [1m([0mtok_emb[1m)[0m: [1;35mEmbedding[0m[1m([0m[1;36m341[0m, [1;36m512[0m[1m)[0m
  [1m([0mdrop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.1[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
  [1m([0mblocks[1m)[0m: [1;35mSequential[0m[1m([0m
    [1m([0m[1;36m0[0m[1m)[0m: [1;35mBlock[0m[1m([0m
      [1m([0mln1[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m512[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
      [1m([0mln2[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m512[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
      [1m([0mattn[1m)[0m: [1;35mCausalSelfAttention[0m[1m([0m
        [1m([0mkey[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m512[0m, [33mout_features[0m=[1;36m512[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
     

In [None]:
################################################################################
# Plot
################################################################################
wrapper.evaluate(context_len=5, sample=False, horizon=50)
wrapper.evaluate(context_len=5, sample=True, horizon=50);

## Get Recommendation



In [None]:
simulator = wi.Simulator(tokenizer, wrapper.model)
simulator.tokenizer.df_tokenized.head(2)

Get a custom context sequence. 

**Note:** The sequence should ends with state, i.e. (s,a,r...s)

In [None]:
custom_context = tokenizer.df_tokenized_seq[:7]
custom_context

### One step sample

`sample` returns a dataframe whose columns are (actions, reward, value, next states) given the
context. The contents of the dataframe is in the original space (approximated).

In [None]:
recommendation_df = simulator.sample(custom_context, max_size=10, as_token=False)
recommendation_df

## Build Your Own Planner

If you want to build your own planner, `whatif` provides a few lower level api.

### Get valid actions

`get_valid_actions` return a dataframe of potential action (in tokenized forms) given the context.

Let's get some custom context, assume always up to current states, and find out the next top_k actions.

In [None]:
valid_actions = simulator.get_valid_actions(custom_context, max_size=2)
valid_actions

### One step lookahead

`lookahead` return reward and next states, given the context and action.

Let pick an action to simulate the reward and next states. This api does not change the simulator internal counter and states

In [None]:
import numpy as np
import whatif as wi

custom_context = np.array([0, 100])
action_seq = [valid_actions.loc[0, "chiller_configuration"]]
print(f"Given the context: {custom_context} and action: {action_seq}\n")

reward, next_states = simulator.lookahead(custom_context, action_seq)
print(f"{reward=}")
print(f"{next_states=}")

## Gym

Get a gym compatible simulator using `SimulatorWrapper`.

In [None]:
sim_wrapper = wi.SimulatorWrapper(env=simulator)

Get the action to gym encoding mapping. Gym expect action to be a list of continuous integer.

In [None]:
sim_wrapper.gym_action_to_enc

In [None]:
sim_wrapper.reset()

In [None]:
obs, reward, done, info = sim_wrapper.step([0])
obs, reward

## 3rd Party Tools 

Use with 3rd party package like `stable_baseline3`.

In [None]:
%%time

from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.ppo import MlpPolicy

model = PPO(MlpPolicy, sim_wrapper, verbose=0)
mean_reward, std_reward = evaluate_policy(model, sim_wrapper, n_eval_episodes=1)

print(f"Mean reward:{mean_reward:.2f} +/- {std_reward:.2f}")