Skip to content

Commit

Permalink
Update batch_size calculation in keras autolog (mlflow#11224)
Browse files Browse the repository at this point in the history
Signed-off-by: Serena Ruan <serena.rxy@gmail.com>
Signed-off-by: Arthur Jenoudet <arthur.jenoudet@databricks.com>
  • Loading branch information
serena-ruan authored and artjen committed Mar 26, 2024
1 parent 366e6d9 commit 3f9b70d
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 13 deletions.
39 changes: 27 additions & 12 deletions mlflow/keras/autologging.py
Expand Up @@ -14,6 +14,7 @@
from mlflow.keras.save import log_model
from mlflow.keras.utils import get_model_signature
from mlflow.tracking.context import registry as context_registry
from mlflow.utils import is_iterator
from mlflow.utils.annotations import experimental
from mlflow.utils.autologging_utils import (
PatchFunction,
Expand All @@ -26,17 +27,31 @@
_logger = logging.getLogger(__name__)


def _infer_batch_size(*keras_fit_args, **keras_fit_kwargs):
if "batch_size" in keras_fit_kwargs:
return keras_fit_kwargs["batch_size"]

training_data = keras_fit_kwargs["x"] if "x" in keras_fit_kwargs else keras_fit_args[0]
batch_size = getattr(training_data, "batch_size", None) or getattr(
training_data, "_batch_size", None
)
if batch_size:
return batch_size
return None
def _infer_batch_size(inst, *args, **kwargs):
batch_size = None
if "batch_size" in kwargs:
batch_size = kwargs["batch_size"]
else:
training_data = kwargs["x"] if "x" in kwargs else args[0]
if _batch_size := getattr(training_data, "batch_size", None):
batch_size = _batch_size
elif _batch_size := getattr(training_data, "_batch_size", None):
batch_size = _batch_size if isinstance(_batch_size, int) else _batch_size.numpy()
elif is_iterator(training_data):
is_single_input_model = isinstance(inst.input_shape, tuple)
peek = next(training_data)
batch_size = len(peek[0]) if is_single_input_model else len(peek[0][0])

def _restore_generator(prev_generator):
yield peek
yield from prev_generator

restored_generator = _restore_generator(training_data)
if "x" in kwargs:
kwargs["x"] = restored_generator
else:
args = (restored_generator,) + args[1:]
return batch_size, args, kwargs


def _check_existing_mlflow_callback(callbacks):
Expand Down Expand Up @@ -212,7 +227,7 @@ def __init__(self):
def _patch_implementation(self, original, inst, *args, **kwargs):
unlogged_params = ["self", "x", "y", "callbacks", "validation_data", "verbose"]

batch_size = _infer_batch_size(*args, **kwargs)
batch_size, args, kwargs = _infer_batch_size(inst, *args, **kwargs)

if batch_size is not None:
mlflow.log_param("batch_size", batch_size)
Expand Down
2 changes: 1 addition & 1 deletion mlflow/ml-package-versions.yml
Expand Up @@ -109,7 +109,7 @@ tensorflow:
maximum: "2.15.0.post1"
requirements:
# Requirements to run tests for keras
">= 0.0.0": ["scikit-learn", "pyspark", "pyarrow", "transformers"]
">= 0.0.0": ["scikit-learn", "pyspark", "pyarrow", "transformers!=4.38.0,!=4.38.1"]
"< 2.7.0": ["pandas>=1.3.5,<2.0"]
">= 2.7.0": ["pandas<2.0"]
# TensorFlow == 2.6.5 are incompatible with SQLAlchemy 2.x due to
Expand Down

0 comments on commit 3f9b70d

Please sign in to comment.