In [1]:
from environment.realuser import RealUserEnvironment
from server.nlu import NLU
from typing import Tuple
from data.dataset import GraphDataset, ReimburseGraphDataset, DataAugmentationLevel
from data.parsers.parserValueProvider import ReimbursementRealValueBackend
from data.parsers.answerTemplateParser import AnswerTemplateParser
from data.parsers.systemTemplateParser import SystemTemplateParser
from data.parsers.logicParser import LogicTemplateParser
from utils.utils import AutoSkipMode
from algorithm.dqn.dqn import CustomDQN
import torch
from data.cache import Cache
from gymnasium import spaces, Env
from encoding.state import StateEncoding

from hydra import compose, initialize
from omegaconf import OmegaConf
from hydra.core.config_store import ConfigStore
from config import register_configs

cs = ConfigStore.instance()
register_configs()

In [2]:
cfg_name = "reimburse_realdata_terminalobs"
ckpt_path = '/mount/arbeitsdaten/asr-2/vaethdk/cts_newcodebase_weights/run_1694965093/best_eval/weights/tmp'

In [3]:
def load_env(data: GraphDataset) -> RealUserEnvironment:
    # setup data & parsers
    answerParser = AnswerTemplateParser()
    logicParser = LogicTemplateParser()
    sysParser = SystemTemplateParser()
    valueBackend = ReimbursementRealValueBackend(a1_laender=data.a1_countries, data=data.hotel_costs)
    nlu = NLU()

    # setup env
    env = RealUserEnvironment(dataset=data, nlu=nlu,
                        sys_token="SYSTEM", usr_token="USER", sep_token="",
                        max_steps=50, max_reward=150, user_patience=2,
                        answer_parser=answerParser, logic_parser=logicParser, value_backend=valueBackend,
                        auto_skip=AutoSkipMode.NONE, stop_on_invalid_skip=False)
    return env

In [4]:
def to_class(path:str):
    from pydoc import locate
    class_instance = locate(path)
    return class_instance

In [5]:

## NOTE: assumes already unzipped checkpoint!
from config import DialogLogLevel, WandbLogLevel
from algorithm.dqn.her import HindsightExperienceReplayWrapper
import gymnasium as gym

