From 1bc4a44ed5139f1ceba2f913e87bc9b3314f07b8 Mon Sep 17 00:00:00 2001 From: AlexGrinch Date: Thu, 18 Mar 2021 07:50:25 -0700 Subject: [PATCH 1/3] fixed branch in IR tutorial Signed-off-by: AlexGrinch --- .../nlp/Information_Retrieval_MSMARCO.ipynb | 37 +------------------ 1 file changed, 2 insertions(+), 35 deletions(-) diff --git a/tutorials/nlp/Information_Retrieval_MSMARCO.ipynb b/tutorials/nlp/Information_Retrieval_MSMARCO.ipynb index a725825ea578..2ea149d5283f 100644 --- a/tutorials/nlp/Information_Retrieval_MSMARCO.ipynb +++ b/tutorials/nlp/Information_Retrieval_MSMARCO.ipynb @@ -1,19 +1,5 @@ { "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "uRLPr0TnIAHO", - "scrolled": true - }, - "outputs": [], - "source": [ - "BRANCH = 'ir_tutorial'" - ] - }, { "cell_type": "code", "execution_count": null, @@ -36,27 +22,8 @@ "# If you're using Google Colab and not running locally, run this cell\n", "\n", "# install NeMo\n", - "!python -m pip install git+https://github.com/AlexGrinch/NeMo.git@{BRANCH}#egg=nemo_toolkit[nlp]\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - }, - "scrolled": true - }, - "outputs": [], - "source": [ - "# If you're not using Colab, you might need to upgrade jupyter notebook to avoid the following error:\n", - "# 'ImportError: IProgress not found. Please update jupyter and ipywidgets.'\n", - "\n", - "! pip install ipywidgets\n", - "! jupyter nbextension enable --py widgetsnbextension\n", - "\n", - "# Please restart the kernel after running this cell" + "BRANCH = 'r1.0.0rc1'\n", + "!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[nlp]" ] }, { From a77c9b9eef6a949229f4694436a4329931312725 Mon Sep 17 00:00:00 2001 From: AlexGrinch Date: Thu, 20 May 2021 06:20:46 -0700 Subject: [PATCH 2/3] shallow fusion init commit Signed-off-by: AlexGrinch --- .../nmt_transformer_infer.py | 41 ++++- .../transformer/transformer_generators.py | 167 ++++++++++++++++++ 2 files changed, 204 insertions(+), 4 deletions(-) diff --git a/examples/nlp/machine_translation/nmt_transformer_infer.py b/examples/nlp/machine_translation/nmt_transformer_infer.py index 38bbae3ca6ed..91f2fe9e9405 100644 --- a/examples/nlp/machine_translation/nmt_transformer_infer.py +++ b/examples/nlp/machine_translation/nmt_transformer_infer.py @@ -27,6 +27,10 @@ import torch import nemo.collections.nlp as nemo_nlp +from nemo.collections.nlp.modules.common.transformer import ( + BeamSearchSequenceGenerator, + BeamSearchSequenceGeneratorWithLanguageModel, +) from nemo.utils import logging @@ -41,6 +45,9 @@ def main(): parser.add_argument("--max_delta_length", type=int, default=5, help="") parser.add_argument("--target_lang", type=str, default=None, help="") parser.add_argument("--source_lang", type=str, default=None, help="") + # shallow fusion specific parameters + parser.add_argument("--lm_model", type=str, default=None, help="") + parser.add_argument("--fusion_coef", type=float, default=0.0, help="") args = parser.parse_args() torch.set_grad_enabled(False) @@ -52,13 +59,39 @@ def main(): else: raise NotImplemented(f"Only support .nemo files, but got: {args.model}") - model.beam_search.beam_size = args.beam_size - model.beam_search.len_pen = args.len_pen - model.beam_search.max_delta_length = args.max_delta_length - if torch.cuda.is_available(): model = model.cuda() + if args.lm_model is not None: + lm_model = nemo_nlp.models.language_modeling.TransformerLMModel.restore_from(restore_path=args.lm_model).eval() + model.beam_search = BeamSearchSequenceGeneratorWithLanguageModel( + embedding=model.decoder.embedding, + decoder=model.decoder.decoder, + log_softmax=model.log_softmax, + bos=model.decoder_tokenizer.bos_id, + pad=model.decoder_tokenizer.pad_id, + eos=model.decoder_tokenizer.eos_id, + language_model=lm_model, + fusion_coef=args.fusion_coef, + max_sequence_length=model.decoder.max_sequence_length, + beam_size=args.beam_size, + len_pen=args.len_pen, + max_delta_length=args.max_delta_length, + ) + else: + model.beam_search = BeamSearchSequenceGenerator( + embedding=model.decoder.embedding, + decoder=model.decoder.decoder, + log_softmax=model.log_softmax, + bos=model.decoder_tokenizer.bos_id, + pad=model.decoder_tokenizer.pad_id, + eos=model.decoder_tokenizer.eos_id, + max_sequence_length=model.decoder.max_sequence_length, + beam_size=args.beam_size, + len_pen=args.len_pen, + max_delta_length=args.max_delta_length, + ) + logging.info(f"Translating: {args.srctext}") count = 0 diff --git a/nemo/collections/nlp/modules/common/transformer/transformer_generators.py b/nemo/collections/nlp/modules/common/transformer/transformer_generators.py index ee580ec25b3b..9bcce66dd4e2 100644 --- a/nemo/collections/nlp/modules/common/transformer/transformer_generators.py +++ b/nemo/collections/nlp/modules/common/transformer/transformer_generators.py @@ -22,6 +22,7 @@ "GreedySequenceGenerator", "TopKSequenceGenerator", "BeamSearchSequenceGenerator", + "BeamSearchSequenceGeneratorWithLanguageModel", ] @@ -373,4 +374,170 @@ def _forward(self, decoder_input_ids=None, encoder_hidden_states=None, encoder_i ) tgt = prefixes.view(batch_size, self.beam_size, -1).gather(1, best_guesses) + print("SHIIIIIIIIT", tgt) + + return tgt.squeeze(1) + + +class BeamSearchSequenceGeneratorWithLanguageModel(GreedySequenceGenerator): + def __init__( + self, embedding, decoder, log_softmax, language_model, beam_size=1, len_pen=0, fusion_coef=0.0, **kwargs + ): + """ + Beam Search sequence generator based on the decoder followed by log_softmax + with external language model fusion. + + Args: + *all args of BeamSearchSequenceGenerator class + language_model: nemo TransformerLMModel + fusion_coef: coefficient before language model score, the resulting score is + score = log P_NMT(y|x) + fusion_coef * log P_LM(y|x) + Kwargs: + all remaining parameters of GreedySequenceGenerator class + """ + + super().__init__(embedding, decoder, log_softmax, **kwargs) + self.language_model = language_model + self.beam_size = beam_size + self.len_pen = len_pen + self.fusion_coef = fusion_coef + + def _one_step_forward( + self, + decoder_input_ids=None, + encoder_hidden_states=None, + encoder_input_mask=None, + decoder_mems_list=None, + lm_mems_list=None, + pos=0, + ): + + nmt_log_probs, decoder_mems_list = super()._one_step_forward( + decoder_input_ids, encoder_hidden_states, encoder_input_mask, decoder_mems_list, pos, + ) + input_mask = mask_padded_tokens(decoder_input_ids, self.pad).float() + lm_hidden_states = self.language_model.encoder.embedding.forward(decoder_input_ids, start_pos=pos) + + lm_mems_list = self.language_model.encoder.encoder.forward( + lm_hidden_states, input_mask, lm_mems_list, return_mems=True, + ) + lm_log_probs = self.language_model.log_softmax.forward(hidden_states=lm_mems_list[-1][:, -1:]) + + log_probs = nmt_log_probs + self.fusion_coef * lm_log_probs + + return log_probs, decoder_mems_list, lm_mems_list + + @staticmethod + def compute_len_penalty(lengths, alpha): + """Returns length penalty according to https://arxiv.org/pdf/1609.08144.pdf""" + return ((5 + lengths) / 6).pow(alpha) + + def _forward(self, decoder_input_ids=None, encoder_hidden_states=None, encoder_input_mask=None): + + tgt, batch_size, max_generation_length = self._prepare_for_search(decoder_input_ids, encoder_hidden_states) + + # generate initial buffer of beam_size prefixes-hypotheses + log_probs, decoder_mems_list, lm_mems_list = self._one_step_forward( + tgt, encoder_hidden_states, encoder_input_mask, None, None, 0 + ) + scores, prefixes = torch.topk(log_probs.permute(0, 2, 1), self.beam_size, dim=1) + scores, prefixes = scores.view(-1, 1), prefixes.view(-1, 1) + + # repeat init target prefixes and cached memory states beam_size times + prefixes = torch.cat((tgt.repeat(1, self.beam_size).view(-1, 1), prefixes), dim=1) + for j in range(len(decoder_mems_list)): + decoder_mems_list[j] = decoder_mems_list[j].repeat(self.beam_size, 1, 1) + for j in range(len(lm_mems_list)): + lm_mems_list[j] = lm_mems_list[j].repeat(self.beam_size, 1, 1) + + # repeat source sequence beam_size times for beam search + if encoder_hidden_states is not None: + _, src_length, hidden_size = encoder_hidden_states.size() + encoder_input_mask = encoder_input_mask.repeat(1, self.beam_size).view(-1, src_length) + encoder_hidden_states = encoder_hidden_states.repeat(1, self.beam_size, 1).view( + -1, src_length, hidden_size + ) + else: + hidden_size = decoder_mems_list[0].size(2) + lm_hidden_size = lm_mems_list[0].size(2) + + # pad_profile tracks finished hypotheses to generate only tokens + # if or has been generated + pad_profile = torch.zeros_like(scores).long() + + # prefixes_len tracks lengths of generated hypotheses to perform + # length penalty correction + prefixes_len = torch.zeros_like(scores).fill_(prefixes.size(1) + 1) + + for i in range(max_generation_length): + + # mask all finished hypotheses to exclude them from beam + pad_mask = pad_profile.repeat(1, self.beam_size) + + # generate and score candidates for prefixes continuation + log_probs, decoder_mems_list, lm_mems_list = self._one_step_forward( + prefixes[:, -1:], encoder_hidden_states, encoder_input_mask, decoder_mems_list, lm_mems_list, i + 1 + ) + scores_i, prefixes_i = torch.topk(log_probs[:, -1, :], self.beam_size, dim=-1) + + # for all prefixes ending with or replace generated + # continuations with + prefixes_i = self.pad * pad_mask + prefixes_i * (1 - pad_mask) + + # force all hypotheses but one generated from already finished + # hypotheses to have extremely low score, so they will not be + # considered during beam re-ranking + pad_mask[:, 1:] = pad_mask[:, 1:] * NEG_INF + scores = scores + scores_i * (1 - pad_mask).to(scores.dtype) + + # choose top-k hypotheses with length penalty applied + len_penalties = self.compute_len_penalty(prefixes_len, self.len_pen) + scores = scores / len_penalties + scores, indices_i = torch.topk(scores.view(-1, self.beam_size ** 2), self.beam_size, dim=1) + scores = scores.view(-1, 1) * len_penalties + + # select prefixes which correspond to the chosen hypotheses + prefixes = prefixes.unsqueeze(1).repeat(1, self.beam_size, 1) + prefixes = torch.cat((prefixes, prefixes_i.unsqueeze(2)), dim=2) + prefixes = prefixes.view(batch_size, self.beam_size ** 2, -1) + p_len = prefixes.size(2) + prefixes_ids = indices_i.unsqueeze(2).repeat(1, 1, p_len) + prefixes = prefixes.gather(1, prefixes_ids).view(-1, p_len) + + # reshuffle cached decoder memory states to restore the order + # of hypotheses broken after top-k selection + mems_ids = indices_i.unsqueeze(2).unsqueeze(3).repeat(1, 1, p_len - 1, hidden_size) // self.beam_size + for j in range(len(decoder_mems_list)): + decoder_mems_list[j] = ( + decoder_mems_list[j] + .view(-1, self.beam_size, p_len - 1, hidden_size) + .gather(1, mems_ids) + .view(-1, p_len - 1, hidden_size) + ) + lm_mems_ids = indices_i.unsqueeze(2).unsqueeze(3).repeat(1, 1, p_len - 1, lm_hidden_size) // self.beam_size + for j in range(len(lm_mems_list)): + lm_mems_list[j] = ( + lm_mems_list[j] + .view(-1, self.beam_size, p_len - 1, lm_hidden_size) + .gather(1, lm_mems_ids) + .view(-1, p_len - 1, lm_hidden_size) + ) + + # update prefixes_len and pad_profile + not_eos_pad = prefixes.ne(self.eos) & prefixes.ne(self.pad) + prefixes_len = 1 + not_eos_pad.sum(dim=1, keepdim=True).to(scores.dtype) + pad_profile = (~not_eos_pad[:, -1:]).long() + + # if all hypotheses end with or , interrupt search + if pad_profile.sum() == batch_size * self.beam_size: + break + + # select best performing hypotheses in each element of the batch + len_penalties = self.compute_len_penalty(prefixes_len, self.len_pen) + scores = scores / len_penalties + best_guesses = ( + torch.argmax(scores.view(-1, self.beam_size), dim=1, keepdim=True).repeat(1, prefixes.size(1)).unsqueeze(1) + ) + tgt = prefixes.view(batch_size, self.beam_size, -1).gather(1, best_guesses) + return tgt.squeeze(1) From 686866f701d7b0549eb9e471ae3ef44b57ada3ec Mon Sep 17 00:00:00 2001 From: AlexGrinch Date: Thu, 20 May 2021 06:22:15 -0700 Subject: [PATCH 3/3] debug info removed Signed-off-by: AlexGrinch --- .../nlp/modules/common/transformer/transformer_generators.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/nemo/collections/nlp/modules/common/transformer/transformer_generators.py b/nemo/collections/nlp/modules/common/transformer/transformer_generators.py index 9bcce66dd4e2..59fe40394112 100644 --- a/nemo/collections/nlp/modules/common/transformer/transformer_generators.py +++ b/nemo/collections/nlp/modules/common/transformer/transformer_generators.py @@ -374,8 +374,6 @@ def _forward(self, decoder_input_ids=None, encoder_hidden_states=None, encoder_i ) tgt = prefixes.view(batch_size, self.beam_size, -1).gather(1, best_guesses) - print("SHIIIIIIIIT", tgt) - return tgt.squeeze(1)