Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions airflow/io/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,11 +263,11 @@ def read_block(self, offset: int, length: int, delimiter=None):
--------
>>> read_block(0, 13)
b'Alice, 100\\nBo'
>>> read_block(0, 13, delimiter=b'\\n')
>>> read_block(0, 13, delimiter=b"\\n")
b'Alice, 100\\nBob, 200\\n'

Use ``length=None`` to read to the end of the file.
>>> read_block(0, None, delimiter=b'\\n')
>>> read_block(0, None, delimiter=b"\\n")
b'Alice, 100\\nBob, 200\\nCharlie, 300'

See Also
Expand Down
8 changes: 4 additions & 4 deletions airflow/macros/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ def ds_add(ds: str, days: int) -> str:
:param ds: anchor date in ``YYYY-MM-DD`` format to add to
:param days: number of days to add to the ds, you can use negative values

>>> ds_add('2015-01-01', 5)
>>> ds_add("2015-01-01", 5)
'2015-01-06'
>>> ds_add('2015-01-06', -5)
>>> ds_add("2015-01-06", -5)
'2015-01-01'
"""
if not days:
Expand All @@ -68,9 +68,9 @@ def ds_format(ds: str, input_format: str, output_format: str) -> str:
:param input_format: input string format. E.g. %Y-%m-%d
:param output_format: output string format E.g. %Y-%m-%d

>>> ds_format('2015-01-01', "%Y-%m-%d", "%m-%d-%y")
>>> ds_format("2015-01-01", "%Y-%m-%d", "%m-%d-%y")
'01-01-15'
>>> ds_format('1/5/2015', "%m/%d/%Y", "%Y-%m-%d")
>>> ds_format("1/5/2015", "%m/%d/%Y", "%Y-%m-%d")
'2015-01-05'
"""
return datetime.strptime(str(ds), input_format).strftime(output_format)
Expand Down
18 changes: 3 additions & 15 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,12 +628,7 @@ class derived from this one results in the creation of a task object,
**Example**: to run this task in a specific docker container through
the KubernetesExecutor ::

MyOperator(...,
executor_config={
"KubernetesExecutor":
{"image": "myCustomDockerImage"}
}
)
MyOperator(..., executor_config={"KubernetesExecutor": {"image": "myCustomDockerImage"}})

:param do_xcom_push: if True, an XCom is pushed containing the Operator's
result
Expand Down Expand Up @@ -1152,9 +1147,7 @@ def set_xcomargs_dependencies(self) -> None:
# This is equivalent to
with DAG(...):
generate_content = GenerateContentOperator(task_id="generate_content")
send_email = EmailOperator(
..., html_content="{{ task_instance.xcom_pull('generate_content') }}"
)
send_email = EmailOperator(..., html_content="{{ task_instance.xcom_pull('generate_content') }}")
generate_content >> send_email

"""
Expand Down Expand Up @@ -1866,12 +1859,7 @@ def chain_linear(*elements: DependencyMixin | Sequence[DependencyMixin]):

Then you can accomplish like so::

chain_linear(
op1,
[op2, op3],
[op4, op5, op6],
op7
)
chain_linear(op1, [op2, op3], [op4, op5, op6], op7)

:param elements: a list of operators / lists of operators
"""
Expand Down
7 changes: 4 additions & 3 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,11 +381,12 @@ class DAG(LoggingMixin):

**Example**: to avoid Jinja from removing a trailing newline from template strings ::

DAG(dag_id='my-dag',
DAG(
dag_id="my-dag",
jinja_environment_kwargs={
'keep_trailing_newline': True,
"keep_trailing_newline": True,
# some other jinja2 Environment options here
}
},
)

