Skip to content

Commit

Permalink
Complete RST -> Google style migration (mlflow#11242)
Browse files Browse the repository at this point in the history
Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: Arthur Jenoudet <arthur.jenoudet@databricks.com>
  • Loading branch information
harupy authored and artjen committed Mar 26, 2024
1 parent ea9fac1 commit 2a6602d
Show file tree
Hide file tree
Showing 17 changed files with 137 additions and 123 deletions.
23 changes: 0 additions & 23 deletions dev/clint/src/clint/linter.py
Expand Up @@ -75,26 +75,6 @@ def json(self) -> dict[str, str | int | None]:
"Builtin modules must be imported at the top level.",
)

# TODO: Remove this once we convert all docstrings to Google style.
NO_RST_IGNORE = {
"mlflow/gateway/client.py",
"mlflow/gateway/providers/utils.py",
"mlflow/keras/callback.py",
"mlflow/metrics/base.py",
"mlflow/metrics/genai/base.py",
"mlflow/models/utils.py",
"mlflow/projects/databricks.py",
"mlflow/projects/kubernetes.py",
"mlflow/store/_unity_catalog/registry/rest_store.py",
"mlflow/store/artifact/azure_data_lake_artifact_repo.py",
"mlflow/store/artifact/gcs_artifact_repo.py",
"mlflow/store/model_registry/rest_store.py",
"mlflow/store/tracking/rest_store.py",
"mlflow/utils/docstring_utils.py",
"mlflow/utils/rest_utils.py",
"tests/utils/test_docstring_utils.py",
}


