Skip to content

Commit

Permalink
Add code snippet formatting in docstrings via Ruff (#36262)
Browse files Browse the repository at this point in the history
This was made available [as part of v0.1.8 of the Ruff Formatter](https://astral.sh/blog/ruff-v0.1.8#formatting-code-snippets-in-docstrings). Adding this config option to the `ruff-format` pre-commit hook.
  • Loading branch information
josh-fell committed Dec 17, 2023
1 parent f7f7183 commit e9ba37b
Show file tree
Hide file tree
Showing 36 changed files with 272 additions and 303 deletions.
4 changes: 2 additions & 2 deletions airflow/io/path.py
Expand Up @@ -263,11 +263,11 @@ def read_block(self, offset: int, length: int, delimiter=None):
--------
>>> read_block(0, 13)
b'Alice, 100\\nBo'
>>> read_block(0, 13, delimiter=b'\\n')
>>> read_block(0, 13, delimiter=b"\\n")
b'Alice, 100\\nBob, 200\\n'
Use ``length=None`` to read to the end of the file.
>>> read_block(0, None, delimiter=b'\\n')
>>> read_block(0, None, delimiter=b"\\n")
b'Alice, 100\\nBob, 200\\nCharlie, 300'
See Also
Expand Down
8 changes: 4 additions & 4 deletions airflow/macros/__init__.py
Expand Up @@ -49,9 +49,9 @@ def ds_add(ds: str, days: int) -> str:
:param ds: anchor date in ``YYYY-MM-DD`` format to add to
:param days: number of days to add to the ds, you can use negative values
>>> ds_add('2015-01-01', 5)
>>> ds_add("2015-01-01", 5)
'2015-01-06'
>>> ds_add('2015-01-06', -5)
>>> ds_add("2015-01-06", -5)
'2015-01-01'
"""
if not days:
Expand All @@ -68,9 +68,9 @@ def ds_format(ds: str, input_format: str, output_format: str) -> str:
:param input_format: input string format. E.g. %Y-%m-%d
:param output_format: output string format E.g. %Y-%m-%d
>>> ds_format('2015-01-01', "%Y-%m-%d", "%m-%d-%y")
>>> ds_format("2015-01-01", "%Y-%m-%d", "%m-%d-%y")
'01-01-15'
>>> ds_format('1/5/2015', "%m/%d/%Y", "%Y-%m-%d")
>>> ds_format("1/5/2015", "%m/%d/%Y", "%Y-%m-%d")
'2015-01-05'
"""
return datetime.strptime(str(ds), input_format).strftime(output_format)
Expand Down
18 changes: 3 additions & 15 deletions airflow/models/baseoperator.py
Expand Up @@ -628,12 +628,7 @@ class derived from this one results in the creation of a task object,
**Example**: to run this task in a specific docker container through
the KubernetesExecutor ::
MyOperator(...,
executor_config={
"KubernetesExecutor":
{"image": "myCustomDockerImage"}
}
)
MyOperator(..., executor_config={"KubernetesExecutor": {"image": "myCustomDockerImage"}})
:param do_xcom_push: if True, an XCom is pushed containing the Operator's
result
Expand Down Expand Up @@ -1152,9 +1147,7 @@ def set_xcomargs_dependencies(self) -> None:
# This is equivalent to
with DAG(...):
generate_content = GenerateContentOperator(task_id="generate_content")
send_email = EmailOperator(
..., html_content="{{ task_instance.xcom_pull('generate_content') }}"
)
send_email = EmailOperator(..., html_content="{{ task_instance.xcom_pull('generate_content') }}")
generate_content >> send_email
"""
Expand Down Expand Up @@ -1866,12 +1859,7 @@ def chain_linear(*elements: DependencyMixin | Sequence[DependencyMixin]):
Then you can accomplish like so::
chain_linear(
op1,
[op2, op3],
[op4, op5, op6],
op7
)
chain_linear(op1, [op2, op3], [op4, op5, op6], op7)
:param elements: a list of operators / lists of operators
"""
Expand Down
7 changes: 4 additions & 3 deletions airflow/models/dag.py
Expand Up @@ -381,11 +381,12 @@ class DAG(LoggingMixin):
**Example**: to avoid Jinja from removing a trailing newline from template strings ::
DAG(dag_id='my-dag',
DAG(
dag_id="my-dag",
jinja_environment_kwargs={
'keep_trailing_newline': True,
"keep_trailing_newline": True,
# some other jinja2 Environment options here
}
},
)
**See**: `Jinja Environment documentation
Expand Down
4 changes: 4 additions & 0 deletions airflow/models/taskinstance.py
Expand Up @@ -3319,19 +3319,23 @@ def get_relevant_upstream_map_indexes(
def this_task(v): # This is self.task.
return v * 2
@task_group
def tg1(inp):
val = upstream(inp) # This is the upstream task.
this_task(val) # When inp is 1, val here should resolve to 2.
return val
# This val is the same object returned by tg1.
val = tg1.expand(inp=[1, 2, 3])
@task_group
def tg2(inp):
another_task(inp, val) # val here should resolve to [2, 4, 6].
tg2.expand(inp=["a", "b"])
The surrounding mapped task groups of ``upstream`` and ``self.task`` are
Expand Down
4 changes: 2 additions & 2 deletions airflow/models/xcom_arg.py
Expand Up @@ -59,8 +59,8 @@ class XComArg(ResolveMixin, DependencyMixin):
xcomarg >> op
xcomarg << op
op >> xcomarg # By BaseOperator code
op << xcomarg # By BaseOperator code
op >> xcomarg # By BaseOperator code
op << xcomarg # By BaseOperator code
**Example**: The moment you get a result from any operator (decorated or regular) you can ::
Expand Down
20 changes: 10 additions & 10 deletions airflow/providers/amazon/aws/operators/s3.py
Expand Up @@ -762,11 +762,11 @@ class S3ListOperator(BaseOperator):
``customers/2018/04/`` key in the ``data`` bucket. ::
s3_file = S3ListOperator(
task_id='list_3s_files',
bucket='data',
prefix='customers/2018/04/',
delimiter='/',
aws_conn_id='aws_customers_conn'
task_id="list_3s_files",
bucket="data",
prefix="customers/2018/04/",
delimiter="/",
aws_conn_id="aws_customers_conn",
)
"""

Expand Down Expand Up @@ -843,11 +843,11 @@ class S3ListPrefixesOperator(BaseOperator):
from the S3 ``customers/2018/04/`` prefix in the ``data`` bucket. ::
s3_file = S3ListPrefixesOperator(
task_id='list_s3_prefixes',
bucket='data',
prefix='customers/2018/04/',
delimiter='/',
aws_conn_id='aws_customers_conn'
task_id="list_s3_prefixes",
bucket="data",
prefix="customers/2018/04/",
delimiter="/",
aws_conn_id="aws_customers_conn",
)
"""

Expand Down
13 changes: 5 additions & 8 deletions airflow/providers/amazon/aws/operators/sagemaker.py
Expand Up @@ -404,14 +404,14 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
If you need to create a SageMaker endpoint based on an existed
SageMaker model and an existed SageMaker endpoint config::
config = endpoint_configuration;
config = endpoint_configuration
If you need to create all of SageMaker model, SageMaker endpoint-config and SageMaker endpoint::
config = {
'Model': model_configuration,
'EndpointConfig': endpoint_config_configuration,
'Endpoint': endpoint_configuration
"Model": model_configuration,
"EndpointConfig": endpoint_config_configuration,
"Endpoint": endpoint_configuration,
}
For details of the configuration parameter of model_configuration see
Expand Down Expand Up @@ -579,10 +579,7 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
If you need to create both SageMaker model and SageMaker Transform job::
config = {
'Model': model_config,
'Transform': transform_config
}
config = {"Model": model_config, "Transform": transform_config}
For details of the configuration parameter of transform_config see
:py:meth:`SageMaker.Client.create_transform_job`
Expand Down
6 changes: 1 addition & 5 deletions airflow/providers/apache/cassandra/hooks/cassandra.py
Expand Up @@ -48,11 +48,7 @@ class CassandraHook(BaseHook, LoggingMixin):
If SSL is enabled in Cassandra, pass in a dict in the extra field as kwargs for
``ssl.wrap_socket()``. For example::
{
'ssl_options' : {
'ca_certs' : PATH_TO_CA_CERTS
}
}
{"ssl_options": {"ca_certs": PATH_TO_CA_CERTS}}
Default load balancing policy is RoundRobinPolicy. To specify a different
LB policy::
Expand Down
10 changes: 6 additions & 4 deletions airflow/providers/apache/cassandra/sensors/record.py
Expand Up @@ -38,10 +38,12 @@ class CassandraRecordSensor(BaseSensorOperator):
primary keys 'p1' and 'p2' to be populated in keyspace 'k' and table 't',
instantiate it as follows:
>>> cassandra_sensor = CassandraRecordSensor(table="k.t",
... keys={"p1": "v1", "p2": "v2"},
... cassandra_conn_id="cassandra_default",
... task_id="cassandra_sensor")
>>> cassandra_sensor = CassandraRecordSensor(
... table="k.t",
... keys={"p1": "v1", "p2": "v2"},
... cassandra_conn_id="cassandra_default",
... task_id="cassandra_sensor",
... )
:param table: Target Cassandra table.
Use dot notation to target a specific keyspace.
Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/apache/cassandra/sensors/table.py
Expand Up @@ -38,9 +38,9 @@ class CassandraTableSensor(BaseSensorOperator):
For example, if you want to wait for a table called 't' to be created
in a keyspace 'k', instantiate it as follows:
>>> cassandra_sensor = CassandraTableSensor(table="k.t",
... cassandra_conn_id="cassandra_default",
... task_id="cassandra_sensor")
>>> cassandra_sensor = CassandraTableSensor(
... table="k.t", cassandra_conn_id="cassandra_default", task_id="cassandra_sensor"
... )
:param table: Target Cassandra table.
Use dot notation to target a specific keyspace.
Expand Down
23 changes: 11 additions & 12 deletions airflow/providers/apache/hive/hooks/hive.py
Expand Up @@ -582,12 +582,11 @@ def check_for_partition(self, schema: str, table: str, partition: str) -> bool:
:param schema: Name of hive schema (database) @table belongs to
:param table: Name of hive table @partition belongs to
:param partition: Expression that matches the partitions to check for
(eg `a = 'b' AND c = 'd'`)
:param partition: Expression that matches the partitions to check for (e.g. `a = 'b' AND c = 'd'`)
>>> hh = HiveMetastoreHook()
>>> t = 'static_babynames_partitioned'
>>> hh.check_for_partition('airflow', t, "ds='2015-01-01'")
>>> t = "static_babynames_partitioned"
>>> hh.check_for_partition("airflow", t, "ds='2015-01-01'")
True
"""
with self.metastore as client:
Expand All @@ -606,10 +605,10 @@ def check_for_named_partition(self, schema: str, table: str, partition_name: str
:param partition_name: Name of the partitions to check for (eg `a=b/c=d`)
>>> hh = HiveMetastoreHook()
>>> t = 'static_babynames_partitioned'
>>> hh.check_for_named_partition('airflow', t, "ds=2015-01-01")
>>> t = "static_babynames_partitioned"
>>> hh.check_for_named_partition("airflow", t, "ds=2015-01-01")
True
>>> hh.check_for_named_partition('airflow', t, "ds=xxx")
>>> hh.check_for_named_partition("airflow", t, "ds=xxx")
False
"""
with self.metastore as client:
Expand All @@ -619,7 +618,7 @@ def get_table(self, table_name: str, db: str = "default") -> Any:
"""Get a metastore table object.
>>> hh = HiveMetastoreHook()
>>> t = hh.get_table(db='airflow', table_name='static_babynames')
>>> t = hh.get_table(db="airflow", table_name="static_babynames")
>>> t.tableName
'static_babynames'
>>> [col.name for col in t.sd.cols]
Expand Down Expand Up @@ -649,8 +648,8 @@ def get_partitions(self, schema: str, table_name: str, partition_filter: str | N
For subpartitioned table, the number might easily exceed this.
>>> hh = HiveMetastoreHook()
>>> t = 'static_babynames_partitioned'
>>> parts = hh.get_partitions(schema='airflow', table_name=t)
>>> t = "static_babynames_partitioned"
>>> parts = hh.get_partitions(schema="airflow", table_name=t)
>>> len(parts)
1
>>> parts
Expand Down Expand Up @@ -765,9 +764,9 @@ def table_exists(self, table_name: str, db: str = "default") -> bool:
Check if table exists.
>>> hh = HiveMetastoreHook()
>>> hh.table_exists(db='airflow', table_name='static_babynames')
>>> hh.table_exists(db="airflow", table_name="static_babynames")
True
>>> hh.table_exists(db='airflow', table_name='does_not_exist')
>>> hh.table_exists(db="airflow", table_name="does_not_exist")
False
"""
try:
Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/apache/hive/macros/hive.py
Expand Up @@ -39,7 +39,7 @@ def max_partition(
:param field: the field to get the max value from. If there's only
one partition field, this will be inferred
>>> max_partition('airflow.static_babynames_partitioned')
>>> max_partition("airflow.static_babynames_partitioned")
'2015-01-01'
"""
from airflow.providers.apache.hive.hooks.hive import HiveMetastoreHook
Expand Down Expand Up @@ -94,8 +94,8 @@ def closest_ds_partition(
:param metastore_conn_id: which metastore connection to use
:returns: The closest date
>>> tbl = 'airflow.static_babynames_partitioned'
>>> closest_ds_partition(tbl, '2015-01-02')
>>> tbl = "airflow.static_babynames_partitioned"
>>> closest_ds_partition(tbl, "2015-01-02")
'2015-01-01'
"""
from airflow.providers.apache.hive.hooks.hive import HiveMetastoreHook
Expand Down
18 changes: 6 additions & 12 deletions airflow/providers/databricks/operators/databricks.py
Expand Up @@ -580,26 +580,20 @@ class DatabricksRunNowOperator(BaseOperator):
For example ::
json = {
"job_id": 42,
"notebook_params": {
"dry-run": "true",
"oldest-time-to-consider": "1457570074236"
}
"job_id": 42,
"notebook_params": {"dry-run": "true", "oldest-time-to-consider": "1457570074236"},
}
notebook_run = DatabricksRunNowOperator(task_id='notebook_run', json=json)
notebook_run = DatabricksRunNowOperator(task_id="notebook_run", json=json)
Another way to accomplish the same thing is to use the named parameters
of the ``DatabricksRunNowOperator`` directly. Note that there is exactly
one named parameter for each top level parameter in the ``run-now``
endpoint. In this method, your code would look like this: ::
job_id=42
job_id = 42
notebook_params = {
"dry-run": "true",
"oldest-time-to-consider": "1457570074236"
}
notebook_params = {"dry-run": "true", "oldest-time-to-consider": "1457570074236"}
python_params = ["douglas adams", "42"]
Expand All @@ -612,7 +606,7 @@ class DatabricksRunNowOperator(BaseOperator):
notebook_params=notebook_params,
python_params=python_params,
jar_params=jar_params,
spark_submit_params=spark_submit_params
spark_submit_params=spark_submit_params,
)
In the case where both the json parameter **AND** the named parameters
Expand Down
Expand Up @@ -2090,9 +2090,9 @@ def oauth_user_info_getter(
@appbuilder.sm.oauth_user_info_getter
def my_oauth_user_info(sm, provider, response=None):
if provider == 'github':
me = sm.oauth_remotes[provider].get('user')
return {'username': me.data.get('login')}
if provider == "github":
me = sm.oauth_remotes[provider].get("user")
return {"username": me.data.get("login")}
return {}
"""

Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/ftp/operators/ftp.py
Expand Up @@ -66,7 +66,7 @@ class FTPFileTransmitOperator(BaseOperator):
remote_filepath="/tmp/tmp1/tmp2/file.txt",
operation="put",
create_intermediate_dirs=True,
dag=dag
dag=dag,
)
"""

Expand Down

0 comments on commit e9ba37b

Please sign in to comment.