In [None]:
from training_rl.offline_rl.custom_envs.custom_envs_registration import register_grid_envs
import warnings
from training_rl.offline_rl.load_env_variables import load_env_variables
from training_rl.offline_rl.custom_envs.utils import Grid2DInitialConfig
from training_rl.offline_rl.custom_envs.custom_2d_grid_env.obstacles_2D_grid_register import ObstacleTypes
from training_rl.offline_rl.custom_envs.custom_envs_registration import CustomEnv
from training_rl.offline_rl.scripts.visualizations.utils import snapshot_env
from training_rl.offline_rl.custom_envs.custom_envs_registration import RenderMode
from training_rl.offline_rl.custom_envs.utils import InitialConfigCustom2DGridEnvWrapper
from training_rl.offline_rl.behavior_policies.behavior_policy_registry import BehaviorPolicyType
from training_rl.offline_rl.generate_custom_minari_datasets.generate_minari_dataset_grid_envs import \
    create_combined_minari_dataset
from training_rl.offline_rl.offline_policies.offpolicy_rendering import offpolicy_rendering
from training_rl.offline_rl.scripts.visualizations.utils import get_state_action_data_and_policy_grid_distributions
from training_rl.offline_rl.utils import load_buffer_minari, state_action_histogram
from training_rl.offline_rl.offline_trainings.policy_config_data_class import TrainedPolicyConfig
from training_rl.offline_rl.offline_policies.policy_registry import PolicyName
from training_rl.offline_rl.offline_trainings.offline_training import offline_training
import torch
from training_rl.offline_rl.offline_trainings.policy_config_data_class import get_trained_policy_path
import os
from training_rl.offline_rl.offline_trainings.restore_policy_model import restore_trained_offline_policy
import gymnasium as gym


load_env_variables()
warnings.filterwarnings("ignore")
# ToDo: this should be load automatically
register_grid_envs()

# Exercise

**In this notebook we will deal with another importat property that we should fulfill with a robust offline RL algorithm, the stitching, i.e. the reuse of different trajectories contain in the data to obtain the best trajecroy in our dataset.**


The goal will be to reach a target at (7,7) starting from (0,0). We will use again the 8x8 grid environment. Our dataset contains trajectories covering our space of interest but generated for different tasks (note that before we collected data for the same task) . One is a suboptimal policy that moves the agent from (0,0) to (7,0) and the other is a deterministic an optimal one (human expert) that brings the agent from (4,0) to (7,7). We have obviously much more data coming from the suboptimal policy than the expert one as it is cheaper.

So we will create the two policies:

I  - **Suboptimal expert policy** (behavior_8x8_moves_downwards_within_strip):  moves agent in suboptimal way downwards starting from (0,0) (collect 8000 steps)

II - **Optimal expert poilcy**(behavior_8x8_deterministic_4_0_to_7_7): moves agent in the optimal path from (4,0) to (7,7) (collect 300 steps) 


In this example we will use again as off-policy RL algorithm, the Deep Q-Network (DQN) algorithm.

Let's setup our configuration and create the environment

## Environment

In [None]:
ENV_NAME = CustomEnv.Grid_2D_8x8_discrete

# Grid configuration
OBSTACLE = ObstacleTypes.vertical_object_8x8
INITIAL_STATE = (0, 0)
FINAL_STATE = (7, 7)

env_2D_grid_initial_config = Grid2DInitialConfig(
    obstacles=OBSTACLE,
    initial_state=INITIAL_STATE,
    target_state=FINAL_STATE,
)

env = InitialConfigCustom2DGridEnvWrapper(gym.make(ENV_NAME, render_mode=RenderMode.RGB_ARRAY_LIST), env_config=env_2D_grid_initial_config)
snapshot_env(env)

## Configure the two datasets

In [None]:
DATA_SET_IDENTIFIER_I = "_downwards_"
BEHAVIOR_POLICY_I = BehaviorPolicyType.behavior_8x8_moves_downwards_within_strip
NUM_STEPS_I = 8000

DATA_SET_IDENTIFIER_II = "_optimal_"
BEHAVIOR_POLICY_II = BehaviorPolicyType.behavior_8x8_deterministic_4_0_to_7_7
NUM_STEPS_II = 300

