In [1]:
# cd /content/drive/MyDrive/Second_English_Data/data

In [2]:
from collections import OrderedDict
import os
import torch
import torch.nn as nn
from itertools import chain

# NN

In [3]:
# coding = utf-8
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
import numpy as np
#from src.NNModule.utils import Attn, batch_embedding_lookup
#from src.config import FRAUD, NON_FRAUD, ManagerRewardDiscount, WorkerRewardDiscount, EPS, Pad_Query_Node


class Agent(nn.Module):
    def __init__(self, score_method, agent_state_size, answer_node_emb_size):
        super(Agent, self).__init__()
        self.policy_network = Attn(score_method, answer_node_emb_size, agent_state_size)
        self.value_network = nn.Sequential(
            nn.Linear(agent_state_size, agent_state_size // 2),
            nn.Tanh(),
            nn.Linear(agent_state_size // 2, 1)
        )
        # two vector to represent the fraud and non-fraud actions
        self.fraud_embed = Parameter(torch.Tensor(answer_node_emb_size))
        self.fraud_embed.data.uniform_(-1, 1)
        self.non_fraud_embed = Parameter(torch.Tensor(answer_node_emb_size))
        self.non_fraud_embed.data.uniform_(-1, 1)

    def forward(self, agent_state, answer_nodes, graph_node_embedding):
        """
        :param agent_state: (batch_size, agent_state_size)
        :param answer_nodes: (batch_size, answer_node_num)
        :param graph_node_embedding: (batch_size, node_num, node_feature_size)
        :return:
        values: (batch_size,)
        logits: (batch, answer_node_num + 2)
        """
        values = self.value_network(agent_state).squeeze(-1)

        batch_size = agent_state.shape[0]
        answer_node_embedding = batch_embedding_lookup(graph_node_embedding, answer_nodes)
        actions_embedding = torch.cat((answer_node_embedding,
                                       self.fraud_embed.repeat(batch_size, 1, 1),
                                       self.non_fraud_embed.repeat(batch_size, 1, 1)),
                                      dim=1)
        logits = self.policy_network(actions_embedding, agent_state)

        return values, logits


class Manager(nn.Module):
    def __init__(self, score_method, manager_state_size, worker_state_size):
        super(Manager, self).__init__()
        self.policy_network = Attn(score_method, worker_state_size, manager_state_size)
        self.value_network = nn.Sequential(
            nn.Linear(manager_state_size, manager_state_size // 2),
            nn.Tanh(),
            nn.Linear(manager_state_size // 2, 1)
        )

        # two vector to represent the fraud and non-fraud actions
        self.fraud_embed = Parameter(torch.Tensor(worker_state_size))
        self.fraud_embed.data.uniform_(-1, 1)
        self.non_fraud_embed = Parameter(torch.Tensor(worker_state_size))
        self.non_fraud_embed.data.uniform_(-1, 1)

    def forward(self, manager_state, workers_state):
        """
        :param manager_state: (batch_size, manager_state_size)
        :param workers_state: (batch_size, personal_node_num, worker_sate_size)
        :return:
        values: (batch_size,)
        logits: (batch_size, personal_node_num + 2)
        """
        values = self.value_network(manager_state).squeeze(-1)

        batch_size = manager_state.shape[0]
        actions_embedding = torch.cat((workers_state,
                                       self.fraud_embed.repeat(batch_size, 1, 1),
                                       self.non_fraud_embed.repeat(batch_size, 1, 1)),
                                      dim=1)
        logits = self.policy_network(actions_embedding, manager_state)

        return values, logits


class Workers(nn.Module):
    def __init__(self, score_method, worker_state_size, answer_node_emb_size):
        super(Workers, self).__init__()
        self.policy_networks = Attn(score_method, answer_node_emb_size, worker_state_size)
        self.value_network = nn.Sequential(
            nn.Linear(worker_state_size, worker_state_size // 2),
            nn.Tanh(),
            nn.Linear(worker_state_size // 2, 1)
        )

        # two vector to represent the fraud and non-fraud actions
        self.fraud_embed = Parameter(torch.Tensor(answer_node_emb_size))
        self.fraud_embed.data.uniform_(-1, 1)
        self.non_fraud_embed = Parameter(torch.Tensor(answer_node_emb_size))
        self.non_fraud_embed.data.uniform_(-1, 1)

    def forward(self, workers_state, answer_nodes, graph_node_embedding):
        """
        :param workers_state: (batch_size, personal_node_num, worker_state_size)
        :param answer_nodes: (batch_size, personal_node_num, answer_node_num)
        :param graph_node_embedding: (batch_size, node_num, node_feature_size)
        :return:
        values: (batch_size, personal_node_num)
        logits: (batch_size, personal_node_num, answer_node_num + 2)
        """
        values = self.value_network(workers_state).squeeze(-1)

        batch_size = answer_nodes.shape[0]
        personal_node_num = answer_nodes.shape[1]
        answer_node_num = answer_nodes.shape[2]

        answer_nodes = answer_nodes.reshape(batch_size, -1)
        answer_node_embedding = batch_embedding_lookup(graph_node_embedding, answer_nodes)
        answer_node_embedding = answer_node_embedding.reshape(batch_size, personal_node_num, answer_node_num, -1)
        actions_embedding = torch.cat((answer_node_embedding,
                                       self.fraud_embed.repeat(batch_size, personal_node_num, 1, 1),
                                       self.non_fraud_embed.repeat(batch_size, personal_node_num, 1, 1)),
                                      dim=2)

        workers_state = workers_state.reshape(batch_size * personal_node_num, -1)
        actions_embedding = actions_embedding.reshape(batch_size * personal_node_num, answer_node_num + 2, -1)

        logits = self.policy_networks(actions_embedding, workers_state)
        logits = logits.reshape(batch_size, personal_node_num, -1)
        return values, logits


def sample_from_prob_matrix(prob_matrix, sample_flag):
    """
    Sample n times based on prob matrix. The prob in i-th experiments is the i-th row of prob matrix.
    :param prob_matrix: (n_times, m_items)
    :param sample_flag: str, max, random, sample
    :return: choices: (n_times,)
    """
    if sample_flag == "sample":
        cumulative_prob = prob_matrix.cumsum(axis=1)
        uniform = np.random.rand(len(cumulative_prob), 1)
        choices = (uniform < cumulative_prob).argmax(axis=1)
    elif sample_flag == "max":
        choices = prob_matrix.argmax(axis=1)
    elif sample_flag == "random":
        new_prob_matrix = np.asarray(prob_matrix > 1000 * EPS, dtype=np.float32)
        items_matrix = new_prob_matrix.sum(axis=1, keepdims=True) + EPS
        choices = sample_from_prob_matrix(new_prob_matrix / items_matrix, sample_flag="sample")
    else:
        raise ValueError('Unknown Sample Flag.')
    return choices


def sample_hierarchy_rl(manager_action_probs, workers_action_probs, manager_actions, workers_actions, sample_flag):
    """
    Assume all inputs are np.array
    :param manager_action_probs: (batch_size, personal_node_num + 2)
    :param workers_action_probs: (batch_size, personal_node_num, answer_node_num + 2)
    :param manager_actions: (batch_size, personal_node_num)
    :param workers_actions: (batch_size, personal_node_num, answer_node_num)
    :param sample_flag:
    :return:
    manager_sample_idx: (batch_size,)
    manager_sample_result: (batch_size,)
    workers_sample_idx: (batch_size, personal_node_num)
    workers_sample_result: (batch_size, personal_node_num)
    """
    manager_sample_idx = sample_from_prob_matrix(manager_action_probs, sample_flag)

    batch_size = workers_action_probs.shape[0]
    personal_node_num = workers_action_probs.shape[1]
    workers_action_probs = workers_action_probs.reshape(batch_size * personal_node_num, -1)
    workers_sample_idx = sample_from_prob_matrix(workers_action_probs, sample_flag)

    manager_terminal_actions = np.tile(np.asarray([FRAUD, NON_FRAUD]), (batch_size, 1))
    manager_actions = np.concatenate((manager_actions, manager_terminal_actions), axis=1)
    manager_sample_result = manager_actions[np.arange(manager_actions.shape[0]), manager_sample_idx]

    worker_terminal_actions = np.tile(np.asarray([FRAUD, NON_FRAUD]), (batch_size, personal_node_num, 1))
    workers_actions = np.concatenate((workers_actions, worker_terminal_actions), axis=2)
    workers_actions = workers_actions.reshape(batch_size * personal_node_num, -1)
    workers_sample_result = workers_actions[np.arange(workers_actions.shape[0]), workers_sample_idx]

    workers_sample_result = workers_sample_result.reshape(batch_size, personal_node_num)
    workers_sample_idx = workers_sample_idx.reshape(batch_size, personal_node_num)

    return manager_sample_idx, manager_sample_result, workers_sample_idx, workers_sample_result


def sample_flatten_rl(action_probs, answer_nodes, query_nodes, sample_flag):
    """
    Assume all inputs are np.array
    :param action_probs: (batch_size, answer_node_num + 2)
    :param answer_nodes: (batch_size, answer_node_num)
    :param query_nodes: (batch_size, answer_node_num)
    :param sample_flag:
    :return:
    sample_idx: (batch_size,)
    sample_content: (batch_size, 2)
    """
    sample_idx = sample_from_prob_matrix(action_probs, sample_flag)
    batch_size = action_probs.shape[0]

    terminal_actions = np.tile(np.asarray([FRAUD, NON_FRAUD]), (batch_size, 1))
    pad_query_nodes = np.tile(np.asarray([Pad_Query_Node, Pad_Query_Node]), (batch_size, 1))
    answer_nodes = np.concatenate((answer_nodes, terminal_actions), axis=1)
    query_nodes = np.concatenate((query_nodes, pad_query_nodes), axis=1)
    sample_query = query_nodes[np.arange(query_nodes.shape[0]), sample_idx]
    sample_result = answer_nodes[np.arange(answer_nodes.shape[0]), sample_idx]

    sample_content = np.concatenate((sample_query[:, np.newaxis], sample_result[:, np.newaxis]), axis=1)

    return sample_idx, sample_content


def get_hierarchy_rl_returns(rewards, masks):
    """
    :param rewards: (batch_size, 1 + workers_num, time_steps)
    :param masks: (batch_size, 1 + workers_num, time_steps)
    :return: returns: (batch_size, 1 + workers_num, time_steps)
    """
    masks = masks.to(torch.uint8)

    gamma = torch.zeros_like(rewards)
    gamma[:, 0, :] = ManagerRewardDiscount
    gamma[:, 1:, :] = WorkerRewardDiscount
    gamma[1 - masks] = torch.ones_like(rewards)[1 - masks]

    batch_size = rewards.shape[0]
    agent_num = rewards.shape[1]
    time_steps = rewards.shape[2]
    rewards = rewards.reshape(batch_size * agent_num, -1)
    masks = masks.reshape(batch_size * agent_num, -1)
    gamma = gamma.reshape(batch_size * agent_num, -1)

    returns = torch.zeros_like(rewards)
    running_returns = torch.zeros_like(rewards[:, 0])
    for t in reversed(range(0, time_steps)):
        running_returns = rewards[:, t] * masks[:, t].to(running_returns.dtype) + gamma[:, t] * running_returns
        returns[:, t] = running_returns

    returns = returns.reshape(batch_size, agent_num, time_steps)
    masks = masks.reshape(batch_size, agent_num, time_steps)

    # Mask the invalid items
    returns = returns * masks.to(returns.dtype)

    return returns


def get_flatten_rl_returns(rewards, masks):
    """
    :param rewards: (batch_size, time_steps)
    :param masks: (batch_size, time_steps)
    :return: returns: (batch_size, time_steps)
    """
    gamma = ManagerRewardDiscount
    time_steps = rewards.shape[1]

    returns = torch.zeros_like(rewards)
    running_returns = torch.zeros_like(rewards[:, 0])
    for t in reversed(range(0, time_steps)):
        running_returns = rewards[:, t] * masks[:, t].to(running_returns.dtype) + gamma * running_returns
        returns[:, t] = running_returns

    # Mask the invalid items
    returns = returns * masks.to(returns.dtype)
    return returns


In [4]:
# coding = utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F
#from src.config import EPS


def batch_embedding_lookup(embeddings, indices):
    """
    Look up from a batch of embedding matrices.
    :param embeddings: (batch_size, num_words, embedding_size)
    :param indices: (batch_size, num_inds)
    :return:
    """
    shape = embeddings.shape
    batch_size = shape[0]
    num_words = shape[1]
    embed_size = shape[2]

    offset = torch.reshape(torch. arange(batch_size) * num_words, (batch_size, 1)).to(dtype=indices.dtype,
                                                                                     device=indices.device)
    flat_embeddings = torch.reshape(embeddings, (-1, embed_size)) # first as it is.
    flat_indices = torch.reshape(indices + offset, (-1,))
    embeds = torch.reshape(F.embedding(flat_indices, flat_embeddings), (batch_size, -1, embed_size)) # from pytorch
    return embeds


def mask_softmax(input_tensor, mask, dim):
    """
    Softmax with mask to input tensor.
    :param input_tensor: Tensor with any shape
    :param mask: same shape with input tensor
    :param dim: a dimension along which softmax will be computed
    :return:
    """
    exps = torch.exp(input_tensor)
    masked_exps = exps * mask.to(exps.dtype)
    masked_sums = masked_exps.sum(dim, keepdim=True) + EPS
    return masked_exps / masked_sums


class Attn(nn.Module):
    def __init__(self, method, encode_hidden_size, decode_hidden_size):
        super(Attn, self).__init__()
        if method.lower() not in ["dotted", "general", "concat"]:
            raise RuntimeError("Attention methods should be dotted, general or concat but get {}!".format(method))
        if method.lower() == "dotted" and encode_hidden_size != decode_hidden_size:
            raise RuntimeError("In dotted attention, the encode_hidden_size should equal to decode_hidden_size.")

        self.method = method.lower()
        self.encode_hidden_size = encode_hidden_size
        self.decode_hidden_size = decode_hidden_size

        if self.method == "general":
            self.attn = nn.Linear(self.encode_hidden_size, self.decode_hidden_size)
        elif self.method == "concat":
            self.attn = nn.Sequential(
                nn.Linear((self.encode_hidden_size + self.decode_hidden_size),
                          (self.encode_hidden_size + self.decode_hidden_size) // 2),
                nn.Tanh(),
                nn.Linear((self.encode_hidden_size + self.decode_hidden_size) // 2, 1)
            )

    def forward(self, encode_outputs, decode_state):
        """
        :param encode_outputs: (batch, output_length, encode_hidden_size)
        :param decode_state: (batch, decode_hidden_size)
        :return: energy: (batch, output_length)
        """
        output_length = encode_outputs.size(1)
        if self.method == "concat":
            decode_state_temp = decode_state.unsqueeze(1)
            decode_state_temp = decode_state_temp.expand(-1, output_length, -1)
            cat_encode_decode = torch.cat([encode_outputs, decode_state_temp], 2)
            energy = self.attn(cat_encode_decode).squeeze(-1)
        elif self.method == "general":
            decode_state_temp = decode_state.unsqueeze(1)
            mapped_encode_outputs = self.attn(encode_outputs)
            energy = torch.sum(decode_state_temp * mapped_encode_outputs, 2)
        else:
            decode_state_temp = decode_state.unsqueeze(1)
            energy = torch.sum(decode_state_temp * encode_outputs, 2)
        return energy


In [5]:
# coding = utf-8
import torch
import torch.nn as nn
#from src.NNModule.utils import batch_embedding_lookup
#from src.config import EPS


class GNN(nn.Module):
    """ A pytorch implementation of Message Passing Network """

    def __init__(self,
                 node_emb_size_list,
                 msg_agg):
        super(GNN, self).__init__()

        self.mp_iters = len(node_emb_size_list) - 1
        self.linear_cells = nn.ModuleList([nn.Linear(in_features, out_features) for in_features, out_features in
                                           zip(node_emb_size_list[:-1], node_emb_size_list[1:])])
        self.msg_agg = msg_agg

    def embed_edge(self, node_embedding, edges, iter_idx):
        """
        Compute embedding of a edge (sender_node, relation_label).
        :param node_embedding: (batch_size, max_node_num, feature_size)
        :param edges: each edge is a tuple of (sender_node, relation_label, receiver_node).
        (batch_size, max_edge_num, 3)
        :param iter_idx: the iter_idx-th iteration of graph
        :return: edge_embeds: (batch_size, max_edge_num, feature_size)
        """
        sender_embeds = batch_embedding_lookup(node_embedding, edges[:, :, 0])
        edge_embeds = torch.tanh(self.linear_cells[iter_idx](sender_embeds))
        return edge_embeds

    def pass_message(self, node_edges, node_edge_mask, edge_embeds):
        """
        Compute new node embedding by summing edge embeds (message) of neighboring nodes.
        :param node_edges: ids of neighboring edges of each node where id is row index in edge_embeds
        (batch_size, max_node_num, max_node_edge_num)
        :param node_edge_mask: mask for node_edges. (batch_size, max_node_num, max_node_edge_num)
        :param edge_embeds: (batch_size, max_edge_num, feature_size)
        :return:
        """
        num_neighbors = torch.sum(node_edge_mask.to(torch.float32), 2, keepdim=True) + EPS

        shape = node_edges.shape
        batch_size = shape[0]
        node_num = shape[1]
        edge_embed_size = edge_embeds.shape[-1]

        # Gather neighboring edge embeddings
        neighbors = torch.reshape(node_edges, (batch_size, -1))  # (batch_size, max_node_num * max_node_edge_num)
        embeds = batch_embedding_lookup(edge_embeds,
                                        neighbors)  # (batch_size, max_node_num * max_node_edge_num, feature_size)
        embeds = torch.reshape(embeds, (batch_size, node_num, -1, edge_embed_size))
        mask = torch.unsqueeze(node_edge_mask, 3)  # (batch_size, max_node_num, max_node_edge_num, 1)
        embeds = embeds * mask.to(embeds.dtype)

        # (batch_size, max_node_num, feature_size)
        if self.msg_agg == 'sum':
            new_node_embed = torch.sum(embeds, 2)
        elif self.msg_agg == 'avg':
            new_node_embed = torch.sum(embeds, 2) / num_neighbors
        elif self.msg_agg == 'max':
            new_node_embed, _ = torch.max(embeds, 2)
        else:
            raise ValueError('Unknown message aggregation method')

        return new_node_embed

    def mp(self, curr_node_embedding, edges, iter_idx, node_edges, node_edge_mask):
        edge_embeds = self.embed_edge(curr_node_embedding, edges, iter_idx)
        new_node_embed = self.pass_message(node_edges, node_edge_mask, edge_embeds)
        return new_node_embed

    def forward(self, initial_node_embed, edges, node_edges, node_edge_mask):
        """
        :param initial_node_embed: (batch_size, max_node_num, feature_size)
        :param edges: (batch_size, max_edge_num, 3)
        :param node_edges: (batch_size, max_node_num, max_node_edge_num)
        :param node_edge_mask: (batch_size, max_node_num, max_node_edge_num)
        :return: final_node_embed: (batch_size, max_node_num, feature_size)
        """
        node_embed_list = [initial_node_embed]
        for iter_idx in range(self.mp_iters):
            node_embed_list.append(self.mp(node_embed_list[-1], edges, iter_idx, node_edges, node_edge_mask))
        final_node_embed = torch.cat(node_embed_list, 2)
        return final_node_embed


class WorkersStateTracker(nn.Module):
    """ Concat personal nodes embedding and hand-crafted features to get the workers dialogue state """

    def __init__(self):
        super(WorkersStateTracker, self).__init__()

    def forward(self, known_one_hot, unknown_one_hot, known_differ_one_hot, workers_qa_turn_one_hot,
                workers_max_qa_turn_one_hot, personal_nodes, final_node_embed):
        """
        :param known_one_hot: (batch_size, personal_node_num, feature_size)
        :param unknown_one_hot: (batch_size, personal_node_num, feature_size)
        :param known_differ_one_hot: (batch_size, personal_node_num, feature_size)
        :param workers_qa_turn_one_hot: (batch_size, personal_node_num, feature_size)
        :param workers_max_qa_turn_one_hot: (batch_size, personal_node_num, feature_size)
        :param personal_nodes: (batch_size, personal_node_num)
        :param final_node_embed: (batch_size, feature_size)
        :return: workers_state: (batch_size, personal_node_num, feature_size)
        """
        personal_node_embed = batch_embedding_lookup(final_node_embed, personal_nodes)

        workers_state = torch.cat((known_one_hot,
                                   unknown_one_hot,
                                   known_differ_one_hot,
                                   workers_qa_turn_one_hot,
                                   workers_max_qa_turn_one_hot,
                                   personal_node_embed),
                                  2)
        return workers_state


class ManagerStateTracker(nn.Module):
    """ Aggregate personal nodes embedding and hand-crafted features to get the manager dialogue state """

    def __init__(self, personal_node_emb_size, manager_agg_size, msg_agg):
        super(ManagerStateTracker, self).__init__()

        self.msg_agg = msg_agg
        self.transfer = nn.Sequential(nn.Linear(personal_node_emb_size, manager_agg_size),
                                      nn.Tanh())

    def forward(self, feasible_personal_info_nodes, workers_decision,
                known_one_hot, unknown_one_hot, known_differ_one_hot,
                total_qa_turn_one_hot, personal_nodes, final_node_embed):
        """
        :param feasible_personal_info_nodes: (batch_size, personal_node_num)
        :param workers_decision: (batch_size, personal_node_num, 2)
        :param known_one_hot: (batch_size, personal_node_num, feature_size)
        :param unknown_one_hot: (batch_size, personal_node_num, feature_size)
        :param known_differ_one_hot: (batch_size, personal_node_num, feature_size)
        :param total_qa_turn_one_hot: (batch_size, feature_size)
        :param personal_nodes: (batch_size, personal_node_num)
        :param final_node_embed: (batch_size, max_node_num, feature_size)
        :return: manager_state: (batch_size, feature_size)
        """
        batch_size = personal_nodes.shape[0]
        personal_node_num = personal_nodes.shape[1]

        personal_node_embed = batch_embedding_lookup(final_node_embed, personal_nodes)
        transferred_personal_node_embed = self.transfer(personal_node_embed)

        if self.msg_agg == 'sum':
            agg_state = torch.sum(transferred_personal_node_embed, 1)
        elif self.msg_agg == 'avg':
            agg_state = torch.sum(transferred_personal_node_embed, 1) / personal_node_num
        elif self.msg_agg == 'max':
            agg_state, _ = torch.max(transferred_personal_node_embed, 1)
        else:
            raise ValueError('Unknown message aggregation method')

        manager_state = torch.cat((feasible_personal_info_nodes,
                                   workers_decision.reshape(batch_size, -1),
                                   known_one_hot.reshape(batch_size, -1),
                                   unknown_one_hot.reshape(batch_size, -1),
                                   known_differ_one_hot.reshape(batch_size, -1),
                                   total_qa_turn_one_hot,
                                   agg_state), 1)
        return manager_state


## **Constants**

In [6]:
DATA_ROOT = r'C:\Users\Mohamed Taha\Desktop\FinalWork\data'

# coding = utf-8
import torch
import os

#DATA_ROOT = os.path.abspath(r'C:\Users\M.user\Desktop\translate\translated')

# for data set split
train_size = 706
test_size = 100
dev_size = 100
# for pad
# NodePad depends on the valid num of nodes
RelationPad = 0
EdgePad = (0, 0, 0)

PLACE_HOLDER = None

# for manager and worker terminal actions, to keep in accordance to other actions
# we insert [FRAUD, NON_FRAUD] in the tail of action space
FRAUD = -2
NON_FRAUD = -1

# simulate user know or do not know about a triple
Known = 1
UnKnown = -1
NotClear = 0
ShowUnknown = "Unknown"
UnKnownUtterance = "我不清楚"

NegativeSampledAnswerNum = 2
Options = ["A", "B", "C", "D"]

# for user type sample
User_Type_Weights = {"Type-4 Fraud": 1,
                     "Type-3 Fraud": 1,
                     "Type-2 Fraud": 1,
                     "Type-1 Fraud": 1,
                     "Non-Fraud": 4}
Personal_Information_Fraud_Weights = {"company": 2, "university": 2, "live_in": 1, "born_in": 1}

EPS = 1e-6

# Turn exploring time step constrain
# Note: the exploring time steps are not the same as dialogue turns,
# punish exploring time steps will punish dialogue turns, not vice versa
MaxExploringTimeStep = 40
MaxWorkerExploringTimeStep = 10
MinFlattenRLQATurn = 8
MaxFlattenRuleQATurn = 10

# For rule based warm ups
MinDifference = 3
MinWorkerQATurn = 3

# Rewards
WorkerBonus = 0
ManagerBonus = 0
# give bonus in hrl model setting, if not, the rl will collapse, just a trick!
# WorkerBonus = 2
# ManagerBonus = 1.5
ManagerRecognitionCorrect = 3
ManagerRecognitionWrong = -3
WorkerRecognitionCorrect = 1
WorkerRecognitionWrong = -1
ExploringPunish = -0.1
WorkerRewardDiscount = 0.99
ManagerRewardDiscount = 0.999

# config for torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# for node feature
PERSONAL_NODE_FILED = 4
ONE_STEP_NODE_FILED = 1
SE_FREQS_FILED = 10
DEGREE_FILED = 10
STATIC_FEATURE_SIZE = PERSONAL_NODE_FILED + ONE_STEP_NODE_FILED + SE_FREQS_FILED + DEGREE_FILED
DYNAMIC_FEATURE_SIZE = 7
Init_Node_Feature_Size = STATIC_FEATURE_SIZE + DYNAMIC_FEATURE_SIZE
ManagerStateRest = 4 + 2 * 4 + (MaxWorkerExploringTimeStep + 1) * 3 * 4 + (MaxExploringTimeStep + 1)
WorkersStateRest = (MaxWorkerExploringTimeStep + 1) * 3 + (MaxWorkerExploringTimeStep + 1) + (
        MaxWorkerExploringTimeStep + 1)

# for flatten rl, this node idx is the padded query node
Pad_Query_Node = 1000


# State Tracker:

## Node *Feature*

In [7]:
# coding = utf-8
from itertools import chain
import os
import json
import math
import numpy as np
#from src.config import DATA_ROOT, DEGREE_FILED, SE_FREQS_FILED, DYNAMIC_FEATURE_SIZE

with open(os.path.join(DATA_ROOT, "se_freqs_bins.json"), "r") as f:
    se_freqs_bins = json.load(f)


class NodeFeature(object):
    def __init__(self):
        self.se_freqs_bins = se_freqs_bins

    def static_feature(self,
                       nodes,
                       personal_information,
                       one_step_nodes,
                       node_se_freqs,
                       node_degree,
                       node_in_degree,
                       node_out_degree):
        """  Get feature before the dialogue.  """
        answer_nodes = list(chain(*one_step_nodes.values()))
        static_feature_matrix = list()
        for node, se_freqs, degree, in_degree, out_degree in zip(nodes,
                                                                 node_se_freqs,
                                                                 node_degree,
                                                                 node_in_degree,
                                                                 node_out_degree):
            vector = list()

            # the personal information type of this node
            for value in personal_information.values():
                if node == value:
                    vector.append(1)
                else:
                    vector.append(0)

            # if the node is the answer node
            if node in answer_nodes:
                vector.append(1)
            else:
                vector.append(0)

            for i, max_se_freqs in enumerate(self.se_freqs_bins):
                if math.log(se_freqs + 1) <= max_se_freqs:
                    vector.extend(self.one_hot(i, SE_FREQS_FILED))
                    break

            vector.extend(self.one_hot(degree, DEGREE_FILED))
            # vector.extend(self.one_hot(in_degree, IN_DEGREE_FILED))
            # vector.extend(self.one_hot(out_degree, OUT_DEGREE_FILED))
            static_feature_matrix.append(vector)
        return static_feature_matrix

    @staticmethod
    def one_hot(idx, max_length): 
        v = [0 for _ in range(max_length)]
        if idx > max_length - 1:
            v[-1] = 1
        else:
            v[idx] = 1
        return v

    @staticmethod
    def dialogue_feature(max_node_num,
                         nodes,
                         explored_nodes,
                         last_turn_q_node,
                         last_turn_a_node,
                         not_explored_nodes,
                         known_nodes,
                         unknown_nodes,
                         not_answered_nodes):
        """  Get feature during the dialogue. Support batch.  """
        dialogue_feature_matrix = np.zeros((max_node_num, DYNAMIC_FEATURE_SIZE))

        for node in nodes:
            vector = list()
            if node in explored_nodes:
                vector.append(1)
            else:
                vector.append(0)

            if node == last_turn_q_node:
                vector.append(1)
            else:
                vector.append(0)

            if node == last_turn_a_node:
                vector.append(1)
            else:
                vector.append(0)

            if node in not_explored_nodes:
                vector.append(1)
            else:
                vector.append(0)

            if node in known_nodes:
                vector.append(1)
            else:
                vector.append(0)

            if node in unknown_nodes:
                vector.append(1)
            else:
                vector.append(0)

            if node in not_answered_nodes:
                vector.append(1)
            else:
                vector.append(0)

            dialogue_feature_matrix[node] = np.asarray(vector, dtype=np.float32)

        return dialogue_feature_matrix


## State Tracker:

In [8]:
def to_one_hot(vec, size):
    """
    :param vec: any shape vector
    :param size: the one hot size
    :return: one hot vector of tensor
    """
    shape = vec.shape
    vec_flat = vec.reshape(-1)

    one_hot_vec = np.zeros((vec_flat.shape[0], size))
    #try:
    one_hot_vec[np.arange(vec_flat.shape[0]), vec_flat] = 1
    #except Exception as e:
      #print(vec_flat.shape[0], vec_flat)
      #print('one hot error', e)
    one_hot_vec = one_hot_vec.reshape((*shape, size))
    return one_hot_vec

In [9]:
# coding = utf-8
import copy
import numpy as np
#from src.config import NON_FRAUD, FRAUD, Known, UnKnown, PLACE_HOLDER, MaxExploringTimeStep, \
    #MaxWorkerExploringTimeStep, ManagerRecognitionCorrect, ManagerRecognitionWrong, WorkerRecognitionCorrect, \
    #WorkerRecognitionWrong, ExploringPunish, MinDifference, MinWorkerQATurn, WorkerBonus, ManagerBonus, \
    #MaxFlattenRuleQATurn, MinFlattenRLQATurn
#from src.Graph.node_feature import NodeFeature


class StateTracker(object):
    def __init__(self, state_tracker_field):
        self.init_episode(state_tracker_field)

    def generate_recent_dialogue_feature(self):
        max_node_num = self.state_tracker_field[-1]["max_node_num"]
        nodes = self.state_tracker_field[-1]["nodes"]
        explored_nodes = self.state_tracker_field[-1]["explored_nodes"]
        last_turn_q_node = self.state_tracker_field[-1]["last_turn_q_node"]
        last_turn_a_node = self.state_tracker_field[-1]["last_turn_a_node"]
        not_explored_nodes = self.state_tracker_field[-1]["not_explored_nodes"]
        known_nodes = self.state_tracker_field[-1]["known_nodes"]
        unknown_nodes = self.state_tracker_field[-1]["unknown_nodes"]
        not_answered_nodes = self.state_tracker_field[-1]["not_answered_nodes"]
        self.state_tracker_field[-1]["dialogue_feature"] = self.node_feat_generator.dialogue_feature(max_node_num,
                                                                                                     nodes,
                                                                                                     explored_nodes,
                                                                                                     last_turn_q_node,
                                                                                                     last_turn_a_node,
                                                                                                     not_explored_nodes,
                                                                                                     known_nodes,
                                                                                                     unknown_nodes,
                                                                                                     not_answered_nodes)

    def init_episode(self, state_tracker_field):
        self.state_tracker_field = state_tracker_field
        self.node_feat_generator = NodeFeature()
        self.generate_recent_dialogue_feature()



   
              

class StateTrackerHRL(StateTracker):
    def __init__(self, state_tracker_field):
        super(StateTrackerHRL, self).__init__(state_tracker_field)
    

    # NEW FN:
    def generate_system_action(self,
                            manager_sample_idx,
                            manager_sample_content,
                            workers_sample_idx,
                            workers_sample_content,
                            #user_identity_state,
                            #user_sub_identity_state_dict
                            ):
      policy_mask = self.state_tracker_field[-1]["policy_mask"]
      policy_idx = policy_mask.argmax()

      if policy_idx == 0:
          system_action = {"manager": {"sample_idx": manager_sample_idx,
                                        "sample_content": manager_sample_content}}

          if manager_sample_content in [FRAUD, NON_FRAUD]:
            self.state_tracker_field[-1]["manager_decision"] = manager_sample_content


      else:
        worker_idx = policy_idx - 1

        system_action = {"worker": {"sample_idx": workers_sample_idx[worker_idx],
                                    "sample_content": workers_sample_content[worker_idx]}}

        if workers_sample_content[worker_idx] in [FRAUD, NON_FRAUD]:
          self.state_tracker_field[-1]["workers_decision"][worker_idx] = workers_sample_content[worker_idx]
      
      self.state_tracker_field[-1]["system_action"] = system_action
      
      return system_action



  

    def generate_workers_reward(self):
        # call this function after the worker choose a terminal action
        policy_mask = self.state_tracker_field[-1]["policy_mask"]
        worker_idx = policy_mask.argmax() - 1

        worker_decision = self.state_tracker_field[-1]["workers_decision"][worker_idx]
        worker_success_state = self.state_tracker_field[-1]["workers_success_state"][worker_idx]

        if worker_success_state is True:
            if worker_decision == FRAUD:
                self.state_tracker_field[-1]["workers_reward"] += (WorkerRecognitionCorrect + WorkerBonus)
            else:
                self.state_tracker_field[-1]["workers_reward"] += WorkerRecognitionCorrect
        else:
            if worker_decision != FRAUD:
                self.state_tracker_field[-1]["workers_reward"] += (WorkerRecognitionWrong - WorkerBonus)
            else:
                self.state_tracker_field[-1]["workers_reward"] += WorkerRecognitionWrong

        # rollback to give turn punishment to each worker action
        for i in range(-2, -(len(self.state_tracker_field) + 1), -1):
            if self.state_tracker_field[i]["policy_mask"].argmax() == 0:
                break
            self.state_tracker_field[i]["workers_reward"] += ExploringPunish

    def _worker_detect_a_fraud(self):
        flag = False
        for worker_decision, worker_success_state in zip(self.state_tracker_field[-1]["workers_decision"],
                                                         self.state_tracker_field[-1]["workers_success_state"]):
            if worker_decision == FRAUD and worker_success_state is True:
                flag = True
                break
        return flag

    def _worker_detect_all_non_fraud(self):
        flag = True
        for worker_decision, worker_success_state in zip(self.state_tracker_field[-1]["workers_decision"],
                                                         self.state_tracker_field[-1]["workers_success_state"]):
            if not ((worker_decision == PLACE_HOLDER and worker_success_state == PLACE_HOLDER) or (
                    worker_decision == NON_FRAUD and worker_success_state is True)):
                flag = False
                break
        return flag

    def generate_manager_reward(self):
        # call this function after the manager execute an action
        manager_decision = self.state_tracker_field[-1]["manager_decision"]
        manager_success_state = self.state_tracker_field[-1]["manager_success_state"]

        if manager_decision in [FRAUD, NON_FRAUD]:
            if manager_success_state is True:
                if not (self._worker_detect_a_fraud() or self._worker_detect_all_non_fraud()):
                    self.state_tracker_field[-1]["manager_reward"] += (ManagerRecognitionCorrect + ManagerBonus)
                else:
                    self.state_tracker_field[-1]["manager_reward"] += ManagerRecognitionCorrect
            else:
                if self._worker_detect_a_fraud() or self._worker_detect_all_non_fraud():
                    self.state_tracker_field[-1]["manager_reward"] += (ManagerRecognitionWrong - ManagerBonus)
                else:
                    self.state_tracker_field[-1]["manager_reward"] += ManagerRecognitionWrong

        # rollback to give turn punishment to the last manager action
        worker_exploring_time = self.state_tracker_field[-1]["current_worker_exploring_turn"]

        if len(self.state_tracker_field) > 1:
            for i in range(-2, -(len(self.state_tracker_field) + 1), -1):
                if self.state_tracker_field[i]["policy_mask"].argmax() == 0:
                    self.state_tracker_field[i]["manager_reward"] += worker_exploring_time * ExploringPunish
                    break

    def move_a_step(self,
                    language_generator,
                    manager_sample_idx,
                    manager_sample_content,
                    manager_action_prob,
                    workers_sample_idx,
                    workers_sample_content,
                    workers_action_prob,
                    mode):
        """
        In current state S_{t}, the system execute the system action.
        And we get the new dialogue state S_{t+1}.
        Then, we generate all mask for the new state S_{t+1} based on dialogue context.
        :param language_generator: LanguageGenerator Class
        :param user: UserSimulator Class
        :param dialogue_recorder: DialogueRecorder Class or None
        :param manager_sample_idx:
        :param manager_sample_content:
        :param manager_action_prob:
        :param workers_sample_idx:
        :param workers_sample_content:
        :param workers_action_prob:
        :param mode: indicate rule based warm up or RL
        :return:
        """
        # just for debug
        self.state_tracker_field[-1]["manager_action_prob"] = manager_action_prob
        self.state_tracker_field[-1]["workers_action_prob"] = workers_action_prob

        # pad to the longest episode
        self.state_tracker_field[-1]["manager_sample_idx"] = manager_sample_idx
        self.state_tracker_field[-1]["workers_sample_idx"] = workers_sample_idx

        if self.state_tracker_field[-1]["episode_not_end"] is False:
            print('self.state_tracker_field[-1]["episode_not_end"] is False')
            # pad to the longest episode
            self.state_tracker_field.append(copy.deepcopy(self.state_tracker_field[-1]))
        else:
            # generate the system action in the current valid dialogue state first  #### CHANGE
            system_action = self.generate_system_action(manager_sample_idx, manager_sample_content,
                                                        workers_sample_idx, workers_sample_content)
            #print(system_action)
            # and then inherit information of it
            state_in_new_step = copy.deepcopy(self.state_tracker_field[-1])

            state_in_new_step["total_exploring_turn"] += 1
            if state_in_new_step["total_exploring_turn"] >= MaxExploringTimeStep:
                # terminal hrl recursively in the next step by force
                state_in_new_step["rl_manager_action_mask"][:-2] = 0
                state_in_new_step["rl_manager_action_mask"][-2:] = 1
                state_in_new_step["rl_workers_action_mask"][:, :-2] = 0
                state_in_new_step["rl_workers_action_mask"][:, -2:] = 1

            # generate question
            q_node, a_node = language_generator.generate_question(system_action)
            questions, choices  =  new_answer(q_node, a_node, language_generator)
            
            
        return questions, choices , state_in_new_step , system_action , language_generator , a_node , q_node

    def continue_a_step(self,
#                     language_generator,
#                     manager_sample_idx,
#                     manager_sample_content,
#                     manager_action_prob,
#                     workers_sample_idx,
#                     workers_sample_content,
#                     workers_action_prob,
#                     mode,
                         user_answer,state_in_new_step,system_action,language_generator,choices, a_node,q_node):
      
#             add node2idx to graph a7sn
        if (user_answer is not None ) and (choices is not None):

            idx = list(language_generator.language_generation_filed['idx2node'].keys())
            nodes = list(language_generator.language_generation_filed['idx2node'].values())

            for sym, answer in choices:
                if sym == user_answer:
                    try:
                        pos = nodes.index(answer)
                        user_answer = int(idx[pos])

                    except:
                        user_answer = -1
        if self.state_tracker_field[-1]["episode_not_end"] is True:            

            if user_answer is not None:

              state_in_new_step["total_qa_turn"] += 1

            state_in_new_step["dialogue_feature"] = PLACE_HOLDER
            state_in_new_step["manager_sample_idx"] = PLACE_HOLDER
            state_in_new_step["workers_sample_idx"] = PLACE_HOLDER
            state_in_new_step["system_action"] = PLACE_HOLDER
            state_in_new_step["reward_mask"] = PLACE_HOLDER
            state_in_new_step["manager_reward"] = np.zeros((1,), dtype=np.float32)
            state_in_new_step["workers_reward"] = np.zeros((len(state_in_new_step["personal_nodes"]),),
                                                          dtype=np.float32)
              
            manager_action = system_action.get("manager", {"sample_idx": None, "sample_content": None})
            worker_action = system_action.get("worker", {"sample_idx": None, "sample_content": None})


            manager_action_idx = manager_action["sample_idx"]
            manager_action_content = manager_action["sample_content"]
            worker_action_idx = worker_action["sample_idx"]
            worker_action_content = worker_action["sample_content"]


            if manager_action_content in [FRAUD, NON_FRAUD]:
              state_in_new_step["episode_not_end"] = False
              #print('ended' * 10)
              if manager_action_content == FRAUD:
                print('FRAUD !!!!!')
                return 'fraud'
                
              else:
                print('NON-Fraud')
                return 'nonfraud'
            
            elif worker_action_idx is None:
              state_in_new_step["valid_workers_num"] += 1

              state_in_new_step["current_worker_exploring_turn"] = 0

              # generate for dialogue feature
              state_in_new_step["explored_nodes"].add(manager_action_content)
              state_in_new_step["last_turn_q_node"] = None
              state_in_new_step["last_turn_a_node"] = None
              if manager_action_content in state_in_new_step["not_explored_nodes"]:
                  state_in_new_step["not_explored_nodes"].remove(manager_action_content)

              state_in_new_step["rl_manager_action_mask"][manager_action_idx] = 0

              policy_mask = [0] + [0 for _ in state_in_new_step["personal_nodes"]]
              worker_idx = state_in_new_step["personal_nodes"].index(manager_action_content)
              policy_mask[worker_idx + 1] = 1
              state_in_new_step["policy_mask"] = np.asarray(policy_mask, dtype=np.int32)
              #print('right' * 10)

            elif user_answer is None:

                if worker_action_content == FRAUD or state_in_new_step["rl_manager_action_mask"].sum() == 0:
                    state_in_new_step["rl_manager_action_mask"][-2:] = 1

                state_in_new_step["last_turn_q_node"] = None
                state_in_new_step["last_turn_a_node"] = None

                # get new policy mask
                policy_mask = [1] + [0 for _ in state_in_new_step["personal_nodes"]]
                state_in_new_step["policy_mask"] = np.asarray(policy_mask, dtype=np.int32)
                #print('wrong 1' * 15)

            else:
                """  low level explore (Worker)  """
                state_in_new_step["current_worker_exploring_turn"] += 1

                worker_idx = state_in_new_step["policy_mask"].argmax() - 1
                state_in_new_step["workers_qa_turn"][worker_idx] += 1

                # generate for dialogue feature
                state_in_new_step["explored_nodes"].add(worker_action_content)
                if worker_action_content in state_in_new_step["not_explored_nodes"]:
                    state_in_new_step["not_explored_nodes"].remove(worker_action_content)
                if worker_action_content in state_in_new_step["not_answered_nodes"]:
                    state_in_new_step["not_answered_nodes"].remove(worker_action_content)
                if user_answer == a_node:
                    print('right answer!!! ' ,user_answer)
                    state_in_new_step["known_nodes"].add(worker_action_content)
                    state_in_new_step["workers_counter"][worker_idx]["Known"] += 1
                elif user_answer != a_node:
                    print('wrong answer!!! ' ,user_answer )
                    state_in_new_step["unknown_nodes"].add(worker_action_content)
                    state_in_new_step["workers_counter"][worker_idx]["UnKnown"] += 1
                
                state_in_new_step["last_turn_q_node"] = q_node
                state_in_new_step["last_turn_a_node"] = a_node
                #print('wrong 2' *15)
                state_in_new_step["rl_workers_action_mask"][worker_idx, worker_action_idx] = 0
                if state_in_new_step["rl_workers_action_mask"][worker_idx].sum() == 0 or \
                        state_in_new_step["workers_qa_turn"][worker_idx] >= MinWorkerQATurn:
                    state_in_new_step["rl_workers_action_mask"][worker_idx, -2:] = 1
                if state_in_new_step["current_worker_exploring_turn"] >= MaxWorkerExploringTimeStep:
                    # terminal current worker in the next step by force
                    state_in_new_step["rl_workers_action_mask"][worker_idx, :-2] = 0


            self.state_tracker_field.append(state_in_new_step)
            self.generate_recent_dialogue_feature()
            return 'continue'




# Build graph embed inputs:

In [10]:
def build_graph_embed_inputs(graph_embed_field, state_tracker_field, rollout=True):


  edges = graph_embed_field["edges"]
  node_edges = graph_embed_field["node_edges"]
  node_edge_mask = graph_embed_field["node_edge_mask"]
  static_feature = graph_embed_field["static_feature"]
  dialogue_feature = np.concatenate([[field[-1]["dialogue_feature"]] for field in state_tracker_field])

  initial_node_embed = np.concatenate([static_feature, dialogue_feature], axis=-1)
  feed_dict = dict()
  feed_dict["initial_node_embed"] = torch.Tensor(initial_node_embed).to(device=device, dtype=torch.float)
  feed_dict["edges"] = torch.Tensor(edges).to(device=device, dtype=torch.long)
  feed_dict["node_edges"] = torch.Tensor(node_edges).to(device=device, dtype=torch.long)
  feed_dict["node_edge_mask"] = torch.Tensor(node_edge_mask).to(device=device, dtype=torch.uint8)
  return feed_dict


# Build manager state tracker inputs:


In [11]:
def build_graph_embed_inputs(graph_embed_field, state_tracker_field, rollout=True):


  edges = graph_embed_field["edges"]
  node_edges = graph_embed_field["node_edges"]
  node_edge_mask = graph_embed_field["node_edge_mask"]
  static_feature = graph_embed_field["static_feature"]
  dialogue_feature = np.concatenate([[field[-1]["dialogue_feature"]] for field in state_tracker_field])

  initial_node_embed = np.concatenate([static_feature, dialogue_feature], axis=-1)
  feed_dict = dict()
  feed_dict["initial_node_embed"] = torch.Tensor(initial_node_embed).to(device=device, dtype=torch.float)
  feed_dict["edges"] = torch.Tensor(edges).to(device=device, dtype=torch.long)
  feed_dict["node_edges"] = torch.Tensor(node_edges).to(device=device, dtype=torch.long)
  feed_dict["node_edge_mask"] = torch.Tensor(node_edge_mask).to(device=device, dtype=torch.uint8)
  return feed_dict

# Build manager state tracker inputs:


In [12]:
def build_manager_state_tracker_inputs(graph_embed_field, final_node_embed, workers_decision, state_tracker_field,
                                       rollout=True):
  

  feasible_personal_info_nodes = graph_embed_field["feasible_personal_info_nodes"]
  personal_nodes = graph_embed_field["personal_nodes"]
  known = np.concatenate(
            [[np.concatenate([[item["Known"]] for item in field[-1]["workers_counter"]])] for field in
             state_tracker_field])
  unknown = np.concatenate(
            [[np.concatenate([[item["UnKnown"]] for item in field[-1]["workers_counter"]])] for field in
             state_tracker_field])
  known_differ = np.concatenate(
            [[np.concatenate(
                [[(item["Known"] - item["UnKnown"]) if (item["Known"] - item["UnKnown"]) > 0 else 0] for item in
                 field[-1]["workers_counter"]])] for field in state_tracker_field])
  total_qa_turn = np.concatenate([[field[-1]["total_qa_turn"]] for field in state_tracker_field])

  known_one_hot = to_one_hot(known, size=(MaxWorkerExploringTimeStep + 1))
  unknown_one_hot = to_one_hot(unknown, size=(MaxWorkerExploringTimeStep + 1))
  known_differ_one_hot = to_one_hot(known_differ, size=(MaxWorkerExploringTimeStep + 1))
  total_qa_turn_one_hot = to_one_hot(total_qa_turn, size=(MaxExploringTimeStep + 1))

  feed_dict = dict()
  feed_dict["feasible_personal_info_nodes"] = torch.Tensor(feasible_personal_info_nodes).to(device=device,
                                                                                            dtype=torch.float)
  feed_dict["workers_decision"] = workers_decision
  feed_dict["known_one_hot"] = torch.Tensor(known_one_hot).to(device=device, dtype=torch.float)
  feed_dict["unknown_one_hot"] = torch.Tensor(unknown_one_hot).to(device=device, dtype=torch.float)
  feed_dict["known_differ_one_hot"] = torch.Tensor(known_differ_one_hot).to(device, dtype=torch.float)
  feed_dict["total_qa_turn_one_hot"] = torch.Tensor(total_qa_turn_one_hot).to(device, dtype=torch.float)
  feed_dict["personal_nodes"] = torch.Tensor(personal_nodes).to(device=device, dtype=torch.long)
  feed_dict["final_node_embed"] = final_node_embed
  return feed_dict


# Build worker state tracker inputs:


In [13]:
def build_workers_state_tracker_inputs(state_tracker_field, policy_field, final_node_embed, rollout=True):

  personal_nodes = policy_field["manager_actions"]
  workers_max_qa_turn = policy_field["workers_max_qa_turn"]
  known = np.concatenate(
          [[np.concatenate([[item["Known"]] for item in field[-1]["workers_counter"]])] for field in
            state_tracker_field])
  unknown = np.concatenate(
          [[np.concatenate([[item["UnKnown"]] for item in field[-1]["workers_counter"]])] for field in
            state_tracker_field])
  known_differ = np.concatenate(
          [[np.concatenate(
              [[(item["Known"] - item["UnKnown"]) if (item["Known"] - item["UnKnown"]) > 0 else 0] for item in
                field[-1]["workers_counter"]])] for field in state_tracker_field])
  workers_qa_turn = np.concatenate([[field[-1]["workers_qa_turn"]] for field in state_tracker_field])

  known_one_hot = to_one_hot(known, size=(MaxWorkerExploringTimeStep + 1))
  unknown_one_hot = to_one_hot(unknown, size=(MaxWorkerExploringTimeStep + 1))
  known_differ_one_hot = to_one_hot(known_differ, size=(MaxWorkerExploringTimeStep + 1))
  workers_qa_turn_one_hot = to_one_hot(workers_qa_turn, size=(MaxWorkerExploringTimeStep + 1))
  workers_max_qa_turn_one_hot = to_one_hot(workers_max_qa_turn, size=(MaxWorkerExploringTimeStep + 1))

  feed_dict = dict()
  feed_dict["known_one_hot"] = torch.Tensor(known_one_hot).to(device=device, dtype=torch.float)
  feed_dict["unknown_one_hot"] = torch.Tensor(unknown_one_hot).to(device=device, dtype=torch.float)
  feed_dict["known_differ_one_hot"] = torch.Tensor(known_differ_one_hot).to(device, dtype=torch.float)
  feed_dict["workers_qa_turn_one_hot"] = torch.Tensor(workers_qa_turn_one_hot).to(device, dtype=torch.float)
  feed_dict["workers_max_qa_turn_one_hot"] = torch.Tensor(workers_max_qa_turn_one_hot).to(device, dtype=torch.float)
  feed_dict["personal_nodes"] = torch.Tensor(personal_nodes).to(device=device, dtype=torch.long)
  feed_dict["final_node_embed"] = final_node_embed
  return feed_dict
  


# Build hierarchy action masks


In [14]:
def build_hierarchy_action_masks(state_tracker_field, mode, rollout=True):

  rl_manager_action_mask = np.concatenate(
            [[field[-1]["rl_manager_action_mask"]] for field in state_tracker_field])
  rl_workers_action_mask = np.concatenate(
            [[field[-1]["rl_workers_action_mask"]] for field in state_tracker_field])
  if mode == "RuleWarmUp":
        warm_up_manager_action_mask = np.concatenate(
            [[field[-1]["warm_up_manager_action_mask"]] for field in state_tracker_field])
        warm_up_workers_action_mask = np.concatenate(
            [[field[-1]["warm_up_workers_action_mask"]] for field in state_tracker_field])
        
  feed_dict = dict()
  feed_dict["rl_manager_action_mask"] = torch.Tensor(rl_manager_action_mask).to(device=device, dtype=torch.uint8)
  feed_dict["rl_workers_action_mask"] = torch.Tensor(rl_workers_action_mask).to(device=device, dtype=torch.uint8)
  if mode == "RuleWarmUp":
      feed_dict["warm_up_manager_action_mask"] = torch.Tensor(warm_up_manager_action_mask).to(device=device,
                                                                                              dtype=torch.uint8)
      feed_dict["warm_up_workers_action_mask"] = torch.Tensor(warm_up_workers_action_mask).to(device=device,
                                                                                              dtype=torch.uint8)

  return feed_dict


# build_hierarchy_manager_inputs 

In [15]:
def build_hierarchy_manager_inputs(manager_state, workers_state, policy_field, rollout=True):

  personal_nodes = policy_field["manager_actions"]
  feed_dict = dict()
  feed_dict["manager_state"] = manager_state
  feed_dict["workers_state"] = workers_state
  feed_dict["personal_nodes"] = torch.Tensor(personal_nodes).to(device=device, dtype=torch.long)
  return feed_dict

# build_workers_inputs:

In [16]:
def build_workers_inputs(workers_state, policy_field, final_node_embed, rollout=True):

  answer_nodes = policy_field["workers_actions"]
  feed_dict = dict()
  feed_dict["workers_state"] = workers_state
  feed_dict["answer_nodes"] = torch.Tensor(answer_nodes).to(device=device, dtype=torch.long)
  feed_dict["graph_node_embedding"] = final_node_embed
  return feed_dict



# Models:

In [17]:
def build_init_models():
    #if self.args.new_node_emb_size_list is not None:
       # node_emb_size_list = [Init_Node_Feature_Size] + self.args.new_node_emb_size_list
    
    node_emb_size_list = [Init_Node_Feature_Size] + [40, 50 ]
    graph_node_emb_size = sum(node_emb_size_list)
    manager_state_size = 100 + ManagerStateRest
    models = dict()
    models["gnn"] = GNN(node_emb_size_list, 'max').to(device)
    worker_state_size = graph_node_emb_size + WorkersStateRest
    models["manager_state_tracker"] = ManagerStateTracker(graph_node_emb_size, 100,
                                                          'max').to(device)
    models["workers_state_tracker"] = WorkersStateTracker().to(device)
    models["manager"] = Manager('concat', manager_state_size, worker_state_size).to(device)
    models["workers"] = Workers('concat', worker_state_size, graph_node_emb_size).to(device)
    
    return models


In [18]:
path = r'C:\Users\Mohamed Taha\Desktop\FinalWork\data\checkpoints\ghrl\RL\epoch_201_success_0.85_turn_9.71953125'


In [19]:
models = build_init_models()
models

{'gnn': GNN(
   (linear_cells): ModuleList(
     (0): Linear(in_features=32, out_features=40, bias=True)
     (1): Linear(in_features=40, out_features=50, bias=True)
   )
 ),
 'manager_state_tracker': ManagerStateTracker(
   (transfer): Sequential(
     (0): Linear(in_features=122, out_features=100, bias=True)
     (1): Tanh()
   )
 ),
 'workers_state_tracker': WorkersStateTracker(),
 'manager': Manager(
   (policy_network): Attn(
     (attn): Sequential(
       (0): Linear(in_features=462, out_features=231, bias=True)
       (1): Tanh()
       (2): Linear(in_features=231, out_features=1, bias=True)
     )
   )
   (value_network): Sequential(
     (0): Linear(in_features=285, out_features=142, bias=True)
     (1): Tanh()
     (2): Linear(in_features=142, out_features=1, bias=True)
   )
 ),
 'workers': Workers(
   (policy_networks): Attn(
     (attn): Sequential(
       (0): Linear(in_features=299, out_features=149, bias=True)
       (1): Tanh()
       (2): Linear(in_features=149, out_f

In [20]:
for model_name in models.keys():
    model_path = os.path.join(path, model_name +  '.pkl')
    models[model_name].load_state_dict(torch.load(model_path,map_location=torch.device('cpu')))

# Answer:

In [21]:
def new_answer(q_node, a_node, language_generator):
    if q_node is None and a_node is None:
        print ('Decision step')
        return  None, None
    questions, choices = language_generator.generate_sys_nl(q_node, a_node)
    # print("plz answer the following question:")
    # print(a)
    # print(b)


    #ans = input('Enter Answer Symbol: (A-B-C-D): '  )

#     # add node2idx to graph a7sn
#     idx = list(language_generator.language_generation_filed['idx2node'].keys())
#     nodes = list(language_generator.language_generation_filed['idx2node'].values())

#     for sym, answer in choices:
#         if sym == ans:
#             try:
#                 pos = nodes.index(answer)
#                 aa_node = idx[pos]

#             except:
#                 aa_node = -1


    return questions, choices 

# Graph Preprocess

In [22]:
class GraphPreprocess(object):
    def __init__(self, data_set_dir="preprocessed_graphs", batch_size=32):
        # preprocess data set
        self.batch_size = batch_size
        self.train_set = self.split_batch(data_set_dir, "test.json", batch_size)
        self.test_set = self.split_batch(data_set_dir, "test.json", batch_size)
        self.dev_set = self.split_batch(data_set_dir, "test.json", batch_size)

    def split_batch(self, data_dir, data_file, batch_size):
        with open(os.path.join(DATA_ROOT, data_dir, data_file), "r", encoding = 'utf-8') as f:
            data_set = json.load(f)

        for _ in range(batch_size - len(data_set) % batch_size):
            data_set.append(copy.deepcopy(random.choice(data_set)))
        data_set.sort(key=lambda graph: len(graph["edges"]))

        split_data_set = list()
        for batch_idx in range(len(data_set) // batch_size):
            batch = list()
            for i in range(batch_size):
                batch.append(data_set[batch_idx * batch_size + i])
            split_data_set.append(self.preprocess_batch(batch))
        return split_data_set

    @staticmethod
    def get_graph_embed_filed(batch):
        """  Used in Graph Embed  """
        max_node_num = 0
        max_edge_num = 0
        max_node_edge_num = 0
        max_answer_node_num = 0
        max_qa_pair_num = 0
        for graph in batch:
            current_node_num = len(graph["nodes"])
            current_edge_num = len(graph["edges"])
            current_max_node_edge_num = max(map(len, graph["node_edges"]))
            current_max_answer_node_num = max(map(len, list(graph["one_step_nodes"].values())))
            current_qa_pair_num = sum(map(len, list(graph["one_step_nodes"].values())))

            if max_node_num < current_node_num:
                max_node_num = current_node_num
            if max_edge_num < current_edge_num:
                max_edge_num = current_edge_num
            if max_node_edge_num < current_max_node_edge_num:
                max_node_edge_num = current_max_node_edge_num
            if max_answer_node_num < current_max_answer_node_num:
                max_answer_node_num = current_max_answer_node_num
            if max_qa_pair_num < current_qa_pair_num:
                max_qa_pair_num = current_qa_pair_num

        # add 1 for node pad
        # the padded node idx is the current_node_num
        # but pad with range(current_node_num, max_node_num)
        max_node_num += 1

        # add 1 for edge pad
        # the padded edge idx is the current_edge_num
        # but pad with EdgePad
        max_edge_num += 1

        batch_node_num = list()
        batch_qa_pair_num = list()
        batch_personal_nodes = list()
        batch_feasible_personal_info_nodes = list()
        batch_static_feature = list()
        batch_edges = list()
        batch_node_edge_mask = list()
        batch_node_edges = list()

        for graph in batch:
            # **************** personal_nodes and node static feature ****************
            current_node_num = len(graph["nodes"])
            batch_node_num.append(current_node_num)
            current_qa_pair_num = sum(map(len, list(graph["one_step_nodes"].values())))
            batch_qa_pair_num.append(current_qa_pair_num)

            personal_nodes = np.asarray(graph["personal_nodes"], dtype=np.int32)
            batch_personal_nodes.append(personal_nodes)

            feasible_personal_info_nodes = np.zeros((len(graph["personal_nodes"]),), dtype=np.float32)
            for worker_idx, answer_nodes in graph["one_step_nodes"].items():
                if len(answer_nodes) > 0:
                    feasible_personal_info_nodes[graph["personal_nodes"].index(int(worker_idx))] = 1
            batch_feasible_personal_info_nodes.append(feasible_personal_info_nodes)

            original_static_feature = np.asarray(graph["static_feature"], dtype=np.float32)
            static_feature = np.zeros((max_node_num, STATIC_FEATURE_SIZE), dtype=np.float32)
            static_feature[:original_static_feature.shape[0]] = original_static_feature
            batch_static_feature.append(static_feature)

            # **************** edges ****************
            edges = np.zeros((max_edge_num, 3), dtype=np.int32)
            current_edge_num = len(graph["edges"])
            edges[:current_edge_num] = np.asarray(graph["edges"], dtype=np.int32)
            edges[current_edge_num:] = np.tile(np.asarray(EdgePad, dtype=np.int32),
                                               (max_edge_num - current_edge_num, 1))
            batch_edges.append(edges)

            # **************** pad node_edges and get mask ****************
            current_node_edge_num = map(len, graph["node_edges"])
            node_edge_mask = np.zeros((max_node_num, max_node_edge_num), dtype=np.int32)
            for i, num in enumerate(current_node_edge_num):
                node_edge_mask[i, :num] = 1
            batch_node_edge_mask.append(node_edge_mask)

            node_edges = np.full((max_node_num, max_node_edge_num), current_edge_num, dtype=np.int32)
            for i, x in enumerate(graph["node_edges"]):
                node_edges[i, :len(x)] = np.asarray(x, dtype=np.int32)
            batch_node_edges.append(node_edges)

        graph_embed_field = dict()
        graph_embed_field["personal_nodes"] = np.concatenate(
            [[personal_nodes] for personal_nodes in batch_personal_nodes])
        graph_embed_field["feasible_personal_info_nodes"] = np.concatenate(
            [[feasible_personal_info_nodes] for feasible_personal_info_nodes in batch_feasible_personal_info_nodes])
        graph_embed_field["static_feature"] = np.concatenate(
            [[static_feature] for static_feature in batch_static_feature])
        graph_embed_field["edges"] = np.concatenate([[edges] for edges in batch_edges])
        graph_embed_field["node_edge_mask"] = np.concatenate(
            [[node_edge_mask] for node_edge_mask in batch_node_edge_mask])
        graph_embed_field["node_edges"] = np.concatenate([[node_edges] for node_edges in batch_node_edges])

        return graph_embed_field, batch_node_num, max_node_num, max_answer_node_num, max_qa_pair_num, batch_qa_pair_num

    @staticmethod
    def get_knowledge_sampler_field(batch):
        """  Used in User Simulators  """
        batch_knowledge_sampler_filed = list()
        for graph in batch:
            personal_information = copy.deepcopy(graph["personal_information"])
            one_step_node_edges = copy.deepcopy(graph["one_step_node_edges"])
            adj_matrix = np.asarray(graph["adj_matrix"], dtype=bool)
            edge_se_freqs_matrix = np.asarray(graph["edge_se_freqs_matrix"], dtype=np.float32)
            edges = copy.deepcopy(graph["edges"])
            identity_dict_keys = list(graph["one_step_node_edges"].keys())
            idx2node = copy.deepcopy(graph["idx2node"])
            knowledge_sampler_filed = dict()
            knowledge_sampler_filed["personal_information"] = personal_information
            knowledge_sampler_filed["one_step_node_edges"] = one_step_node_edges
            knowledge_sampler_filed["adj_matrix"] = adj_matrix
            knowledge_sampler_filed["edge_se_freqs_matrix"] = edge_se_freqs_matrix
            knowledge_sampler_filed["edges"] = edges
            knowledge_sampler_filed["identity_dict_keys"] = identity_dict_keys
            knowledge_sampler_filed["idx2node"] = idx2node
            knowledge_sampler_filed["identity_dict_values"] = PLACE_HOLDER
            knowledge_sampler_filed["results"] = PLACE_HOLDER
            batch_knowledge_sampler_filed.append(knowledge_sampler_filed)
        return batch_knowledge_sampler_filed

    @staticmethod
    def get_language_generation_field(batch):
        batch_language_generation_filed = list()
        for graph in batch:
            language_generation_filed = dict()
            language_generation_filed["idx2node"] = copy.deepcopy(graph["idx2node"])
            h_t_to_r = dict()
            for edge in graph["edges"]:
                h = edge[2]
                t = edge[0]
                r = edge[1]
                key = str(h) + " " + str(t)
                value = str(r)
                h_t_to_r[key] = value
            language_generation_filed["h_t_to_r"] = h_t_to_r
            batch_language_generation_filed.append(language_generation_filed)
        return batch_language_generation_filed

    def generator(self, data_set_name, shuffle=True):
        if data_set_name == "train":
            data_set = self.train_set
        elif data_set_name == "dev":
            data_set = self.dev_set
        else:
            data_set = self.test_set

        size = len(data_set)
        data_set_idx = list(range(size))
        if shuffle:
            random.shuffle(data_set_idx)

        for idx in data_set_idx:
            yield {"graph_embed_field": data_set[idx]["graph_embed_field"],
                   "policy_field": data_set[idx]["policy_field"],
                   "knowledge_sampler_filed": copy.deepcopy(data_set[idx]["knowledge_sampler_filed"]),
                   "language_generation_filed": data_set[idx]["language_generation_filed"],
                   "state_tracker_field": copy.deepcopy(data_set[idx]["state_tracker_field"])}


class GraphPreprocessHRL(GraphPreprocess):
    def __init__(self, data_set_dir="preprocessed_graphs", batch_size=32):
        super(GraphPreprocessHRL, self).__init__(data_set_dir, batch_size)

    def preprocess_batch(self, batch):
        graph_embed_field, batch_node_num, max_node_num, max_answer_node_num, _, _ = self.get_graph_embed_filed(batch)
        policy_field = self.get_policy_field(batch, batch_node_num, max_answer_node_num)
        knowledge_sampler_filed = self.get_knowledge_sampler_field(batch)
        language_generation_filed = self.get_language_generation_field(batch)
        state_tracker_field = self.get_state_tracker_field(batch, max_node_num, max_answer_node_num)
        preprocessed_batch = {"graph_embed_field": graph_embed_field,
                              "policy_field": policy_field,
                              "knowledge_sampler_filed": knowledge_sampler_filed,
                              "language_generation_filed": language_generation_filed,
                              "state_tracker_field": state_tracker_field}
        return preprocessed_batch

    @staticmethod
    def get_state_tracker_field(batch, max_node_num, max_answer_node_num):
        """  A state tracker for each time step of a dialogue.
             It records current state,
             the mask information,
             the sample information of manager and workers,
             execute an action based on the state and get some reward  """
        batch_state_tracker_field = list()
        for graph in batch:
            # store state information of each time step in a list
            state_tracker_field = list()

            # this is the initial dialogue state
            # for each time step, we generate a same data structure and append it.
            state_in_initial_step = dict()

            # some constant
            # used in generate dialogue feature
            state_in_initial_step["max_node_num"] = max_node_num
            state_in_initial_step["nodes"] = copy.deepcopy(graph["nodes"])
            # used in move a new step
            state_in_initial_step["personal_nodes"] = copy.deepcopy(graph["personal_nodes"])

            # for all nodes
            state_in_initial_step["explored_nodes"] = set()
            state_in_initial_step["last_turn_q_node"] = None
            state_in_initial_step["last_turn_a_node"] = None
            state_in_initial_step["not_explored_nodes"] = set(copy.deepcopy(graph["nodes"]))

            # only for the answer nodes
            state_in_initial_step["known_nodes"] = set()
            state_in_initial_step["unknown_nodes"] = set()
            state_in_initial_step["not_answered_nodes"] = set(
                list(chain(*graph["one_step_nodes"].values())))

            # dialogue feature of current state, calculate before all NN operations
            state_in_initial_step["dialogue_feature"] = PLACE_HOLDER

            # episode end flag
            state_in_initial_step["episode_not_end"] = True

            # exploring turn counter
            state_in_initial_step["total_exploring_turn"] = 0
            state_in_initial_step["current_worker_exploring_turn"] = 0

            # policy mask, mark the running policy of current state
            policy_mask = [1] + [0 for _ in graph["personal_nodes"]]
            state_in_initial_step["policy_mask"] = np.asarray(policy_mask, dtype=np.int32)

            # manager action mask
            # create two counterparts, one for RL, one for RuleWarmUp
            manager_action_mask = [0 for _ in graph["personal_nodes"]]
            for worker_idx, answer_nodes in graph["one_step_nodes"].items():
                if len(answer_nodes) > 0:
                    manager_action_mask[graph["personal_nodes"].index(int(worker_idx))] = 1
            # append [0, 0] for two terminal actions
            manager_action_mask += [0, 0]
            state_in_initial_step["rl_manager_action_mask"] = np.asarray(manager_action_mask, dtype=np.int32)
            state_in_initial_step["warm_up_manager_action_mask"] = np.asarray(manager_action_mask, dtype=np.int32)

            # the workers and manager decision and if they are right
            # used to get the reward Bonus
            state_in_initial_step["workers_decision"] = [PLACE_HOLDER for _ in graph["personal_nodes"]]
            state_in_initial_step["workers_success_state"] = [PLACE_HOLDER for _ in graph["personal_nodes"]]
            state_in_initial_step["manager_decision"] = PLACE_HOLDER
            state_in_initial_step["manager_success_state"] = PLACE_HOLDER

            # worker action mask
            # similar to manager action mask, the terminal action is available only after giving a few questions
            # create two counterparts, one for RL, one for RuleWarmUp
            workers_action_mask = np.zeros((len(graph["personal_nodes"]), max_answer_node_num + 2), dtype=np.int32)
            answer_node_num = list()
            for worker in graph["personal_nodes"]:
                answer_node_num.append(len(graph["one_step_nodes"][str(worker)]))
            valid_worker_action_idx = list()
            row, col = workers_action_mask.shape
            for i, num in enumerate(answer_node_num):
                for idx in range(num):
                    valid_worker_action_idx.append(idx + i * col)
            workers_action_mask = workers_action_mask.reshape(-1)
            workers_action_mask[valid_worker_action_idx] = 1
            workers_action_mask = workers_action_mask.reshape((row, col))
            state_in_initial_step["rl_workers_action_mask"] = workers_action_mask
            state_in_initial_step["warm_up_workers_action_mask"] = workers_action_mask.copy()

            # the manager sample action idx, used for NN update
            state_in_initial_step["manager_sample_idx"] = PLACE_HOLDER

            # the workers sample action idx, used for NN update
            state_in_initial_step["workers_sample_idx"] = PLACE_HOLDER

            # the system action execute in current state
            state_in_initial_step["system_action"] = PLACE_HOLDER

            # reward recorder, include reward mask, manager reward, workers reward
            state_in_initial_step["reward_mask"] = np.copy(state_in_initial_step["policy_mask"])
            state_in_initial_step["manager_reward"] = np.zeros((1,), dtype=np.float32)
            state_in_initial_step["workers_reward"] = np.zeros((len(state_in_initial_step["personal_nodes"]),),
                                                               dtype=np.float32)

            # qa statistical information
            state_in_initial_step["total_qa_turn"] = 0

            # workers qa statistical information
            state_in_initial_step["workers_qa_turn"] = np.zeros((len(graph["personal_nodes"]),), dtype=np.int32)

            # debug
            state_in_initial_step["manager_action_prob"] = PLACE_HOLDER
            state_in_initial_step["workers_action_prob"] = PLACE_HOLDER

            # for Rule based warm up
            state_in_initial_step["workers_counter"] = [{"Known": 0, "UnKnown": 0},
                                                        {"Known": 0, "UnKnown": 0},
                                                        {"Known": 0, "UnKnown": 0},
                                                        {"Known": 0, "UnKnown": 0}]

            # for divide in loss
            state_in_initial_step["valid_workers_num"] = 0

            state_tracker_field.append(state_in_initial_step)
            batch_state_tracker_field.append(state_tracker_field)

        return batch_state_tracker_field

    @staticmethod
    def get_policy_field(batch, batch_node_num, max_answer_node_num):
        batch_manager_actions = list()
        batch_workers_actions = list()
        batch_workers_max_qa_turn = list()

        for graph, current_node_num in zip(batch, batch_node_num):
            manager_actions = np.asarray(graph["personal_nodes"], dtype=np.int32)
            batch_manager_actions.append(manager_actions)

            workers_actions = np.full((len(graph["personal_nodes"]), max_answer_node_num), current_node_num,
                                      dtype=np.int32)
            workers_max_qa_turn = np.zeros((len(graph["personal_nodes"]),), dtype=np.int32)
            for worker_idx, answer_nodes in graph["one_step_nodes"].items():
                if len(answer_nodes) > 0:
                    workers_actions[int(worker_idx), :len(answer_nodes)] = np.asarray(answer_nodes, dtype=np.int32)
                    workers_max_qa_turn[int(worker_idx)] = len(answer_nodes) if len(
                        answer_nodes) < MaxWorkerExploringTimeStep else MaxWorkerExploringTimeStep

            batch_workers_actions.append(workers_actions)
            batch_workers_max_qa_turn.append(workers_max_qa_turn)

        policy_field = dict()
        policy_field["manager_actions"] = np.concatenate(
            [[manager_actions] for manager_actions in batch_manager_actions])
        policy_field["workers_actions"] = np.concatenate(
            [[workers_actions] for workers_actions in batch_workers_actions])
        policy_field["workers_max_qa_turn"] = np.concatenate(
            [[workers_max_qa_turn] for workers_max_qa_turn in batch_workers_max_qa_turn])

        return policy_field


class GraphPreprocessRL(GraphPreprocess):
    def __init__(self, data_set_dir="preprocessed_graphs", batch_size=32):
        super(GraphPreprocessRL, self).__init__(data_set_dir, batch_size)

    def preprocess_batch(self, batch):
        graph_embed_field, batch_node_num, max_node_num, _, max_qa_pair_num, batch_qa_pair_num = self.get_graph_embed_filed(
            batch)
        policy_field = self.get_policy_field(batch, batch_node_num, max_qa_pair_num)
        knowledge_sampler_filed = self.get_knowledge_sampler_field(batch)
        language_generation_filed = self.get_language_generation_field(batch)
        state_tracker_field = self.get_state_tracker_field(batch, max_node_num, max_qa_pair_num, batch_qa_pair_num)
        preprocessed_batch = {"graph_embed_field": graph_embed_field,
                              "policy_field": policy_field,
                              "knowledge_sampler_filed": knowledge_sampler_filed,
                              "language_generation_filed": language_generation_filed,
                              "state_tracker_field": state_tracker_field}
        return preprocessed_batch

    @staticmethod
    def get_policy_field(batch, batch_node_num, max_qa_pair_num):
        batch_actions = list()
        batch_workers_max_qa_turn = list()

        for graph, current_node_num in zip(batch, batch_node_num):
            actions = list()
            for q_node, answer_nodes in graph["one_step_nodes"].items():
                for answer_node in answer_nodes:
                    actions.append((int(q_node), int(answer_node)))
            for _ in range(max_qa_pair_num - len(actions)):
                actions.append((int(current_node_num), int(current_node_num)))
            batch_actions.append(actions)

            workers_max_qa_turn = np.zeros((len(graph["personal_nodes"]),), dtype=np.int32)
            batch_workers_max_qa_turn.append(workers_max_qa_turn)

        policy_field = dict()
        policy_field["actions"] = np.concatenate(
            [[actions] for actions in batch_actions])
        policy_field["workers_max_qa_turn"] = np.concatenate(
            [[workers_max_qa_turn] for workers_max_qa_turn in batch_workers_max_qa_turn])

        return policy_field

    @staticmethod
    def get_state_tracker_field(batch, max_node_num, max_qa_pair_num, batch_qa_pair_num):
        batch_state_tracker_field = list()
        for graph, qa_pair_num in zip(batch, batch_qa_pair_num):
            state_tracker_field = list()
            state_in_initial_step = dict()

            state_in_initial_step["max_node_num"] = max_node_num
            state_in_initial_step["nodes"] = copy.deepcopy(graph["nodes"])

            # for all nodes
            state_in_initial_step["explored_nodes"] = set()
            state_in_initial_step["last_turn_q_node"] = None
            state_in_initial_step["last_turn_a_node"] = None
            state_in_initial_step["not_explored_nodes"] = set(copy.deepcopy(graph["nodes"]))

            # only for the answer nodes
            state_in_initial_step["known_nodes"] = set()
            state_in_initial_step["unknown_nodes"] = set()
            state_in_initial_step["not_answered_nodes"] = set(
                list(chain(*graph["one_step_nodes"].values())))

            # dialogue feature of current state, calculate before all NN operations
            state_in_initial_step["dialogue_feature"] = PLACE_HOLDER

            # episode end flag
            state_in_initial_step["episode_not_end"] = True

            # exploring turn counter
            state_in_initial_step["total_exploring_turn"] = 0

            # rl action mask
            rl_action_mask = np.zeros((max_qa_pair_num + 2,), dtype=np.int32)
            rl_action_mask[:qa_pair_num] = 1
            state_in_initial_step["rl_action_mask"] = rl_action_mask
            state_in_initial_step["warm_up_action_mask"] = rl_action_mask.copy()

            state_in_initial_step["decision"] = PLACE_HOLDER
            state_in_initial_step["success_state"] = PLACE_HOLDER

            state_in_initial_step["sample_idx"] = PLACE_HOLDER

            state_in_initial_step["system_action"] = PLACE_HOLDER

            state_in_initial_step["reward"] = np.zeros((), dtype=np.float32)

            state_in_initial_step["total_qa_turn"] = 0

            # debug
            state_in_initial_step["action_prob"] = PLACE_HOLDER

            # for Rule based warm up
            state_in_initial_step["workers_counter"] = [{"Known": 0, "UnKnown": 0},
                                                        {"Known": 0, "UnKnown": 0},
                                                        {"Known": 0, "UnKnown": 0},
                                                        {"Known": 0, "UnKnown": 0}]

            state_tracker_field.append(state_in_initial_step)
            batch_state_tracker_field.append(state_tracker_field)

        return batch_state_tracker_field

# Language Generator:


In [23]:
# coding = utf-8
import os
import json
import re
import random
import copy
#from src.config import DATA_ROOT, FRAUD, NON_FRAUD, NegativeSampledAnswerNum, Options, ShowUnknown, UnKnownUtterance

with open(os.path.join(DATA_ROOT, "idx2r.json"), "r", encoding= 'utf-8') as f:
    idx2r = json.load(f)

with open(os.path.join(DATA_ROOT, "answersLibrary.json"), "r", encoding= 'utf-8') as f:
    answers_library = json.load(f)

with open(os.path.join(DATA_ROOT, "languageTemplates.json"), "r", encoding= 'utf-8') as f:
    templates = json.load(f)





class LanguageGenerator(object):
    def __init__(self, language_generation_filed):
        self.idx2r = idx2r
        self.language_generation_filed = language_generation_filed
        self.query_entity = None
        self.answers_library = answers_library
        self.templates = templates

    def generate_question(self, system_action):
        if "worker" in system_action.keys() and system_action["worker"]["sample_content"] not in [FRAUD, NON_FRAUD]:
            return self.query_entity["node_id"], system_action["worker"]["sample_content"]
        elif "manager" in system_action.keys() and system_action["manager"]["sample_content"] not in [FRAUD, NON_FRAUD]:
            query_entity_node_id = system_action["manager"]["sample_content"]
            node = self.language_generation_filed["idx2node"][str(query_entity_node_id)]
            self.query_entity = {"node": node, "node_id": query_entity_node_id}
            return None, None
        else:
            return None, None

    def _generate_sys_nl(self, h, r, t):
        """
        Assume the h r t have been transferred to NL
        return: natural language question, candidates, the correct answer option
        """
        question = self.templates[str(r)]
        question = random.choice(question)
        question = re.sub(r"\$\S\$", h, question)
        

        # To avoid user using exclusive method,
        # the sampled negative answers should have similar appearance to the correct answer.
        all_candidates = copy.deepcopy(self.answers_library[r])
        print(all_candidates , t)
        all_candidates.remove(t)
        for c in all_candidates:
            if t.find(c) != -1 or c.find(t) != -1:
                all_candidates.remove(c)

        answers_candidates = random.sample(all_candidates, NegativeSampledAnswerNum)
        answers_candidates.append(t)
        random.shuffle(answers_candidates)
        answers_candidates.append("Not Sure")
        #candidates = "  ".join([" ".join([option, answer]) for option, answer in zip(Options, answers_candidates)])
        candidates = [(option, answer) for option, answer in zip(Options, answers_candidates)]

        return question, candidates

    def generate_sys_nl(self, h, t):
        r = self.language_generation_filed["h_t_to_r"][str(h) + " " + str(t)]
        h = self.language_generation_filed["idx2node"][str(h)]
        t = self.language_generation_filed["idx2node"][str(t)]
        r = self.idx2r[str(r)]

        question, candidates = self._generate_sys_nl(h, r, t)
        #return "    ".join([question, candidates])
        return question, candidates
    def return_user_answer(choice):
      #ans = input()

      pass

    @staticmethod
    def generate_user_nl(user_answer):
        if user_answer != ShowUnknown:
            return user_answer
        else:
            return UnKnownUtterance


# Rollout:

In [24]:
def return_questions(graph_embed_field,
                             policy_field,
                             state_tracker_field,
                             state_trackers,
                             #users,
                             language_generators,
                             #dialogue_recorders,
                             sample_flag,
                             **models):
  
  
  
  # declare system mode
  if sample_flag == "random":
      mode = "RuleWarmUp"
  else:
      mode = "RL"

  # rollout
  if np.asarray([field[-1]["episode_not_end"] for field in state_tracker_field], dtype=np.bool).any():
      # forward for GNN
      graph_embed_feed_dict = build_graph_embed_inputs(graph_embed_field, state_tracker_field)
      final_node_embed = models["gnn"](graph_embed_feed_dict["initial_node_embed"],
                                        graph_embed_feed_dict["edges"],
                                        graph_embed_feed_dict["node_edges"],
                                        graph_embed_feed_dict["node_edge_mask"])

      # forward for workers state tracker
      workers_state_tracker_feed_dict = build_workers_state_tracker_inputs(state_tracker_field, policy_field,
                                                                            final_node_embed)
      workers_state = models["workers_state_tracker"](workers_state_tracker_feed_dict["known_one_hot"],
                                                      workers_state_tracker_feed_dict["unknown_one_hot"],
                                                      workers_state_tracker_feed_dict["known_differ_one_hot"],
                                                      workers_state_tracker_feed_dict["workers_qa_turn_one_hot"],
                                                      workers_state_tracker_feed_dict["workers_max_qa_turn_one_hot"],
                                                      workers_state_tracker_feed_dict["personal_nodes"],
                                                      workers_state_tracker_feed_dict["final_node_embed"])

      # forward for workers
      workers_feed_dict = build_workers_inputs(workers_state, policy_field, final_node_embed)
      _, workers_logits = models["workers"](workers_feed_dict["workers_state"],
                                            workers_feed_dict["answer_nodes"],
                                            workers_feed_dict["graph_node_embedding"])

      # get the action mask
      action_masks_dict = build_hierarchy_action_masks(state_tracker_field, mode)

      # get the RL workers decision (batch_size, personal_node_num, 2)
      rl_workers_probs = mask_softmax(workers_logits, action_masks_dict["rl_workers_action_mask"], dim=2).detach()
      rl_workers_decision = rl_workers_probs[:, :, -2:]

      # forward for manager state tracker
      manager_state_tracker_feed_dict = build_manager_state_tracker_inputs(graph_embed_field, final_node_embed,
                                                                            rl_workers_decision,
                                                                            state_tracker_field)
      manager_state = models["manager_state_tracker"](manager_state_tracker_feed_dict["feasible_personal_info_nodes"],
                                                      manager_state_tracker_feed_dict["workers_decision"],
                                                      manager_state_tracker_feed_dict["known_one_hot"],
                                                      manager_state_tracker_feed_dict["unknown_one_hot"],
                                                      manager_state_tracker_feed_dict["known_differ_one_hot"],
                                                      manager_state_tracker_feed_dict["total_qa_turn_one_hot"],
                                                      manager_state_tracker_feed_dict["personal_nodes"],
                                                      manager_state_tracker_feed_dict["final_node_embed"])

      # forward for manager
      manager_feed_dict = build_hierarchy_manager_inputs(manager_state, workers_state, policy_field)
      _, manager_logits = models["manager"](manager_feed_dict["manager_state"],
                                            manager_feed_dict["workers_state"])

      # get current manager and workers policy distribution
      if mode == "RL":
          manager_probs = mask_softmax(manager_logits, action_masks_dict["rl_manager_action_mask"], dim=1)
          workers_probs = mask_softmax(workers_logits, action_masks_dict["rl_workers_action_mask"], dim=2)
      else:
          manager_probs = mask_softmax(manager_logits, action_masks_dict["warm_up_manager_action_mask"], dim=1)
          workers_probs = mask_softmax(workers_logits, action_masks_dict["warm_up_workers_action_mask"], dim=2)

      # sample from probs
      manager_action_probs = manager_probs.cpu().detach().numpy()
      workers_action_probs = workers_probs.cpu().detach().numpy()
      manager_actions = manager_feed_dict["personal_nodes"].cpu().detach().numpy()
      workers_actions = workers_feed_dict["answer_nodes"].cpu().detach().numpy()
      manager_sample_idxs, manager_sample_results, workers_sample_idxs, workers_sample_results = sample_hierarchy_rl(
          manager_action_probs,
          workers_action_probs,
          manager_actions,
          workers_actions,
          sample_flag)
      
      for state_tracker, language_generator, manager_sample_idx, \
            manager_sample_result, manager_action_prob, workers_sample_idx, \
            workers_sample_result, workers_action_prob in zip(state_trackers,
                                                              language_generators,
                                                              manager_sample_idxs,
                                                              manager_sample_results,
                                                              manager_action_probs,
                                                              workers_sample_idxs,
                                                              workers_sample_results,
                                                              workers_action_probs):
              
            questions, choices , state_in_new_step , system_action , language_generator , a_node , q_node  = state_tracker.move_a_step(
                                          language_generator,
                                          manager_sample_idx,
                                          manager_sample_result,
                                          manager_action_prob,
                                          workers_sample_idx,
                                          workers_sample_result,
                                          workers_action_prob,  mode = "RL")
            return questions, choices , state_in_new_step , system_action  , state_tracker , language_generator ,a_node ,q_node
            
  else:
    for state_tracker in state_trackers:
      del state_tracker.state_tracker_field[-1]


    

# def after_answer(questions, choices , state_in_new_step , system_action , state_tracker_field):
      

# Try:

In [25]:
# graph_preprocess = GraphPreprocessHRL(batch_size=1)

In [26]:
def initialize():
    graph_preprocess = GraphPreprocessHRL(batch_size=1)
    for batch in graph_preprocess.generator('test', shuffle=False):
      graph_embed_field = batch["graph_embed_field"]
      policy_field = batch["policy_field"]
      state_tracker_field = batch["state_tracker_field"]
      state_trackers = [StateTrackerHRL(field) for field in state_tracker_field]
      language_generators = [LanguageGenerator(field) for field in batch["language_generation_filed"]]
      #print(batch['language_generation_filed'][1]['h_t_to_r']['0 9'])
    return graph_embed_field , policy_field , state_tracker_field , state_trackers , language_generators



#Flask

In [27]:
# # flask_ngrok_example.py
# # !pwd
# from flask import Flask, render_template, request
# # from flask_ngrok import run_with_ngrok

# app = Flask(__name__)
# # run_with_ngrok(app)  # Start ngrok when app is run

# user_answer= ''
# state_in_new_step = { }
# system_action =  { }
# state_tracker = 0
# language_generator = 0
# ####################################################################################
# # @app.route('/', methods=["GET", "POST"])
# # def App():

# #   return render_template( 'chatapp.html' )


# @app.route('/', methods=["GET"])
# def show_questions():
 
#    questions, choices , state_in_new_step2 , system_action2  , state_tracker2 , language_generator2   = return_questions(graph_embed_field,
#                              policy_field,
#                              state_tracker_field,
#                              state_trackers,
#                              #users,
#                              language_generators,
#                              #dialogue_recorders,
#                              "max",
#                              **models)
#    print(language_generator2)
  
#    global state_in_new_step
#    state_in_new_step= state_in_new_step2
#    global system_action 
#    system_action= system_action2
#    global state_tracker 
#    state_tracker = state_tracker2
#    global language_generator
#    language_generator  = language_generator2
   

#    return render_template('chatapp.html', 
#                           question= questions,
#                           answer= choices
                         
#                          )
  
# @app.route('/answer' , methods= ["POST"])
# def get_answer():
#   if request.method =='POST':
#     if 'submit_button' in request.form:
#       user_answer=request.form['options']
      
#       global state_in_new_step
#       state_in_new_step3 =  state_in_new_step 
#       global system_action 
#       system_action3 =  system_action
#       global state_tracker
#       state_tracker3 = state_tracker
#       global language_generator
#       language_generator3 = language_generator
#       print(list(language_generator.language_generation_filed['idx2node'].keys()))
      
#       state_tracker3.continue_a_step(user_answer,state_in_new_step3,system_action3,language_generator3,choices)

#       questions, choices , state_in_new_step4 , system_action4  , state_tracker4 , language_generator4  = return_questions(graph_embed_field,
#                              policy_field,
#                              state_tracker_field,
#                              state_trackers,
#                              #users,
#                              language_generators,
#                              #dialogue_recorders,
#                              "max",
#                              **models)
      
#       state_in_new_step = state_in_new_step4
#       system_action = system_action4
#       state_tracker = state_tracker4
#       language_generator = language_generator4
      
#       return (questions, choices)
  
# # @app.route('/' , methods= ["POST"])
# # def get_answer():
# #   if request.method =='POST':
# #     if 'submit_button' in request.form:
# #       user_answer=request.form['options']
      
# #       global state_in_new_step
# #       state_in_new_step3 =  state_in_new_step 
# #       global system_action
# #       system_action3 =  system_action 
# #       global state_tracker_field
# #       state_tracker_field3 =  state_tracker_field

# #       state_tracker_field4 = after_answer(user_answer,  state_in_new_step3 , system_action3 , state_tracker_field3)

# #       questions, choices , state_in_new_step4 , system_action4 , state_tracker_field4  = return_questions(graph_embed_field,
# #                              policy_field,
# #                              state_tracker_field4,
# #                              state_trackers,
# #                              #users,
# #                              language_generators,
# #                              #dialogue_recorders,
# #                              sample_flag,
# #                              **models)
      
# #       state_in_new_step = state_in_new_step4
# #       system_action = system_action4
# #       state_tracker_field = state_tracker_field4
      
# #       return (questions, choices)

# ######################################################################################################################

#   # new_rollout(graph_embed_field,
#   #             policy_field,
#   #             state_tracker_field,
#   #             state_trackers,
#   #             language_generators,
#   #             'max',
#   #             **models
#   #             )
#   # user_answer= ''

#   # if request.method =='POST':
#   #     if 'submit_button' in request.form:
#   #         user_answer=request.form['options']
#   #         print(user_answer) 
#   # #ans = request.form['options']
#   # #print(user_answer)
#   # return render_template('chatapp.html', 
#   #                        question= question,
#   #                        answer= answer,
#   #                        ans = user_answer)

# if __name__ == '__main__' :
#   app.run()

In [28]:
from flask import Flask, render_template, request

app = Flask(__name__)
state_in_new_step = 0
system_action = 0
state_tracker = 0
language_generator = 0
choices = 0
a_node = 0
q_node = 0
graph_embed_field , policy_field , state_tracker_field , state_trackers , language_generators = 0,0,0,0,0
@app.route('/', methods=["GET"])
def get_question():
    global graph_embed_field , policy_field , state_tracker_field , state_trackers , language_generators
    graph_embed_field , policy_field , state_tracker_field , state_trackers , language_generators = initialize()
    
    global state_in_new_step
    global system_action
    global state_tracker
    global language_generator
    global choices
    global a_node 
    global q_node
    questions = None 
    while questions is None:
        questions, choices , state_in_new_step , system_action  , state_tracker , language_generator ,a_node ,q_node = return_questions(
                                 graph_embed_field,
                                 policy_field,
                                 state_tracker_field,
                                 state_trackers,
                                 #users,
                                 language_generators,
                                 #dialogue_recorders,
                                 "max",
                                 **models)
        
        if questions is None:
            case  = state_tracker.continue_a_step(None,state_in_new_step,system_action,language_generator,choices ,a_node,q_node)     
            if case == 'continue':
                pass
            elif case == 'fraud':
                return render_template('Fraud.html')
            elif case == 'nonfraud':
                return render_template('Non-Fraud.html')

   
        
    return render_template('chatapp.html', question= questions , answer = choices)

@app.route('/answer', methods=["POST"])
def get_answer():
    global state_in_new_step
    global system_action
    global state_tracker
    global language_generator
    global choices
    global a_node 
    global q_node
    user_answer= ''
    if request.method =='POST':
        if 'submit_button' in request.form:
            user_answer=request.form['options']
            print(user_answer , 'USERANSWER') 
    case = state_tracker.continue_a_step(user_answer,state_in_new_step,system_action,language_generator, choices ,a_node,q_node)
    if case == 'continue':
        pass
    elif case == 'fraud':
            return render_template('Fraud.html')
    elif case == 'nonfraud':
            return render_template('Non-Fraud.html')

            
            
    questions = None 
    while questions is None:
        if (case == 'fraud') or (case == 'nonfraud'):
            break
        questions, choices , state_in_new_step , system_action  , state_tracker , language_generator ,a_node ,q_node = return_questions(
                                 graph_embed_field,
                                 policy_field,
                                 state_tracker_field,
                                 state_trackers,
                                 #users,
                                 language_generators,
                                 #dialogue_recorders,
                                 "max",
                                 **models)
        if questions is None:
            case = state_tracker.continue_a_step(None,state_in_new_step,system_action,language_generator,choices ,a_node,q_node)
            if case == 'continue':
                pass
            elif case == 'fraud':
                    return render_template('Fraud.html')
            elif case == 'nonfraud':
                    
                    return render_template('Non-Fraud.html')

    
    return render_template('chatapp.html', question= questions,answer= choices)

# state_tracker.continue_a_step(user_answer,state_in_new_step,system_action,language_generator,choices ,a_node,q_node)
if __name__ == '__main__' :
  app.run()

 * Serving Flask app "__main__" (lazy loading)
 * Environment: production
[2m   Use a production WSGI server instead.[0m
 * Debug mode: off


 * Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)
127.0.0.1 - - [22/Jan/2022 14:39:16] "[37mGET / HTTP/1.1[0m" 200 -


Decision step
['اسواق عباد الرحمن', 'كارفور ماركت', 'مكتب تموين جمعيتي', 'أسواق الشيتانى'] كارفور ماركت


127.0.0.1 - - [22/Jan/2022 14:39:17] "[33mGET /favicon.ico HTTP/1.1[0m" 404 -
127.0.0.1 - - [22/Jan/2022 14:39:20] "[37mPOST /answer HTTP/1.1[0m" 200 -


C USERANSWER
wrong answer!!!  22
['مركز براعم للاطفال', 'مستشفى صندلا المركزى.. Central Hospital of Sundela', 'عيادة د عبدالقادر حجازي ش ابراهيم المغازي ناصية ش الصداقة', 'Elaf Egypt'] Elaf Egypt


127.0.0.1 - - [22/Jan/2022 14:39:31] "[37mPOST /answer HTTP/1.1[0m" 200 -


A USERANSWER
wrong answer!!!  17
['سوبر ماركت الربيع', 'سليم ماركت', 'منصة مكوك للمتاجر الالكترونية'] منصة مكوك للمتاجر الالكترونية


127.0.0.1 - - [22/Jan/2022 14:39:36] "[37mPOST /answer HTTP/1.1[0m" 200 -


B USERANSWER
wrong answer!!!  21
Decision step
Decision step
FRAUD !!!!!
