Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Multi-yield transformers #396

Merged
merged 4 commits into from
Jan 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -938,8 +938,15 @@ job.launch()


## List of transformers

Transformers are implemented by subclassing [Transformer](https://github.com/amundsen-io/amundsendatabuilder/blob/master/databuilder/transformer/base_transformer.py#L12 "Transformer") and implementing `transform(self, record)`. A transformer can:

- Modify a record and return it,
- Return `None` to filter a record out,
- Yield multiple records. This is useful for e.g. inferring metadata (such as ownership) from table descriptions.

#### [ChainedTransformer](https://github.com/amundsen-io/amundsendatabuilder/blob/master/databuilder/transformer/base_transformer.py#L41 "ChainedTransformer")
A chanined transformer that can take multiple transformer.
A chanined transformer that can take multiple transformers, passing each record through the chain.

#### [RegexStrReplaceTransformer](https://github.com/amundsen-io/amundsendatabuilder/blob/master/databuilder/transformer/regex_str_replace_transformer.py "RegexStrReplaceTransformer")
Generic string replacement transformer using REGEX. User can pass list of tuples where tuple contains regex and replacement pair.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _get_extract_iter(self) -> Iterator[Any]:
if not record:
break # the end.

record = self._transformer.transform(record=record)
record = next(self._transformer.transform(record=record), None)

if not self._is_published_dashboard(record):
continue # filter this one out
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def extract(self) -> Any:
if not record:
return None

return self._transformer.transform(record=record)
return next(self._transformer.transform(record=record), None)

def get_scope(self) -> str:
return 'extractor.tableau_dashboard_metadata'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def extract(self) -> Any:
if not record:
return None

return self._transformer.transform(record=record)
return next(self._transformer.transform(record=record), None)

def get_scope(self) -> str:
return 'extractor.tableau_dashboard_last_modified'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def extract(self) -> Any:
if not record:
return None

return self._transformer.transform(record=record)
return next(self._transformer.transform(record=record), None)

def get_scope(self) -> str:
return 'extractor.tableau_dashboard_query'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def extract(self) -> Any:
if not record:
return None

return self._transformer.transform(record=record)
return next(self._transformer.transform(record=record), None)

def get_scope(self) -> str:
return 'extractor.tableau_dashboard_table'
Expand Down
18 changes: 13 additions & 5 deletions databuilder/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import logging
from typing import Iterator

from pyhocon import ConfigTree

Expand Down Expand Up @@ -48,22 +49,29 @@ def init(self, conf: ConfigTree) -> None:
def run(self) -> None:
"""
Runs a task
:return:
"""
LOGGER.info('Running a task')
try:
record = self.extractor.extract()
count = 1
count = 0
while record:
record = self.transformer.transform(record)
if not record:
# Move on if the transformer filtered the record out
record = self.extractor.extract()
continue
self.loader.load(record)
record = self.extractor.extract()
count += 1

# Support transformers which return one record, or yield multiple
results = record if isinstance(record, Iterator) else [record]
for result in results:
if result:
self.loader.load(result)
count += 1

if count > 0 and count % self._progress_report_frequency == 0:
LOGGER.info(f'Extracted %i records so far', count)

# Prepare the next record
record = self.extractor.extract()
finally:
self._closer.close()
27 changes: 19 additions & 8 deletions databuilder/transformer/base_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import abc
from typing import (
Any, Iterable, Optional,
Any, Iterable, Iterator, List, Optional,
)

from pyhocon import ConfigTree
Expand Down Expand Up @@ -41,7 +41,10 @@ def get_scope(self) -> str:

class ChainedTransformer(Transformer):
"""
A chained transformer that iterates transformers and transforms a record
A chained transformer that iterates transformers and transforms a record.
Transfomers implemented using generator functons can yield multiple records,
which all get passed to the next transformer.
Returning None from a transformer filters the record out.
"""

def __init__(self,
Expand All @@ -56,13 +59,21 @@ def init(self, conf: ConfigTree) -> None:
transformer.init(Scoped.get_scoped_conf(conf, transformer.get_scope()))

def transform(self, record: Any) -> Any:
records = [record]
for t in self.transformers:
record = t.transform(record)
# Check filtered record
if not record:
return None

return record
new_records: List[Any] = []
for r in records:
result = t.transform(r)
# Get all records if the transformer returns an Iterator.
if isinstance(result, Iterator):
new_records += list(result)

# Filter the record if it is None
elif result is not None:
new_records.append(result)
records = new_records

yield from records

def get_scope(self) -> str:
return 'transformer.chained'
Expand Down
156 changes: 156 additions & 0 deletions tests/integration/test_chained_trainsformers_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Copyright Contributors to the Amundsen project.
# SPDX-License-Identifier: Apache-2.0

import unittest
from typing import (
Any, Iterable, List, Optional,
)

from pyhocon import ConfigFactory, ConfigTree

from databuilder.extractor.base_extractor import Extractor
from databuilder.job.job import DefaultJob
from databuilder.loader.base_loader import Loader
from databuilder.models.table_metadata import TableMetadata
from databuilder.models.table_owner import TableOwner
from databuilder.task.task import DefaultTask
from databuilder.transformer.base_transformer import (
ChainedTransformer, NoopTransformer, Transformer,
)

TEST_DATA = [
TableMetadata(
database="db1", schema="schema1", name="table1", cluster="prod", description=""
),
TableMetadata(
database="db2", schema="schema2", name="table2", cluster="prod", description=""
),
]

EXPECTED_OWNERS = [
TableOwner(
db_name="db1",
cluster="prod",
schema="schema1",
table_name="table1",
owners=["foo", "bar"],
),
TableOwner(
db_name="db2",
cluster="prod",
schema="schema2",
table_name="table2",
owners=["foo", "bar"],
),
]


class TestChainedTransformerTask(unittest.TestCase):
def test_multi_yield_task(self) -> None:
""" Test that MultiYieldTask is able to unpack a transformer which yields multiple nodes """

result = _run_transformer(AddFakeOwnerTransformer())

expected = [TEST_DATA[0], EXPECTED_OWNERS[0], TEST_DATA[1], EXPECTED_OWNERS[1]]

self.assertEqual(repr(result), repr(expected))

def test_multi_yield_chained_transformer(self) -> None:
"""
Test that MultiYieldChainedTransformer is able handle both:
- transformers which yield multiple nodes
- transformers which transform single nodes
"""

transformer = ChainedTransformer(
[AddFakeOwnerTransformer(), NoopTransformer(), DuplicateTransformer()]
)

result = _run_transformer(transformer)

expected = [
TEST_DATA[0],
TEST_DATA[0],
EXPECTED_OWNERS[0],
EXPECTED_OWNERS[0],
TEST_DATA[1],
TEST_DATA[1],
EXPECTED_OWNERS[1],
EXPECTED_OWNERS[1],
]

self.assertEqual(repr(result), repr(expected))


class AddFakeOwnerTransformer(Transformer):
""" A transformer which yields the input record, and also a TableOwner """

def init(self, conf: ConfigTree) -> None:
pass

def get_scope(self) -> str:
return "transformer.fake_owner"

def transform(self, record: Any) -> Iterable[Any]:
yield record
if isinstance(record, TableMetadata):
yield TableOwner(
db_name=record.database,
schema=record.schema,
table_name=record.name,
cluster=record.cluster,
owners=["foo", "bar"],
)


class DuplicateTransformer(Transformer):
""" A transformer which yields the input record twice"""

def init(self, conf: ConfigTree) -> None:
pass

def get_scope(self) -> str:
return "transformer.duplicate"

def transform(self, record: Any) -> Iterable[Any]:
yield record
yield record


class ListExtractor(Extractor):
""" An extractor which yields a list of records """

def init(self, conf: ConfigTree) -> None:
self.items = conf.get("items")

def extract(self) -> Optional[Any]:
try:
return self.items.pop(0)
except IndexError:
return None

def get_scope(self) -> str:
return "extractor.test"


class ListLoader(Loader):
""" A loader which appends all records to a list """

def init(self, conf: ConfigTree) -> None:
self.loaded: List[Any] = []

def load(self, record: Any) -> None:
self.loaded.append(record)


def _run_transformer(transformer: Transformer) -> List[Any]:
job_config = ConfigFactory.from_dict({"extractor.test.items": TEST_DATA})

loader = ListLoader()
task = DefaultTask(
extractor=ListExtractor(), transformer=transformer, loader=loader
)
job = DefaultJob(conf=job_config, task=task)

job.launch()
return loader.loaded
50 changes: 34 additions & 16 deletions tests/unit/transformer/test_chained_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,21 @@


class TestChainedTransformer(unittest.TestCase):

def test_init_not_called(self) -> None:

mock_transformer1 = MagicMock()
mock_transformer1.transform.return_value = "foo"
mock_transformer2 = MagicMock()
chained_transformer = ChainedTransformer(transformers=[mock_transformer1, mock_transformer2])
mock_transformer2.transform.return_value = "bar"

chained_transformer = ChainedTransformer(
transformers=[mock_transformer1, mock_transformer2]
)

config = ConfigFactory.from_dict({})
chained_transformer.init(conf=config)

chained_transformer.transform(
{
'foo': 'bar'
}
)
next(chained_transformer.transform({"foo": "bar"}))

mock_transformer1.init.assert_not_called()
mock_transformer1.transform.assert_called_once()
Expand All @@ -34,22 +34,40 @@ def test_init_not_called(self) -> None:
def test_init_called(self) -> None:

mock_transformer1 = MagicMock()
mock_transformer1.get_scope.return_value = 'foo'
mock_transformer1.get_scope.return_value = "foo"
mock_transformer1.transform.return_value = "foo"
mock_transformer2 = MagicMock()
mock_transformer2.get_scope.return_value = 'bar'
chained_transformer = ChainedTransformer(transformers=[mock_transformer1, mock_transformer2],
is_init_transformers=True)
mock_transformer2.get_scope.return_value = "bar"
mock_transformer2.transform.return_value = "bar"

chained_transformer = ChainedTransformer(
transformers=[mock_transformer1, mock_transformer2],
is_init_transformers=True,
)

config = ConfigFactory.from_dict({})
chained_transformer.init(conf=config)

chained_transformer.transform(
{
'foo': 'bar'
}
)
next(chained_transformer.transform({"foo": "bar"}))

mock_transformer1.init.assert_called_once()
mock_transformer1.transform.assert_called_once()
mock_transformer2.init.assert_called_once()
mock_transformer2.transform.assert_called_once()

def test_transformer_transforms(self) -> None:

mock_transformer1 = MagicMock()
mock_transformer1.transform.side_effect = lambda s: s + "b"
mock_transformer2 = MagicMock()
mock_transformer2.transform.side_effect = lambda s: s + "c"

chained_transformer = ChainedTransformer(
transformers=[mock_transformer1, mock_transformer2]
)

config = ConfigFactory.from_dict({})
chained_transformer.init(conf=config)

result = next(chained_transformer.transform("a"))
self.assertEqual(result, "abc")