Skip to content

Commit

Permalink
Shallow fusion (#2315)
Browse files Browse the repository at this point in the history
* fixed branch in IR tutorial

Signed-off-by: AlexGrinch <grinchuk.alexey@gmail.com>

* shallow fusion init commit

Signed-off-by: AlexGrinch <grinchuk.alexey@gmail.com>

* debug info removed

Signed-off-by: AlexGrinch <grinchuk.alexey@gmail.com>

Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com>
Co-authored-by: Sandeep Subramanian <sandeep.subramanian.1@umontreal.ca>
  • Loading branch information
3 people committed Jun 10, 2021
1 parent 638539f commit fa76d45
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 4 deletions.
41 changes: 37 additions & 4 deletions examples/nlp/machine_translation/nmt_transformer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
import torch

import nemo.collections.nlp as nemo_nlp
from nemo.collections.nlp.modules.common.transformer import (
BeamSearchSequenceGenerator,
BeamSearchSequenceGeneratorWithLanguageModel,
)
from nemo.utils import logging


Expand All @@ -41,6 +45,9 @@ def main():
parser.add_argument("--max_delta_length", type=int, default=5, help="")
parser.add_argument("--target_lang", type=str, default=None, help="")
parser.add_argument("--source_lang", type=str, default=None, help="")
# shallow fusion specific parameters
parser.add_argument("--lm_model", type=str, default=None, help="")
parser.add_argument("--fusion_coef", type=float, default=0.0, help="")

args = parser.parse_args()
torch.set_grad_enabled(False)
Expand All @@ -52,13 +59,39 @@ def main():
else:
raise NotImplemented(f"Only support .nemo files, but got: {args.model}")

model.beam_search.beam_size = args.beam_size
model.beam_search.len_pen = args.len_pen
model.beam_search.max_delta_length = args.max_delta_length

if torch.cuda.is_available():
model = model.cuda()

if args.lm_model is not None:
lm_model = nemo_nlp.models.language_modeling.TransformerLMModel.restore_from(restore_path=args.lm_model).eval()
model.beam_search = BeamSearchSequenceGeneratorWithLanguageModel(
embedding=model.decoder.embedding,
decoder=model.decoder.decoder,
log_softmax=model.log_softmax,
bos=model.decoder_tokenizer.bos_id,
pad=model.decoder_tokenizer.pad_id,
eos=model.decoder_tokenizer.eos_id,
language_model=lm_model,
fusion_coef=args.fusion_coef,
max_sequence_length=model.decoder.max_sequence_length,
beam_size=args.beam_size,
len_pen=args.len_pen,
max_delta_length=args.max_delta_length,
)
else:
model.beam_search = BeamSearchSequenceGenerator(
embedding=model.decoder.embedding,
decoder=model.decoder.decoder,
log_softmax=model.log_softmax,
bos=model.decoder_tokenizer.bos_id,
pad=model.decoder_tokenizer.pad_id,
eos=model.decoder_tokenizer.eos_id,
max_sequence_length=model.decoder.max_sequence_length,
beam_size=args.beam_size,
len_pen=args.len_pen,
max_delta_length=args.max_delta_length,
)

logging.info(f"Translating: {args.srctext}")

count = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"GreedySequenceGenerator",
"TopKSequenceGenerator",
"BeamSearchSequenceGenerator",
"BeamSearchSequenceGeneratorWithLanguageModel",
]


