Skip to content

Commit

Permalink
✨ [destination-DuckDB] Improve performance, use pyarrow batch inser…
Browse files Browse the repository at this point in the history
…t as replacement of `executemany` (#36715)
  • Loading branch information
hrl20 committed Apr 22, 2024
1 parent dfc933a commit d4944a3
Show file tree
Hide file tree
Showing 6 changed files with 277 additions and 42 deletions.
Expand Up @@ -9,9 +9,10 @@
import uuid
from collections import defaultdict
from logging import getLogger
from typing import Any, Iterable, Mapping
from typing import Any, Dict, Iterable, List, Mapping

import duckdb
import pyarrow as pa
from airbyte_cdk import AirbyteLogger
from airbyte_cdk.destinations import Destination
from airbyte_cdk.models import AirbyteConnectionStatus, AirbyteMessage, ConfiguredAirbyteCatalog, DestinationSyncMode, Status, Type
Expand Down Expand Up @@ -109,53 +110,58 @@ def write(

con.execute(query)

buffer = defaultdict(list)
buffer = defaultdict(lambda: defaultdict(list))

for message in input_messages:
if message.type == Type.STATE:
# flush the buffer
for stream_name in buffer.keys():
logger.info(f"flushing buffer for state: {message}")
table_name = f"_airbyte_raw_{stream_name}"
query = f"""
INSERT INTO {schema_name}.{table_name}
(_airbyte_ab_id, _airbyte_emitted_at, _airbyte_data)
VALUES (?,?,?)
"""
con.executemany(query, buffer[stream_name])
DestinationDuckdb._safe_write(con=con, buffer=buffer, schema_name=schema_name, stream_name=stream_name)

con.commit()
buffer = defaultdict(list)
buffer = defaultdict(lambda: defaultdict(list))

yield message
elif message.type == Type.RECORD:
data = message.record.data
stream = message.record.stream
if stream not in streams:
logger.debug(f"Stream {stream} was not present in configured streams, skipping")
stream_name = message.record.stream
if stream_name not in streams:
logger.debug(f"Stream {stream_name} was not present in configured streams, skipping")
continue

# add to buffer
buffer[stream].append(
(
str(uuid.uuid4()),
datetime.datetime.now().isoformat(),
json.dumps(data),
)
)
buffer[stream_name]["_airbyte_ab_id"].append(str(uuid.uuid4()))
buffer[stream_name]["_airbyte_emitted_at"].append(datetime.datetime.now().isoformat())
buffer[stream_name]["_airbyte_data"].append(json.dumps(data))

else:
logger.info(f"Message type {message.type} not supported, skipping")

# flush any remaining messages
for stream_name in buffer.keys():
table_name = f"_airbyte_raw_{stream_name}"
DestinationDuckdb._safe_write(con=con, buffer=buffer, schema_name=schema_name, stream_name=stream_name)

@staticmethod
def _safe_write(*, con: duckdb.DuckDBPyConnection, buffer: Dict[str, Dict[str, List[Any]]], schema_name: str, stream_name: str):
table_name = f"_airbyte_raw_{stream_name}"
try:
pa_table = pa.Table.from_pydict(buffer[stream_name])
except:
logger.exception(
f"Writing with pyarrow view failed, falling back to writing with executemany. Expect some performance degradation."
)
query = f"""
INSERT INTO {schema_name}.{table_name}
(_airbyte_ab_id, _airbyte_emitted_at, _airbyte_data)
VALUES (?,?,?)
"""

con.executemany(query, buffer[stream_name])
con.commit()
entries_to_write = buffer[stream_name]
con.executemany(
query, zip(entries_to_write["_airbyte_ab_id"], entries_to_write["_airbyte_emitted_at"], entries_to_write["_airbyte_data"])
)
else:
# DuckDB will automatically find and SELECT from the `pa_table`
# local variable defined above.
con.sql(f"INSERT INTO {schema_name}.{table_name} SELECT * FROM pa_table")

def check(self, logger: AirbyteLogger, config: Mapping[str, Any]) -> AirbyteConnectionStatus:
"""
Expand Down
Expand Up @@ -9,9 +9,10 @@
import random
import string
import tempfile
import time
from datetime import datetime
from pathlib import Path
from typing import Any, Dict
from typing import Any, Callable, Dict, Generator, Iterable
from unittest.mock import MagicMock

import duckdb
Expand All @@ -30,6 +31,7 @@
)
from destination_duckdb import DestinationDuckdb
from destination_duckdb.destination import CONFIG_MOTHERDUCK_API_KEY
from faker import Faker

CONFIG_PATH = "integration_tests/config.json"
SECRETS_CONFIG_PATH = (
Expand Down Expand Up @@ -96,6 +98,12 @@ def test_table_name() -> str:
return f"airbyte_integration_{rand_string}"


@pytest.fixture
def test_large_table_name() -> str:
letters = string.ascii_lowercase
rand_string = "".join(random.choice(letters) for _ in range(10))
return f"airbyte_integration_{rand_string}"

@pytest.fixture
def table_schema() -> str:
schema = {"type": "object", "properties": {"column1": {"type": ["null", "string"]}}}
Expand All @@ -104,7 +112,7 @@ def table_schema() -> str:

@pytest.fixture
def configured_catalogue(
test_table_name: str, table_schema: str
test_table_name: str, test_large_table_name: str, table_schema: str,
) -> ConfiguredAirbyteCatalog:
append_stream = ConfiguredAirbyteStream(
stream=AirbyteStream(
Expand All @@ -115,7 +123,16 @@ def configured_catalogue(
sync_mode=SyncMode.incremental,
destination_sync_mode=DestinationSyncMode.append,
)
return ConfiguredAirbyteCatalog(streams=[append_stream])
append_stream_large = ConfiguredAirbyteStream(
stream=AirbyteStream(
name=test_large_table_name,
json_schema=table_schema,
supported_sync_modes=[SyncMode.full_refresh, SyncMode.incremental],
),
sync_mode=SyncMode.incremental,
destination_sync_mode=DestinationSyncMode.append,
)
return ConfiguredAirbyteCatalog(streams=[append_stream, append_stream_large])


@pytest.fixture
Expand Down Expand Up @@ -206,3 +223,101 @@ def test_write(
assert len(result) == 2
assert result[0][2] == json.dumps(airbyte_message1.record.data)
assert result[1][2] == json.dumps(airbyte_message2.record.data)

def _airbyte_messages(n: int, batch_size: int, table_name: str) -> Generator[AirbyteMessage, None, None]:
fake = Faker()
Faker.seed(0)

for i in range(n):
if i != 0 and i % batch_size == 0:
yield AirbyteMessage(
type=Type.STATE, state=AirbyteStateMessage(data={"state": str(i // batch_size)})
)
else:
message = AirbyteMessage(
type=Type.RECORD,
record=AirbyteRecordMessage(
stream=table_name,
data={"key1": fake.first_name() , "key2": fake.ssn()},
emitted_at=int(datetime.now().timestamp()) * 1000,
),
)
yield message


def _airbyte_messages_with_inconsistent_json_fields(n: int, batch_size: int, table_name: str) -> Generator[AirbyteMessage, None, None]:
fake = Faker()
Faker.seed(0)
random.seed(0)

for i in range(n):
if i != 0 and i % batch_size == 0:
yield AirbyteMessage(
type=Type.STATE, state=AirbyteStateMessage(data={"state": str(i // batch_size)})
)
else:
message = AirbyteMessage(
type=Type.RECORD,
record=AirbyteRecordMessage(
stream=table_name,
# Throw in empty nested objects and see how pyarrow deals with them.
data={"key1": fake.first_name() ,
"key2": fake.ssn() if random.random()< 0.5 else random.randrange(1000,9999999999999),
"nested1": {} if random.random()< 0.1 else {
"key3": fake.first_name() ,
"key4": fake.ssn() if random.random()< 0.5 else random.randrange(1000,9999999999999),
"dictionary1":{} if random.random()< 0.1 else {
"key3": fake.first_name() ,
"key4": "True" if random.random() < 0.5 else True
}
}
}
if random.random() < 0.9 else {},

emitted_at=int(datetime.now().timestamp()) * 1000,
),
)
yield message


TOTAL_RECORDS = 5_000
BATCH_WRITE_SIZE = 1000

@pytest.mark.slow
@pytest.mark.parametrize("airbyte_message_generator,explanation",
[(_airbyte_messages, "Test writing a large number of simple json objects."),
(_airbyte_messages_with_inconsistent_json_fields, "Test writing a large number of json messages with inconsistent schema.")] )
def test_large_number_of_writes(
config: Dict[str, str],
request,
configured_catalogue: ConfiguredAirbyteCatalog,
test_large_table_name: str,
test_schema_name: str,
airbyte_message_generator: Callable[[int, int, str], Iterable[AirbyteMessage]],
explanation: str,
):
destination = DestinationDuckdb()
generator = destination.write(
config,
configured_catalogue,
airbyte_message_generator(TOTAL_RECORDS, BATCH_WRITE_SIZE, test_large_table_name),
)

result = list(generator)
assert len(result) == TOTAL_RECORDS // (BATCH_WRITE_SIZE + 1)
motherduck_api_key = str(config.get(CONFIG_MOTHERDUCK_API_KEY, ""))
duckdb_config = {}
if motherduck_api_key:
duckdb_config["motherduck_token"] = motherduck_api_key
duckdb_config["custom_user_agent"] = "airbyte_intg_test"

con = duckdb.connect(
database=config.get("destination_path"), read_only=False, config=duckdb_config
)
with con:
cursor = con.execute(
"SELECT count(1) "
f"FROM {test_schema_name}._airbyte_raw_{test_large_table_name}"
)
result = cursor.fetchall()
assert result[0][0] == TOTAL_RECORDS - TOTAL_RECORDS // (BATCH_WRITE_SIZE + 1)
Expand Up @@ -4,7 +4,7 @@ data:
connectorSubtype: database
connectorType: destination
definitionId: 94bd199c-2ff0-4aa2-b98e-17f0acb72610
dockerImageTag: 0.3.3
dockerImageTag: 0.3.4
dockerRepository: airbyte/destination-duckdb
githubIssueLabel: destination-duckdb
icon: duckdb.svg
Expand Down

0 comments on commit d4944a3

Please sign in to comment.