In [1]:
from environment.realuser import RealUserEnvironment
from typing import Tuple
from data.dataset import GraphDataset, DataAugmentationLevel
from data.parsers.parserValueProvider import RealValueBackend
from data.parsers.answerTemplateParser import AnswerTemplateParser
from data.parsers.systemTemplateParser import SystemTemplateParser
from data.parsers.logicParser import LogicTemplateParser
from utils.utils import AutoSkipMode
from algorithm.dqn.dqn import CustomDQN
import torch
from data.cache import Cache
from gymnasium import spaces, Env
from encoding.state import StateEncoding
from stable_baselines3.common.save_util import load_from_zip_file


In [10]:
def load_env(data: GraphDataset) -> RealUserEnvironment:
    # setup data & parsers
    answerParser = AnswerTemplateParser()
    logicParser = LogicTemplateParser()
    sysParser = SystemTemplateParser()
    valueBackend = RealValueBackend(a1_laender=data.a1_countries, data=data.hotel_costs)

    # setup env
    env = RealUserEnvironment(dataset=data, 
                        sys_token="SYSTEM", usr_token="USER", sep_token="",
                        max_steps=50, max_reward=150, user_patience=2,
                        answer_parser=answerParser, logic_parser=logicParser, value_backend=valueBackend,
                        auto_skip=False, stop_on_invalid_skip=False)
    return env

In [3]:
def load_model(ckpt_path: str, device: str, data: GraphDataset) -> Tuple[CustomDQN, StateEncoding]:
    # unzip checkpoint
    ckpt_data, ckpt_params, ckpt_pytorch_variables = load_from_zip_file(
                ckpt_path,
                device=device,
                custom_objects=None,
                print_system_info=True)
    state_cfg = ckpt_data['configuration'].experiment.state
    action_cfg = ckpt_data['configuration'].experiment.actions

    # load encodings
    cache = Cache(device=device, data=data, state_config=state_cfg, torch_compile=False)
    encoding = StateEncoding(cache=cache, state_config=state_cfg, action_config=action_cfg, data=data)

    # setup model & restore weights
    class CustomEnv(Env):
        def __init__(self, observation_space, action_space) -> None:
            self.observation_space = observation_space
            self.action_space = action_space

    env = CustomEnv(observation_space=ckpt_data['observation_space'], action_space=ckpt_data['action_space'])
    model = CustomDQN(configuration=ckpt_data['configuration'],
                    policy=ckpt_data['policy_class'], policy_kwargs=ckpt_data['policy_kwargs'],
                    target=ckpt_data['target'],
                    seed=None,
                    batch_size=ckpt_data['batch_size'],
                    env=env, 
                    verbose=1, device='cpu',  
                    learning_rate=ckpt_data['learning_rate'], 
                    exploration_initial_eps=ckpt_data['exploration_initial_eps'], exploration_final_eps=ckpt_data['exploration_final_eps'], exploration_fraction=ckpt_data['exploration_fraction'],
                    buffer_size=ckpt_data['buffer_size'], 
                    learning_starts=ckpt_data['learning_starts'],
                    gamma=ckpt_data['gamma'],
                    train_freq=ckpt_data['train_freq'],
                    gradient_steps=ckpt_data['gradient_steps'],
                    target_update_interval=ckpt_data['target_update_interval'],
                    max_grad_norm=ckpt_data['max_grad_norm'],
                    tensorboard_log=None,
                    replay_buffer_class=ckpt_data['replay_buffer_class'],
                    optimize_memory_usage=ckpt_data['optimize_memory_usage'],
                    replay_buffer_kwargs=ckpt_data['replay_buffer_kwargs'],
                    action_masking=ckpt_data['action_masking'],
                    actions_in_state_space=ckpt_data['actions_in_state_space'])

    model.policy.load_state_dict(ckpt_params['policy'])
    model.policy.set_training_mode(False)
    return model, encoding

In [11]:
data = GraphDataset('en/test_graph.json', 'en/test_answers.json', use_answer_synonyms=True, augmentation=DataAugmentationLevel.NONE, augmentation_version=0)
user_env = load_env(data)
model, state_encoding = load_model("/mount/arbeitsdaten/asr-2/vaethdk/cts_newcodebase_weights/o8ga0gru/best_eval/weights/ckpt_1.pt", device='cpu', data=data)


