Skip to content

Commit

Permalink
Refactoring candgen and planner
Browse files Browse the repository at this point in the history
- Removing obsolete code (sampling planner)
- Making get_all_successors() stateful (keeps track of the DA
  for which we are generating). This makes the API simpler, allows
  to make several methods private.
  • Loading branch information
tuetschek committed Nov 16, 2015
1 parent 8ffec59 commit 4a13d85
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 93 deletions.
79 changes: 37 additions & 42 deletions tgen/candgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ def __init__(self, cfg):
# do the same also for DA slots?
self.compatible_slots = cfg.get('compatible_slots', False)

# cache fields for generating successors:
self.cur_da = None
self.cur_cdfs = None
self.cur_limits = None

@staticmethod
def load_from_file(fname):
log_info('Loading model from ' + fname)
Expand Down Expand Up @@ -129,10 +134,10 @@ def train(self, da_file, t_file):
# prune counts
if self.prune_threshold > 1:
for dai, forms in child_type_counts.items():
self.prune(forms)
self._prune(forms)
if not forms:
del child_type_counts[dai]
self.prune(child_num_counts)
self._prune(child_num_counts)

# transform counts
self.child_type_counts = child_type_counts
Expand Down Expand Up @@ -196,17 +201,28 @@ def _parent_node_id(self, node):
return (node.t_lemma, node.formeme)
return node.formeme

def prune(self, counts):
def _prune(self, counts):
"""Prune a counts dictionary, keeping only items with counts above
the prune_threshold given in the constructor."""
the prune_threshold given in the constructor.
"""
for parent_type, child_types in counts.items():
for child_form, child_count in child_types.items():
if child_count < self.prune_threshold:
del child_types[child_form]
if not counts[parent_type]:
del counts[parent_type]

def get_merged_child_type_cdfs(self, da):
def init_run(self, da):
"""TODO
cdfs: Merged CDFs of children given the current DA (obtained using _get_merged_child_type_cdfs)
node_limits: limits on the number of nodes (total and on different child_depth levels, \
obtained via get_merged_limits)
"""
self.cur_da = da
self.cur_cdfs = self._get_merged_child_type_cdfs(da)
self.cur_limits = self.get_merged_limits(da)

def _get_merged_child_type_cdfs(self, da):
"""Get merged child CDFs (i.e. lists of possible children, given parent IDs) for the
given DA.
Expand All @@ -231,7 +247,7 @@ def get_merged_child_type_cdfs(self, da):
# minimum compatibility DAIs is not similar to the current DA)
for _, counts in merged_counts.items():
for node in counts.keys():
if not self.compatible(da, NodeData(t_lemma=node[1], formeme=node[0])):
if not self._compatible(da, NodeData(t_lemma=node[1], formeme=node[0])):
del counts[node]

# log_info('Node types after pruning: %d' % sum(len(c.keys()) for c in merged_counts.values()))
Expand All @@ -240,7 +256,7 @@ def get_merged_child_type_cdfs(self, da):

return self.cdfs_from_counts(merged_counts)

def compatible(self, da, node):
def _compatible(self, da, node):
"""This limits the possibility of candidates (if compatible_dais are used). It returns
true only if the given node is "compatible enough" with the given DA.
Expand Down Expand Up @@ -272,7 +288,7 @@ def get_merged_limits(self, da):
"""Return merged limits on node counts (total and on each tree level). Uses a
maximum for all DAIs in the given DA.
Returns none if the given candidate generator does not have any node limits.
Returns None if the given candidate generator does not have any node limits.
@param da: the current dialogue act
@rtype: defaultdict(Counter)
Expand Down Expand Up @@ -304,7 +320,8 @@ def cdfs_from_counts(self, counts):

