In [None]:
import warnings
from examples.offline_RL_workshop import load_env_variables
import minari
from examples.offline_RL_workshop.behavior_policies.behavior_policy_registry import BehaviorPolicyType
from examples.offline_RL_workshop.custom_envs.custom_2d_grid_env.obstacles_2D_grid_register import ObstacleTypes
from examples.offline_RL_workshop.custom_envs.utils import Grid2DInitialConfig
from examples.offline_RL_workshop.generate_custom_minari_datasets.generate_minari_dataset_grid_envs import \
    create_minari_datasets, MinariDatasetConfig
from examples.offline_RL_workshop.generate_custom_minari_datasets.utils import generate_compatible_minari_dataset_name, \
    get_dataset_name_2D_grid
from examples.offline_RL_workshop.custom_envs.custom_envs_registration import CustomEnv
from examples.offline_RL_workshop.custom_envs.custom_envs_registration import RenderMode
from examples.offline_RL_workshop.offline_policies.policy_registry import PolicyName
from examples.offline_RL_workshop.offline_trainings.offline_training import offline_training
from examples.offline_RL_workshop.offline_trainings.policy_config_data_class import TrainedPolicyConfig
from examples.offline_RL_workshop.utils import state_action_histogram
from examples.offline_RL_workshop.visualizations.utils import get_state_action_data_and_policy_grid_distributions
from examples.offline_RL_workshop.custom_envs.utils import InitialConfigCustom2DGridEnvWrapper
from examples.offline_RL_workshop.custom_envs.custom_envs_registration import register_grid_envs
import gymnasium as gym
from examples.offline.utils import load_buffer_minari
from examples.offline_RL_workshop.behavior_policies.behavior_policy_rendering import render_rgb_frames, snapshot_env
import cv2
import torch
from examples.offline_RL_workshop.offline_trainings.policy_config_data_class import get_trained_policy_path
import os
from examples.offline_RL_workshop.offline_trainings.restore_policy_model import restore_trained_offline_policy
from examples.offline_RL_workshop.utils import compare_state_action_histograms
from tianshou.data import Collector
from examples.offline_RL_workshop.behavior_policies.behavior_policy_rendering import behavior_policy_rendering
from examples.offline_RL_workshop.offline_policies.offpolicy_rendering import offpolicy_rendering
from examples.offline_RL_workshop.generate_custom_minari_datasets.generate_minari_dataset_grid_envs import create_minari_config_from_dict
from examples.offline_RL_workshop.utils import delete_minari_data_if_exists
from minari import combine_datasets
from copy import copy
from examples.offline_RL_workshop.load_env_variables import load_env_variables

load_env_variables()

warnings.filterwarnings("ignore")
# ToDo: this should be load automatically
register_grid_envs()

**In this notebook we will deal with another importat property that we should fulfill with a robust offline RL algorithm, the stiching, i.e. the reuse of different trajectories contain in the data to obtain the best trajecroy in our dataset.**


The goal will be to reach a target at (7,7) starting from (0,0). We will use again the 8x8 grid environment. Our dataset contains trajectories covering our space of interest but generated for different tasks. One is a suboptimal policy that moves the agent from (0,0) to (7,0) and the other is a determinitic an optimal one (human expert) that brings the agent from (4,0) into (7,7). We have obviously much more data coming from the suboptial policy than the expert one as it is cheaper.

So we will create the two policies:

I  - **Suboptimal expert policy**:  moves agent in suboptimal way downwards starting from (0,0) (collect 8000 steps)

II - **Optimal expert poilcy**: moves agent in the optimal path from (4,0) to (7,7) (collect 300 steps) 


In this example we will use again as off-policy RL algorithm, the Deep Q-Network (DQN) algorithm.

Let's setup our configuration and create the environment

In [None]:
ENV_NAME = CustomEnv.Grid_2D_8x8_discrete
BEHAVIOR_POLICY_I = BehaviorPolicyType.behavior_8x8_moves_downwards_within_strip
NUM_STEPS_I = 3000
BEHAVIOR_POLICY_II = BehaviorPolicyType.behavior_8x8_deterministic_4_0_to_7_7
NUM_STEPS_II = 300

OFFLINE_POLICY = PolicyName.dqn

