Skip to content

Commit

Permalink
data_io: make BucketBatchSize a dataclass (#966)
Browse files Browse the repository at this point in the history
  • Loading branch information
fhieber committed Sep 23, 2021
1 parent 4730ca8 commit a0bc3f0
Showing 1 changed file with 4 additions and 10 deletions.
14 changes: 4 additions & 10 deletions sockeye/data_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,17 +135,11 @@ def get_bucket(seq_len: int, buckets: List[int]) -> Optional[int]:
return buckets[bucket_idx]


@dataclass
class BucketBatchSize:
"""
:param bucket: The corresponding bucket.
:param batch_size: Number of sequences in each batch.
:param average_target_words_per_batch: Approximate number of target non-padding tokens in each batch.
"""

def __init__(self, bucket: Tuple[int, int], batch_size: int, average_target_words_per_batch: float) -> None:
self.bucket = bucket
self.batch_size = batch_size
self.average_target_words_per_batch = average_target_words_per_batch
bucket: Tuple[int, int] # The corresponding bucket.
batch_size: int # Number of sequences in each batch.
average_target_words_per_batch: float # Approximate number of target non-padding tokens in each batch.


def define_bucket_batch_sizes(buckets: List[Tuple[int, int]],
Expand Down

0 comments on commit a0bc3f0

Please sign in to comment.