Expand Down Expand Up @@ -374,3 +375,167 @@ def _forward(self, decoder_input_ids=None, encoder_hidden_states=None, encoder_i
tgt = prefixes.view(batch_size, self.beam_size, -1).gather(1, best_guesses)

return tgt.squeeze(1)


class BeamSearchSequenceGeneratorWithLanguageModel(GreedySequenceGenerator):
def __init__(
self, embedding, decoder, log_softmax, language_model, beam_size=1, len_pen=0, fusion_coef=0.0, **kwargs
):
"""
Beam Search sequence generator based on the decoder followed by log_softmax
with external language model fusion.
Args:
*all args of BeamSearchSequenceGenerator class
language_model: nemo TransformerLMModel
fusion_coef: coefficient before language model score, the resulting score is
score = log P_NMT(y|x) + fusion_coef * log P_LM(y|x)
Kwargs:
all remaining parameters of GreedySequenceGenerator class
"""

super().__init__(embedding, decoder, log_softmax, **kwargs)
self.language_model = language_model
self.beam_size = beam_size
self.len_pen = len_pen
self.fusion_coef = fusion_coef

def _one_step_forward(
self,
decoder_input_ids=None,
encoder_hidden_states=None,
encoder_input_mask=None,
decoder_mems_list=None,
lm_mems_list=None,
pos=0,
):

nmt_log_probs, decoder_mems_list = super()._one_step_forward(
decoder_input_ids, encoder_hidden_states, encoder_input_mask, decoder_mems_list, pos,
)
input_mask = mask_padded_tokens(decoder_input_ids, self.pad).float()
lm_hidden_states = self.language_model.encoder.embedding.forward(decoder_input_ids, start_pos=pos)

lm_mems_list = self.language_model.encoder.encoder.forward(
lm_hidden_states, input_mask, lm_mems_list, return_mems=True,
)
lm_log_probs = self.language_model.log_softmax.forward(hidden_states=lm_mems_list[-1][:, -1:])

log_probs = nmt_log_probs + self.fusion_coef * lm_log_probs

return log_probs, decoder_mems_list, lm_mems_list

@staticmethod
def compute_len_penalty(lengths, alpha):
"""Returns length penalty according to https://arxiv.org/pdf/1609.08144.pdf"""
return ((5 + lengths) / 6).pow(alpha)

def _forward(self, decoder_input_ids=None, encoder_hidden_states=None, encoder_input_mask=None):

tgt, batch_size, max_generation_length = self._prepare_for_search(decoder_input_ids, encoder_hidden_states)

# generate initial buffer of beam_size prefixes-hypotheses
log_probs, decoder_mems_list, lm_mems_list = self._one_step_forward(
tgt, encoder_hidden_states, encoder_input_mask, None, None, 0
)
scores, prefixes = torch.topk(log_probs.permute(0, 2, 1), self.beam_size, dim=1)
scores, prefixes = scores.view(-1, 1), prefixes.view(-1, 1)

# repeat init target prefixes and cached memory states beam_size times
prefixes = torch.cat((tgt.repeat(1, self.beam_size).view(-1, 1), prefixes), dim=1)
for j in range(len(decoder_mems_list)):
decoder_mems_list[j] = decoder_mems_list[j].repeat(self.beam_size, 1, 1)
for j in range(len(lm_mems_list)):
lm_mems_list[j] = lm_mems_list[j].repeat(self.beam_size, 1, 1)

# repeat source sequence beam_size times for beam search
if encoder_hidden_states is not None:
_, src_length, hidden_size = encoder_hidden_states.size()
encoder_input_mask = encoder_input_mask.repeat(1, self.beam_size).view(-1, src_length)
encoder_hidden_states = encoder_hidden_states.repeat(1, self.beam_size, 1).view(
-1, src_length, hidden_size
)
else:
hidden_size = decoder_mems_list[0].size(2)
lm_hidden_size = lm_mems_list[0].size(2)

# pad_profile tracks finished hypotheses to generate only <pad> tokens
# if <eos> or <pad> has been generated
pad_profile = torch.zeros_like(scores).long()

# prefixes_len tracks lengths of generated hypotheses to perform
# length penalty correction
prefixes_len = torch.zeros_like(scores).fill_(prefixes.size(1) + 1)

for i in range(max_generation_length):

# mask all finished hypotheses to exclude them from beam
pad_mask = pad_profile.repeat(1, self.beam_size)

# generate and score candidates for prefixes continuation
log_probs, decoder_mems_list, lm_mems_list = self._one_step_forward(
prefixes[:, -1:], encoder_hidden_states, encoder_input_mask, decoder_mems_list, lm_mems_list, i + 1
)
scores_i, prefixes_i = torch.topk(log_probs[:, -1, :], self.beam_size, dim=-1)

# for all prefixes ending with <eos> or <pad> replace generated
# continuations with <pad>
prefixes_i = self.pad * pad_mask + prefixes_i * (1 - pad_mask)

# force all hypotheses but one generated from already finished
# hypotheses to have extremely low score, so they will not be
# considered during beam re-ranking
pad_mask[:, 1:] = pad_mask[:, 1:] * NEG_INF
scores = scores + scores_i * (1 - pad_mask).to(scores.dtype)

# choose top-k hypotheses with length penalty applied
len_penalties = self.compute_len_penalty(prefixes_len, self.len_pen)
scores = scores / len_penalties
scores, indices_i = torch.topk(scores.view(-1, self.beam_size ** 2), self.beam_size, dim=1)
scores = scores.view(-1, 1) * len_penalties

# select prefixes which correspond to the chosen hypotheses
prefixes = prefixes.unsqueeze(1).repeat(1, self.beam_size, 1)
prefixes = torch.cat((prefixes, prefixes_i.unsqueeze(2)), dim=2)
prefixes = prefixes.view(batch_size, self.beam_size ** 2, -1)
p_len = prefixes.size(2)
prefixes_ids = indices_i.unsqueeze(2).repeat(1, 1, p_len)
prefixes = prefixes.gather(1, prefixes_ids).view(-1, p_len)

# reshuffle cached decoder memory states to restore the order
# of hypotheses broken after top-k selection
mems_ids = indices_i.unsqueeze(2).unsqueeze(3).repeat(1, 1, p_len - 1, hidden_size) // self.beam_size
for j in range(len(decoder_mems_list)):
decoder_mems_list[j] = (
decoder_mems_list[j]
.view(-1, self.beam_size, p_len - 1, hidden_size)
.gather(1, mems_ids)
.view(-1, p_len - 1, hidden_size)
)
lm_mems_ids = indices_i.unsqueeze(2).unsqueeze(3).repeat(1, 1, p_len - 1, lm_hidden_size) // self.beam_size
for j in range(len(lm_mems_list)):
lm_mems_list[j] = (
lm_mems_list[j]
.view(-1, self.beam_size, p_len - 1, lm_hidden_size)
.gather(1, lm_mems_ids)
.view(-1, p_len - 1, lm_hidden_size)
)

# update prefixes_len and pad_profile
not_eos_pad = prefixes.ne(self.eos) & prefixes.ne(self.pad)
prefixes_len = 1 + not_eos_pad.sum(dim=1, keepdim=True).to(scores.dtype)
pad_profile = (~not_eos_pad[:, -1:]).long()

# if all hypotheses end with <eos> or <pad>, interrupt search
if pad_profile.sum() == batch_size * self.beam_size:
break

# select best performing hypotheses in each element of the batch
len_penalties = self.compute_len_penalty(prefixes_len, self.len_pen)
scores = scores / len_penalties
best_guesses = (
torch.argmax(scores.view(-1, self.beam_size), dim=1, keepdim=True).repeat(1, prefixes.size(1)).unsqueeze(1)
)
tgt = prefixes.view(batch_size, self.beam_size, -1).gather(1, best_guesses)

return tgt.squeeze(1)

0 comments on commit fa76d45

Please sign in to comment.