# Grid configuration
OBSTACLE = ObstacleTypes.verical_object_8x8
INITIAL_STATE = (0, 0)
FINAL_STATE = (7, 7)

env_2D_grid_initial_config = Grid2DInitialConfig(
    obstacles=OBSTACLE,
    initial_state=INITIAL_STATE,
    target_state=FINAL_STATE,
)

env = InitialConfigCustom2DGridEnvWrapper(gym.make(ENV_NAME, render_mode=RenderMode.RGB_ARRAY_LIST), env_config=env_2D_grid_initial_config)

Let's give a look to our environment and both policies configurations

In [None]:
# Suboptimal policy
behavior_policy_rendering(
    env_or_env_name=ENV_NAME,
    render_mode=RenderMode.RGB_ARRAY_LIST,
    behavior_policy_name=BEHAVIOR_POLICY_I,
    env_2d_grid_initial_config=env_2D_grid_initial_config,
    num_frames=1000,
)

In [None]:
# Expert policy
behavior_policy_rendering(
    env_or_env_name=ENV_NAME,
    render_mode=RenderMode.RGB_ARRAY_LIST,
    behavior_policy_name=BEHAVIOR_POLICY_II,
    env_2d_grid_initial_config=env_2D_grid_initial_config,
    num_frames=1000,
)

### EXERCISE:

Create both minari datasets and combine them into a single one. This is what you will need to do in realistic situations!

To combine the datasets you will need to use:

combined_dataset = combine_datasets(
    minari_datasets, new_dataset_id=name_combined_dataset
)

with minari_datasets a list with your two datasets.

Let's collect the data

In [None]:
DATA_SET_NAME = "data"
DATA_SET_IDENTIFIER_I = "_move_downwards"
DATA_SET_IDENTIFIER_II = "_move_deterministic"
VERSION_DATA_SET = "v0"

# Create metadata config for set I
minari_dataset_config_I = create_minari_config_from_dict(
env_name=ENV_NAME,
dataset_name=DATA_SET_NAME,
data_set_identifier=DATA_SET_IDENTIFIER_I,
version_dataset=VERSION_DATA_SET,
num_steps=NUM_STEPS_I,
behavior_policy_name=BEHAVIOR_POLICY_I,
env_2d_grid_initial_config=env_2D_grid_initial_config
)

# Create metadata config for set II
minari_dataset_config_II = create_minari_config_from_dict(
env_name=ENV_NAME,
dataset_name=DATA_SET_NAME,
data_set_identifier=DATA_SET_IDENTIFIER_II,
version_dataset=VERSION_DATA_SET,
num_steps=NUM_STEPS_II,
behavior_policy_name=BEHAVIOR_POLICY_II,
env_2d_grid_initial_config=env_2D_grid_initial_config
)

create_minari_datasets(minari_dataset_config_I)
create_minari_datasets(minari_dataset_config_II)

We will now combine both datasets in a single one as we will do in a realistic scenario.

In [None]:
MINARI_DATASET_I = "Grid_2D_8x8_discrete-data_verical_object_8x8_start_0_0_target_7_7_move_downwards-v0"
MINARI_DATASET_II = "Grid_2D_8x8_discrete-data_verical_object_8x8_start_0_0_target_7_7_move_deterministic-v0"
NAME_COMBINED_DATASET = "combined_data_sets_vertical_obstacle"

name_combined_dataset = generate_compatible_minari_dataset_name(
    env_name=ENV_NAME,
    data_set_name=NAME_COMBINED_DATASET,
    version="V0"
)

delete_minari_data_if_exists(name_combined_dataset)

list_dataset_names = [MINARI_DATASET_I, MINARI_DATASET_II]
minari_datasets = [
    minari.load_dataset(dataset_id) for dataset_id in list_dataset_names
]

combined_dataset = combine_datasets(
    minari_datasets, new_dataset_id=name_combined_dataset
)


print(f"Number of episodes in dataset I:{len(minari_datasets[0])}, in dataset I:{len(minari_datasets[1])} and  "
      f"in the combined dataset: {len(combined_dataset)}")


# Create metadata for the combined dataset (we can reuse the metadata of set I for simplicity)
minari_combined_dataset = MinariDatasetConfig.load_from_file(MINARI_DATASET_I)
minari_combined_dataset.num_steps = NUM_STEPS_I + NUM_STEPS_II
minari_combined_dataset.data_set_name = name_combined_dataset
minari_combined_dataset.save_to_file()


