diff --git a/recml/layers/linen/sparsecore.py b/recml/layers/linen/sparsecore.py index cc41957..1f1b660 100644 --- a/recml/layers/linen/sparsecore.py +++ b/recml/layers/linen/sparsecore.py @@ -276,10 +276,12 @@ class SparsecorePreprocessor: sparsecore_config: The sparsecore config used to create the tables. global_batch_size: The global batch size across all devices to partition the inputs across. + _batch_number: The batch number for preprocessing, incremented on each call. """ sparsecore_config: SparsecoreConfig global_batch_size: int + _batch_number: int = dataclasses.field(init=False, default=0) def __post_init__(self): self.sparsecore_config.init_feature_specs(self.global_batch_size) @@ -328,6 +330,7 @@ def _to_np(x: Any) -> np.ndarray: if weights[key] is not None: weights[key] = np.reshape(weights[key], (-1, 1)) + self._batch_number += 1 csr_inputs, _ = embedding.preprocess_sparse_dense_matmul_input( features=features, features_weights=weights, @@ -337,6 +340,7 @@ def _to_np(x: Any) -> np.ndarray: num_sc_per_device=self.sparsecore_config.num_sc_per_device, sharding_strategy=self.sparsecore_config.sharding_strategy, allow_id_dropping=False, + batch_number=self._batch_number, ) processed_inputs = {