Skip to content

Commit

Permalink
Copy generator
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Aug 23, 2017
1 parent 90b2d70 commit a7ba830
Showing 1 changed file with 39 additions and 30 deletions.
69 changes: 39 additions & 30 deletions onmt/modules/CopyGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.nn.functional as F
import torch
import torch.cuda
from torch.autograd import Variable
from onmt.modules import aeq


class CopyGenerator(nn.Module):
Expand All @@ -14,39 +14,56 @@ class CopyGenerator(nn.Module):

def __init__(self, opt, src_dict, tgt_dict):
super(CopyGenerator, self).__init__()
self.linear = nn.Linear(opt.rnn_size, tgt_dict.size())
self.linear = nn.Linear(opt.rnn_size, len(tgt_dict))
self.linear_copy = nn.Linear(opt.rnn_size, 1)
self.src_dict = src_dict
self.tgt_dict = tgt_dict
self.sm = nn.Softmax()

def forward(self, hidden, attn, verbose=False):
def forward(self, hidden, attn, src_map, verbose=False):
"""
Computes p(w) = p(z=1) p_{copy}(w|z=0) + p(z=0) * p_{softmax}(w|z=0)
Args:
hidden (FloatTensor): (tgt_len*batch) x hidden
attn (FloatTensor): (tgt_len*batch) x src_len
Returns:
prob (FloatTensor): (tgt_len*batch) x vocab
attn (FloatTensor): (tgt_len*batch) x src_len
"""
# CHECKS
batch_by_tlen, _ = hidden.size()
batch_by_tlen_, slen = attn.size()
slen_, batch, cvocab = src_map.size()
aeq(batch_by_tlen, batch_by_tlen_)
aeq(slen, slen_)

# Original probabilities.
logits = self.linear(hidden)
logits[:, onmt.Constants.UNK] = -float('inf')
logits[:, onmt.Constants.PAD] = -float('inf')
prob = F.softmax(logits)

# Probability of copying p(z=1) batch
copy = F.sigmoid(self.linear_copy(hidden))
if True:
logits[:, self.tgt_dict.stoi[onmt.IO.PAD_WORD]] = -float('inf')

# Probibility of not copying: p_{word}(w) * (1 - p(z))
out_prob = torch.mul(prob, 1 - copy.expand_as(prob))
mul_attn = torch.mul(attn, copy.expand_as(attn))
return out_prob, mul_attn
prob = self.sm(logits)
# Probability of copying p(z=1) batch
copy = F.sigmoid(self.linear_copy(hidden))

def _debug_copy(self, src, copy, prob, out_prob, attn, mul_attn):
# Probibility of not copying: p_{word}(w) * (1 - p(z))
out_prob = torch.mul(prob, 1 - copy.expand_as(prob))
mul_attn = torch.mul(attn, copy.expand_as(attn))
copy_prob = torch.bmm(mul_attn.view(-1, batch, slen)
.transpose(0, 1),
src_map.transpose(0, 1)).transpose(0, 1)
copy_prob = copy_prob.contiguous().view(-1, cvocab)
dynamic_probs = torch.cat([out_prob, copy_prob], 1)
else:
# copy = self.linear_copy(hidden)
copy_logit = torch.bmm(attn.view(-1, batch, slen).transpose(0, 1),
src_map.transpose(0, 1)).transpose(0, 1)
copy_logit = copy_logit.contiguous().view(-1, cvocab)
copy_logit.data.masked_fill_(copy_logit.data.eq(0), -1e20)
# print(copy_logit)
# exit()
# copy_logit[:, 0] = -1e20
# copy_logit[:, 1] = -1e20
dynamic_logits = torch.cat([logits, copy_logit], 1)
dynamic_probs = self.sm(dynamic_logits.contiguous())
return dynamic_probs

def k_debug_copy(self, src, copy, prob, out_prob, attn, mul_attn):
v, mid = prob[0].data.max(0)
print("Initial:", self.tgt_dict.getLabel(mid[0], "FAIL"), v[0])
print("COPY %3f" % copy.data[0][0])
Expand All @@ -58,11 +75,3 @@ def _debug_copy(self, src, copy, prob, out_prob, attn, mul_attn):
j,
attn[0, j].data[0],
mul_attn[0, j].data[0]))


def CopyCriterion(probs, attn, targ, align, eps=1e-12):
copies = attn.mul(Variable(align)).sum(-1).add(eps)
# Can't use UNK, must copy.
out = torch.log(probs.gather(1, targ.view(-1, 1)).view(-1) + copies + eps)
out = out.mul(targ.ne(onmt.Constants.PAD).float())
return -out.sum()

0 comments on commit a7ba830

Please sign in to comment.