def load_model(ckpt_path: str, cfg_name: str, device: str, data: GraphDataset) -> Tuple[CustomDQN, StateEncoding]:
    # load config
    cfg_path = "./conf/"

    with initialize(version_base=None, config_path=cfg_path):
        # parse config
        print("Parsing config...")
        cfg = compose(config_name=cfg_name)
        # print(OmegaConf.to_yaml(cfg))

        # disable logging
        cfg.experiment.logging.dialog_log = DialogLogLevel.NONE
        cfg.experiment.logging.wandb_log = WandbLogLevel.NONE
        cfg.experiment.logging.log_interval = 9999999
        cfg.experiment.logging.keep_checkpoints = 9

        # load encodings
        print("Loading encodings...")
        state_cfg = cfg.experiment.state
        action_cfg = cfg.experiment.actions
        cache = Cache(device=device, data=data, state_config=state_cfg, torch_compile=False)
        encoding = StateEncoding(cache=cache, state_config=state_cfg, action_config=action_cfg, data=data)

        # setup spaces
        action_space = gym.spaces.Discrete(encoding.space_dims.num_actions)
        if encoding.action_config.in_state_space == True:
            # state space: max. node degree (#actions) x state dim
            observation_space = gym.spaces.Box(low=float('-inf'), high=float('inf'), shape=(encoding.space_dims.num_actions, encoding.space_dims.state_vector,)) #, dtype=np.float32)
        else:
            observation_space = gym.spaces.Box(low=float('-inf'), high=float('inf'), shape=(encoding.space_dims.state_vector,)) #, dtype=np.float32)

        class CustomEnv(Env):
            def __init__(self, observation_space, action_space) -> None:
                self.observation_space = observation_space
                self.action_space = action_space
        dummy_env = CustomEnv(observation_space=observation_space, action_space=action_space)

        # setup model
        print("Settung up model...")
        net_arch = OmegaConf.to_container(cfg.experiment.policy.net_arch)
        net_arch['state_dims'] = encoding.space_dims # patch arguments
        optim = OmegaConf.to_container(cfg.experiment.optimizer)
        optim_class = to_class(optim.pop('class_path'))
        lr = optim.pop('lr')
        print("Optim ARGS:", optim_class, lr, optim)
        policy_kwargs = {
            "activation_fn": to_class(cfg.experiment.policy.activation_fn),   
            "net_arch": net_arch,
            "torch_compile": cfg.experiment.torch_compile,
            "optimizer_class": optim_class,
            "optimizer_kwargs": optim
        }
        replay_buffer_kwargs = {
            "num_train_envs": cfg.experiment.environment.num_train_envs,
            "batch_size": cfg.experiment.algorithm.dqn.batch_size,
            "dataset": data,
            "append_ask_action": False,
            # "state_encoding": state_encoding,
            "auto_skip": AutoSkipMode.NONE,
            "normalize_rewards": True,
            "stop_when_reaching_goal": cfg.experiment.environment.stop_when_reaching_goal,
            "stop_on_invalid_skip": cfg.experiment.environment.stop_on_invalid_skip,
            "max_steps": cfg.experiment.environment.max_steps,
            "user_patience": cfg.experiment.environment.user_patience,
            "sys_token": cfg.experiment.environment.sys_token,
            "usr_token": cfg.experiment.environment.usr_token,
            "sep_token": cfg.experiment.environment.sep_token,
            "alpha": cfg.experiment.algorithm.dqn.buffer.backend.alpha,
            "beta": cfg.experiment.algorithm.dqn.buffer.backend.beta,
            "use_lap": cfg.experiment.algorithm.dqn.buffer.backend.use_lap 
        }
        replay_buffer_class = HindsightExperienceReplayWrapper
        dqn_target_cls =  to_class(cfg.experiment.algorithm.dqn.targets._target_)
        dqn_target_args = {'gamma': cfg.experiment.algorithm.dqn.gamma}
        dqn_target_args.update(cfg.experiment.algorithm.dqn.targets) 
        model = CustomDQN(policy=to_class(cfg.experiment.policy._target_), policy_kwargs=policy_kwargs,
                    target=dqn_target_cls(**dqn_target_args),
                    seed=cfg.experiment.seed,
                    env=dummy_env, 
                    batch_size=cfg.experiment.algorithm.dqn.batch_size,
                    verbose=1, device=cfg.experiment.device,  
                    learning_rate=lr, 
                    exploration_initial_eps=cfg.experiment.algorithm.dqn.eps_start, exploration_final_eps=cfg.experiment.algorithm.dqn.eps_end, exploration_fraction=cfg.experiment.algorithm.dqn.exploration_fraction,
                    buffer_size=1, 
                    learning_starts=cfg.experiment.algorithm.dqn.warmup_turns,
                    gamma=cfg.experiment.algorithm.dqn.gamma,
                    train_freq=1, # how many rollouts to perform before training once (one rollout = num_train_envs steps)
                    gradient_steps=max(cfg.experiment.environment.num_train_envs // cfg.experiment.training.every_steps, 1),
                    target_update_interval=cfg.experiment.algorithm.dqn.target_network_update_frequency * cfg.experiment.environment.num_train_envs,
                    max_grad_norm=cfg.experiment.algorithm.dqn.max_grad_norm,
                    tensorboard_log=None,
                    replay_buffer_class=replay_buffer_class,
                    optimize_memory_usage=False,
                    replay_buffer_kwargs=replay_buffer_kwargs,
                    action_masking=cfg.experiment.actions.action_masking,
                    actions_in_state_space=cfg.experiment.actions.in_state_space
                ) 
        
        # restore weights
        print("Restoring weights...")
        ckpt_params = torch.load(f"{ckpt_path}/policy.pth", map_location=device)
        model.policy.load_state_dict(ckpt_params)
        model.policy.set_training_mode(False)
        model.policy.eval()
    return model, encoding

In [6]:
data = ReimburseGraphDataset('en/reimburse/test_graph.json', 'en/reimburse/test_answers.json', use_answer_synonyms=True, augmentation=DataAugmentationLevel.NONE, resource_dir='resources')
user_env = load_env(data)

===== Dataset Statistics =====
- files:  en/reimburse/test_graph.json en/reimburse/test_answers.json
- synonyms: True
- depth: 20  - degree: 13
- answers: 162
- questions: 173
- loaded original data: True
- loaded generated data: False


In [7]:
model, state_encoding = load_model(ckpt_path=ckpt_path, cfg_name=cfg_name, device='cpu', data=data)


Parsing config...
Loading encodings...
Loading Embedding (caching: False) encoding.text.sbert.SentenceEmbeddings ...


  from .autonotebook import tqdm as notebook_tqdm


Building tree embedding for nodes...
Done
Space dimensions: StateDims(state_vector=3932, action_vector=1, state_action_subvector=783, num_actions=14)
Settung up model...
Optim ARGS: <class 'torch.optim.adam.Adam'> 0.0001 {}
Using cuda:0 device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
HER BUFFER BACKEND PrioritizedLAPReplayBuffer
HER ENV!! TOKENS: SYSTEM: USER: 
ARCHITECUTRE OptimizedModule(
  (_orig_mod): CustomDuelingQNetworkWithIntentPrediction(
    (shared_net): ModuleList(
      (0): Linear(in_features=3149, out_features=4096, bias=True)
      (1): ReLU()
      (2): Linear(in_features=4096, out_features=4096, bias=True)
      (3): ReLU()
      (4): Linear(in_features=4096, out_features=4096, bias=True)
      (5): ReLU()
    )
    (action_input_net): ModuleList(
      (0): Linear(in_features=783, out_features=4096, bias=True)
      (1): ReLU()
      (2): Linear(in_features=4096, out_features=4096, bias=True)
      (3): ReLU()
      (4): Linear(in_

In [8]:
def next_action(obs: dict) -> Tuple[int, bool]:
    # encode observation
    s = state_encoding.batch_encode(observation=[obs], sys_token="SYSTEM", usr_token="USER", sep_token="") 
    # predict action & intent
    action, intent = model.predict(observation=s, deterministic=True)
    action = int(action)
    intent = intent.item()

    return action, intent


In [20]:
obs = user_env.reset()
done = False

while not done and user_env.current_user_utterance != "exit":
    action, intent = next_action(obs)
    print(f"  (policy: action {action}, is faq: {intent})")
    obs, reward, done = user_env.step(action=action)
    print(f"  (done: {done})")

What topic do you have questions about? You can either click on an answer from the suggested topics or enter your own text.
  (policy: action 0, is faq: True)
ASKING What topic do you have questions about? You can either click on an answer from the suggested topics or enter your own text.
  (done: False)
  (policy: action 1, is faq: True)
SKIPPING
-> TO Are you going on an intracity trip or a business trip?
  (done: False)
  (policy: action 1, is faq: True)
SKIPPING
-> TO Did you get written permission from your supervisor?
  (done: False)
  (policy: action 2, is faq: True)
SKIPPING
-> TO What country are you traveling to?
  (done: False)
  (policy: action 0, is faq: True)
ASKING What country are you traveling to?
  (done: False)
  (policy: action 1, is faq: True)
SKIPPING
-> TO What city are you traveling to?
  (done: False)
  (policy: action 1, is faq: True)
SKIPPING
-> TO {{COUNTRY
  (done: False)
  (policy: action 1, is faq: True)
SKIPPING
  (done: False)
  (policy: action 1, is fa

VisitError: Error trying to process rule "ge":

'>=' not supported between instances of 'str' and 'float'

In [13]:
user_env.current_episode_log

AttributeError: 'RealUserEnvironment' object has no attribute 'current_episode_log'

In [14]:
user_env.current_node

DialogNode.LOGIC(key: 16378316272591567, answers: 2, questions: 0)
        - connected_node: None
        - text: {{PRIVATE_EXTENSION
        