Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shallow fusion #2315

Merged
merged 11 commits into from
Jun 10, 2021
41 changes: 37 additions & 4 deletions examples/nlp/machine_translation/nmt_transformer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
import torch

import nemo.collections.nlp as nemo_nlp
from nemo.collections.nlp.modules.common.transformer import (
BeamSearchSequenceGenerator,
BeamSearchSequenceGeneratorWithLanguageModel,
)
from nemo.utils import logging


Expand All @@ -41,6 +45,9 @@ def main():
parser.add_argument("--max_delta_length", type=int, default=5, help="")
parser.add_argument("--target_lang", type=str, default=None, help="")
parser.add_argument("--source_lang", type=str, default=None, help="")
# shallow fusion specific parameters
parser.add_argument("--lm_model", type=str, default=None, help="")
parser.add_argument("--fusion_coef", type=float, default=0.0, help="")

args = parser.parse_args()
torch.set_grad_enabled(False)
Expand All @@ -52,13 +59,39 @@ def main():
else:
raise NotImplemented(f"Only support .nemo files, but got: {args.model}")

model.beam_search.beam_size = args.beam_size
model.beam_search.len_pen = args.len_pen
model.beam_search.max_delta_length = args.max_delta_length

if torch.cuda.is_available():
model = model.cuda()

if args.lm_model is not None:
lm_model = nemo_nlp.models.language_modeling.TransformerLMModel.restore_from(restore_path=args.lm_model).eval()
model.beam_search = BeamSearchSequenceGeneratorWithLanguageModel(
embedding=model.decoder.embedding,
decoder=model.decoder.decoder,
log_softmax=model.log_softmax,
bos=model.decoder_tokenizer.bos_id,
pad=model.decoder_tokenizer.pad_id,
eos=model.decoder_tokenizer.eos_id,
language_model=lm_model,
fusion_coef=args.fusion_coef,
max_sequence_length=model.decoder.max_sequence_length,
beam_size=args.beam_size,
len_pen=args.len_pen,
max_delta_length=args.max_delta_length,
)
else:
model.beam_search = BeamSearchSequenceGenerator(
embedding=model.decoder.embedding,
decoder=model.decoder.decoder,
log_softmax=model.log_softmax,
bos=model.decoder_tokenizer.bos_id,
pad=model.decoder_tokenizer.pad_id,
eos=model.decoder_tokenizer.eos_id,
max_sequence_length=model.decoder.max_sequence_length,
beam_size=args.beam_size,
len_pen=args.len_pen,
max_delta_length=args.max_delta_length,
)

logging.info(f"Translating: {args.srctext}")

count = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"GreedySequenceGenerator",
"TopKSequenceGenerator",
"BeamSearchSequenceGenerator",
"BeamSearchSequenceGeneratorWithLanguageModel",
]


