Skip to content

Commit

Permalink
Event triggers function operator (#17)
Browse files Browse the repository at this point in the history
* trying something out

* actually save the dag

* base commit

* working with event now

* add example DAGs

* allowing function to pick downstream DAG

* Function triggers event listener working

* change error message

* correct file name

* remove unused packages

Co-authored-by: Dylan Storey <dylan.storey@astronomer.io>
  • Loading branch information
TJaniF and dylanbstorey committed Dec 15, 2022
1 parent 8e365bb commit f4e3bb0
Show file tree
Hide file tree
Showing 20 changed files with 307 additions and 49 deletions.
7 changes: 3 additions & 4 deletions airflow_provider_kafka/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
# multiple places, but at this point it's the only workaround if you'd like your
# custom conn type to show up in the Airflow UI.

__version__= "0.1.1"
__version__ = "0.1.1"


def get_provider_info():
return {
"package-name": "airflow-provider-kafka", # Required
"name": "Airflow Provider Kafka", # Required
"description": "Airflow hooks and operators for Kafka", # Required
"versions": [__version__], # Required
}



18 changes: 14 additions & 4 deletions airflow_provider_kafka/hooks/admin_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ class KafkaAdminClientHook(BaseHook):

default_conn_name = "kafka_default"

def __init__(self, kafka_conn_id: Optional[str] = None, config: Optional[Dict[Any, Any]] = None) -> None:
def __init__(
self,
kafka_conn_id: Optional[str] = None,
config: Optional[Dict[Any, Any]] = None,
) -> None:
super().__init__()

self.kafka_conn_id = kafka_conn_id
Expand All @@ -25,10 +29,14 @@ def __init__(self, kafka_conn_id: Optional[str] = None, config: Optional[Dict[An
self.extra_configs = {"bootstrap.servers": conn}

if not (self.config.get("bootstrap.servers", None) or self.kafka_conn_id):
raise AirflowException("One of config['bootsrap.servers'] or kafka_conn_id must be provided.")
raise AirflowException(
"One of config['bootsrap.servers'] or kafka_conn_id must be provided."
)

if self.config.get("bootstrap.servers", None) and self.kafka_conn_id:
raise AirflowException("One of config['bootsrap.servers'] or kafka_conn_id must be provided.")
raise AirflowException(
"One of config['bootsrap.servers'] or kafka_conn_id must be provided."
)

def get_admin_client(self) -> AdminClient:
"""
Expand All @@ -47,7 +55,9 @@ def create_topic(

admin_client = self.get_admin_client()

new_topics = [NewTopic(t[0], num_partitions=t[1], replication_factor=t[2]) for t in topics]
new_topics = [
NewTopic(t[0], num_partitions=t[1], replication_factor=t[2]) for t in topics
]

futures = admin_client.create_topics(new_topics)

Expand Down
8 changes: 6 additions & 2 deletions airflow_provider_kafka/hooks/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,14 @@ def __init__(
)

if not (self.config.get("bootstrap.servers", None) or self.kafka_conn_id):
raise AirflowException("One of config['bootsrap.servers'] or kafka_conn_id must be provided.")
raise AirflowException(
"One of config['bootsrap.servers'] or kafka_conn_id must be provided."
)

if self.config.get("bootstrap.servers", None) and self.kafka_conn_id:
raise AirflowException("One of config['bootsrap.servers'] or kafka_conn_id must be provided.")
raise AirflowException(
"One of config['bootsrap.servers'] or kafka_conn_id must be provided."
)

self.extra_configs = {}
if self.kafka_conn_id:
Expand Down
8 changes: 6 additions & 2 deletions airflow_provider_kafka/hooks/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,14 @@ def __init__(
self.config: Dict[Any, Any] = config or {}

if not (self.config.get("bootstrap.servers", None) or self.kafka_conn_id):
raise AirflowException("One of config['bootstrap.servers'] or kafka_conn_id must be provided.")
raise AirflowException(
"One of config['bootstrap.servers'] or kafka_conn_id must be provided."
)

if self.config.get("bootstrap.servers", None) and self.kafka_conn_id:
raise AirflowException("One of config['bootstrap.servers'] or kafka_conn_id must be provided.")
raise AirflowException(
"One of config['bootstrap.servers'] or kafka_conn_id must be provided."
)

self.extra_configs = {}
if self.kafka_conn_id:
Expand Down
9 changes: 7 additions & 2 deletions airflow_provider_kafka/operators/await_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,13 @@ class AwaitKafkaMessageOperator(BaseOperator):
BLUE = "#ffefeb"
ui_color = BLUE

template_fields = ('topics', 'apply_function', 'apply_function_args', 'apply_function_kwargs')

template_fields = (
"topics",
"apply_function",
"apply_function_args",
"apply_function_kwargs",
)

def __init__(
self,
topics: Sequence[str],
Expand Down
18 changes: 15 additions & 3 deletions airflow_provider_kafka/operators/consume_from_topic.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,13 @@ class ConsumeFromTopicOperator(BaseOperator):

BLUE = "#ffefeb"
ui_color = BLUE
template_fields = ('topics', 'apply_function', 'apply_function_args', 'apply_function_kwargs')
template_fields = (
"topics",
"apply_function",
"apply_function_args",
"apply_function_kwargs",
)

def __init__(
self,
topics: Sequence[str],
Expand Down Expand Up @@ -103,15 +109,21 @@ def execute(self, context) -> Any:
self.apply_function = get_callable(self.apply_function)

apply_callable = self.apply_function
apply_callable = partial(apply_callable, *self.apply_function_args, **self.apply_function_kwargs)
apply_callable = partial(
apply_callable, *self.apply_function_args, **self.apply_function_kwargs
)

messages_left = self.max_messages
messages_processed = 0

while messages_left > 0: # bool(True > 0) == True

if not isinstance(messages_left, bool):
batch_size = self.max_batch_size if messages_left > self.max_batch_size else messages_left
batch_size = (
self.max_batch_size
if messages_left > self.max_batch_size
else messages_left
)
else:
batch_size = self.max_batch_size

Expand Down
120 changes: 120 additions & 0 deletions airflow_provider_kafka/operators/event_triggers_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from typing import Any, Dict, Optional, Sequence

from airflow.models import BaseOperator

from airflow_provider_kafka.triggers.await_message import AwaitMessageTrigger

VALID_COMMIT_CADENCE = {"never", "end_of_batch", "end_of_operator"}


class EventTriggersFunctionOperator(BaseOperator):
"""AwaitKafkaMessageOperator An Airflow operator that defers until a specific message is published to Kafka.
The behavior of the consumer for this trigger is as follows:
- poll the Kafka topics for a message
- if no message returned, sleep
- process the message with provided callable and commit the message offset
- if callable returns any data, raise a TriggerEvent with the return data
- else continue to next message
- return event (as default xcom or specific xcom key)
:param topics: Topics (or topic regex) to use for reading from
:type topics: Sequence[str]
:param apply_function: The functoin to apply to messages to determine if an event occurred. As a dot
notation string.
:type apply_function: str
:param apply_function_args: Arguments to be applied to the processing function, defaults to None
:type apply_function_args: Optional[Sequence[Any]], optional
:param apply_function_kwargs: Key word arguments to be applied to the processing function,, defaults to None
:type apply_function_kwargs: Optional[Dict[Any, Any]], optional
:param kafka_conn_id: The airflow connection storing the Kafka broker address, defaults to None
:type kafka_conn_id: Optional[str], optional
:param kafka_config: the config dictionary for the kafka client (additional information available on the
confluent-python-kafka documentation), defaults to None
:type kafka_config: Optional[Dict[Any, Any]], optional
:param poll_timeout: How long the kafka consumer should wait for a message to arrive from the kafka cluster,
defaults to 1
:type poll_timeout: float, optional
:param poll_interval: How long the kafka consumer should sleep after reaching the end of the Kafka log,
defaults to 5
:type poll_interval: float, optional
:param xcom_push_key: the name of a key to push the returned message to, defaults to None
:type xcom_push_key: _type_, optional
"""

BLUE = "#ffefeb"
ui_color = BLUE

template_fields = (
"topics",
"apply_function",
"apply_function_args",
"apply_function_kwargs",
)

def __init__(
self,
topics: Sequence[str],
apply_function: str,
event_triggered_function: callable,
apply_function_args: Optional[Sequence[Any]] = None,
apply_function_kwargs: Optional[Dict[Any, Any]] = None,
kafka_conn_id: Optional[str] = None,
kafka_config: Optional[Dict[Any, Any]] = None,
poll_timeout: float = 1,
poll_interval: float = 5,
**kwargs: Any,
) -> None:

super().__init__(**kwargs)

self.topics = topics
self.apply_function = apply_function
self.apply_function_args = apply_function_args
self.apply_function_kwargs = apply_function_kwargs
self.kafka_conn_id = kafka_conn_id
self.kafka_config = kafka_config
self.poll_timeout = poll_timeout
self.poll_interval = poll_interval
self.event_triggered_function = event_triggered_function

if not callable(self.event_triggered_function):
raise TypeError(
f"parameter event_triggered_function is expected to be of type callable, got {type(event_triggered_function)}"
)

def execute(self, context, event=None) -> Any:

self.defer(
trigger=AwaitMessageTrigger(
topics=self.topics,
apply_function=self.apply_function,
apply_function_args=self.apply_function_args,
apply_function_kwargs=self.apply_function_kwargs,
kafka_conn_id=self.kafka_conn_id,
kafka_config=self.kafka_config,
poll_timeout=self.poll_timeout,
poll_interval=self.poll_interval,
),
method_name="execute_complete",
)

return event

def execute_complete(self, context, event=None):

self.event_triggered_function(event, **context)

self.defer(
trigger=AwaitMessageTrigger(
topics=self.topics,
apply_function=self.apply_function,
apply_function_args=self.apply_function_args,
apply_function_kwargs=self.apply_function_kwargs,
kafka_conn_id=self.kafka_conn_id,
kafka_config=self.kafka_config,
poll_timeout=self.poll_timeout,
poll_interval=self.poll_interval,
),
method_name="execute_complete",
)
17 changes: 13 additions & 4 deletions airflow_provider_kafka/operators/produce_to_topic.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,13 @@ class ProduceToTopicOperator(BaseOperator):
:type poll_timeout: float, optional
:raises AirflowException: _description_
"""

template_fields = ('topic', 'producer_function', 'producer_function_args', 'producer_function_kwargs')

template_fields = (
"topic",
"producer_function",
"producer_function_args",
"producer_function_kwargs",
)

def __init__(
self,
Expand Down Expand Up @@ -100,12 +105,16 @@ def execute(self, context) -> Any:

producer_callable = self.producer_function
producer_callable = partial(
producer_callable, *self.producer_function_args, **self.producer_function_kwargs
producer_callable,
*self.producer_function_args,
**self.producer_function_kwargs,
)

# For each returned k/v in the callable : publish and flush if needed.
for k, v in producer_callable():
producer.produce(self.topic, key=k, value=v, on_delivery=self.delivery_callback)
producer.produce(
self.topic, key=k, value=v, on_delivery=self.delivery_callback
)
producer.poll(self.poll_timeout)
if self.synchronous:
producer.flush()
Expand Down
4 changes: 3 additions & 1 deletion airflow_provider_kafka/triggers/await_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ async def run(self):
async_commit = sync_to_async(consumer.commit)

processing_call = get_callable(self.apply_function)
processing_call = partial(processing_call, *self.apply_function_args, **self.apply_function_kwargs)
processing_call = partial(
processing_call, *self.apply_function_args, **self.apply_function_kwargs
)
async_message_process = sync_to_async(processing_call)
while True:

Expand Down
4 changes: 2 additions & 2 deletions dev/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
ARG IMAGE_NAME="quay.io/astronomer/ap-airflow:2.2.3"
ARG IMAGE_NAME="quay.io/astronomer/ap-airflow:2.4.3"
FROM ${IMAGE_NAME}

USER root
COPY airflow_provider_kafka ${AIRFLOW_HOME}/airflow_provider_kafka
COPY setup.cfg ${AIRFLOW_HOME}/airflow_provider_kafka/setup.cfg
COPY setup.py ${AIRFLOW_HOME}/airflow_provider_kafka/setup.py

RUN pip install ${AIRFLOW_HOME}/airflow_provider_kafka[dev]
RUN pip install -e "${AIRFLOW_HOME}/airflow_provider_kafka[dev]"
USER astro
5 changes: 2 additions & 3 deletions dev/docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ x-airflow-common:
# In order to add custom dependencies or upgrade provider packages you can use your extended image.
# Comment the image line, place your Dockerfile in the directory where you placed the docker-compose.yaml
# and uncomment the "build" line below, Then run `docker-compose build` to build the images.
image: astronomer-operators-dev
# image: astronomer-operators-dev
build:
context: ..
dockerfile: dev/Dockerfile
args:
IMAGE_NAME: ${IMAGE_NAME:-quay.io/astronomer/ap-airflow:2.2.3}
IMAGE_NAME: ${IMAGE_NAME:-quay.io/astronomer/ap-airflow:2.4.3}
environment:
&airflow-common-env
DB_BACKEND: postgres
Expand All @@ -26,7 +26,6 @@ x-airflow-common:
volumes:
- ../example_dags:/usr/local/airflow/dags
- ./logs:/usr/local/airflow/logs
- ../kafka_provider:/usr/local/airflow/kafka_provider
depends_on:
&airflow-common-depends-on
redis:
Expand Down
Loading

0 comments on commit f4e3bb0

Please sign in to comment.