===== Dataset Statistics =====
- files:  resources/en/test_graph.json resources/en/test_answers.json
- synonyms: True
- depth: 20  - degree: 13
- answers: 162
- questions: 173
- loaded original data: True
- loaded generated data: False
== SAVED MODEL SYSTEM INFO ==
- OS: Linux-6.0.7-100.fc35.x86_64-x86_64-with-glibc2.34 # 1 SMP PREEMPT_DYNAMIC Thu Nov 3 21:31:24 UTC 2022
- Python: 3.10.8
- Stable-Baselines3: 2.0.0a5
- PyTorch: 2.0.1+cu118
- GPU Enabled: True
- Numpy: 1.24.2
- Cloudpickle: 2.2.1
- Gymnasium: 0.28.1
- OpenAI Gym: 0.26.2

Loading Embedding (caching: False) encoding.text.sbert.SentenceEmbeddings ...
Building tree embedding for nodes...
Done
Space dimensions: StateDims(state_vector=3932, action_vector=1, state_action_subvector=783, num_actions=14)
Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
HER BUFFER BACKEND PrioritizedLAPReplayBuffer
HER ENV!! TOKENS: SYSTEM: USER: 
ARCHITECUTRE OptimizedModule(
  (_orig_mod): CustomDuelin

In [5]:
def next_action(obs: dict) -> Tuple[int, bool]:
    # encode observation
    s = state_encoding.batch_encode(observation=[obs], sys_token="SYSTEM", usr_token="USER", sep_token="") 
    # predict action & intent
    action, intent = model.predict(observation=s, deterministic=True)
    action = int(action)
    intent = intent.item()

    return action, intent


In [12]:
obs = user_env.reset()
done = False

while not done and user_env.current_user_utterance != "exit":
    action, intent = next_action(obs)
    print(f"  (policy: action {action}, is faq: {intent})")
    obs, reward, done = user_env.step(action=action)
    print(f"  (done: {done})")

What topic do you have questions about? You can either click on an answer from the suggested topics or enter your own text.
  (policy: action 4, is faq: True)
SKIPPING
-> TO What emergency are you experiencing?
  (done: False)
  (policy: action 2, is faq: True)
SKIPPING
-> TO If you are at fault, costs cannot be reimbursed.
  (done: False)
  (policy: action 1, is faq: True)
SKIPPING
-> TO If you are not at fault (e.g., a conference ran over), costs can be reimbursed if a suitable justifi
  (done: False)
  (policy: action 1, is faq: True)
SKIPPING
-> TO What topic do you have questions about? You can either click on an answer from the suggested topics 
  (done: False)
  (policy: action 4, is faq: True)
SKIPPING
-> TO What emergency are you experiencing?
  (done: False)
  (policy: action 2, is faq: True)
SKIPPING
-> TO If you are at fault, costs cannot be reimbursed.
  (done: False)
  (policy: action 1, is faq: True)
SKIPPING
-> TO If you are not at fault (e.g., a conference ran over), c

In [14]:
user_env.current_episode_log

 '28479535-1$ GOAL: 0 START',
 '28479535-1$ CONSTRAINTS:',
 '28479535-1$ INITIAL UTTERANCE: Can I bring my own car?',
 '28479535-1$ -> TURN REWARD: -1.0',
 '28479535-1$ -> USER UTTERANCE: ',
 '28479535-1$ TO NODE: userResponseNode - 16460436532310883 - What emergency are you experiencing?',
 '28479535-1$ -> TURN REWARD: -1.0',
 '28479535-1$ -> USER UTTERANCE: ',
 '28479535-1$ TO NODE: infoNode - 16460439592347465 - If you are at fault, costs cannot be reimbursed.',
 '28479535-1$ -> TURN REWARD: -1.0',
 '28479535-1$ -> USER UTTERANCE: ',
 '28479535-1$ TO NODE: infoNode - 16460439966919842 - If you are not at fault (e.g., a conference ran over), costs can be reimbursed if a suitable justifi',
 '28479535-1$ -> TURN REWARD: -1.0',
 '28479535-1$ -> USER UTTERANCE: ',
 '28479535-1$ TO NODE: userResponseNode - 16348058621438633 - What topic do you have questions about? You can either click on an answer from the suggested topics ',
 '28479535-1$ -> TURN REWARD: -1.0',
 '28479535-1$ -> USER UTT