## Create combined Minari dataset

In [None]:
config_combined_data = create_combined_minari_dataset(
        env_name=ENV_NAME,
        dataset_identifiers = (DATA_SET_IDENTIFIER_I, DATA_SET_IDENTIFIER_II),
        num_colected_points = (NUM_STEPS_I, NUM_STEPS_II),
        behavior_policy_names = (BEHAVIOR_POLICY_I, BEHAVIOR_POLICY_II),
        combined_dataset_identifier = "_stiching",
        env_2d_grid_initial_config = env_2D_grid_initial_config,
)

## Rendering behavioral policy

In [None]:
# Suboptimal policy

offpolicy_rendering(
    env_or_env_name=ENV_NAME,
    render_mode=RenderMode.RGB_ARRAY_LIST,
    behavior_policy_name=BEHAVIOR_POLICY_I,
    env_2d_grid_initial_config=env_2D_grid_initial_config,
    num_frames=1000,
)

In [None]:
# Expert policy
offpolicy_rendering(
    env_or_env_name=ENV_NAME,
    render_mode=RenderMode.RGB_ARRAY_LIST,
    behavior_policy_name=BEHAVIOR_POLICY_II,
    env_2d_grid_initial_config=env_2D_grid_initial_config,
    num_frames=1000,
)

## State-action distribution

In [None]:
name_combined_dataset = config_combined_data.data_set_name

#Create Buffers with minari datasets
buffer_data = load_buffer_minari(name_combined_dataset)

# Compute state-action data distribution
state_action_count_data, _ = get_state_action_data_and_policy_grid_distributions(buffer_data, env)
state_action_histogram(state_action_count_data, title="State-Action data distribution", inset_pos_xy=(-0.1, -0.012))

snapshot_env(env)

## Policy to train

In [None]:
POLICY_NAME = PolicyName.dqn

NAME_EXPERT_DATA = name_combined_dataset
# TrainedPolicyConfig is a handy object that will help us to deal with the policy configuration data.
offline_policy_config = TrainedPolicyConfig(
    name_expert_data=NAME_EXPERT_DATA,
    policy_name=POLICY_NAME,
    render_mode=RenderMode.RGB_ARRAY_LIST,
    device="cpu"
)


## Training

In [None]:
# Run the training

NUM_EPOCHS = 50
BATCH_SIZE = 256
UPDATE_PER_EPOCH = 100

# After every epoch we will collect some test statistics from the policy from NUMBER_TEST_ENVS independent envs.
NUMBER_TEST_ENVS = 1
EXPLORATION_NOISE = True
SEED = None #1626


offline_training(
    offline_policy_config=offline_policy_config,
    num_epochs = NUM_EPOCHS,
    number_test_envs=NUMBER_TEST_ENVS,
    update_per_epoch=UPDATE_PER_EPOCH,
    restore_training=False,
)

## Restore policy

In [None]:
POLICY_FILE = "policy.pth"

# restore a policy with the same configuration as the one we trained.
policy = restore_trained_offline_policy(offline_policy_config)
# load the weights
name_expert_data = offline_policy_config.name_expert_data
log_name = os.path.join(name_expert_data, POLICY_NAME)
log_path = get_trained_policy_path(log_name)
policy.load_state_dict(torch.load(os.path.join(log_path, POLICY_FILE), map_location="cpu"))


## Let's visualize the policy

In [None]:
env_2D_grid_initial_config.obstacles = OBSTACLE.obst_free_8x8
env = InitialConfigCustom2DGridEnvWrapper(gym.make(ENV_NAME, render_mode=RenderMode.RGB_ARRAY_LIST), env_config=env_2D_grid_initial_config)

offpolicy_rendering(
    env_or_env_name=env,
    render_mode=RenderMode.RGB_ARRAY_LIST,
    policy_model=policy,
    env_2d_grid_initial_config=env_2D_grid_initial_config,
    num_frames=1000,
    imitation_policy_sampling=False
)

## Questions:

1 - What do you notice? What happens if you increase the collected data? Is it better? 

2 - Try again with the offline BCQ algorithm as we did in the previous example. What happens now?