Skip to content

Commit

Permalink
CDK: Add initial Destination abstraction and tests (#4719)
Browse files Browse the repository at this point in the history
Co-authored-by: Eugene Kulak <widowmakerreborn@gmail.com>
  • Loading branch information
sherifnada and keu committed Jul 13, 2021
1 parent 066db10 commit cb4fe72
Show file tree
Hide file tree
Showing 5 changed files with 346 additions and 4 deletions.
1 change: 0 additions & 1 deletion airbyte-cdk/python/airbyte_cdk/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def __init__(self, spec_string):


class Connector(ABC):

# can be overridden to change an input config
def configure(self, config: Mapping[str, Any], temp_dir: str) -> Mapping[str, Any]:
"""
Expand Down
3 changes: 3 additions & 0 deletions airbyte-cdk/python/airbyte_cdk/destinations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .destination import Destination

__all__ = ["Destination"]
98 changes: 96 additions & 2 deletions airbyte-cdk/python/airbyte_cdk/destinations/destination.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,103 @@
# SOFTWARE.
#

import argparse
import io
import sys
from abc import ABC, abstractmethod
from typing import Any, Iterable, List, Mapping

from airbyte_cdk import AirbyteLogger
from airbyte_cdk.connector import Connector
from airbyte_cdk.models import AirbyteMessage, ConfiguredAirbyteCatalog, Type
from pydantic import ValidationError


class Destination(Connector):
pass # TODO
class Destination(Connector, ABC):
logger = AirbyteLogger()

@abstractmethod
def write(
self, config: Mapping[str, Any], configured_catalog: ConfiguredAirbyteCatalog, input_messages: Iterable[AirbyteMessage]
) -> Iterable[AirbyteMessage]:
"""Implement to define how the connector writes data to the destination"""

def _run_spec(self) -> AirbyteMessage:
return AirbyteMessage(type=Type.SPEC, spec=self.spec(self.logger))

def _run_check(self, config_path: str) -> AirbyteMessage:
config = self.read_config(config_path=config_path)
check_result = self.check(self.logger, config)
return AirbyteMessage(type=Type.CONNECTION_STATUS, connectionStatus=check_result)

def _parse_input_stream(self, input_stream: io.TextIOWrapper) -> Iterable[AirbyteMessage]:
""" Reads from stdin, converting to Airbyte messages"""
for line in input_stream:
try:
yield AirbyteMessage.parse_raw(line)
except ValidationError:
self.logger.info(f"ignoring input which can't be deserialized as Airbyte Message: {line}")

def _run_write(self, config_path: str, configured_catalog_path: str, input_stream: io.TextIOWrapper) -> Iterable[AirbyteMessage]:
config = self.read_config(config_path=config_path)
catalog = ConfiguredAirbyteCatalog.parse_file(configured_catalog_path)
input_messages = self._parse_input_stream(input_stream)
self.logger.info("Begin writing to the destination...")
yield from self.write(config=config, configured_catalog=catalog, input_messages=input_messages)
self.logger.info("Writing complete.")

def parse_args(self, args: List[str]) -> argparse.Namespace:
"""
:param args: commandline arguments
:return:
"""

parent_parser = argparse.ArgumentParser(add_help=False)
main_parser = argparse.ArgumentParser()
subparsers = main_parser.add_subparsers(title="commands", dest="command")

# spec
subparsers.add_parser("spec", help="outputs the json configuration specification", parents=[parent_parser])

# check
check_parser = subparsers.add_parser("check", help="checks the config can be used to connect", parents=[parent_parser])
required_check_parser = check_parser.add_argument_group("required named arguments")
required_check_parser.add_argument("--config", type=str, required=True, help="path to the json configuration file")

# write
write_parser = subparsers.add_parser("write", help="Writes data to the destination", parents=[parent_parser])
write_required = write_parser.add_argument_group("required named arguments")
write_required.add_argument("--config", type=str, required=True, help="path to the JSON configuration file")
write_required.add_argument("--catalog", type=str, required=True, help="path to the configured catalog JSON file")

parsed_args = main_parser.parse_args(args)
cmd = parsed_args.command
if not cmd:
raise Exception("No command entered. ")
elif cmd not in ["spec", "check", "write"]:
# This is technically dead code since parse_args() would fail if this was the case
# But it's non-obvious enough to warrant placing it here anyways
raise Exception(f"Unknown command entered: {cmd}")

return parsed_args

def run_cmd(self, parsed_args: argparse.Namespace) -> Iterable[AirbyteMessage]:
cmd = parsed_args.command
if cmd == "spec":
yield self._run_spec()
elif cmd == "check":
yield self._run_check(config_path=parsed_args.config)
elif cmd == "write":
# Wrap in UTF-8 to override any other input encodings
wrapped_stdin = io.TextIOWrapper(sys.stdin.buffer, encoding="utf-8")
yield from self._run_write(
config_path=parsed_args.config, configured_catalog_path=parsed_args.catalog, input_stream=wrapped_stdin
)
else:
raise Exception(f"Unrecognized command: {cmd}")

def run(self, args: List[str]):
parsed_args = self.parse_args(args)
output_messages = self.run_cmd(parsed_args)
for message in output_messages:
print(message.json(exclude_unset=True))
2 changes: 1 addition & 1 deletion airbyte-cdk/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

setup(
name="airbyte-cdk",
version="0.1.5",
version="0.1.6-rc1",
description="A framework for writing Airbyte Connectors.",
long_description=README,
long_description_content_type="text/markdown",
Expand Down
246 changes: 246 additions & 0 deletions airbyte-cdk/python/unit_tests/destinations/test_destination.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
#
# MIT License
#
# Copyright (c) 2020 Airbyte
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#

import argparse
import io
import json
from os import PathLike
from typing import Any, Dict, Iterable, List, Mapping, Union
from unittest.mock import ANY

import pytest
from airbyte_cdk.destinations import Destination
from airbyte_cdk.models import (
AirbyteCatalog,
AirbyteConnectionStatus,
AirbyteMessage,
AirbyteRecordMessage,
AirbyteStateMessage,
AirbyteStream,
ConfiguredAirbyteCatalog,
ConfiguredAirbyteStream,
ConnectorSpecification,
DestinationSyncMode,
Status,
SyncMode,
Type,
)


@pytest.fixture(name="destination")
def destination_fixture(mocker) -> Destination:
# Wipe the internal list of abstract methods to allow instantiating the abstract class without implementing its abstract methods
mocker.patch("airbyte_cdk.destinations.Destination.__abstractmethods__", set())
# Mypy yells at us because we're init'ing an abstract class
return Destination() # type: ignore


class TestArgParsing:
@pytest.mark.parametrize(
("arg_list", "expected_output"),
[
(["spec"], {"command": "spec"}),
(["check", "--config", "bogus_path/"], {"command": "check", "config": "bogus_path/"}),
(
["write", "--config", "config_path1", "--catalog", "catalog_path1"],
{"command": "write", "config": "config_path1", "catalog": "catalog_path1"},
),
],
)
def test_successful_parse(self, arg_list: List[str], expected_output: Mapping[str, Any], destination: Destination):
parsed_args = vars(destination.parse_args(arg_list))
assert (
parsed_args == expected_output
), f"Expected parsing {arg_list} to return parsed args {expected_output} but instead found {parsed_args}"

@pytest.mark.parametrize(
("arg_list"),
[
# Invalid commands
([]),
(["not-a-real-command"]),
([""]),
# Incorrect parameters
(["spec", "--config", "path"]),
(["check"]),
(["check", "--catalog", "path"]),
(["check", "path"]),
],
)
def test_failed_parse(self, arg_list: List[str], destination: Destination):
# We use BaseException because it encompasses SystemExit (raised by failed parsing) and other exceptions (raised by additional semantic
# checks)
with pytest.raises(BaseException):
destination.parse_args(arg_list)


def _state(state: Dict[str, Any]) -> AirbyteStateMessage:
return AirbyteStateMessage(data=state)


def _record(stream: str, data: Dict[str, Any]) -> AirbyteRecordMessage:
return AirbyteRecordMessage(stream=stream, data=data, emitted_at=0)


def _spec(schema: Dict[str, Any]) -> ConnectorSpecification:
return ConnectorSpecification(connectionSpecification=schema)


def write_file(path: PathLike, content: Union[str, Mapping]):
content = json.dumps(content) if isinstance(content, Mapping) else content
with open(path, "w") as f:
f.write(content)


def _wrapped(
msg: Union[AirbyteRecordMessage, AirbyteStateMessage, AirbyteCatalog, ConnectorSpecification, AirbyteConnectionStatus]
) -> AirbyteMessage:
if isinstance(msg, AirbyteRecordMessage):
return AirbyteMessage(type=Type.RECORD, record=msg)
elif isinstance(msg, AirbyteStateMessage):
return AirbyteMessage(type=Type.STATE, state=msg)
elif isinstance(msg, AirbyteCatalog):
return AirbyteMessage(type=Type.CATALOG, catalog=msg)
elif isinstance(msg, AirbyteConnectionStatus):
return AirbyteMessage(type=Type.CONNECTION_STATUS, connectionStatus=msg)
elif isinstance(msg, ConnectorSpecification):
return AirbyteMessage(type=Type.SPEC, spec=msg)
else:
raise Exception(f"Invalid Airbyte Message: {msg}")


class OrderedIterableMatcher(Iterable):
"""
A class whose purpose is to verify equality of one iterable object against another
in an ordered fashion
"""

def attempt_consume(self, iterator):
try:
return next(iterator)
except StopIteration:
return None

def __iter__(self):
return iter(self.iterable)

def __init__(self, iterable: Iterable):
self.iterable = iterable

def __eq__(self, other):
if not isinstance(other, Iterable):
return False

return list(self) == list(other)


class TestRun:
def test_run_spec(self, mocker, destination: Destination):
args = {"command": "spec"}
parsed_args = argparse.Namespace(**args)

expected_spec = ConnectorSpecification(connectionSpecification={"json_schema": {"prop": "value"}})
mocker.patch.object(destination, "spec", return_value=expected_spec, autospec=True)

spec_message = next(iter(destination.run_cmd(parsed_args)))

# Mypy doesn't understand magicmock so it thinks spec doesn't have assert_called_once attr
destination.spec.assert_called_once() # type: ignore

# verify the output of spec was returned
assert _wrapped(expected_spec) == spec_message

def test_run_check(self, mocker, destination: Destination, tmp_path):
file_path = tmp_path / "config.json"
dummy_config = {"user": "sherif"}
write_file(file_path, dummy_config)
args = {"command": "check", "config": file_path}

parsed_args = argparse.Namespace(**args)
destination.run_cmd(parsed_args)

expected_check_result = AirbyteConnectionStatus(status=Status.SUCCEEDED)
mocker.patch.object(destination, "check", return_value=expected_check_result, autospec=True)

returned_check_result = next(iter(destination.run_cmd(parsed_args)))
# verify method call with the correct params
# Affirm to Mypy that this is indeed a method on this mock
destination.check.assert_called_once() # type: ignore
# Affirm to Mypy that this is indeed a method on this mock
destination.check.assert_called_with(logger=ANY, config=dummy_config) # type: ignore

# verify output was correct
assert _wrapped(expected_check_result) == returned_check_result

def test_run_write(self, mocker, destination: Destination, tmp_path, monkeypatch):
config_path, dummy_config = tmp_path / "config.json", {"user": "sherif"}
write_file(config_path, dummy_config)

dummy_catalog = ConfiguredAirbyteCatalog(
streams=[
ConfiguredAirbyteStream(
stream=AirbyteStream(name="mystream", json_schema={"type": "object"}),
sync_mode=SyncMode.full_refresh,
destination_sync_mode=DestinationSyncMode.overwrite,
)
]
)
catalog_path = tmp_path / "catalog.json"
write_file(catalog_path, dummy_catalog.json(exclude_unset=True))

args = {"command": "write", "config": config_path, "catalog": catalog_path}
parsed_args = argparse.Namespace(**args)

expected_write_result = [_wrapped(_state({"k1": "v1"})), _wrapped(_state({"k2": "v2"}))]
mocker.patch.object(
destination, "write", return_value=iter(expected_write_result), autospec=True # convert to iterator to mimic real usage
)
# mock input is a record followed by some state messages
mocked_input: List[AirbyteMessage] = [_wrapped(_record("s1", {"k1": "v1"})), *expected_write_result]
mocked_stdin_string = "\n".join([record.json(exclude_unset=True) for record in mocked_input])
mocked_stdin_string += "\n add this non-serializable string to verify the destination does not break on malformed input"
mocked_stdin = io.TextIOWrapper(io.BytesIO(bytes(mocked_stdin_string, "utf-8")))

monkeypatch.setattr("sys.stdin", mocked_stdin)

returned_write_result = list(destination.run_cmd(parsed_args))
# verify method call with the correct params
# Affirm to Mypy that call_count is indeed a method on this mock
destination.write.assert_called_once() # type: ignore
# Affirm to Mypy that call_count is indeed a method on this mock
destination.write.assert_called_with( # type: ignore
config=dummy_config,
configured_catalog=dummy_catalog,
# Stdin is internally consumed as a generator so we use a custom matcher
# that iterates over two iterables to check equality
input_messages=OrderedIterableMatcher(mocked_input),
)

# verify output was correct
assert expected_write_result == returned_write_result

@pytest.mark.parametrize("args", [{}, {"command": "fake"}])
def test_run_cmd_with_incorrect_args_fails(self, args, destination: Destination):
with pytest.raises(Exception):
list(destination.run_cmd(parsed_args=argparse.Namespace(**args)))

0 comments on commit cb4fe72

Please sign in to comment.