In [30]:
import json
import time
import torch
import math
from typing import Optional
import numpy as np
import pickle

from pearl.replay_buffers import BasicReplayBuffer
from pearl.utils.instantiations.spaces import DiscreteActionSpace
from pearl import PearlAgent
from pearl.utils.functional_utils.train_and_eval.learning_logger import LearningLogger, null_learning_logger
from pearl.utils.functional_utils.experimentation.set_seed import set_seed
from pearl.utils.functional_utils.train_and_eval.offline_learning_and_evaluation import TRAINING_TAG
from pearl.replay_buffers import TransitionBatch, ReplayBuffer
from pearl.policy_learners.sequential_decision_making.double_dqn import DoubleDQN
from pearl.neural_networks.sequential_decision_making import VanillaContinuousActorNetwork
from pearl.policy_learners.sequential_decision_making import ImplicitQLearning
from pearl.utils.functional_utils.train_and_eval.offline_learning_and_evaluation import offline_evaluation,offline_learning
from pearl.neural_networks.sequential_decision_making.q_value_networks import VanillaQValueNetwork, VanillaQValueMultiHeadNetwork
from pearl.action_representation_modules.one_hot_action_representation_module import (
    OneHotActionTensorRepresentationModule,
)


set_seed(42)


In [31]:
FRONTIER_COUNT:int = 6 # the maximum number of selected frontiers
FRONTIER_FEATURES:int = 6 # the number of features measured for each frontier
OTHER_FRONTIER_INPUTS:int = 3 # other observations (current % explored area)

OBSERVATION_SPACE:int = FRONTIER_COUNT*FRONTIER_FEATURES+OTHER_FRONTIER_INPUTS
ACTION_SPACE:int = FRONTIER_COUNT

In [32]:
replay_buffer_size = 1_000_000
device = "cpu"
is_action_continuous = False
data_file_path = "data/rl-run-data.json"
max_number_actions_if_discrete = ACTION_SPACE

In [33]:
with open(data_file_path, "r") as f:
    data_transitions = json.load(f)

In [44]:

offline_data_replay_buffer = BasicReplayBuffer(replay_buffer_size)
if is_action_continuous:
    offline_data_replay_buffer._is_action_continuous = True
count = 0
for map_name in data_transitions:
    for i in range(len(data_transitions[map_name])):
        transitions = data_transitions[map_name][i]["run"]
        for j in range(len(transitions["obs"])):
            transition = {}
            transition["observation"] = transitions["obs"][j]
            transition["action"] = np.argmax(transitions["action"][j])
            transition["next_observation"] = transitions["next_obs"][j]
            transition["reward"] = transitions["reward"][j]
            transition["curr_available_actions"] = DiscreteActionSpace(
                actions=list(
                    torch.arange(len(transitions["action"][j])).view(-1, 1)
                )
            )
            transition["next_available_actions"] = DiscreteActionSpace(
                actions=list(
                    torch.arange(len(transitions["action"][j])).view(-1, 1)
                )
            )
            transition["done"] = False
            
            assert len(transitions["action"][j]) == ACTION_SPACE
            assert len(transition["observation"]) == OBSERVATION_SPACE
            assert len(transition["next_observation"]) == OBSERVATION_SPACE
            assert len(transition["reward"]) == 1
            
            count += 1
            offline_data_replay_buffer.push(
                state=transition["observation"],
                action=transition["action"],
                reward=transition["reward"],
                next_state=transition["next_observation"],
                curr_available_actions=transition["curr_available_actions"],
                next_available_actions=transition["next_available_actions"],
                terminated=transition["done"],
                truncated=False,
                max_number_actions=max_number_actions_if_discrete,
            )
print(f"{count} transitions saved")

488915


In [45]:
Q_value_network = VanillaQValueMultiHeadNetwork(
    state_dim=OBSERVATION_SPACE,  # dimension of the state representation
    action_dim=ACTION_SPACE,                    # dimension of the action representation
    hidden_dims=[1024, 1024],                   # dimensions of the intermediate layers
    output_dim=ACTION_SPACE
)    

action_space = DiscreteActionSpace(
    actions=list(torch.arange(ACTION_SPACE).view(-1, 1))
)

agent = PearlAgent(
    policy_learner=DoubleDQN(
        state_dim=OBSERVATION_SPACE,
        action_space=action_space,
        batch_size=512,
        training_rounds=10,
        soft_update_tau=0.75,
        network_instance=Q_value_network, # pass an instance of Q value network to the policy learner.
        action_representation_module=OneHotActionTensorRepresentationModule(
            max_number_actions=ACTION_SPACE
        ),
    )
)


In [36]:
# Number of training epochs
training_epochs = 50000
experiment_seed = 100

offline_learning(
    offline_agent=agent,
    data_buffer=offline_data_replay_buffer, # replay buffer created using the offline data
    training_epochs=training_epochs,
    seed=experiment_seed,
)

  1%|          | 647/117188 [00:09<28:24, 68.37it/s]


KeyboardInterrupt: 

In [28]:
with open("agent.pkl", "wb") as f:
    pickle.dump(agent, f)
    
with open("Q_func.pkl", "wb") as f:
    pickle.dump(agent.policy_learner._Q._model, f)

In [21]:
trans = offline_data_replay_buffer.sample(1)
agent.reset(trans.state[0],action_space)
agent.act(exploit=True)

tensor([3])

In [27]:
agent.policy_learner._Q._model

KeyboardInterrupt: 