Skip to content

Commit

Permalink
fixed init order and convert to spaces
Browse files Browse the repository at this point in the history
  • Loading branch information
sebastianGehrmann committed Mar 7, 2018
1 parent a5942d8 commit 56e51d2
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 62 deletions.
4 changes: 2 additions & 2 deletions onmt/translate/Beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,8 @@ def score(self, beam, logprobs):
Rescores a prediction based on penalty functions
"""
penalty = self.cov_penalty(beam,
beam.global_state["coverage"],
self.beta)
beam.global_state["coverage"],
self.beta)
normalized_probs = self.length_penalty(beam,
logprobs,
self.alpha)
Expand Down
120 changes: 60 additions & 60 deletions onmt/translate/Penalties.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,77 +2,77 @@
import torch

class PenaltyBuilder(object):
"""
Returns the Length and Coverage Penalty function for Beam Search.
"""
Returns the Length and Coverage Penalty function for Beam Search.
Args:
length_pen (str): option name of length pen
cov_pen (str): option name of cov pen
"""
def __init__(self, length_pen, cov_pen):
self.length_pen = length_pen
self.cov_pen = cov_pen
Args:
length_pen (str): option name of length pen
cov_pen (str): option name of cov pen
"""
def __init__(self, cov_pen, length_pen):
self.length_pen = length_pen
self.cov_pen = cov_pen

def coverage_penalty(self):
if self.cov_pen == "wu":
return self.coverage_wu
elif self.cov_pen == "summary":
return self.coverage_summary
else:
return self.coverage_none
def coverage_penalty(self):
if self.cov_pen == "wu":
return self.coverage_wu
elif self.cov_pen == "summary":
return self.coverage_summary
else:
return self.coverage_none

def length_penalty(self):
if self.length_pen == "wu":
return self.length_wu
elif self.length_pen == "avg":
return self.length_average
else:
return self.length_none
def length_penalty(self):
if self.length_pen == "wu":
return self.length_wu
elif self.length_pen == "avg":
return self.length_average
else:
return self.length_none

"""
Below are all the different penalty terms implemented so far
"""
"""
Below are all the different penalty terms implemented so far
"""

def coverage_wu(self, beam, cov, beta=0.):
"""
NMT coverage re-ranking score from
"Google's Neural Machine Translation System" :cite:`wu2016google`.
"""
return -beta * torch.min(cov, cov.clone().fill_(1.0)).log().sum(1)
def coverage_wu(self, beam, cov, beta=0.):
"""
NMT coverage re-ranking score from
"Google's Neural Machine Translation System" :cite:`wu2016google`.
"""
return -beta * torch.min(cov, cov.clone().fill_(1.0)).log().sum(1)

def coverage_summary(self, beam, cov, beta=0.):
"""
Our summary penalty.
"""
return beta * (torch.max(cov, cov.clone().fill_(1.0)).sum(1) - cov.size(1))
def coverage_summary(self, beam, cov, beta=0.):
"""
Our summary penalty.
"""
return beta * (torch.max(cov, cov.clone().fill_(1.0)).sum(1) - cov.size(1))

def coverage_none(self, beam, cov, beta=0.):
"""
returns zero as penalty
"""
return beam.scores.clone().fill_(0.0)
def coverage_none(self, beam, cov, beta=0.):
"""
returns zero as penalty
"""
return beam.scores.clone().fill_(0.0)

def length_wu(self, beam, logprobs, alpha=0.):
"""
NMT length re-ranking score from
"Google's Neural Machine Translation System" :cite:`wu2016google`.
"""
def length_wu(self, beam, logprobs, alpha=0.):
"""
NMT length re-ranking score from
"Google's Neural Machine Translation System" :cite:`wu2016google`.
"""

modifier = (((5 + len(beam.next_ys)) ** self.alpha) /
((5 + 1) ** self.alpha))
return (logprobs / modifier)
modifier = (((5 + len(beam.next_ys)) ** alpha) /
((5 + 1) ** alpha))
return (logprobs / modifier)

def length_average(self, beam, logprobs, alpha=0.):
"""
Returns the average probability of tokens in a sequence.
"""
return logprobs / len(beam.next_ys)
def length_average(self, beam, logprobs, alpha=0.):
"""
Returns the average probability of tokens in a sequence.
"""
return logprobs / len(beam.next_ys)

def length_none(self, beam, logprobs, alpha=0., beta=0.):
"""
Returns unmodified scores.
"""
return logprobs
def length_none(self, beam, logprobs, alpha=0., beta=0.):
"""
Returns unmodified scores.
"""
return logprobs



0 comments on commit 56e51d2

Please sign in to comment.