**See**: `Jinja Environment documentation
Expand Down
4 changes: 4 additions & 0 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3319,19 +3319,23 @@ def get_relevant_upstream_map_indexes(
def this_task(v): # This is self.task.
return v * 2


@task_group
def tg1(inp):
val = upstream(inp) # This is the upstream task.
this_task(val) # When inp is 1, val here should resolve to 2.
return val


# This val is the same object returned by tg1.
val = tg1.expand(inp=[1, 2, 3])


@task_group
def tg2(inp):
another_task(inp, val) # val here should resolve to [2, 4, 6].


tg2.expand(inp=["a", "b"])

The surrounding mapped task groups of ``upstream`` and ``self.task`` are
Expand Down
4 changes: 2 additions & 2 deletions airflow/models/xcom_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ class XComArg(ResolveMixin, DependencyMixin):

xcomarg >> op
xcomarg << op
op >> xcomarg # By BaseOperator code
op << xcomarg # By BaseOperator code
op >> xcomarg # By BaseOperator code
op << xcomarg # By BaseOperator code

**Example**: The moment you get a result from any operator (decorated or regular) you can ::

Expand Down
20 changes: 10 additions & 10 deletions airflow/providers/amazon/aws/operators/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,11 +762,11 @@ class S3ListOperator(BaseOperator):
``customers/2018/04/`` key in the ``data`` bucket. ::

s3_file = S3ListOperator(
task_id='list_3s_files',
bucket='data',
prefix='customers/2018/04/',
delimiter='/',
aws_conn_id='aws_customers_conn'
task_id="list_3s_files",
bucket="data",
prefix="customers/2018/04/",
delimiter="/",
aws_conn_id="aws_customers_conn",
)
"""

Expand Down Expand Up @@ -843,11 +843,11 @@ class S3ListPrefixesOperator(BaseOperator):
from the S3 ``customers/2018/04/`` prefix in the ``data`` bucket. ::

s3_file = S3ListPrefixesOperator(
task_id='list_s3_prefixes',
bucket='data',
prefix='customers/2018/04/',
delimiter='/',
aws_conn_id='aws_customers_conn'
task_id="list_s3_prefixes",
bucket="data",
prefix="customers/2018/04/",
delimiter="/",
aws_conn_id="aws_customers_conn",
)
"""

Expand Down
13 changes: 5 additions & 8 deletions airflow/providers/amazon/aws/operators/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,14 +404,14 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
If you need to create a SageMaker endpoint based on an existed
SageMaker model and an existed SageMaker endpoint config::

config = endpoint_configuration;
config = endpoint_configuration

If you need to create all of SageMaker model, SageMaker endpoint-config and SageMaker endpoint::

config = {
'Model': model_configuration,
'EndpointConfig': endpoint_config_configuration,
'Endpoint': endpoint_configuration
"Model": model_configuration,
"EndpointConfig": endpoint_config_configuration,
"Endpoint": endpoint_configuration,
}

For details of the configuration parameter of model_configuration see
Expand Down Expand Up @@ -579,10 +579,7 @@ class SageMakerTransformOperator(SageMakerBaseOperator):

If you need to create both SageMaker model and SageMaker Transform job::

config = {
'Model': model_config,
'Transform': transform_config
}
config = {"Model": model_config, "Transform": transform_config}

For details of the configuration parameter of transform_config see
:py:meth:`SageMaker.Client.create_transform_job`
Expand Down
6 changes: 1 addition & 5 deletions airflow/providers/apache/cassandra/hooks/cassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,7 @@ class CassandraHook(BaseHook, LoggingMixin):
If SSL is enabled in Cassandra, pass in a dict in the extra field as kwargs for
``ssl.wrap_socket()``. For example::

{
'ssl_options' : {
'ca_certs' : PATH_TO_CA_CERTS
}
}
{"ssl_options": {"ca_certs": PATH_TO_CA_CERTS}}

Default load balancing policy is RoundRobinPolicy. To specify a different
LB policy::
Expand Down
10 changes: 6 additions & 4 deletions airflow/providers/apache/cassandra/sensors/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@ class CassandraRecordSensor(BaseSensorOperator):
primary keys 'p1' and 'p2' to be populated in keyspace 'k' and table 't',
instantiate it as follows:

>>> cassandra_sensor = CassandraRecordSensor(table="k.t",
... keys={"p1": "v1", "p2": "v2"},
... cassandra_conn_id="cassandra_default",
... task_id="cassandra_sensor")
>>> cassandra_sensor = CassandraRecordSensor(
... table="k.t",
... keys={"p1": "v1", "p2": "v2"},
... cassandra_conn_id="cassandra_default",
... task_id="cassandra_sensor",
... )

:param table: Target Cassandra table.
Use dot notation to target a specific keyspace.
Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/apache/cassandra/sensors/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ class CassandraTableSensor(BaseSensorOperator):
For example, if you want to wait for a table called 't' to be created
in a keyspace 'k', instantiate it as follows:

>>> cassandra_sensor = CassandraTableSensor(table="k.t",
... cassandra_conn_id="cassandra_default",
... task_id="cassandra_sensor")
>>> cassandra_sensor = CassandraTableSensor(
... table="k.t", cassandra_conn_id="cassandra_default", task_id="cassandra_sensor"
... )

:param table: Target Cassandra table.
Use dot notation to target a specific keyspace.
Expand Down
23 changes: 11 additions & 12 deletions airflow/providers/apache/hive/hooks/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,12 +582,11 @@ def check_for_partition(self, schema: str, table: str, partition: str) -> bool:

:param schema: Name of hive schema (database) @table belongs to
:param table: Name of hive table @partition belongs to
:param partition: Expression that matches the partitions to check for
(eg `a = 'b' AND c = 'd'`)
:param partition: Expression that matches the partitions to check for (e.g. `a = 'b' AND c = 'd'`)

>>> hh = HiveMetastoreHook()
>>> t = 'static_babynames_partitioned'
>>> hh.check_for_partition('airflow', t, "ds='2015-01-01'")
>>> t = "static_babynames_partitioned"
>>> hh.check_for_partition("airflow", t, "ds='2015-01-01'")
True
"""
with self.metastore as client:
Expand All @@ -606,10 +605,10 @@ def check_for_named_partition(self, schema: str, table: str, partition_name: str
:param partition_name: Name of the partitions to check for (eg `a=b/c=d`)

>>> hh = HiveMetastoreHook()
>>> t = 'static_babynames_partitioned'
>>> hh.check_for_named_partition('airflow', t, "ds=2015-01-01")
>>> t = "static_babynames_partitioned"
>>> hh.check_for_named_partition("airflow", t, "ds=2015-01-01")
True
>>> hh.check_for_named_partition('airflow', t, "ds=xxx")
>>> hh.check_for_named_partition("airflow", t, "ds=xxx")
False
"""
with self.metastore as client:
Expand All @@ -619,7 +618,7 @@ def get_table(self, table_name: str, db: str = "default") -> Any:
"""Get a metastore table object.

>>> hh = HiveMetastoreHook()
>>> t = hh.get_table(db='airflow', table_name='static_babynames')
>>> t = hh.get_table(db="airflow", table_name="static_babynames")
>>> t.tableName
'static_babynames'
>>> [col.name for col in t.sd.cols]
Expand Down Expand Up @@ -649,8 +648,8 @@ def get_partitions(self, schema: str, table_name: str, partition_filter: str | N
For subpartitioned table, the number might easily exceed this.

>>> hh = HiveMetastoreHook()
>>> t = 'static_babynames_partitioned'
>>> parts = hh.get_partitions(schema='airflow', table_name=t)
>>> t = "static_babynames_partitioned"
>>> parts = hh.get_partitions(schema="airflow", table_name=t)
>>> len(parts)
1
>>> parts
Expand Down Expand Up @@ -765,9 +764,9 @@ def table_exists(self, table_name: str, db: str = "default") -> bool:
Check if table exists.

>>> hh = HiveMetastoreHook()
>>> hh.table_exists(db='airflow', table_name='static_babynames')
>>> hh.table_exists(db="airflow", table_name="static_babynames")
True
>>> hh.table_exists(db='airflow', table_name='does_not_exist')
>>> hh.table_exists(db="airflow", table_name="does_not_exist")
False
"""
try:
Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/apache/hive/macros/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def max_partition(
:param field: the field to get the max value from. If there's only
one partition field, this will be inferred

>>> max_partition('airflow.static_babynames_partitioned')
>>> max_partition("airflow.static_babynames_partitioned")
'2015-01-01'
"""
from airflow.providers.apache.hive.hooks.hive import HiveMetastoreHook
Expand Down Expand Up @@ -94,8 +94,8 @@ def closest_ds_partition(
:param metastore_conn_id: which metastore connection to use
:returns: The closest date

>>> tbl = 'airflow.static_babynames_partitioned'
>>> closest_ds_partition(tbl, '2015-01-02')
>>> tbl = "airflow.static_babynames_partitioned"
>>> closest_ds_partition(tbl, "2015-01-02")
'2015-01-01'
"""
from airflow.providers.apache.hive.hooks.hive import HiveMetastoreHook
Expand Down
18 changes: 6 additions & 12 deletions airflow/providers/databricks/operators/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,26 +580,20 @@ class DatabricksRunNowOperator(BaseOperator):
For example ::

json = {
"job_id": 42,
"notebook_params": {
"dry-run": "true",
"oldest-time-to-consider": "1457570074236"
}
"job_id": 42,
"notebook_params": {"dry-run": "true", "oldest-time-to-consider": "1457570074236"},
}

notebook_run = DatabricksRunNowOperator(task_id='notebook_run', json=json)
notebook_run = DatabricksRunNowOperator(task_id="notebook_run", json=json)

Another way to accomplish the same thing is to use the named parameters
of the ``DatabricksRunNowOperator`` directly. Note that there is exactly
one named parameter for each top level parameter in the ``run-now``
endpoint. In this method, your code would look like this: ::

job_id=42
job_id = 42

notebook_params = {
"dry-run": "true",
"oldest-time-to-consider": "1457570074236"
}
notebook_params = {"dry-run": "true", "oldest-time-to-consider": "1457570074236"}

python_params = ["douglas adams", "42"]

Expand All @@ -612,7 +606,7 @@ class DatabricksRunNowOperator(BaseOperator):
notebook_params=notebook_params,
python_params=python_params,
jar_params=jar_params,
spark_submit_params=spark_submit_params
spark_submit_params=spark_submit_params,
)

In the case where both the json parameter **AND** the named parameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2088,9 +2088,9 @@ def oauth_user_info_getter(

@appbuilder.sm.oauth_user_info_getter
def my_oauth_user_info(sm, provider, response=None):
if provider == 'github':
me = sm.oauth_remotes[provider].get('user')
return {'username': me.data.get('login')}
if provider == "github":
me = sm.oauth_remotes[provider].get("user")
return {"username": me.data.get("login")}
return {}
"""

Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/ftp/operators/ftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class FTPFileTransmitOperator(BaseOperator):
remote_filepath="/tmp/tmp1/tmp2/file.txt",
operation="put",
create_intermediate_dirs=True,
dag=dag
dag=dag,
)
"""

Expand Down
Loading