Expand Down Expand Up @@ -374,3 +375,167 @@ def _forward(self, decoder_input_ids=None, encoder_hidden_states=None, encoder_i
tgt = prefixes.view(batch_size, self.beam_size, -1).gather(1, best_guesses)

return tgt.squeeze(1)


class BeamSearchSequenceGeneratorWithLanguageModel(GreedySequenceGenerator):
def __init__(
self, embedding, decoder, log_softmax, language_model, beam_size=1, len_pen=0, fusion_coef=0.0, **kwargs
):
"""
Beam Search sequence generator based on the decoder followed by log_softmax
with external language model fusion.

Args:
*all args of BeamSearchSequenceGenerator class
language_model: nemo TransformerLMModel
fusion_coef: coefficient before language model score, the resulting score is
score = log P_NMT(y|x) + fusion_coef * log P_LM(y|x)
Kwargs:
all remaining parameters of GreedySequenceGenerator class
"""

super().__init__(embedding, decoder, log_softmax, **kwargs)
self.language_model = language_model
self.beam_size = beam_size
self.len_pen = len_pen
self.fusion_coef = fusion_coef

def _one_step_forward(
self,
decoder_input_ids=None,
encoder_hidden_states=None,
encoder_input_mask=None,
decoder_mems_list=None,
lm_mems_list=None,
pos=0,
):

nmt_log_probs, decoder_mems_list = super()._one_step_forward(
decoder_input_ids, encoder_hidden_states, encoder_input_mask, decoder_mems_list, pos,
)
input_mask = mask_padded_tokens(decoder_input_ids, self.pad).float()
lm_hidden_states = self.language_model.encoder.embedding.forward(decoder_input_ids, start_pos=pos)

lm_mems_list = self.language_model.encoder.encoder.forward(
lm_hidden_states, input_mask, lm_mems_list, return_mems=True,
)
lm_log_probs = self.language_model.log_softmax.forward(hidden_states=lm_mems_list[-1][:, -1:])

log_probs = nmt_log_probs + self.fusion_coef * lm_log_probs

return log_probs, decoder_mems_list, lm_mems_list

@staticmethod
def compute_len_penalty(lengths, alpha):
"""Returns length penalty according to https://arxiv.org/pdf/1609.08144.pdf"""
return ((5 + lengths) / 6).pow(alpha)

def _forward(self, decoder_input_ids=None, encoder_hidden_states=None, encoder_input_mask=None):

tgt, batch_size, max_generation_length = self._prepare_for_search(decoder_input_ids, encoder_hidden_states)

# generate initial buffer of beam_size prefixes-hypotheses
log_probs, decoder_mems_list, lm_mems_list = self._one_step_forward(
tgt, encoder_hidden_states, encoder_input_mask, None, None, 0
)
scores, prefixes = torch.topk(log_probs.permute(0, 2, 1), self.beam_size, dim=1)
scores, prefixes = scores.view(-1, 1), prefixes.view(-1, 1)

# repeat init target prefixes and cached memory states beam_size times
prefixes = torch.cat((tgt.repeat(1, self.beam_size).view(-1, 1), prefixes), dim=1)
for j in range(len(decoder_mems_list)):
decoder_mems_list[j] = decoder_mems_list[j].repeat(self.beam_size, 1, 1)
for j in range(len(lm_mems_list)):
lm_mems_list[j] = lm_mems_list[j].repeat(self.beam_size, 1, 1)

# repeat source sequence beam_size times for beam search
if encoder_hidden_states is not None:
_, src_length, hidden_size = encoder_hidden_states.size()
encoder_input_mask = encoder_input_mask.repeat(1, self.beam_size).view(-1, src_length)
encoder_hidden_states = encoder_hidden_states.repeat(1, self.beam_size, 1).view(
-1, src_length, hidden_size
)
else:
hidden_size = decoder_mems_list[0].size(2)
lm_hidden_size = lm_mems_list[0].size(2)

# pad_profile tracks finished hypotheses to generate only <pad> tokens
# if <eos> or <pad> has been generated
pad_profile = torch.zeros_like(scores).long()

# prefixes_len tracks lengths of generated hypotheses to perform
# length penalty correction
prefixes_len = torch.zeros_like(scores).fill_(prefixes.size(1) + 1)

for i in range(max_generation_length):

# mask all finished hypotheses to exclude them from beam
pad_mask = pad_profile.repeat(1, self.beam_size)

# generate and score candidates for prefixes continuation
log_probs, decoder_mems_list, lm_mems_list = self._one_step_forward(
prefixes[:, -1:], encoder_hidden_states, encoder_input_mask, decoder_mems_list, lm_mems_list, i + 1
)
scores_i, prefixes_i = torch.topk(log_probs[:, -1, :], self.beam_size, dim=-1)

# for all prefixes ending with <eos> or <pad> replace generated
# continuations with <pad>
prefixes_i = self.pad * pad_mask + prefixes_i * (1 - pad_mask)

# force all hypotheses but one generated from already finished
# hypotheses to have extremely low score, so they will not be
# considered during beam re-ranking
pad_mask[:, 1:] = pad_mask[:, 1:] * NEG_INF
scores = scores + scores_i * (1 - pad_mask).to(scores.dtype)

# choose top-k hypotheses with length penalty applied
len_penalties = self.compute_len_penalty(prefixes_len, self.len_pen)
scores = scores / len_penalties
scores, indices_i = torch.topk(scores.view(-1, self.beam_size ** 2), self.beam_size, dim=1)
scores = scores.view(-1, 1) * len_penalties

# select prefixes which correspond to the chosen hypotheses
prefixes = prefixes.unsqueeze(1).repeat(1, self.beam_size, 1)
prefixes = torch.cat((prefixes, prefixes_i.unsqueeze(2)), dim=2)
prefixes = prefixes.view(batch_size, self.beam_size ** 2, -1)
p_len = prefixes.size(2)
prefixes_ids = indices_i.unsqueeze(2).repeat(1, 1, p_len)
prefixes = prefixes.gather(1, prefixes_ids).view(-1, p_len)

# reshuffle cached decoder memory states to restore the order
# of hypotheses broken after top-k selection
mems_ids = indices_i.unsqueeze(2).unsqueeze(3).repeat(1, 1, p_len - 1, hidden_size) // self.beam_size
for j in range(len(decoder_mems_list)):
decoder_mems_list[j] = (
decoder_mems_list[j]
.view(-1, self.beam_size, p_len - 1, hidden_size)
.gather(1, mems_ids)
.view(-1, p_len - 1, hidden_size)
)
lm_mems_ids = indices_i.unsqueeze(2).unsqueeze(3).repeat(1, 1, p_len - 1, lm_hidden_size) // self.beam_size
for j in range(len(lm_mems_list)):
lm_mems_list[j] = (
lm_mems_list[j]
.view(-1, self.beam_size, p_len - 1, lm_hidden_size)
.gather(1, lm_mems_ids)
.view(-1, p_len - 1, lm_hidden_size)
)

# update prefixes_len and pad_profile
not_eos_pad = prefixes.ne(self.eos) & prefixes.ne(self.pad)
prefixes_len = 1 + not_eos_pad.sum(dim=1, keepdim=True).to(scores.dtype)
pad_profile = (~not_eos_pad[:, -1:]).long()

# if all hypotheses end with <eos> or <pad>, interrupt search
if pad_profile.sum() == batch_size * self.beam_size:
break

# select best performing hypotheses in each element of the batch
len_penalties = self.compute_len_penalty(prefixes_len, self.len_pen)
scores = scores / len_penalties
best_guesses = (
torch.argmax(scores.view(-1, self.beam_size), dim=1, keepdim=True).repeat(1, prefixes.size(1)).unsqueeze(1)
)
tgt = prefixes.view(batch_size, self.beam_size, -1).gather(1, best_guesses)

return tgt.squeeze(1)