Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow easier programmatic use by extracting code from translate.py #1107

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 81 additions & 53 deletions sockeye/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,49 +24,30 @@

from sockeye.lexicon import load_restrict_lexicon, RestrictLexicon
from sockeye.log import setup_main_logger
from sockeye.model import load_models
from sockeye.model import load_models, SockeyeModel
from sockeye.output_handler import get_output_handler, OutputHandler
from sockeye.utils import log_basic_info, check_condition, grouper, smart_open, seed_rngs
from . import arguments
from . import arguments, vocab
from . import constants as C
from . import inference
from . import utils

logger = logging.getLogger(__name__)


def main():
def parse_translation_arguments(args=None):
params = arguments.ConfigArgumentParser(description='Translate CLI')
arguments.add_translate_cli_args(params)
args = params.parse_args()
run_translate(args)
return params.parse_args(args)


def run_translate(args: argparse.Namespace):
# Seed randomly unless a seed has been passed
seed_rngs(args.seed if args.seed is not None else int(time.time()))

if args.output is not None:
setup_main_logger(console=not args.quiet,
file_logging=not args.no_logfile,
path="%s.%s" % (args.output, C.LOG_NAME),
level=args.loglevel)
else:
setup_main_logger(file_logging=False, level=args.loglevel)
def main():
args = parse_translation_arguments()
run_translate(args)

log_basic_info(args)

if args.nbest_size > 1:
if args.output_type != C.OUTPUT_HANDLER_JSON:
logger.warning(
"For nbest translation, you must specify `--output-type '%s'; overriding your setting of '%s'.",
C.OUTPUT_HANDLER_JSON, args.output_type)
args.output_type = C.OUTPUT_HANDLER_JSON
output_handler = get_output_handler(args.output_type,
args.output)

def load_models_from_args(args: argparse.Namespace):
device = utils.init_device(args)
logger.info(f"Translate Device: {device}")

models, source_vocabs, target_vocabs = load_models(device=device,
model_folders=args.models,
Expand All @@ -76,6 +57,15 @@ def run_translate(args: argparse.Namespace):
inference_only=True,
knn_index=args.knn_index)

for model in models:
model.eval()

return models, source_vocabs, target_vocabs


def restrict_lexicon_from_args(args: argparse.Namespace,
source_vocabs: List[vocab.Vocab],
target_vocabs: List[vocab.Vocab]):
restrict_lexicon = None # type: Optional[Union[RestrictLexicon, Dict[str, RestrictLexicon]]]
if args.restrict_lexicon is not None:
logger.info(str(args.restrict_lexicon))
Expand All @@ -94,6 +84,10 @@ def run_translate(args: argparse.Namespace):
lexicon = load_restrict_lexicon(path, source_vocabs[0], target_vocabs[0], k=args.restrict_lexicon_topk)
restrict_lexicon[key] = lexicon

return restrict_lexicon


def load_scorer_from_args(args: argparse.Namespace, models: List[SockeyeModel]):
brevity_penalty_weight = args.brevity_penalty_weight
if args.brevity_penalty_type == C.BREVITY_PENALTY_CONSTANT:
if args.brevity_penalty_constant_length_ratio > 0.0:
Expand All @@ -110,38 +104,72 @@ def run_translate(args: argparse.Namespace):
else:
raise ValueError("Unknown brevity penalty type %s" % args.brevity_penalty_type)

for model in models:
model.eval()

scorer = inference.CandidateScorer(
length_penalty_alpha=args.length_penalty_alpha,
length_penalty_beta=args.length_penalty_beta,
brevity_penalty_weight=brevity_penalty_weight)
scorer.to(models[0].dtype)

translator = inference.Translator(device=device,
ensemble_mode=args.ensemble_mode,
scorer=scorer,
batch_size=args.batch_size,
beam_size=args.beam_size,
beam_search_stop=args.beam_search_stop,
nbest_size=args.nbest_size,
models=models,
source_vocabs=source_vocabs,
target_vocabs=target_vocabs,
restrict_lexicon=restrict_lexicon,
strip_unknown_words=args.strip_unknown_words,
sample=args.sample,
output_scores=output_handler.reports_score(),
constant_length_ratio=constant_length_ratio,
knn_lambda=args.knn_lambda,
max_output_length_num_stds=args.max_output_length_num_stds,
max_input_length=args.max_input_length,
max_output_length=args.max_output_length,
prevent_unk=args.prevent_unk,
greedy=args.greedy,
skip_nvs=args.skip_nvs,
nvs_thresh=args.nvs_thresh)
return scorer, constant_length_ratio


def load_translator_from_args(args: argparse.Namespace, output_scores: bool):
device = utils.init_device(args)
logger.info(f"Translate Device: {device}")

models, source_vocabs, target_vocabs = load_models_from_args(args)
restrict_lexicon = restrict_lexicon_from_args(args, source_vocabs, target_vocabs)
scorer, constant_length_ratio = load_scorer_from_args(args, models)

return inference.Translator(device=device,
ensemble_mode=args.ensemble_mode,
scorer=scorer,
batch_size=args.batch_size,
beam_size=args.beam_size,
beam_search_stop=args.beam_search_stop,
nbest_size=args.nbest_size,
models=models,
source_vocabs=source_vocabs,
target_vocabs=target_vocabs,
restrict_lexicon=restrict_lexicon,
strip_unknown_words=args.strip_unknown_words,
sample=args.sample,
output_scores=output_scores,
constant_length_ratio=constant_length_ratio,
knn_lambda=args.knn_lambda,
max_output_length_num_stds=args.max_output_length_num_stds,
max_input_length=args.max_input_length,
max_output_length=args.max_output_length,
prevent_unk=args.prevent_unk,
greedy=args.greedy,
skip_nvs=args.skip_nvs,
nvs_thresh=args.nvs_thresh)


def run_translate(args: argparse.Namespace):
# Seed randomly unless a seed has been passed
seed_rngs(args.seed if args.seed is not None else int(time.time()))

if args.output is not None:
setup_main_logger(console=not args.quiet,
file_logging=not args.no_logfile,
path="%s.%s" % (args.output, C.LOG_NAME),
level=args.loglevel)
else:
setup_main_logger(file_logging=False, level=args.loglevel)

log_basic_info(args)

if args.nbest_size > 1:
if args.output_type != C.OUTPUT_HANDLER_JSON:
logger.warning(
"For nbest translation, you must specify `--output-type '%s'; overriding your setting of '%s'.",
C.OUTPUT_HANDLER_JSON, args.output_type)
args.output_type = C.OUTPUT_HANDLER_JSON
output_handler = get_output_handler(args.output_type,
args.output)

translator = load_translator_from_args(args, output_handler.reports_score())

read_and_translate(translator=translator,
output_handler=output_handler,
Expand Down