In [1]:
import gymnasium as gym
import dill
import torch
import numpy as np
import random
import os

from four_room.env import FourRoomsEnv
from four_room.wrappers import gym_wrapper
from four_room.shortest_path import find_all_action_values
from four_room.utils import obs_to_state
from four_room_extensions import fourrooms_dataset_gen
from d3rlpy.algos import DiscreteBCConfig
from d3rlpy.metrics import EnvironmentEvaluator, TDErrorEvaluator, DiscreteActionMatchEvaluator, evaluate_transformer_with_environment
from d3rlpy.datasets import MDPDataset
from d3rlpy.logging import WanDBAdapterFactory
from d3rlpy.ope import FQEConfig, DiscreteFQE
from d3rlpy import load_learnable
import wandb
import utils
from datetime import datetime
import imageio
from functools import partial
from tqdm import tqdm
from utils import get_DQN_checkpoints, create_env
import pickle
from four_room_extensions.fourrooms_dataset_gen import get_mixed_policy_dataset


ModuleNotFoundError: No module named 'numpy.random.seed'

In [None]:
seed = 42
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)

batch_size = 100 # default value from library
learning_rate = 0.001 # default value from library

n_epochs = 100
n_steps_per_epoch = 50  # Evaluation done after thin many steps, but we can change that logic

train_config_path = 'train'
reachable_test_config_path = 'test_100'
unreachable_test_config_path = 'test_0'
render = False
wandb_project_name = "BC"
device = True if torch.cuda.is_available() else None


## if using stored mixed data, use these
mixed_data_file = ""
wandb_run_name = f"{mixed_data_file}-BC_mixed"
DQN_mixed_data_path = os.path.join("/kaggle/working/offline_multi_task_rl", "datasets", "dataset_from_models_", mixed_data_file)

## if simulating mixed_data, use these
# best_policy = True
# DQN_models_path = os.path.join("/kaggle/working/offline_multi_task_rl", "four_room_extensions", "DQN_models", "performance_per_model.txt")
# episode_length = [0, 25, 50, 75, 100]
# wandb_run_name = f"{episode_length}-best_policy-100epochs-50stepsPerEpoch"


wandb_config = {
                "n_epochs": n_epochs,
                "n_steps_per_epoch": n_steps_per_epoch,
                }

In [None]:
train_config = fourrooms_dataset_gen.get_config(train_config_path)
train_dataset, train_env, tasks_finished, tasks_failed = fourrooms_dataset_gen.get_expert_dataset_from_config(train_config, render=render, render_name="DT_train_expert")

train_dataset = MDPDataset(
    observations=train_dataset.get("observations"),
    actions=train_dataset.get("actions"),
    rewards=train_dataset.get("rewards"),
    terminals=train_dataset.get("terminals"),
)

test_config_reachable = fourrooms_dataset_gen.get_config(reachable_test_config_path)
test_dataset_reachable, test_env_reachable, tasks_finished, tasks_failed = fourrooms_dataset_gen.get_expert_dataset_from_config(test_config_reachable, render=render, render_name="DT_test_expert_reachable")

test_dataset_reachable = MDPDataset(
    observations=test_dataset_reachable.get("observations"),
    actions=test_dataset_reachable.get("actions"),
    rewards=test_dataset_reachable.get("rewards"),
    terminals=test_dataset_reachable.get("terminals"),
)

test_config_unreachable = fourrooms_dataset_gen.get_config(unreachable_test_config_path)
test_dataset_unreachable, test_env_unreachable, tasks_finished, tasks_failed = fourrooms_dataset_gen.get_expert_dataset_from_config(test_config_unreachable, render=render, render_name="DT_test_expert_unreachable")

test_dataset_unreachable = MDPDataset(
    observations=test_dataset_unreachable.get("observations"),
    actions=test_dataset_unreachable.get("actions"),
    rewards=test_dataset_unreachable.get("rewards"),
    terminals=test_dataset_unreachable.get("terminals"),
)

