Skip to content

Commit

Permalink
CDK: Add base pydantic model for connector config and schemas (#8485)
Browse files Browse the repository at this point in the history
* add base spec model

* fix usage of state_checkpoint_interval in case it is dynamic

* add schema base models, fix spelling, signatures and polishing

Co-authored-by: Eugene Kulak <kulak.eugene@gmail.com>
  • Loading branch information
keu and eugene-kulak committed Dec 7, 2021
1 parent 161a1ee commit aa67604
Show file tree
Hide file tree
Showing 12 changed files with 336 additions and 54 deletions.
3 changes: 3 additions & 0 deletions airbyte-cdk/python/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Changelog

## 0.1.42
Add base pydantic model for connector config and schemas.

## 0.1.41
Fix build error

Expand Down
2 changes: 1 addition & 1 deletion airbyte-cdk/python/airbyte_cdk/destinations/destination.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def run_cmd(self, parsed_args: argparse.Namespace) -> Iterable[AirbyteMessage]:
return
config = self.read_config(config_path=parsed_args.config)
if self.check_config_against_spec or cmd == "check":
check_config_against_spec_or_exit(config, spec, self.logger)
check_config_against_spec_or_exit(config, spec)

if cmd == "check":
yield self._run_check(config=config)
Expand Down
2 changes: 1 addition & 1 deletion airbyte-cdk/python/airbyte_cdk/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def run(self, parsed_args: argparse.Namespace) -> Iterable[str]:
# jsonschema's additionalProperties flag wont fail the validation
config, internal_config = split_config(config)
if self.source.check_config_against_spec or cmd == "check":
check_config_against_spec_or_exit(config, source_spec, self.logger)
check_config_against_spec_or_exit(config, source_spec)
# Put internal flags back to config dict
config.update(internal_config.dict())

Expand Down
2 changes: 1 addition & 1 deletion airbyte-cdk/python/airbyte_cdk/sources/abstract_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ def _read_incremental(
if stream_state:
logger.info(f"Setting state of {stream_name} stream to {stream_state}")

checkpoint_interval = stream_instance.state_checkpoint_interval
slices = stream_instance.stream_slices(
cursor_field=configured_stream.cursor_field, sync_mode=SyncMode.incremental, stream_state=stream_state
)
Expand All @@ -186,6 +185,7 @@ def _read_incremental(
for record_counter, record_data in enumerate(records, start=1):
yield self._as_airbyte_record(stream_name, record_data)
stream_state = stream_instance.get_updated_state(stream_state, record_data)
checkpoint_interval = stream_instance.state_checkpoint_interval
if checkpoint_interval and record_counter % checkpoint_interval == 0:
yield self._checkpoint_state(stream_name, stream_state, connector_state, logger)

Expand Down
65 changes: 65 additions & 0 deletions airbyte-cdk/python/airbyte_cdk/sources/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#
# Copyright (c) 2021 Airbyte, Inc., all rights reserved.
#

from typing import Any, Dict, List, MutableMapping, Optional

from jsonschema import RefResolver
from pydantic import BaseModel


class BaseConfig(BaseModel):
"""Base class for connector spec, adds the following behaviour:
- resolve $ref and replace it with definition
- replace all occurrences of anyOf with oneOf
- drop description
"""

@classmethod
def _rename_key(cls, schema: Any, old_key: str, new_key: str) -> None:
"""Iterate over nested dictionary and replace one key with another. Used to replace anyOf with oneOf. Recursive."
:param schema: schema that will be patched
:param old_key: name of the key to replace
:param new_key: new name of the key
"""
if not isinstance(schema, MutableMapping):
return

for key, value in schema.items():
cls._rename_key(value, old_key, new_key)
if old_key in schema:
schema[new_key] = schema.pop(old_key)

@classmethod
def _expand_refs(cls, schema: Any, ref_resolver: Optional[RefResolver] = None) -> None:
"""Iterate over schema and replace all occurrences of $ref with their definitions. Recursive.
:param schema: schema that will be patched
:param ref_resolver: resolver to get definition from $ref, if None pass it will be instantiated
"""
ref_resolver = ref_resolver or RefResolver.from_schema(schema)

if isinstance(schema, MutableMapping):
if "$ref" in schema:
ref_url = schema.pop("$ref")
_, definition = ref_resolver.resolve(ref_url)
cls._expand_refs(definition, ref_resolver=ref_resolver) # expand refs in definitions as well
schema.update(definition)
else:
for key, value in schema.items():
cls._expand_refs(value, ref_resolver=ref_resolver)
elif isinstance(schema, List):
for value in schema:
cls._expand_refs(value, ref_resolver=ref_resolver)

@classmethod
def schema(cls, **kwargs) -> Dict[str, Any]:
"""We're overriding the schema classmethod to enable some post-processing"""
schema = super().schema(**kwargs)
cls._rename_key(schema, old_key="anyOf", new_key="oneOf") # UI supports only oneOf
cls._expand_refs(schema) # UI and destination doesn't support $ref's
schema.pop("definitions", None) # remove definitions created by $ref
schema.pop("description", None) # description added from the docstring
return schema
90 changes: 44 additions & 46 deletions airbyte-cdk/python/airbyte_cdk/sources/utils/schema_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,49 @@
from typing import Any, ClassVar, Dict, Mapping, Tuple

import jsonref
from airbyte_cdk.logger import AirbyteLogger
from airbyte_cdk.models import ConnectorSpecification
from jsonschema import validate
from jsonschema.exceptions import ValidationError
from pydantic import BaseModel, Field


class JsonFileLoader:
"""
Custom json file loader to resolve references to resources located in "shared" directory.
We need this for compatability with existing schemas cause all of them have references
pointing to shared_schema.json file instead of shared/shared_schema.json
"""

def __init__(self, uri_base: str, shared: str):
self.shared = shared
self.uri_base = uri_base

def __call__(self, uri: str) -> Dict[str, Any]:
uri = uri.replace(self.uri_base, f"{self.uri_base}/{self.shared}/")
return json.load(open(uri))


def resolve_ref_links(obj: Any) -> Dict[str, Any]:
"""
Scan resolved schema and convert jsonref.JsonRef object to JSON serializable dict.
:param obj - jsonschema object with ref field resolved.
:return JSON serializable object with references without external dependencies.
"""
if isinstance(obj, jsonref.JsonRef):
obj = resolve_ref_links(obj.__subject__)
# Omit existing definitions for external resource since
# we dont need it anymore.
obj.pop("definitions", None)
return obj
elif isinstance(obj, dict):
return {k: resolve_ref_links(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [resolve_ref_links(item) for item in obj]
else:
return obj


class ResourceSchemaLoader:
"""JSONSchema loader from package resources"""

Expand All @@ -42,10 +78,8 @@ def get_schema(self, name: str) -> dict:
raise IOError(f"Cannot find file {schema_filename}")
try:
raw_schema = json.loads(raw_file)
except ValueError:
# TODO use proper logging
print(f"Invalid JSON file format for file {schema_filename}")
raise
except ValueError as err:
raise RuntimeError(f"Invalid JSON file format for file {schema_filename}") from err

return self.__resolve_schema_references(raw_schema)

Expand All @@ -57,58 +91,20 @@ def __resolve_schema_references(self, raw_schema: dict) -> dict:
:return JSON serializable object with references without external dependencies.
"""

class JsonFileLoader:
"""
Custom json file loader to resolve references to resources located in "shared" directory.
We need this for compatability with existing schemas cause all of them have references
pointing to shared_schema.json file instead of shared/shared_schema.json
"""

def __init__(self, uri_base: str, shared: str):
self.shared = shared
self.uri_base = uri_base

def __call__(self, uri: str) -> Dict[str, Any]:
uri = uri.replace(self.uri_base, f"{self.uri_base}/{self.shared}/")
return json.load(open(uri))

package = importlib.import_module(self.package_name)
base = os.path.dirname(package.__file__) + "/"

def resolve_ref_links(obj: Any) -> Dict[str, Any]:
"""
Scan resolved schema and convert jsonref.JsonRef object to JSON
serializable dict.
:param obj - jsonschema object with ref field resovled.
:return JSON serializable object with references without external dependencies.
"""
if isinstance(obj, jsonref.JsonRef):
obj = resolve_ref_links(obj.__subject__)
# Omit existance definitions for extenal resource since
# we dont need it anymore.
obj.pop("definitions", None)
return obj
elif isinstance(obj, dict):
return {k: resolve_ref_links(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [resolve_ref_links(item) for item in obj]
else:
return obj

resolved = jsonref.JsonRef.replace_refs(raw_schema, loader=JsonFileLoader(base, "schemas/shared"), base_uri=base)
resolved = resolve_ref_links(resolved)
return resolved


def check_config_against_spec_or_exit(config: Mapping[str, Any], spec: ConnectorSpecification, logger: AirbyteLogger):
def check_config_against_spec_or_exit(config: Mapping[str, Any], spec: ConnectorSpecification):
"""
Check config object against spec. In case of spec is invalid, throws
an exception with validation error description.
:param config - config loaded from file specified over command line
:param spec - spec object generated by connector
:param logger - Airbyte logger for reporting validation error
"""
spec_schema = spec.connectionSpecification
try:
Expand All @@ -122,8 +118,10 @@ class InternalConfig(BaseModel):
limit: int = Field(None, alias="_limit")
page_size: int = Field(None, alias="_page_size")

def dict(self):
return super().dict(by_alias=True, exclude_unset=True)
def dict(self, *args, **kwargs):
kwargs["by_alias"] = True
kwargs["exclude_unset"] = True
return super().dict(*args, **kwargs)


def split_config(config: Mapping[str, Any]) -> Tuple[dict, InternalConfig]:
Expand Down
76 changes: 76 additions & 0 deletions airbyte-cdk/python/airbyte_cdk/sources/utils/schema_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#
# Copyright (c) 2021 Airbyte, Inc., all rights reserved.
#

from typing import Any, Dict, Optional, Type

from pydantic import BaseModel, Extra
from pydantic.main import ModelMetaclass
from pydantic.typing import resolve_annotations


class AllOptional(ModelMetaclass):
"""
Metaclass for marking all Pydantic model fields as Optional
Here is example of declaring model using this metaclass like:
'''
class MyModel(BaseModel, metaclass=AllOptional):
a: str
b: str
'''
it is an equivalent of:
'''
class MyModel(BaseModel):
a: Optional[str]
b: Optional[str]
'''
It would make code more clear and eliminate a lot of manual work.
"""

def __new__(mcs, name, bases, namespaces, **kwargs):
"""
Iterate through fields and wrap then with typing.Optional type.
"""
annotations = resolve_annotations(namespaces.get("__annotations__", {}), namespaces.get("__module__", None))
for base in bases:
annotations = {**annotations, **getattr(base, "__annotations__", {})}
for field in annotations:
if not field.startswith("__"):
annotations[field] = Optional[annotations[field]]
namespaces["__annotations__"] = annotations
return super().__new__(mcs, name, bases, namespaces, **kwargs)


class BaseSchemaModel(BaseModel):
"""
Base class for all schema models. It has some extra schema postprocessing.
Can be used in combination with AllOptional metaclass
"""

class Config:
extra = Extra.allow

@classmethod
def schema_extra(cls, schema: Dict[str, Any], model: Type[BaseModel]) -> None:
"""Modify generated jsonschema, remove "title", "description" and "required" fields.
Pydantic doesn't treat Union[None, Any] type correctly when generate jsonschema,
so we can't set field as nullable (i.e. field that can have either null and non-null values),
We generate this jsonschema value manually.
:param schema: generated jsonschema
:param model:
"""
schema.pop("title", None)
schema.pop("description", None)
schema.pop("required", None)
for name, prop in schema.get("properties", {}).items():
prop.pop("title", None)
prop.pop("description", None)
allow_none = model.__fields__[name].allow_none
if allow_none:
if "type" in prop:
prop["type"] = ["null", prop["type"]]
elif "$ref" in prop:
ref = prop.pop("$ref")
prop["oneOf"] = [{"type": "null"}, {"$ref": ref}]
2 changes: 1 addition & 1 deletion airbyte-cdk/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

setup(
name="airbyte-cdk",
version="0.1.41",
version="0.1.42",
description="A framework for writing Airbyte Connectors.",
long_description=README,
long_description_content_type="text/markdown",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def test_run_check(self, mocker, destination: Destination, tmp_path):
# Affirm to Mypy that this is indeed a method on this mock
destination.check.assert_called_with(logger=ANY, config=dummy_config) # type: ignore
# Check if config validation has been called
validate_mock.assert_called_with(dummy_config, spec_msg, destination.logger)
validate_mock.assert_called_with(dummy_config, spec_msg)

# verify output was correct
assert _wrapped(expected_check_result) == returned_check_result
Expand Down Expand Up @@ -224,7 +224,7 @@ def test_run_write(self, mocker, destination: Destination, tmp_path, monkeypatch
input_messages=OrderedIterableMatcher(mocked_input),
)
# Check if config validation has been called
validate_mock.assert_called_with(dummy_config, spec_msg, destination.logger)
validate_mock.assert_called_with(dummy_config, spec_msg)

# verify output was correct
assert expected_write_result == returned_write_result
Expand Down
Loading

0 comments on commit aa67604

Please sign in to comment.