Skip to content

Commit

Permalink
Alex/declarative stream incremental fix (#14268)
Browse files Browse the repository at this point in the history
* checkout files from test branch

* read_incremental works

* reset to master

* remove dead code

* comment

* fix

* Add test

* comments

* utc

* format

* small fix

* Add test with rfc3339

* remove unused param

* fix test
  • Loading branch information
girarda committed Jul 1, 2022
1 parent efc872f commit e23789b
Show file tree
Hide file tree
Showing 22 changed files with 169 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,15 @@ def name(self) -> str:

@property
def state(self) -> MutableMapping[str, Any]:
return self._retriever.get_state()
return self._retriever.state

@state.setter
def state(self, value: MutableMapping[str, Any]):
"""State setter, accept state serialized by state getter."""
self._retriever.state = value

def get_updated_state(self, current_stream_state: MutableMapping[str, Any], latest_record: Mapping[str, Any]):
return self.state

@property
def cursor_field(self) -> Union[str, List[str]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from airbyte_cdk.sources.declarative.interpolation.jinja import JinjaInterpolation

false_values = {"False", "false", "{}", "[]", "()", "", "0", "0.0", "False", "false"}
false_values = ["False", "false", "{}", "[]", "()", "", "0", "0.0", "False", "false", {}, False, [], (), set()]


class InterpolatedBoolean:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#

import ast
import datetime
import numbers

from airbyte_cdk.sources.declarative.interpolation.interpolation import Interpolation
from dateutil import parser
from jinja2 import Environment
from jinja2.exceptions import UndefinedError

Expand All @@ -17,21 +20,33 @@ def __init__(self):
self._environment.globals["now_local"] = datetime.datetime.now
self._environment.globals["now_utc"] = lambda: datetime.datetime.now(datetime.timezone.utc)
self._environment.globals["today_utc"] = lambda: datetime.datetime.now(datetime.timezone.utc).date()
self._environment.globals["timestamp"] = (
lambda dt: int(dt)
if isinstance(dt, numbers.Number)
else int(parser.parse(dt).replace(tzinfo=datetime.timezone.utc).timestamp())
)
self._environment.globals["max"] = lambda a, b: max(a, b)

def eval(self, input_str: str, config, default=None, **kwargs):
context = {"config": config, **kwargs}
try:
if isinstance(input_str, str):
result = self._eval(input_str, context)
if result:
return result
return self._literal_eval(result)
else:
# If input is not a string, return it as is
raise Exception(f"Expected a string. got {input_str}")
except UndefinedError:
pass
# If result is empty or resulted in an undefined error, evaluate and return the default string
return self._eval(default, context)
return self._literal_eval(self._eval(default, context))

def _literal_eval(self, result):
try:
return ast.literal_eval(result)
except (ValueError, SyntaxError):
return result

def _eval(self, s: str, context):
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
from airbyte_cdk.sources.declarative.requesters.paginators.interpolated_paginator import InterpolatedPaginator
from airbyte_cdk.sources.declarative.requesters.paginators.next_page_url_paginator import NextPageUrlPaginator
from airbyte_cdk.sources.declarative.requesters.paginators.offset_paginator import OffsetPaginator
from airbyte_cdk.sources.declarative.stream_slicers.datetime_stream_slicer import DatetimeStreamSlicer
from airbyte_cdk.sources.streams.http.requests_native_auth.token import TokenAuthenticator

CLASS_TYPES_REGISTRY: Mapping[str, Type] = {
"NextPageUrlPaginator": NextPageUrlPaginator,
"InterpolatedPaginator": InterpolatedPaginator,
"OffsetPaginator": OffsetPaginator,
"TokenAuthenticator": TokenAuthenticator,
"DatetimeStreamSlicer": DatetimeStreamSlicer,
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,19 @@
from airbyte_cdk.sources.declarative.extractors.jello import JelloExtractor
from airbyte_cdk.sources.declarative.extractors.record_selector import RecordSelector
from airbyte_cdk.sources.declarative.requesters.http_requester import HttpRequester
from airbyte_cdk.sources.declarative.requesters.paginators.no_pagination import NoPagination
from airbyte_cdk.sources.declarative.requesters.paginators.paginator import Paginator
from airbyte_cdk.sources.declarative.requesters.requester import Requester
from airbyte_cdk.sources.declarative.requesters.retriers.default_retrier import DefaultRetrier
from airbyte_cdk.sources.declarative.requesters.retriers.retrier import Retrier
from airbyte_cdk.sources.declarative.retrievers.retriever import Retriever
from airbyte_cdk.sources.declarative.retrievers.simple_retriever import SimpleRetriever
from airbyte_cdk.sources.declarative.schema.json_schema import JsonSchema
from airbyte_cdk.sources.declarative.schema.schema_loader import SchemaLoader
from airbyte_cdk.sources.declarative.states.dict_state import DictState
from airbyte_cdk.sources.declarative.states.state import State
from airbyte_cdk.sources.declarative.stream_slicers.single_slice import SingleSlice
from airbyte_cdk.sources.declarative.stream_slicers.stream_slicer import StreamSlicer

DEFAULT_IMPLEMENTATIONS_REGISTRY: Mapping[Type, Type] = {
Requester: HttpRequester,
Expand All @@ -29,4 +35,7 @@
Retrier: DefaultRetrier,
Decoder: JsonDecoder,
JelloExtractor: JelloExtractor,
State: DictState,
StreamSlicer: SingleSlice,
Paginator: NoPagination,
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import copy
import importlib
from typing import Any, Mapping, Type, Union, get_type_hints
from typing import Any, Mapping, Type, Union, get_args, get_origin, get_type_hints

from airbyte_cdk.sources.declarative.create_partial import create
from airbyte_cdk.sources.declarative.interpolation.jinja import JinjaInterpolation
Expand Down Expand Up @@ -107,6 +107,13 @@ def is_object_definition_with_type(definition):
def get_default_type(parameter_name, parent_class):
type_hints = get_type_hints(parent_class.__init__)
interface = type_hints.get(parameter_name)
origin = get_origin(interface)
if origin == Union:
# Handling Optional, which are implement as a Union[T, None]
# the interface we're looking for being the first type argument
args = get_args(interface)
interface = args[0]

expected_type = DEFAULT_IMPLEMENTATIONS_REGISTRY.get(interface)
return expected_type

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@ def next_page_token(self, response: requests.Response, last_records: List[Mappin
self._config, decoded_response=decoded_response, headers=headers, last_records=last_records
)

non_null_tokens = {k: v for k, v in interpolated_values.items() if v}
non_null_tokens = {k: v for k, v in interpolated_values.items() if v is not None}

return non_null_tokens if non_null_tokens else None
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@


class OffsetPaginator(Paginator):
def __init__(self, page_size: int, state: DictState, offset_key: str = "offset"):
def __init__(self, page_size: int, state: Optional[DictState] = None, offset_key: str = "offset"):
self._limit = page_size
self._state: DictState = state
self._state = state or DictState()
self._offsetKey = offset_key
self._update_state_with_offset(0)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ def read_records(
def stream_slices(self, *, sync_mode: SyncMode, stream_state: Mapping[str, Any] = None) -> Iterable[Optional[Mapping[str, Any]]]:
pass

@property
@abstractmethod
def get_state(self) -> MutableMapping[str, Any]:
def state(self) -> MutableMapping[str, Any]:
"""State getter, should return state in form that can serialized to a string and send to the output
as a STATE AirbyteMessage.
Expand All @@ -36,3 +37,8 @@ def get_state(self) -> MutableMapping[str, Any]:
State should try to be as small as possible but at the same time descriptive enough to restore
syncing process from the point where it stopped.
"""

@state.setter
@abstractmethod
def state(self, value: MutableMapping[str, Any]):
"""State setter, accept state serialized by state getter."""
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def __init__(
name,
primary_key,
requester: Requester,
paginator: Paginator,
record_selector: HttpSelector,
stream_slicer: Optional[StreamSlicer] = SingleSlice,
paginator: Paginator = None,
stream_slicer: Optional[StreamSlicer] = SingleSlice(),
state: Optional[State] = None,
):
self._name = name
Expand All @@ -34,7 +34,7 @@ def __init__(
self._requester = requester
self._record_selector = record_selector
super().__init__(self._requester.get_authenticator())
self._iterator: StreamSlicer = stream_slicer
self._iterator = stream_slicer
self._state: State = (state or DictState()).deep_copy()
self._last_response = None
self._last_records = None
Expand Down Expand Up @@ -106,7 +106,8 @@ def request_headers(
Specifies request headers.
Authentication headers will overwrite any overlapping headers returned from this method.
"""
return self._requester.request_headers(stream_state, stream_slice, next_page_token)
# Warning: use self.state instead of the stream_state passed as argument!
return self._requester.request_headers(self.state, stream_slice, next_page_token)

def request_body_data(
self,
Expand All @@ -123,7 +124,8 @@ def request_body_data(
At the same time only one of the 'request_body_data' and 'request_body_json' functions can be overridden.
"""
return self._requester.request_body_data(stream_state, stream_slice, next_page_token)
# Warning: use self.state instead of the stream_state passed as argument!
return self._requester.request_body_data(self.state, stream_slice, next_page_token)

def request_body_json(
self,
Expand All @@ -136,7 +138,8 @@ def request_body_json(
At the same time only one of the 'request_body_data' and 'request_body_json' functions can be overridden.
"""
return self._requester.request_body_json(stream_state, stream_slice, next_page_token)
# Warning: use self.state instead of the stream_state passed as argument!
return self._requester.request_body_json(self.state, stream_slice, next_page_token)

def request_kwargs(
self,
Expand All @@ -149,12 +152,13 @@ def request_kwargs(
Any option listed in https://docs.python-requests.org/en/latest/api/#requests.adapters.BaseAdapter.send for can be returned from
this method. Note that these options do not conflict with request-level options such as headers, request params, etc..
"""
return self._requester.request_kwargs(stream_state, stream_slice, next_page_token)
# Warning: use self.state instead of the stream_state passed as argument!
return self._requester.request_kwargs(self.state, stream_slice, next_page_token)

def path(
self, *, stream_state: Mapping[str, Any] = None, stream_slice: Mapping[str, Any] = None, next_page_token: Mapping[str, Any] = None
) -> str:
return self._requester.get_path(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token)
return self._requester.get_path(stream_state=self.state, stream_slice=stream_slice, next_page_token=next_page_token)

def request_params(
self,
Expand All @@ -167,7 +171,8 @@ def request_params(
E.g: you might want to define query parameters for paging if next_page_token is not None.
"""
return self._requester.request_params(stream_state, stream_slice, next_page_token)
# Warning: use self.state instead of the stream_state passed as argument!
return self._requester.request_params(self.state, stream_slice, next_page_token)

@property
def cache_filename(self):
Expand All @@ -191,9 +196,10 @@ def parse_response(
stream_slice: Mapping[str, Any] = None,
next_page_token: Mapping[str, Any] = None,
) -> Iterable[Mapping]:
# Warning: use self.state instead of the stream_state passed as argument!
self._last_response = response
records = self._record_selector.select_records(
response=response, stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token
response=response, stream_state=self.state, stream_slice=stream_slice, next_page_token=next_page_token
)
self._last_records = records
return records
Expand All @@ -219,14 +225,13 @@ def read_records(
stream_slice: Mapping[str, Any] = None,
stream_state: Mapping[str, Any] = None,
) -> Iterable[Mapping[str, Any]]:
records_generator = HttpStream.read_records(self, sync_mode, cursor_field, stream_slice, stream_state)
# Warning: use self.state instead of the stream_state passed as argument!
records_generator = HttpStream.read_records(self, sync_mode, cursor_field, stream_slice, self.state)
for r in records_generator:
self._state.update_state(stream_slice=stream_slice, stream_state=stream_state, last_response=self._last_response, last_record=r)
self._state.update_state(stream_slice=stream_slice, stream_state=self.state, last_response=self._last_response, last_record=r)
yield r
else:
self._state.update_state(
stream_slice=stream_slice, stream_state=stream_state, last_reponse=self._last_response, last_record=None
)
self._state.update_state(stream_slice=stream_slice, stream_state=self.state, last_reponse=self._last_response)
yield from []

def stream_slices(
Expand All @@ -240,8 +245,14 @@ def stream_slices(
:param stream_state:
:return:
"""
# FIXME: this is not passing the cursor field because it is always known at init time
return self._iterator.stream_slices(sync_mode, stream_state)
# Warning: use self.state instead of the stream_state passed as argument!
return self._iterator.stream_slices(sync_mode, self.state)

def get_state(self) -> MutableMapping[str, Any]:
@property
def state(self) -> MutableMapping[str, Any]:
return self._state.get_stream_state()

@state.setter
def state(self, value: MutableMapping[str, Any]):
"""State setter, accept state serialized by state getter."""
self._state.set_state(value)
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#

from enum import Enum
from typing import Mapping, Union
from typing import Mapping

from airbyte_cdk.sources.declarative.interpolation.jinja import JinjaInterpolation
from airbyte_cdk.sources.declarative.states.state import State
Expand All @@ -25,24 +25,19 @@ class StateType(Enum):
class DictState(State):
stream_state_field = "stream_state"

def __init__(self, initial_mapping: Mapping[str, str] = None, state_type: Union[str, StateType, type] = "STR", config=None):
def __init__(self, initial_mapping: Mapping[str, str] = None, config=None):
if initial_mapping is None:
initial_mapping = dict()
if config is None:
config = dict()
self._templates_to_evaluate = initial_mapping
if type(state_type) == str:
self._state_type = StateType[state_type].value
elif type(state_type) == StateType:
self._state_type = state_type.value
elif type(state_type) == type:
self._state_type = state_type
else:
raise Exception(f"Unexpected type for state_type. Got {state_type}")
self._interpolator = JinjaInterpolation()
self._context = dict()
self._config = config

def set_state(self, state):
self._context[self.stream_state_field] = state

def update_state(self, **kwargs):
stream_state = kwargs.get(self.stream_state_field)
prev_stream_state = self.get_stream_state() or stream_state
Expand All @@ -61,7 +56,7 @@ def _compute_state(self, prev_state):
self._interpolator.eval(name, self._config): self._interpolator.eval(value, self._config, **self._context)
for name, value in self._templates_to_evaluate.items()
}
updated_state = {name: self._state_type(value) for name, value in updated_state.items() if value}
updated_state = {name: value for name, value in updated_state.items() if value}

if prev_state:
next_state = {name: _get_max(name=name, val=value, other_state=prev_state) for name, value in updated_state.items()}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ def __init__(
):
self._timezone = datetime.timezone.utc
self._interpolation = JinjaInterpolation()
if isinstance(start_time, str):
start_time = InterpolatedString(start_time)
if isinstance(end_time, str):
end_time = InterpolatedString(end_time)
if isinstance(cursor_value, str):
cursor_value = InterpolatedString(cursor_value)
self._datetime_format = datetime_format
self._start_time = self.parse_date(start_time.eval(config))
self._end_time = self.parse_date(end_time.eval(config))
Expand Down Expand Up @@ -72,6 +78,8 @@ def parse_date(self, date: Any) -> datetime:
return datetime.datetime.fromtimestamp(int(date)).replace(tzinfo=self._timezone)
else:
return datetime.datetime.strptime(date, self._datetime_format).replace(tzinfo=self._timezone)
elif isinstance(date, int):
return datetime.datetime.fromtimestamp(int(date)).replace(tzinfo=self._timezone)
return date

def is_start_date_valid(self, start_date: datetime) -> bool:
Expand Down

0 comments on commit e23789b

Please sign in to comment.