train_env = create_env(train_config)
# checkpoints = get_DQN_checkpoints(DQN_models_path, episode_length, best_policy=best_policy)
# mixed_dataset, finished, failed = get_mixed_policy_dataset(train_config, train_env, checkpoints)
with open(DQN_mixed_data_path, 'rb') as f:
            mixed_dataset = pickle.load(f)

mixed_dataset = MDPDataset(
    observations=mixed_dataset.get("observations"),
    actions=mixed_dataset.get("actions"),
    rewards=mixed_dataset.get("rewards"),
    terminals=mixed_dataset.get("terminals"),
)

In [None]:
def model_saver_d3rlpy_callback(algo, epoch, total_step, n_epochs, n_steps_per_epoch, title_addition = ""):
    """
    Callback to save the model at the end of each epoch

    Args:
        algo: The algorithm object
        epoch: The current epoch
        total_step: The total number of steps taken so far
        n_epochs: The total number of epochs
        n_steps_per_epoch: The number of steps in each epoch
    # """
    algo.save(f"dt_{title_addition}_model_at_step_{total_step}_{datetime.now().strftime('%Y%m%d-%H%M%S')}.d3")
        
def eval_model(policy, env, n_episodes):
    total_reward = 0
    n_steps_taken = 0
    for _ in range(n_episodes):
        policy.reset()
        observation, reward = env.reset(seed=seed)[0], 0.0

        done = False
        while not done:
            # take action
            n_steps_taken += 1
            action = policy.predict(observation, reward)

            observation, _reward, terminated, truncated, _ = env.step(action)
            reward = float(_reward)
            total_reward += reward
            done = terminated or truncated

    return total_reward / n_episodes, n_steps_taken / n_episodes


In [None]:
BC = DiscreteBCConfig(batch_size=batch_size, learning_rate=learning_rate).create(device=device)

model_saver_d3rlpy_callback_partial = partial(model_saver_d3rlpy_callback, n_epochs=n_epochs, n_steps_per_epoch=n_steps_per_epoch, title_addition=wandb_run_name)

train_env = utils.ObservationFlattenerWrapper(train_env)
test_env_reachable = utils.ObservationFlattenerWrapper(test_env_reachable)
test_env_unreachable = utils.ObservationFlattenerWrapper(test_env_unreachable)

with wandb.init(project=wandb_project_name, name=wandb_run_name, config=wandb_config):
    for epoch in tqdm(range(n_epochs)):
        BC.fit(mixed_dataset, n_steps=n_steps_per_epoch, n_steps_per_epoch=n_steps_per_epoch, epoch_callback=model_saver_d3rlpy_callback_partial, show_progress=False, save_interval=1000)

        train_eval_score, train_num_steps = eval_model(BC.as_stateful_wrapper(target_return=1, action_sampler=None),
                                                                 train_env,
                                                                 len(train_config["topologies"])
                                                                 )
        test_reachable_eval_score, test_reachable_num_steps = eval_model(BC.as_stateful_wrapper(target_return=1, action_sampler=None),
                                                               test_env_reachable,
                                                               len(test_config_reachable["topologies"])
                                                               )
        test_unreachable_eval_score, test_unreachable_num_steps = eval_model(BC.as_stateful_wrapper(target_return=1, action_sampler=None),
                                                                test_env_unreachable,
                                                                len(test_config_unreachable["topologies"])
                                                                )
        
        wandb.log({"Cumulative Reward": {"Train": train_eval_score, "Test_reachable": test_reachable_eval_score, "Test_unreachable": test_unreachable_eval_score}}, step=(epoch+1) * n_steps_per_epoch)
        wandb.log({"Number of steps taken": {"Train": train_num_steps, "Test_reachable": test_reachable_num_steps, "Test_unreachable": test_unreachable_num_steps}}, step=(epoch+1) * n_steps_per_epoch)
    

# save final model
BC.save(f"dt_final_model_{datetime.now().strftime('%Y%m%d-%H%M%S')}.d3")