Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
CopyNet: replace in-place tensor operation with out-of-place equivale…
Browse files Browse the repository at this point in the history
…nt (#2925)

* remove in-place operation

* oops, fixed
  • Loading branch information
epwalsh authored and DeNeutoy committed Jun 6, 2019
1 parent 89700de commit c629093
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions allennlp/models/encoder_decoders/copynet_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ def _gather_final_log_probs(self,
source_token_ids = state["source_token_ids"]

# shape: [(batch_size, *)]
modified_log_probs_list: List[torch.Tensor] = [generation_log_probs]
modified_log_probs_list: List[torch.Tensor] = []
for i in range(trimmed_source_length):
# shape: (group_size,)
copy_log_probs_slice = copy_log_probs[:, i]
Expand All @@ -648,7 +648,9 @@ def _gather_final_log_probs(self,
selected_generation_log_probs = generation_log_probs.gather(1, source_to_target_slice.unsqueeze(-1))
combined_scores = util.logsumexp(
torch.cat((selected_generation_log_probs, copy_log_probs_to_add), dim=1))
generation_log_probs.scatter_(-1, source_to_target_slice.unsqueeze(-1), combined_scores.unsqueeze(-1))
generation_log_probs = generation_log_probs.scatter(-1,
source_to_target_slice.unsqueeze(-1),
combined_scores.unsqueeze(-1))
# We have to combine copy scores for duplicate source tokens so that
# we can find the overall most likely source token. So, if this is the first
# occurence of this particular source token, we add the log_probs from all other
Expand Down Expand Up @@ -676,6 +678,7 @@ def _gather_final_log_probs(self,
# shape: (group_size,)
left_over_copy_log_probs = copy_log_probs_slice + (1.0 - copy_log_probs_to_add_mask + 1e-45).log()
modified_log_probs_list.append(left_over_copy_log_probs.unsqueeze(-1))
modified_log_probs_list.insert(0, generation_log_probs)

# shape: (group_size, target_vocab_size + trimmed_source_length)
modified_log_probs = torch.cat(modified_log_probs_list, dim=-1)
Expand Down

0 comments on commit c629093

Please sign in to comment.