class Linter(ast.NodeVisitor):
def __init__(self, path: str, ignore: dict[str, set[int]]):
Expand All @@ -120,9 +100,6 @@ def _docstring(
return None

def _no_rst(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None:
if self.path in NO_RST_IGNORE:
return

if (nd := self._docstring(node)) and (
PARAM_REGEX.search(nd.s) or RETURN_REGEX.search(nd.s)
):
Expand Down
7 changes: 4 additions & 3 deletions mlflow/gateway/client.py
Expand Up @@ -35,9 +35,10 @@ class MlflowGatewayClient:
"""
Client for interacting with the MLflow Gateway API.
:param gateway_uri: Optional URI of the gateway. If not provided, attempts to resolve from
first the stored result of `set_gateway_uri()`, then the environment variable
`MLFLOW_GATEWAY_URI`.
Args:
gateway_uri: Optional URI of the gateway. If not provided, attempts to resolve from
first the stored result of `set_gateway_uri()`, then the environment variable
`MLFLOW_GATEWAY_URI`.
"""

def __init__(self, gateway_uri: Optional[str] = None):
Expand Down
35 changes: 23 additions & 12 deletions mlflow/gateway/providers/utils.py
Expand Up @@ -22,12 +22,17 @@ async def send_request(headers: Dict[str, str], base_url: str, path: str, payloa
"""
Send an HTTP request to a specific URL path with given headers and payload.
:param headers: The headers to include in the request.
:param base_url: The base URL where the request will be sent.
:param path: The specific path of the URL to which the request will be sent.
:param payload: The payload (or data) to be included in the request.
:return: The server's response as a JSON object.
:raise: HTTPException if the HTTP request fails.
Args:
headers: The headers to include in the request.
base_url: The base URL where the request will be sent.
path: The specific path of the URL to which the request will be sent.
payload: The payload (or data) to be included in the request.
Returns:
The server's response as a JSON object.
Raises:
HTTPException if the HTTP request fails.
"""
from fastapi import HTTPException

Expand Down Expand Up @@ -56,12 +61,18 @@ async def send_stream_request(
) -> AsyncGenerator[bytes, None]:
"""
Send an HTTP request to a specific URL path with given headers and payload.
:param headers: The headers to include in the request.
:param base_url: The base URL where the request will be sent.
:param path: The specific path of the URL to which the request will be sent.
:param payload: The payload (or data) to be included in the request.
:return: The server's response as a JSON object.
:raise: HTTPException if the HTTP request fails.
Args:
headers: The headers to include in the request.
base_url: The base URL where the request will be sent.
path: The specific path of the URL to which the request will be sent.
payload: The payload (or data) to be included in the request.
Returns:
The server's response as a JSON object.
Raises:
HTTPException if the HTTP request fails.
"""
async with _aiohttp_post(headers, base_url, path, payload) as response:
async for line in response.content:
Expand Down
9 changes: 5 additions & 4 deletions mlflow/keras/callback.py
Expand Up @@ -14,10 +14,11 @@ class MLflowCallback(keras.callbacks.Callback, metaclass=ExceptionSafeClass):
This callback logs model metadata at training begins, and logs training metrics every epoch or
every n steps (defined by the user) to MLflow.
:param log_every_epoch: bool, defaults to True. If True, log metrics every epoch. If False,
log metrics every n steps.
:param log_every_n_steps: int, defaults to None. If set, log metrics every n steps. If None,
log metrics every epoch. Must be `None` if `log_every_epoch=True`.
Args:
log_every_epoch: bool, defaults to True. If True, log metrics every epoch. If False,
log metrics every n steps.
log_every_n_steps: int, defaults to None. If set, log metrics every n steps. If None,
log metrics every epoch. Must be `None` if `log_every_epoch=True`.
.. code-block:: python
:caption: Example
Expand Down
8 changes: 5 additions & 3 deletions mlflow/metrics/base.py
Expand Up @@ -21,9 +21,11 @@ class MetricValue:
"""
The value of a metric.
:param scores: The value of the metric per row
:param justifications: The justification (if applicable) for the respective score
:param aggregate_results: A dictionary mapping the name of the aggregation to its value
Args:
scores: The value of the metric per row
justifications: The justification (if applicable) for the respective score
aggregate_results: A dictionary mapping the name of the aggregation to its value
"""

scores: Union[List[str], List[float]] = None
Expand Down
15 changes: 8 additions & 7 deletions mlflow/metrics/genai/base.py
Expand Up @@ -11,13 +11,14 @@ class EvaluationExample:
"""
Stores the sample example during few shot learning during LLM evaluation
:param input: The input provided to the model
:param output: The output generated by the model
:param score: The score given by the evaluator
:param justification: The justification given by the evaluator
:param grading_context: The grading_context provided to the evaluator for evaluation. Either
a dictionary of grading context column names and grading context strings
or a single grading context string.
Args:
input: The input provided to the model
output: The output generated by the model
score: The score given by the evaluator
justification: The justification given by the evaluator
grading_context: The grading_context provided to the evaluator for evaluation. Either
a dictionary of grading context column names and grading context strings
or a single grading context string.
.. code-block:: python
:caption: Example for creating an EvaluationExample
Expand Down
17 changes: 10 additions & 7 deletions mlflow/models/utils.py
Expand Up @@ -1016,13 +1016,16 @@ def _enforce_pyspark_dataframe_schema(
DataFrame that are declared in the model's input schema. Any extra columns in the original
DataFrame are dropped.Note that this function does not modify the original DataFrame.
:param original_pf_input: Original input PySpark DataFrame.
:param pf_input_as_pandas: Input DataFrame converted to pandas.
:param input_schema: Expected schema of the input DataFrame.
:param flavor: Optional model flavor. If specified, it is used to handle specific behaviors
for different model flavors. Currently, only the '_FEATURE_STORE_FLAVOR' is
handled specially.
:return: New PySpark DataFrame that conforms to the model's input schema.
Args:
original_pf_input: Original input PySpark DataFrame.
pf_input_as_pandas: Input DataFrame converted to pandas.
input_schema: Expected schema of the input DataFrame.
flavor: Optional model flavor. If specified, it is used to handle specific behaviors
for different model flavors. Currently, only the '_FEATURE_STORE_FLAVOR' is
handled specially.
Returns:
New PySpark DataFrame that conforms to the model's input schema.
"""
if not HAS_PYSPARK:
raise MlflowException("PySpark is not installed. Cannot handle a PySpark DataFrame.")
Expand Down
14 changes: 9 additions & 5 deletions mlflow/projects/databricks.py
Expand Up @@ -85,7 +85,9 @@ def before_run_validations(tracking_uri, backend_config):
class DatabricksJobRunner:
"""
Helper class for running an MLflow project as a Databricks Job.
:param databricks_profile: Optional Databricks CLI profile to use to fetch hostname &
Args:
databricks_profile: Optional Databricks CLI profile to use to fetch hostname &
authentication information when making Databricks API requests.
"""

Expand Down Expand Up @@ -404,10 +406,12 @@ class DatabricksSubmittedRun(SubmittedRun):
Instance of SubmittedRun corresponding to a Databricks Job run launched to run an MLflow
project. Note that run_id may be None, e.g. if we did not launch the run against a tracking
server accessible to the local client.
:param databricks_run_id: Run ID of the launched Databricks Job.
:param mlflow_run_id: ID of the MLflow project run.
:param databricks_job_runner: Instance of ``DatabricksJobRunner`` used to make Databricks API
requests.
Args:
databricks_run_id: Run ID of the launched Databricks Job.
mlflow_run_id: ID of the MLflow project run.
databricks_job_runner: Instance of ``DatabricksJobRunner`` used to make Databricks API
requests.
"""

# How often to poll run status when waiting on a run
Expand Down
8 changes: 5 additions & 3 deletions mlflow/projects/kubernetes.py
Expand Up @@ -92,9 +92,11 @@ class KubernetesSubmittedRun(SubmittedRun):
"""
Instance of SubmittedRun corresponding to a Kubernetes Job run launched to run an MLflow
project.
:param mlflow_run_id: ID of the MLflow project run.
:param job_name: Kubernetes job name.
:param job_namespace: Kubernetes job namespace.
Args:
mlflow_run_id: ID of the MLflow project run.
job_name: Kubernetes job name.
job_namespace: Kubernetes job namespace.
"""

# How often to poll run status when waiting on a run
Expand Down
9 changes: 5 additions & 4 deletions mlflow/store/_unity_catalog/registry/rest_store.py
Expand Up @@ -207,10 +207,11 @@ class UcModelRegistryStore(BaseRestStore):
"""
Client for a remote model registry server accessed via REST API calls
:param store_uri: URI with scheme 'databricks-uc'
:param tracking_uri: URI of the Databricks MLflow tracking server from which to fetch
run info and download run artifacts, when creating new model
versions from source artifacts logged to an MLflow run.
Args:
store_uri: URI with scheme 'databricks-uc'
tracking_uri: URI of the Databricks MLflow tracking server from which to fetch
run info and download run artifacts, when creating new model
versions from source artifacts logged to an MLflow run.
"""

def __init__(self, store_uri, tracking_uri):
Expand Down
5 changes: 3 additions & 2 deletions mlflow/store/artifact/azure_data_lake_artifact_repo.py
Expand Up @@ -72,8 +72,9 @@ class AzureDataLakeArtifactRepository(CloudArtifactRepository):
This repository is used with URIs of the form
``abfs[s]://file_system@account_name.dfs.core.windows.net/<path>/<path>``.
:param credential: Azure credential (see options in https://learn.microsoft.com/en-us/python/api/azure-core/azure.core.credentials?view=azure-python)
to use to authenticate to storage
Args
credential: Azure credential (see options in https://learn.microsoft.com/en-us/python/api/azure-core/azure.core.credentials?view=azure-python)
to use to authenticate to storage
"""

def __init__(self, artifact_uri, credential):
Expand Down
9 changes: 5 additions & 4 deletions mlflow/store/artifact/gcs_artifact_repo.py
Expand Up @@ -29,10 +29,11 @@ class GCSArtifactRepository(ArtifactRepository, MultipartUploadMixin):
"""
Stores artifacts on Google Cloud Storage.
:param artifact_uri: URI of GCS bucket
:param client: Optional. The client to use for GCS operations; a default
client object will be created if unspecified, using default
credentials as described in https://google-cloud.readthedocs.io/en/latest/core/auth.html
Args:
artifact_uri: URI of GCS bucket
client: Optional. The client to use for GCS operations; a default
client object will be created if unspecified, using default
credentials as described in https://google-cloud.readthedocs.io/en/latest/core/auth.html
"""

def __init__(self, artifact_uri, client=None):
Expand Down
7 changes: 4 additions & 3 deletions mlflow/store/model_registry/rest_store.py
Expand Up @@ -44,9 +44,10 @@ class RestStore(BaseRestStore):
"""
Client for a remote model registry server accessed via REST API calls
:param get_host_creds: Method to be invoked prior to every REST request to get the
:py:class:`mlflow.rest_utils.MlflowHostCreds` for the request. Note that this
is a function so that we can obtain fresh credentials in the case of expiry.
Args:
get_host_creds: Method to be invoked prior to every REST request to get the
:py:class:`mlflow.rest_utils.MlflowHostCreds` for the request. Note that this
is a function so that we can obtain fresh credentials in the case of expiry.
"""

def _get_response_from_method(self, method):
Expand Down
7 changes: 4 additions & 3 deletions mlflow/store/tracking/rest_store.py
Expand Up @@ -44,9 +44,10 @@ class RestStore(AbstractStore):
"""
Client for a remote tracking server accessed via REST API calls
:param get_host_creds: Method to be invoked prior to every REST request to get the
:py:class:`mlflow.rest_utils.MlflowHostCreds` for the request. Note that this
is a function so that we can obtain fresh credentials in the case of expiry.
Args
get_host_creds: Method to be invoked prior to every REST request to get the
:py:class:`mlflow.rest_utils.MlflowHostCreds` for the request. Note that this
is a function so that we can obtain fresh credentials in the case of expiry.
"""

def __init__(self, get_host_creds):
Expand Down
17 changes: 10 additions & 7 deletions mlflow/utils/docstring_utils.py
Expand Up @@ -108,8 +108,9 @@ def format_docstring(self, docstring: str) -> str:
>>> pd = ParamDocs(p1="doc1", p2="doc2
doc2 second line")
>>> docstring = '''
... :param p1: {{ p1 }}
... :param p2: {{ p2 }}
... Args:
... p1: {{ p1 }}
... p2: {{ p2 }}
... '''.strip()
>>> print(pd.format_docstring(docstring))
"""
Expand Down Expand Up @@ -143,14 +144,16 @@ def format_docstring(param_docs):
>>> @format_docstring(param_docs)
... def func(p1, p2):
... '''
... :param p1: {{ p1 }}
... :param p2: {{ p2 }}
... Args:
... p1: {{ p1 }}
... p2: {{ p2 }}
... '''
>>> import textwrap
>>> print(textwrap.dedent(func.__doc__).strip())
:param p1: doc1
:param p2: doc2
doc2 second line
Args:
p1: doc1
p2: doc2
doc2 second line
"""
param_docs = ParamDocs(param_docs)

Expand Down
48 changes: 25 additions & 23 deletions mlflow/utils/rest_utils.py
Expand Up @@ -239,29 +239,31 @@ def call_endpoints(host_creds, endpoints, json_body, response_proto, extra_heade
class MlflowHostCreds:
"""
Provides a hostname and optional authentication for talking to an MLflow tracking server.
:param host: Hostname (e.g., http://localhost:5000) to MLflow server. Required.
:param username: Username to use with Basic authentication when talking to server.
If this is specified, password must also be specified.
:param password: Password to use with Basic authentication when talking to server.
If this is specified, username must also be specified.
:param token: Token to use with Bearer authentication when talking to server.
If provided, user/password authentication will be ignored.
:param aws_sigv4: If true, we will create a signature V4 to be added for any outgoing request.
Keys for signing the request can be passed via ENV variables,
or will be fetched via boto3 session.
:param auth: If set, the auth will be added for any outgoing request.
Keys for signing the request can be passed via ENV variables,
:param ignore_tls_verification: If true, we will not verify the server's hostname or TLS
certificate. This is useful for certain testing situations, but should never be
true in production.
If this is set to true ``server_cert_path`` must not be set.
:param client_cert_path: Path to ssl client cert file (.pem).
Sets the cert param of the ``requests.request``
function (see https://requests.readthedocs.io/en/master/api/).
:param server_cert_path: Path to a CA bundle to use.
Sets the verify param of the ``requests.request``
function (see https://requests.readthedocs.io/en/master/api/).
If this is set ``ignore_tls_verification`` must be false.
Args:
host: Hostname (e.g., http://localhost:5000) to MLflow server. Required.
username: Username to use with Basic authentication when talking to server.
If this is specified, password must also be specified.
password: Password to use with Basic authentication when talking to server.
If this is specified, username must also be specified.
token: Token to use with Bearer authentication when talking to server.
If provided, user/password authentication will be ignored.
aws_sigv4: If true, we will create a signature V4 to be added for any outgoing request.
Keys for signing the request can be passed via ENV variables,
or will be fetched via boto3 session.
auth: If set, the auth will be added for any outgoing request.
Keys for signing the request can be passed via ENV variables,
ignore_tls_verification: If true, we will not verify the server's hostname or TLS
certificate. This is useful for certain testing situations, but should never be
true in production.
If this is set to true ``server_cert_path`` must not be set.
client_cert_path: Path to ssl client cert file (.pem).
Sets the cert param of the ``requests.request``
function (see https://requests.readthedocs.io/en/master/api/).
server_cert_path: Path to a CA bundle to use.
Sets the verify param of the ``requests.request``
function (see https://requests.readthedocs.io/en/master/api/).
If this is set ``ignore_tls_verification`` must be false.
"""

def __init__(
Expand Down

0 comments on commit 2a6602d

Please sign in to comment.