Let's give a look to the state-action distribution

In [None]:
#Create Buffers with minari datasets
buffer_data = load_buffer_minari(name_combined_dataset)

# Compute state-action data distribution
state_action_count_data, _ = get_state_action_data_and_policy_grid_distributions(buffer_data, env)
state_action_histogram(state_action_count_data, title="State-Action data distribution", inset_pos_xy=(-0.1, -0.005))

snapshot_env(env)

Let's train the DQN policy!

In [None]:
NAME_EXPERT_DATA = name_combined_dataset

# The model policy to be trained.
POLICY_NAME = PolicyName.cql


NUM_EPOCHS = 50
BATCH_SIZE = 128
UPDATE_PER_EPOCH = 100

# After every epoch we will collect some test statistics from the policy from NUMBER_TEST_ENVS independent envs.
NUMBER_TEST_ENVS = 1
EXPLORATION_NOISE = True
SEED = None #1626

# TrainedPolicyConfig is a handy object that will help us to deal with the policy configuration data.
offline_policy_config = TrainedPolicyConfig(
    name_expert_data=NAME_EXPERT_DATA,
    policy_name=POLICY_NAME,
    render_mode=RenderMode.RGB_ARRAY_LIST,
    device="cpu"
)

# Run the training
offline_training(
    offline_policy_config=offline_policy_config,
    num_epochs = NUM_EPOCHS,
    number_test_envs=NUMBER_TEST_ENVS,
    update_per_epoch=UPDATE_PER_EPOCH,
    restore_training=False,
)

Let's restore the policy

In [None]:
POLICY_FILE = "policy_final.pth"

# restore a policy with the same configuration as the one we trained.
policy = restore_trained_offline_policy(offline_policy_config)
# load the weights
name_expert_data = offline_policy_config.name_expert_data
log_name = os.path.join(name_expert_data, POLICY_NAME)
log_path = get_trained_policy_path(log_name)
policy.load_state_dict(torch.load(os.path.join(log_path, POLICY_FILE), map_location="cpu"))

In [None]:
env_2D_grid_initial_config.initial_state=(0,5)

env = InitialConfigCustom2DGridEnvWrapper(
    gym.make(ENV_NAME, render_mode=RenderMode.RGB_ARRAY_LIST), 
    env_config=env_2D_grid_initial_config#
)


offpolicy_rendering(
    env_or_env_name=env,
    render_mode=RenderMode.RGB_ARRAY_LIST,
    policy_model=policy,
    env_2d_grid_initial_config=env_2D_grid_initial_config,
    num_frames=1000,
    imitation_policy_sampling=False
)

Let's give a look to the policy state-action distribution

In [None]:
NUM_EPISODES_FOR_STATISTICS = 100 # as more episodes the better

# compute statistics
state_action_count_data, state_action_count_policy = \
    get_state_action_data_and_policy_grid_distributions(
    buffer_data, 
    env, 
    policy, 
    num_episodes=NUM_EPISODES_FOR_STATISTICS,
    logits_sampling=True,
)

# plots
state_action_histogram(state_action_count_data)
state_action_histogram(state_action_count_policy)
compare_state_action_histograms(state_action_count_data, state_action_count_policy)
snapshot_env(env)

Let's visualize the policy

In [None]:
#final_collector = Collector(policy, env, exploration_noise=EXPLORATION_NOISE)
#final_collector.collect(n_episode=20, render=1 / 35)

#ToDo: Sole error in DQN visualization !!!!

offpolicy_rendering(
    env_or_env_name=env,
    render_mode=RenderMode.RGB_ARRAY_LIST,
    policy_model=policy,
    env_2d_grid_initial_config=env_2D_grid_initial_config,
    num_frames=1000,
    imitation_policy_sampling=False
)


Conclusion:

1 - DQN cannot find the optimal solution encoded in a few expert data episodes in the policy II.
2 - DQN produces as expected a drastic reduction of the effective dataset dimension. This is fine in this simple case but as mentioned before it could be that under-represented states-actions in the dataset will appear during inference .... ????? Clarify this.
3 - .....