In [None]:
import os
import random
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

from base_rl.eval_policy import EvalDiscreteStatePolicy
from dynamic_programming.mdp_model import MDPModel
from dynamic_programming.policy import DPPolicy
from envs.env_creator import env_creator
from envs.plot import plot_industrial_benchmark_trajectories
from rmin.train import RMinTrainer
from experiments.offline_experiment_configs import RMinExperimentConfig

In [None]:
plt.rcParams["figure.figsize"] = [20, 12]
fixed_digits = 6

In [None]:
# parameters

model_names = ['model_aeq-20bits3']
root_path = 'tmp'
training_episodes = [10, 100, 1000, 10000]
min_count = [1, 2, 3, 5]
total_epochs=500

## Load Data and MDP Model

In [None]:
if len(training_episodes) != len(min_count):
    raise Exception('training_episodes and min_count must be of same length')

trajectory_paths = [os.path.join(root_path, "offline_rl_trajectories", model, "rl_dataset.npy") for model in model_names]
steps_per_episode = 1000


experiment_configs = []
device = 'cpu'
for model_name in model_names:
    for i, training_episode in enumerate(training_episodes):
        experiment_configs.append(
            RMinExperimentConfig(
                model_name=model_name,
                model_path=os.path.join(root_path, 'state_quantization', model_name),
                dataset_path=os.path.join(root_path, "offline_rl_trajectories", model_name, "rl_dataset.npy"),
                mdp_path=os.path.join(root_path, 'rmin', 'mdp', model_name, f'{training_episode}', 'mdp_model.pkl'),
                policy_path=os.path.join(root_path, 'rmin', model_name, f'{training_episode}',
                                         'policy.pkl'),
                dataset_size=training_episode * steps_per_episode,
                r_min=min_count[i]
            )
        )



## Create MDP Models

In [None]:
from dynamic_programming.mdp_model import create_mdp_models

for config in experiment_configs:
    create_mdp_models(load_path=config.dataset_path, mdp_save_path=config.mdp_path, reward_function_type='state_action',
                      device=device, dataset_size=config.dataset_size)

## Train

In [None]:
def train_r_min(mdp_path, policy_save_path):
    mdp_model = MDPModel.load(mdp_path)
    solver = RMinTrainer(reward_function=mdp_model.reward_function, transition_model=mdp_model.transition_model,
                         count_state_action=mdp_model.count_state_action, min_count=min_count[i])
    solver.train(epochs=total_epochs, gamma=0.995)
    trained_policy = DPPolicy(policy_table=solver.get_policy(), state_to_index=mdp_model.state_to_index,
                              index_to_action=mdp_model.index_to_actions)
    trained_policy.save(policy_save_path)


for config in experiment_configs:
    print(config.mdp_path)
    train_r_min(mdp_path=config.mdp_path, policy_save_path=config.policy_path)




## Evaluate

In [None]:
from benchmarks.policy_benchmarks import PolicyBenchmarks

steps_per_episode = 1000
evaluators = []
for config in experiment_configs:
    print(config.__dict__)
    eval_policy = DPPolicy.load(config.policy_path)
    env_kwargs = {'steps_per_episode': steps_per_episode, 'device': device, 'model_path': config.model_path}
    evaluator = EvalDiscreteStatePolicy(policy=eval_policy, env_creator=env_creator, env_kwargs=env_kwargs,
                                        tag=f'{config.model_name}/{config.dataset_size}')
    evaluators.append(evaluator)

policy_benchmarks = PolicyBenchmarks(evaluators=evaluators, epochs=10)
policy_benchmarks.benchmark()

In [None]:
plot_industrial_benchmark_trajectories(policy_benchmarks.evaluators[-2].eval_trajectories[0]['info'])
np.mean(policy_benchmarks.evaluators[-2].eval_rewards_per_epoch)

In [None]:
df = pd.DataFrame(policy_benchmarks.benchmark_metrics)
df.T

In [None]:
plt.rcParams["figure.figsize"] = [20, 12]


for model_name in model_names:
    m = np.core.defchararray.find(df.columns.values.astype(str), model_name) >= 0
    fdf = df.loc[:, m]
    ax = fdf.plot.bar()

    for container in ax.containers:
        ax.bar_label(container)
    plt.show()