Skip to content

Commit

Permalink
Add FileTrigger (#29265)
Browse files Browse the repository at this point in the history
Contributes back one of the core Triggers from https://github.com/astronomer/astronomer-providers so that it can be used to create an operator /sensor or used within taskflow API
  • Loading branch information
kaxil committed Jan 31, 2023
1 parent 51d9633 commit 42dd812
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 0 deletions.
73 changes: 73 additions & 0 deletions airflow/triggers/file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import asyncio
import datetime
import os
import typing
from glob import glob
from typing import Any

from airflow.triggers.base import BaseTrigger, TriggerEvent


class FileTrigger(BaseTrigger):
"""
A trigger that fires exactly once after it finds the requested file or folder.
:param filepath: File or folder name (relative to the base path set within the connection), can
be a glob.
:param recursive: when set to ``True``, enables recursive directory matching behavior of
``**`` in glob filepath parameter. Defaults to ``False``.
"""

def __init__(
self,
filepath: str,
recursive: bool = False,
poll_interval: float = 5.0,
):
super().__init__()
self.filepath = filepath
self.recursive = recursive
self.poll_interval = poll_interval

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes FileTrigger arguments and classpath."""
return (
"airflow.triggers.file.FileTrigger",
{
"filepath": self.filepath,
"recursive": self.recursive,
"poll_interval": self.poll_interval,
},
)

async def run(self) -> typing.AsyncIterator["TriggerEvent"]:
"""Loop until the relevant files are found."""
while True:
for path in glob(self.filepath, recursive=self.recursive):
if os.path.isfile(path):
mod_time_f = os.path.getmtime(path)
mod_time = datetime.datetime.fromtimestamp(mod_time_f).strftime("%Y%m%d%H%M%S")
self.log.info("Found File %s last modified: %s", str(path), str(mod_time))
yield TriggerEvent(True)
for _, _, files in os.walk(self.filepath):
if len(files) > 0:
yield TriggerEvent(True)
await asyncio.sleep(self.poll_interval)
64 changes: 64 additions & 0 deletions tests/triggers/test_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import asyncio

import pytest

from airflow.triggers.file import FileTrigger


class TestFileTrigger:
FILE_PATH = "/files/dags/example_async_file.py"

def test_serialization(self):
"""Asserts that the trigger correctly serializes its arguments and classpath."""
trigger = FileTrigger(filepath=self.FILE_PATH, poll_interval=5)
classpath, kwargs = trigger.serialize()
assert classpath == "airflow.triggers.file.FileTrigger"
assert kwargs == {
"filepath": self.FILE_PATH,
"poll_interval": 5,
"recursive": False,
}

@pytest.mark.asyncio
async def test_task_file_trigger(self, tmp_path):
"""Asserts that the trigger only goes off on or after file is found"""
tmp_dir = tmp_path / "test_dir"
tmp_dir.mkdir()
p = tmp_dir / "hello.txt"

trigger = FileTrigger(
filepath=str(p.resolve()),
poll_interval=0.2,
)

task = asyncio.create_task(trigger.run().__anext__())
await asyncio.sleep(0.5)

# It should not have produced a result
assert task.done() is False

p.touch()

await asyncio.sleep(0.5)
assert task.done() is True

# Prevents error when task is destroyed while in "pending" state
asyncio.get_event_loop().stop()

0 comments on commit 42dd812

Please sign in to comment.