Skip to content

Commit

Permalink
Refactored to avoid code duplication
Browse files Browse the repository at this point in the history
  • Loading branch information
tuetschek committed May 14, 2015
1 parent 19ab079 commit 3050c57
Showing 1 changed file with 58 additions and 75 deletions.
133 changes: 58 additions & 75 deletions tgen/rank_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,60 @@
from tgen.logf import log_debug, log_info


class SimpleNNRanker(FeaturesPerceptronRanker):
class NNRanker(BasePerceptronRanker):
"""Set of methods to be used in any NN ranker."""

def store_iter_weights(self):
"""Remember the current weights to be used for averaged perceptron."""
self.w_after_iter.append(self.nn.get_param_values())

def update_weights_sum(self):
"""Update the current weights sum figure."""
vals = self.nn.get_param_values()
# only use the last layer for summation (w, b)
self.w_sum = np.sum(vals[-2]) + np.sum(vals[-1])

def get_weights_sum(self):
"""Return the sum of weights (at start of current iteration) to be used to weigh future
promise."""
return self.w_sum

def get_weights(self):
"""Return the current neural net weights."""
return self.nn.get_param_values()

def set_weights(self, w):
"""Set new neural network weights."""
self.nn.set_param_values(w)

def set_weights_average(self, wss):
"""Set the weights as the average of the given array of weights (used in parallel training)."""
self.nn.set_param_values(np.average(wss, axis=0))

def set_weights_iter_average(self):
"""Average the remembered weights."""
self.nn.set_param_values(np.average(self.w_after_iter, axis=0))

def _update_weights(self, da, good_tree, bad_tree, good_feats, bad_feats):
"""Update NN weights, given a DA, a good and a bad tree, and their features."""
if self.diffing_trees:
good_sts, bad_sts = good_tree.diffing_trees(bad_tree, symmetric=True)
for good_st, bad_st in zip(good_sts, bad_sts):
good_feats = self._extract_feats(good_st, da)
bad_feats = self._extract_feats(bad_st, da)
subtree_w = 1
if self.diffing_trees.endswith('weighted'):
subtree_w = (len(good_st) + len(bad_st)) / float(len(good_tree) + len(bad_tree))
self._update_nn(bad_feats, good_feats, subtree_w * self.alpha)
else:
self._update_nn(bad_feats, good_feats, self.alpha)

def _update_nn(self, bad_feats, good_feats, rate):
"""Direct call to NN weights update."""
self.nn.update(bad_feats, good_feats, rate)


class SimpleNNRanker(FeaturesPerceptronRanker, NNRanker):
"""A simple ranker using a neural network on top of the usual features; using the same
updates as the original perceptron as far as possible."""

Expand Down Expand Up @@ -54,50 +107,6 @@ def _init_neural_network(self):
self.nn = NN([[FeedForwardLayer('perc', self.train_feats.shape[1], 1,
None, self.initialization)]])

def _update_weights(self, da, good_tree, bad_tree, good_feats, bad_feats):
if self.diffing_trees:
good_sts, bad_sts = good_tree.diffing_trees(bad_tree, symmetric=True)
for good_st, bad_st in zip(good_sts, bad_sts):
good_feats = self._extract_feats(good_st, da)
bad_feats = self._extract_feats(bad_st, da)
subtree_w = 1
if self.diffing_trees.endswith('weighted'):
subtree_w = (len(good_st) + len(bad_st)) / float(len(good_tree) + len(bad_tree))
self.nn.update(bad_feats, good_feats, subtree_w * self.alpha)
else:
self.nn.update(bad_feats, good_feats, self.alpha)

def get_weights(self):
"""Return the current neural net weights."""
return self.nn.get_param_values()

def set_weights(self, w):
"""Set new neural network weights."""
self.nn.set_param_values(w)

def set_weights_average(self, wss):
"""Set the weights as the average of the given array of weights (used in parallel training)."""
self.nn.set_param_values(np.average(wss, axis=0))

def store_iter_weights(self):
"""Remember the current weights to be used for averaged perceptron."""
self.w_after_iter.append(self.nn.get_param_values())

def set_weights_iter_average(self):
"""Average the remembered weights."""
self.nn.set_param_values(np.average(self.w_after_iter, axis=0))

def get_weights_sum(self):
"""Return the sum of weights (at start of current iteration) to be used to weigh future
promise."""
return self.w_sum

def update_weights_sum(self):
"""Update the current weights sum figure."""
vals = self.nn.get_param_values()
# only use the last layer for summation (w, b)
self.w_sum = np.sum(vals[-2]) + np.sum(vals[-1])

# def __getstate__(self):
# state = dict(self.__dict__)
# w = self.nn.get_param_values()
Expand All @@ -115,7 +124,7 @@ def update_weights_sum(self):
# self.set_weights(w)


class EmbNNRanker(BasePerceptronRanker):
class EmbNNRanker(NNRanker):

UNK_SLOT = 0
UNK_VALUE = 1
Expand Down Expand Up @@ -220,32 +229,6 @@ def _init_neural_network(self):
input_num=2,
input_type=T.ivector)

# # TODO -- merge the stuff below with SimpleNNRanker !!!

def update_weights_sum(self):
"""Update the current weights sum figure."""
vals = self.nn.get_param_values()
# only use the last layer for summation (w, b)
self.w_sum = np.sum(vals[-2]) + np.sum(vals[-1])

def get_weights_sum(self):
"""Return the sum of weights (at start of current iteration) to be used to weigh future
promise."""
return self.w_sum

def _update_weights(self, da, good_tree, bad_tree, good_feats, bad_feats):
if self.diffing_trees:
good_sts, bad_sts = good_tree.diffing_trees(bad_tree, symmetric=True)
for good_st, bad_st in zip(good_sts, bad_sts):
good_feats = self._extract_feats(good_st, da)
bad_feats = self._extract_feats(bad_st, da)
subtree_w = 1
if self.diffing_trees.endswith('weighted'):
subtree_w = (len(good_st) + len(bad_st)) / float(len(good_tree) + len(bad_tree))
self.nn.update(*(bad_feats + good_feats + (subtree_w * self.alpha,)))
else:
self.nn.update(*(bad_feats + good_feats + (self.alpha,)))

def store_iter_weights(self):
"""Remember the current weights to be used for averaged perceptron."""
self.w_after_iter.append(self.nn.get_param_values())
def _update_nn(self, bad_feats, good_feats, rate):
"""Changing the NN update call to support arrays of parameters."""
self.nn.update(*(bad_feats + good_feats + (rate,)))

0 comments on commit 3050c57

Please sign in to comment.