Skip to content

Commit

Permalink
Convert classes to dataclasses (#969)
Browse files Browse the repository at this point in the history
  • Loading branch information
fhieber committed Sep 30, 2021
1 parent a0bc3f0 commit 33eb25f
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 144 deletions.
19 changes: 8 additions & 11 deletions sockeye/data_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1990,18 +1990,15 @@ def load_state(self, fname: str):
self.data = self.data.permute(self.data_permutations)


@dataclass
class Batch:

__slots__ = ['source', 'source_length', 'target', 'target_length', 'labels', 'samples', 'tokens']

def __init__(self, source, source_length, target, target_length, labels, samples, tokens):
self.source = source
self.source_length = source_length
self.target = target
self.target_length = target_length
self.labels = labels
self.samples = samples
self.tokens = tokens
source: mx.nd.NDArray
source_length: mx.nd.NDArray
target: mx.nd.NDArray
target_length: mx.nd.NDArray
labels: Dict[str, mx.nd.NDArray]
samples: int
tokens: int

def split_and_load(self, ctx: List[mx.context.Context]) -> 'Batch':
source = mx.gluon.utils.split_and_load(self.source, ctx, batch_axis=0)
Expand Down
203 changes: 70 additions & 133 deletions sockeye/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
import itertools
import json
import logging
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, Generator, List, Optional, NamedTuple, Set, Tuple, Union
from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Union

import mxnet as mx
import numpy as np
Expand Down Expand Up @@ -115,6 +116,7 @@ def get_max_output_length(input_length: int):
SentenceId = Union[int, str]


@dataclass
class TranslatorInput:
"""
Object required by Translator.translate().
Expand All @@ -123,38 +125,15 @@ class TranslatorInput:
If `--output-type json` is selected, all such fields that are not fields used or changed by
Sockeye will be included in the output JSON object. This provides a mechanism for passing
fields through the call to Sockeye.
:param sentence_id: Sentence id.
:param tokens: List of input tokens.
:param factors: Optional list of additional factor sequences.
:param restrict_lexicon: Optional lexicon for vocabulary selection.
:param constraints: Optional list of target-side constraints.
:param pass_through_dict: Optional raw dictionary of arbitrary input data.
"""

__slots__ = ('sentence_id',
'tokens',
'factors',
'restrict_lexicon',
'constraints',
'avoid_list',
'pass_through_dict')

def __init__(self,
sentence_id: SentenceId,
tokens: Tokens,
factors: Optional[List[Tokens]] = None,
restrict_lexicon: Optional[lexicon.TopKLexicon] = None,
constraints: Optional[List[Tokens]] = None,
avoid_list: Optional[List[Tokens]] = None,
pass_through_dict: Optional[Dict] = None) -> None:
self.sentence_id = sentence_id
self.tokens = tokens
self.factors = factors
self.restrict_lexicon = restrict_lexicon
self.constraints = constraints
self.avoid_list = avoid_list
self.pass_through_dict = pass_through_dict
sentence_id: SentenceId
tokens: Tokens
factors: Optional[List[Tokens]] = None
restrict_lexicon: Optional[lexicon.TopKLexicon] = None
constraints: Optional[List[Tokens]] = None
avoid_list: Optional[List[Tokens]] = None
pass_through_dict: Optional[Dict] = None

def __str__(self):
return 'TranslatorInput(%s, %s, factors=%s, constraints=%s, avoid=%s)' \
Expand Down Expand Up @@ -190,7 +169,8 @@ def chunks(self, chunk_size: int) -> Generator['TranslatorInput', None, None]:
# Constrained decoding is not supported for chunked TranslatorInputs. As a fall-back, constraints are
# assigned to the first chunk
constraints = self.constraints if chunk_id == 0 else None
pass_through_dict = self.pass_through_dict if chunk_id == 0 else None
pass_through_dict = copy.deepcopy(self.pass_through_dict) \
if (chunk_id == 0 and self.pass_through_dict is not None) else None
yield TranslatorInput(sentence_id=self.sentence_id,
tokens=self.tokens[i:i + chunk_size],
factors=factors,
Expand Down Expand Up @@ -389,64 +369,37 @@ def make_input_from_multiple_strings(sentence_id: SentenceId, strings: List[str]
return TranslatorInput(sentence_id=sentence_id, tokens=tokens, factors=factors)


@dataclass
class TranslatorOutput:
"""
Output structure from Translator.
:param sentence_id: Sentence id.
:param translation: Translation string without sentence boundary tokens.
:param tokens: List of translated tokens.
:param score: Negative log probability of generated translation.
:param pass_through_dict: Dictionary of key/value pairs to pass through when working with JSON.
:param beam_histories: List of beam histories. The list will contain more than one
history if it was split due to exceeding max_length.
:param nbest_translations: List of nbest translations as strings.
:param nbest_tokens: List of nbest translations as lists of tokens.
:param nbest_scores: List of nbest scores, one for each nbest translation.
:param factor_translations: List of factor outputs.
:param factor_tokens: List of list of secondary factor tokens.
sentence_id: Sentence id.
translation: Translation string without sentence boundary tokens.
tokens: List of translated tokens.
score: Negative log probability of generated translation.
pass_through_dict: Dictionary of key/value pairs to pass through when working with JSON.
beam_histories: List of beam histories. The list will contain more than one
history if it was split due to exceeding max_length.
nbest_translations: List of nbest translations as strings.
nbest_tokens: List of nbest translations as lists of tokens.
nbest_scores: List of nbest scores, one for each nbest translation.
factor_translations: List of factor outputs.
factor_tokens: List of list of secondary factor tokens.
"""
__slots__ = ('sentence_id',
'translation',
'tokens',
'score',
'pass_through_dict',
'beam_histories',
'nbest_translations',
'nbest_tokens',
'nbest_scores',
'factor_translations',
'factor_tokens',
'nbest_factor_translations',
'nbest_factor_tokens')

def __init__(self,
sentence_id: SentenceId,
translation: str,
tokens: Tokens,
score: float,
pass_through_dict: Optional[Dict[str, Any]] = None,
beam_histories: Optional[List[BeamHistory]] = None,
nbest_translations: Optional[List[str]] = None,
nbest_tokens: Optional[List[Tokens]] = None,
nbest_scores: Optional[List[float]] = None,
factor_translations: Optional[List[str]] = None,
factor_tokens: Optional[List[Tokens]] = None,
nbest_factor_translations: Optional[List[List[str]]] = None,
nbest_factor_tokens: Optional[List[List[Tokens]]] = None) -> None:
self.sentence_id = sentence_id
self.translation = translation
self.tokens = tokens
self.score = score
self.pass_through_dict = copy.deepcopy(pass_through_dict) if pass_through_dict else {}
self.beam_histories = beam_histories
self.nbest_translations = nbest_translations
self.nbest_tokens = nbest_tokens
self.nbest_scores = nbest_scores
self.factor_translations = factor_translations
self.factor_tokens = factor_tokens
self.nbest_factor_translations = nbest_factor_translations
self.nbest_factor_tokens = nbest_factor_tokens
sentence_id: SentenceId
translation: str
tokens: Tokens
score: float
pass_through_dict: Optional[Dict[str, Any]] = None
beam_histories: Optional[List[BeamHistory]] = None
nbest_translations: Optional[List[str]] = None
nbest_tokens: Optional[List[Tokens]] = None
nbest_scores: Optional[List[float]] = None
factor_translations: Optional[List[str]] = None
factor_tokens: Optional[List[Tokens]] = None
nbest_factor_translations: Optional[List[List[str]]] = None
nbest_factor_tokens: Optional[List[List[Tokens]]] = None

def json(self) -> Dict:
"""
Expand All @@ -458,7 +411,7 @@ def json(self) -> Dict:
:return: A dictionary.
"""
_d = self.pass_through_dict # type: Dict[str, Any]
_d = copy.deepcopy(self.pass_through_dict) if self.pass_through_dict is not None else {} # type: Dict[str, Any]
_d['sentence_id'] = self.sentence_id
_d['translation'] = self.translation
_d['score'] = self.score
Expand All @@ -480,35 +433,19 @@ def json(self) -> Dict:
return _d


@dataclass
class NBestTranslations:
__slots__ = ('target_ids_list',
'scores')

def __init__(self,
target_ids_list: List[TokenIds],
scores: List[float]) -> None:
self.target_ids_list = target_ids_list
self.scores = scores
target_ids_list: List[TokenIds]
scores: List[float]


@dataclass
class Translation:
__slots__ = ('target_ids',
'score',
'beam_histories',
'nbest_translations',
'estimated_reference_length')

def __init__(self,
target_ids: TokenIds,
score: float,
beam_histories: List[BeamHistory] = None,
nbest_translations: NBestTranslations = None,
estimated_reference_length: Optional[float] = None) -> None:
self.target_ids = target_ids
self.score = score
self.beam_histories = beam_histories if beam_histories is not None else []
self.nbest_translations = nbest_translations
self.estimated_reference_length = estimated_reference_length
target_ids: TokenIds
score: float
beam_histories: List[BeamHistory] = None
nbest_translations: NBestTranslations = None
estimated_reference_length: Optional[float] = None


def empty_translation(add_nbest: bool = False) -> Translation:
Expand All @@ -522,32 +459,32 @@ def empty_translation(add_nbest: bool = False) -> Translation:
nbest_translations=NBestTranslations([], []) if add_nbest else None)


IndexedTranslatorInput = NamedTuple('IndexedTranslatorInput', [
('input_idx', int),
('chunk_idx', int),
('translator_input', TranslatorInput)
])
"""
Translation of a chunk of a sentence.
@dataclass
class IndexedTranslatorInput:
"""
Translation of a chunk of a sentence.
:param input_idx: Internal index of translation requests to keep track of the correct order of translations.
:param chunk_idx: The index of the chunk. Used when TranslatorInputs get split across multiple chunks.
:param input: The translator input.
"""
input_idx: Internal index of translation requests to keep track of the correct order of translations.
chunk_idx: The index of the chunk. Used when TranslatorInputs get split across multiple chunks.
input: The translator input.
"""
input_idx: int
chunk_idx: int
translator_input: TranslatorInput


IndexedTranslation = NamedTuple('IndexedTranslation', [
('input_idx', int),
('chunk_idx', int),
('translation', Translation)
])
"""
Translation of a chunk of a sentence.
@dataclass(order=True)
class IndexedTranslation:
"""
Translation of a chunk of a sentence.
:param input_idx: Internal index of translation requests to keep track of the correct order of translations.
:param chunk_idx: The index of the chunk. Used when TranslatorInputs get split across multiple chunks.
:param translation: The translation of the input chunk.
"""
input_idx: Internal index of translation requests to keep track of the correct order of translations.
chunk_idx: The index of the chunk. Used when TranslatorInputs get split across multiple chunks.
translation: The translation of the input chunk.
"""
input_idx: int
chunk_idx: int
translation: Translation


def _concat_nbest_translations(translations: List[Translation],
Expand Down

0 comments on commit 33eb25f

Please sign in to comment.