def exp_from_cdfs(self, cdfs):
"""Given a dictionary of CDFs (with numeric subkeys), create a dictionary of
corresponding expected values. Used for children counts."""
corresponding expected values. Used for children counts.
"""
exps = {}
for key, cdf in cdfs.iteritems():
# convert the CDF -- array of tuples (value, cumulative probability) into an
Expand All @@ -323,40 +340,20 @@ def exp_from_cdfs(self, cdfs):
exps[key] = sum(1 - cdf_val for cdf_val in cdf_arr)
return exps

def sample(self, cdf):
"""Return a sample from the distribution, given a CDF (as a list)."""
total = cdf[-1][1]
rand = rnd.random() * total # get a random number in [0,total)
for key, ubound in cdf:
if ubound > rand:
return key
raise Exception('Unable to generate from CDF!')

def get_number_of_children(self, parent_id):
if parent_id not in self.child_num_cdfs:
return 0
return self.sample(self.child_num_cdfs[parent_id])

def get_best_child(self, parent, da, cdf):
return self.sample(cdf)

def get_all_successors(self, cand_tree, cdfs, node_limits=None):
def get_all_successors(self, cand_tree):
"""Get all possible successors of a candidate tree, given CDFS and node number limits.
NB: This assumes projectivity (will never create a non-projective tree).
@param cand_tree: The current candidate tree to be expanded
@param cdfs: Merged CDFs of children given the current DA (obtained using get_merged_child_type_cdfs)
@param node_limits: limits on the number of nodes (total and on different child_depth levels, \
obtained via get_merged_limits)
"""
# TODO possibly avoid creating TreeNode instances for iterating
nodes = TreeNode(cand_tree).get_descendants(add_self=1, ordered=1)
nodes_on_level = defaultdict(int)
res = []
if node_limits is not None:
if self.cur_limits is not None:
# stop if maximum number of nodes is reached
if len(nodes) >= node_limits['total']:
if len(nodes) >= self.cur_limits['total']:
return []
# remember number of nodes on all levels
for node in nodes:
Expand All @@ -367,15 +364,15 @@ def get_all_successors(self, cand_tree, cdfs, node_limits=None):
# skip nodes that can't have more children
parent_id = self._parent_node_id(node)
if (len(node.get_children()) >= self.max_children.get(parent_id, 0) or
parent_id not in cdfs):
parent_id not in self.cur_cdfs):
continue
# skip nodes above child_depth levels where the maximum number of nodes has been reached
if node_limits is not None:
if self.cur_limits is not None:
child_depth = node.get_depth() + 1
if nodes_on_level[child_depth] >= node_limits[child_depth]:
if nodes_on_level[child_depth] >= self.cur_limits[child_depth]:
continue
# try all formeme/t-lemma/direction variants of a new child under the given parent node
for formeme, t_lemma, right in map(lambda item: item[0], cdfs[parent_id]):
for formeme, t_lemma, right in map(lambda item: item[0], self.cur_cdfs[parent_id]):
# place the child directly following/preceding the parent
succ_tree = cand_tree.clone()
succ_tree.create_child(node_num, right, NodeData(t_lemma, formeme))
Expand Down Expand Up @@ -410,8 +407,7 @@ def can_generate(self, tree, da):
Tries if get_all_successors always returns a successor that leads to the given tree
(puts on the open list only successors that are subtrees of the given tree).
"""
cdfs = self.get_merged_child_type_cdfs(da)
node_limits = self.get_merged_limits(da)
self.init_run(da)
open_list = CandidateList({TreeData(): 1})
found = False
tree_no = 0
Expand All @@ -421,7 +417,7 @@ def can_generate(self, tree, da):
if cur_st == tree:
found = True
break
for succ in self.get_all_successors(cur_st, cdfs, node_limits):
for succ in self.get_all_successors(cur_st):
tree_no += 1
# only push on the open list if the successor is still a subtree of the target tree
if tree.common_subtree_size(succ) == len(succ):
Expand All @@ -441,14 +437,13 @@ def can_generate_greedy(self, tree, da):
Uses `get_all_successors` and always goes on with the first one that increases coverage
of the current tree.
"""
cdfs = self.get_merged_child_type_cdfs(da)
node_limits = self.get_merged_limits(da)
self.init_run(da)
cur_subtree = TreeData()
found = True

while found and cur_subtree != tree:
found = False
for succ in self.get_all_successors(cur_subtree, cdfs, node_limits):
for succ in self.get_all_successors(cur_subtree):
# use the first successor that is still a subtree of the target tree
if tree.common_subtree_size(succ) == len(succ):
cur_subtree = succ
Expand Down
53 changes: 2 additions & 51 deletions tgen/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,51 +180,6 @@ def get_target_zone(self, gen_doc):
return zone


class SamplingPlanner(SentencePlanner):
"""Random t-tree generator given DAs.
TODO: This is obsolete, it will not work after the introduction of TreeData.
Fix it or remove it (there's probably no point in having it now).
"""

MAX_TREE_SIZE = 50

def __init__(self, cfg):
super(SamplingPlanner, self).__init__(cfg)
# ranker (selecting the best candidate)
self.ranker = None
if 'ranker' in cfg:
self.ranker = cfg['ranker']

def generate_tree(self, da, gen_doc=None):
root = TreeNode(TreeData())
cdfs = self.candgen.get_merged_child_type_cdfs(da)
nodes = deque([self.generate_child(root, da, cdfs[root.formeme])])
treesize = 1
while nodes and treesize < self.MAX_TREE_SIZE:
node = nodes.popleft()
if node.formeme not in cdfs: # skip weirdness
continue
for _ in xrange(self.candgen.get_number_of_children(node.formeme)):
child = self.generate_child(node, da, cdfs[node.formeme])
nodes.append(child)
treesize += 1
if gen_doc:
zone = self.get_target_zone(gen_doc)
zone.ttree = root.create_ttree()
return
return root.tree

def generate_child(self, parent, da, cdf):
"""Generate one t-node, given its parent and the CDF for the possible children."""
if self.ranker:
formeme, t_lemma, right = self.ranker.get_best_child(parent, da, cdf)
else:
formeme, t_lemma, right = self.candgen.sample(cdf)
child = parent.create_child(right, NodeData(t_lemma, formeme))
return child


class ASearchPlanner(SentencePlanner):
"""Sentence planner using A*-search."""

Expand Down Expand Up @@ -269,8 +224,6 @@ def reset(self):
self.defic_iter = 0
self.num_iter = -1
self.input_da = None
self.cdfs = None
self.node_limits = None

def init_run(self, input_da, max_iter=None, max_defic_iter=None, beam_size=None):
"""Init the A*-search generation for the given input DA, with the given parameters
Expand All @@ -294,8 +247,7 @@ def init_run(self, input_da, max_iter=None, max_defic_iter=None, beam_size=None)
self.input_da = input_da
self.defic_iter = 0
self.num_iter = 0
self.cdfs = self.candgen.get_merged_child_type_cdfs(input_da)
self.node_limits = self.candgen.get_merged_limits(input_da)
self.candgen.init_run(input_da)

if max_iter is not None:
self.max_iter = max_iter
Expand Down Expand Up @@ -331,8 +283,7 @@ def run_iter(self):
self.close_list.push(cand, score[1]) # only use score without future promise
log_debug("-- IT %4d: O %5d S %12.5f -- %s" %
(self.num_iter, len(self.open_list), -score[1], unicode(cand)))
successors = [succ for succ
in self.candgen.get_all_successors(cand, self.cdfs, self.node_limits)
successors = [succ for succ in self.candgen.get_all_successors(cand)
if succ not in self.close_list]

if successors:
Expand Down

0 comments on commit 4a13d85

Please sign in to comment.