From 7824a2259f45929698bce4daa5d48d57aef46c4b Mon Sep 17 00:00:00 2001 From: Arun Sharma Date: Thu, 21 May 2026 18:56:44 -0700 Subject: [PATCH] Add native ADBC extension --- CMakeLists.txt | 1 + adbc/CMakeLists.txt | 29 +++ adbc/src/catalog/CMakeLists.txt | 8 + adbc/src/catalog/adbc_catalog.cpp | 53 ++++ adbc/src/catalog/adbc_table_catalog_entry.cpp | 59 +++++ adbc/src/connector/CMakeLists.txt | 7 + adbc/src/connector/adbc_connector.cpp | 235 ++++++++++++++++++ adbc/src/function/CMakeLists.txt | 11 + adbc/src/function/adbc_scan.cpp | 148 +++++++++++ adbc/src/include/catalog/adbc_catalog.h | 28 +++ .../catalog/adbc_table_catalog_entry.h | 30 +++ adbc/src/include/connector/adbc_connector.h | 79 ++++++ adbc/src/include/function/adbc_scan.h | 48 ++++ adbc/src/include/main/adbc_extension.h | 17 ++ adbc/src/include/storage/adbc_storage.h | 21 ++ .../include/storage/attached_adbc_database.h | 31 +++ adbc/src/main/CMakeLists.txt | 7 + adbc/src/main/adbc_extension.cpp | 35 +++ adbc/src/storage/CMakeLists.txt | 7 + adbc/src/storage/adbc_storage.cpp | 43 ++++ adbc/test/test_files/adbc.test | 17 ++ extension_config.cmake | 2 +- 22 files changed, 915 insertions(+), 1 deletion(-) create mode 100644 adbc/CMakeLists.txt create mode 100644 adbc/src/catalog/CMakeLists.txt create mode 100644 adbc/src/catalog/adbc_catalog.cpp create mode 100644 adbc/src/catalog/adbc_table_catalog_entry.cpp create mode 100644 adbc/src/connector/CMakeLists.txt create mode 100644 adbc/src/connector/adbc_connector.cpp create mode 100644 adbc/src/function/CMakeLists.txt create mode 100644 adbc/src/function/adbc_scan.cpp create mode 100644 adbc/src/include/catalog/adbc_catalog.h create mode 100644 adbc/src/include/catalog/adbc_table_catalog_entry.h create mode 100644 adbc/src/include/connector/adbc_connector.h create mode 100644 adbc/src/include/function/adbc_scan.h create mode 100644 adbc/src/include/main/adbc_extension.h create mode 100644 adbc/src/include/storage/adbc_storage.h create mode 100644 adbc/src/include/storage/attached_adbc_database.h create mode 100644 adbc/src/main/CMakeLists.txt create mode 100644 adbc/src/main/adbc_extension.cpp create mode 100644 adbc/src/storage/CMakeLists.txt create mode 100644 adbc/src/storage/adbc_storage.cpp create mode 100644 adbc/test/test_files/adbc.test diff --git a/CMakeLists.txt b/CMakeLists.txt index 85d2006..43e853e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -79,6 +79,7 @@ function(add_extension_if_enabled_and_skip_32bit extension) endfunction() add_extension_if_enabled_and_skip_32bit("duckdb") +add_extension_if_enabled_and_skip_32bit("adbc") add_extension_if_enabled_and_skip_32bit("postgres") add_extension_if_enabled_and_skip_32bit("sqlite") add_extension_if_enabled_and_skip_32bit("delta") diff --git a/adbc/CMakeLists.txt b/adbc/CMakeLists.txt new file mode 100644 index 0000000..cc11d0d --- /dev/null +++ b/adbc/CMakeLists.txt @@ -0,0 +1,29 @@ +find_path(ADBC_INCLUDE_DIR + NAMES arrow-adbc/adbc.h + HINTS "$ENV{CONDA_PREFIX}/include") + +find_library(ADBC_DRIVER_MANAGER_LIBRARY + NAMES adbc_driver_manager + HINTS "$ENV{CONDA_PREFIX}/lib") + +if (NOT ADBC_INCLUDE_DIR OR NOT ADBC_DRIVER_MANAGER_LIBRARY) + message(FATAL_ERROR "ADBC extension requires libadbc-driver-manager. Install the C++ ADBC dependencies with pixi/conda.") +endif () + +include_directories( + src/include + ${CMAKE_BINARY_DIR}/src/include + ${PROJECT_SOURCE_DIR}/src/include + ${ADBC_INCLUDE_DIR}) + +add_subdirectory(src/connector) +add_subdirectory(src/function) +add_subdirectory(src/catalog) +add_subdirectory(src/storage) +add_subdirectory(src/main) + +build_extension_lib(${BUILD_STATIC_EXTENSION} "adbc") + +target_link_libraries(lbug_${EXTENSION_LIB_NAME}_extension + PRIVATE + ${ADBC_DRIVER_MANAGER_LIBRARY}) diff --git a/adbc/src/catalog/CMakeLists.txt b/adbc/src/catalog/CMakeLists.txt new file mode 100644 index 0000000..3f693fc --- /dev/null +++ b/adbc/src/catalog/CMakeLists.txt @@ -0,0 +1,8 @@ +add_library(lbug_adbc_catalog + OBJECT + adbc_catalog.cpp + adbc_table_catalog_entry.cpp) + +set(ADBC_EXTENSION_OBJECT_FILES + ${ADBC_EXTENSION_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/adbc/src/catalog/adbc_catalog.cpp b/adbc/src/catalog/adbc_catalog.cpp new file mode 100644 index 0000000..4795560 --- /dev/null +++ b/adbc/src/catalog/adbc_catalog.cpp @@ -0,0 +1,53 @@ +#include "catalog/adbc_catalog.h" + +#include "catalog/adbc_table_catalog_entry.h" +#include "catalog/catalog_entry/node_table_catalog_entry.h" +#include "function/adbc_scan.h" +#include "main/client_context.h" +#include "main/database.h" +#include "storage/storage_manager.h" +#include "transaction/transaction.h" + +namespace lbug { +namespace adbc_extension { + +void ADBCCatalog::init() { + for (auto& tableName : connector.getTableNames()) { + createForeignTable(tableName); + } +} + +void ADBCCatalog::createForeignTable(const std::string& tableName) { + auto columns = connector.getTableSchema(schemaName, tableName); + std::vector columnNames; + std::vector columnTypes; + for (auto& [name, type] : columns) { + columnNames.push_back(name); + columnTypes.push_back(type.copy()); + } + auto scanInfo = std::make_shared(tableName, columnNames, + copyVector(columnTypes), connector); + auto attachedEntry = std::make_unique(tableName, + getADBCScanFunction(scanInfo), scanInfo); + for (auto i = 0u; i < columnNames.size(); i++) { + attachedEntry->addProperty(binder::PropertyDefinition{ + binder::ColumnDefinition{columnNames[i], columnTypes[i].copy()}}); + } + auto attachedEntryPtr = attachedEntry.get(); + tables->createEntry(&transaction::DUMMY_TRANSACTION, std::move(attachedEntry)); + auto primaryKeyName = columnNames[0]; + auto mainTableEntry = std::make_unique(tableName, + primaryKeyName, tableName, catalog::ShadowTag{}); + for (auto i = 0u; i < columnNames.size(); i++) { + mainTableEntry->addProperty(binder::PropertyDefinition{ + binder::ColumnDefinition{columnNames[i], columnTypes[i].copy()}}); + } + mainTableEntry->setReferencedEntry(attachedEntryPtr); + context->getDatabase()->getCatalog()->addTableEntry(std::move(mainTableEntry)); + auto mainEntry = context->getDatabase()->getCatalog()->getTableCatalogEntry( + &transaction::DUMMY_TRANSACTION, tableName); + lbug::storage::StorageManager::Get(*context)->createTable(mainEntry); +} + +} // namespace adbc_extension +} // namespace lbug diff --git a/adbc/src/catalog/adbc_table_catalog_entry.cpp b/adbc/src/catalog/adbc_table_catalog_entry.cpp new file mode 100644 index 0000000..3120066 --- /dev/null +++ b/adbc/src/catalog/adbc_table_catalog_entry.cpp @@ -0,0 +1,59 @@ +#include "catalog/adbc_table_catalog_entry.h" + +#include "binder/bound_scan_source.h" +#include "binder/expression/variable_expression.h" +#include "common/constants.h" + +namespace lbug { +namespace catalog { + +ADBCTableCatalogEntry::ADBCTableCatalogEntry(std::string name, + std::optional scanFunction, + std::shared_ptr scanInfo) + : TableCatalogEntry{CatalogEntryType::FOREIGN_TABLE_ENTRY, std::move(name)}, + scanFunction{std::move(scanFunction)}, scanInfo{std::move(scanInfo)} {} + +common::TableType ADBCTableCatalogEntry::getTableType() const { + return common::TableType::FOREIGN; +} + +std::unique_ptr ADBCTableCatalogEntry::getBoundScanInfo( + main::ClientContext* /*context*/, const std::string& nodeUniqueName) { + binder::expression_vector columns; + std::vector scanColumnNames; + std::vector scanColumnTypes; + if (!nodeUniqueName.empty()) { + auto idUniqueName = nodeUniqueName + "." + std::string(common::InternalKeyword::ID); + columns.push_back(std::make_shared(common::LogicalType::INT64(), + idUniqueName, scanInfo->columnNames[0])); + scanColumnNames.push_back(scanInfo->columnNames[0]); + scanColumnTypes.push_back(common::LogicalType::INT64()); + } + for (auto i = 0u; i < scanInfo->columnNames.size(); i++) { + auto uniqueName = nodeUniqueName.empty() ? scanInfo->columnNames[i] : + nodeUniqueName + "." + scanInfo->columnNames[i]; + columns.push_back(std::make_shared( + scanInfo->columnTypes[i].copy(), uniqueName, scanInfo->columnNames[i])); + scanColumnNames.push_back(scanInfo->columnNames[i]); + scanColumnTypes.push_back(scanInfo->columnTypes[i].copy()); + } + auto boundScanInfo = std::make_shared(scanInfo->tableName, + std::move(scanColumnNames), std::move(scanColumnTypes), scanInfo->connector); + auto bindData = std::make_unique(std::move(boundScanInfo), + std::move(columns)); + return std::make_unique(scanFunction, std::move(bindData)); +} + +std::unique_ptr ADBCTableCatalogEntry::copy() const { + auto other = std::make_unique(name, scanFunction, scanInfo); + other->copyFrom(*this); + return other; +} + +std::unique_ptr +ADBCTableCatalogEntry::getBoundExtraCreateInfo(transaction::Transaction*) const { + UNREACHABLE_CODE; +} + +} // namespace catalog +} // namespace lbug diff --git a/adbc/src/connector/CMakeLists.txt b/adbc/src/connector/CMakeLists.txt new file mode 100644 index 0000000..791e5ef --- /dev/null +++ b/adbc/src/connector/CMakeLists.txt @@ -0,0 +1,7 @@ +add_library(lbug_adbc_connector + OBJECT + adbc_connector.cpp) + +set(ADBC_EXTENSION_OBJECT_FILES + ${ADBC_EXTENSION_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/adbc/src/connector/adbc_connector.cpp b/adbc/src/connector/adbc_connector.cpp new file mode 100644 index 0000000..138dd18 --- /dev/null +++ b/adbc/src/connector/adbc_connector.cpp @@ -0,0 +1,235 @@ +#include "connector/adbc_connector.h" + +#include + +#include "common/arrow/arrow_converter.h" +#include "common/exception/runtime.h" +#include "common/string_utils.h" +#include + +namespace lbug { +namespace adbc_extension { + +static constexpr const char* DRIVER_OPTION = "DRIVER"; +static constexpr const char* TABLES_OPTION = "TABLES"; +static constexpr const char* SCHEMA_OPTION = "SCHEMA"; + +static std::vector splitCommaSeparated(const std::string& input) { + std::vector result; + std::stringstream ss{input}; + std::string item; + while (std::getline(ss, item, ',')) { + item = common::StringUtils::ltrim(common::StringUtils::rtrim(item)); + if (!item.empty()) { + result.push_back(std::move(item)); + } + } + return result; +} + +ADBCQueryResult::~ADBCQueryResult() { + if (stream.release) { + stream.release(&stream); + } + AdbcError error{}; + if (statement.private_data) { + AdbcStatementRelease(&statement, &error); + } + if (connectionInitialized) { + AdbcConnectionRelease(&connection, &error); + } + if (error.release) { + error.release(&error); + } +} + +ADBCConnector::ADBCConnector(const binder::AttachOption& attachOption) + : attachOption{attachOption} {} + +ADBCConnector::~ADBCConnector() { + if (connectionInitialized) { + AdbcConnectionRelease(&connection, &error); + } + if (databaseInitialized) { + AdbcDatabaseRelease(&database, &error); + } + if (error.release) { + error.release(&error); + } +} + +bool ADBCConnector::hasOption(const std::string& key) const { + return attachOption.options.contains(key); +} + +std::string ADBCConnector::getStringOption(const std::string& key, + const std::string& defaultValue) const { + if (!hasOption(key)) { + return defaultValue; + } + const auto& value = attachOption.options.at(key); + if (value.getDataType().getLogicalTypeID() != common::LogicalTypeID::STRING) { + throw common::RuntimeException{std::format("Invalid option value for {}", key)}; + } + return value.getValue(); +} + +void ADBCConnector::checkStatus(AdbcStatusCode status, const std::string& operation) const { + if (status == ADBC_STATUS_OK) { + return; + } + std::string message = error.message == nullptr ? "unknown ADBC error" : error.message; + throw common::RuntimeException{std::format("{} failed: {}", operation, message)}; +} + +void ADBCConnector::connect(const std::string& uri) { + if (!hasOption(DRIVER_OPTION)) { + throw common::RuntimeException{"ADBC attach requires a DRIVER option."}; + } + checkStatus(AdbcDatabaseNew(&database, &error), "AdbcDatabaseNew"); + databaseInitialized = true; + auto driverName = getStringOption(DRIVER_OPTION); + checkStatus(AdbcDatabaseSetOption(&database, "driver", driverName.c_str(), &error), + "AdbcDatabaseSetOption(driver)"); + checkStatus(AdbcDriverManagerDatabaseSetLoadFlags(&database, ADBC_LOAD_FLAG_DEFAULT, &error), + "AdbcDriverManagerDatabaseSetLoadFlags"); + if (!uri.empty()) { + auto uriOption = + uri.find("://") != std::string::npos || uri.starts_with("file:") ? "uri" : "path"; + checkStatus(AdbcDatabaseSetOption(&database, uriOption, uri.c_str(), &error), + std::format("AdbcDatabaseSetOption({})", uriOption)); + } + for (auto& [key, value] : attachOption.options) { + auto upperKey = common::StringUtils::getUpper(key); + if (upperKey == DRIVER_OPTION || upperKey == TABLES_OPTION || upperKey == SCHEMA_OPTION) { + continue; + } + if (value.getDataType().getLogicalTypeID() != common::LogicalTypeID::STRING) { + throw common::RuntimeException{std::format("Invalid option value for {}", key)}; + } + auto stringValue = value.getValue(); + checkStatus(AdbcDatabaseSetOption(&database, key.c_str(), stringValue.c_str(), &error), + std::format("AdbcDatabaseSetOption({})", key)); + } + checkStatus(AdbcDatabaseInit(&database, &error), "AdbcDatabaseInit"); + checkStatus(AdbcConnectionNew(&connection, &error), "AdbcConnectionNew"); + connectionInitialized = true; + checkStatus(AdbcConnectionInit(&connection, &database, &error), "AdbcConnectionInit"); +} + +std::vector ADBCConnector::getTableNames() const { + auto tables = getStringOption(TABLES_OPTION); + if (tables.empty()) { + throw common::RuntimeException{ + "ADBC attach currently requires TABLES='table1,table2,...' for table discovery."}; + } + return splitCommaSeparated(tables); +} + +std::vector> ADBCConnector::getTableSchema( + const std::string& schemaName, const std::string& tableName) const { + std::lock_guard lock{mtx}; + if (schemaCache.contains(tableName)) { + std::vector> cachedResult; + cachedResult.reserve(schemaCache.at(tableName).size()); + for (auto& [name, type] : schemaCache.at(tableName)) { + cachedResult.emplace_back(name, type.copy()); + } + return cachedResult; + } + ArrowSchemaWrapper schema; + checkStatus(AdbcConnectionGetTableSchema(&connection, nullptr, + schemaName.empty() ? nullptr : schemaName.c_str(), tableName.c_str(), &schema, + &error), + std::format("AdbcConnectionGetTableSchema({})", tableName)); + std::vector> result; + result.reserve(schema.n_children); + for (auto i = 0; i < schema.n_children; i++) { + auto child = schema.children[i]; + result.emplace_back(child->name == nullptr ? std::format("column{}", i) : child->name, + common::ArrowConverter::fromArrowSchema(child)); + } + std::vector> cachedResult; + cachedResult.reserve(result.size()); + for (auto& [name, type] : result) { + cachedResult.emplace_back(name, type.copy()); + } + schemaCache.emplace(tableName, std::move(cachedResult)); + return result; +} + +std::unique_ptr ADBCConnector::executeQuery(const std::string& query, + const std::vector& /*columnNames*/, + const std::vector& /*columnTypes*/) const { + auto result = std::make_unique(); + AdbcError queryError{}; + auto checkQueryStatus = [&queryError](AdbcStatusCode status, const std::string& operation) { + if (status == ADBC_STATUS_OK) { + return; + } + std::string message = + queryError.message == nullptr ? "unknown ADBC error" : queryError.message; + if (queryError.release) { + queryError.release(&queryError); + } + throw common::RuntimeException{std::format("{} failed: {}", operation, message)}; + }; + { + std::lock_guard lock{mtx}; + checkQueryStatus(AdbcConnectionNew(&result->connection, &queryError), "AdbcConnectionNew"); + result->connectionInitialized = true; + checkQueryStatus(AdbcConnectionInit(&result->connection, + const_cast(&database), &queryError), + "AdbcConnectionInit"); + } + checkQueryStatus(AdbcStatementNew(&result->connection, &result->statement, &queryError), + "AdbcStatementNew"); + checkQueryStatus(AdbcStatementSetSqlQuery(&result->statement, query.c_str(), &queryError), + "AdbcStatementSetSqlQuery"); + int64_t rowsAffected = 0; + checkQueryStatus( + AdbcStatementExecuteQuery(&result->statement, &result->stream, &rowsAffected, &queryError), + "AdbcStatementExecuteQuery"); + if (result->stream.get_schema(&result->stream, &result->schema) != 0) { + auto streamError = result->stream.get_last_error == nullptr ? + "unknown Arrow stream error" : + result->stream.get_last_error(&result->stream); + if (queryError.release) { + queryError.release(&queryError); + } + throw common::RuntimeException{ + std::format("ArrowArrayStream.get_schema failed: {}", streamError)}; + } + // Fully consume the ADBC stream before returning control to the execution engine. Some drivers + // keep query state pending while the stream is open, and later catalog calls can close it. + while (true) { + ArrowArrayWrapper array; + if (result->stream.get_next(&result->stream, &array) != 0) { + auto streamError = result->stream.get_last_error == nullptr ? + "unknown Arrow stream error" : + result->stream.get_last_error(&result->stream); + if (queryError.release) { + queryError.release(&queryError); + } + throw common::RuntimeException{ + std::format("ArrowArrayStream.get_next failed after {} batches: {}", + result->arrays.size(), streamError)}; + } + if (array.release == nullptr) { + break; + } + if (array.length > 0) { + result->arrays.push_back(std::move(array)); + } + } + if (result->stream.release) { + result->stream.release(&result->stream); + } + if (queryError.release) { + queryError.release(&queryError); + } + return result; +} + +} // namespace adbc_extension +} // namespace lbug diff --git a/adbc/src/function/CMakeLists.txt b/adbc/src/function/CMakeLists.txt new file mode 100644 index 0000000..4557d8d --- /dev/null +++ b/adbc/src/function/CMakeLists.txt @@ -0,0 +1,11 @@ +add_library(lbug_adbc_function + OBJECT + adbc_scan.cpp + ${PROJECT_SOURCE_DIR}/src/common/arrow/arrow_array_scan.cpp + ${PROJECT_SOURCE_DIR}/src/common/arrow/arrow_converter.cpp + ${PROJECT_SOURCE_DIR}/src/common/arrow/arrow_null_mask_tree.cpp + ${PROJECT_SOURCE_DIR}/src/common/arrow/arrow_type.cpp) + +set(ADBC_EXTENSION_OBJECT_FILES + ${ADBC_EXTENSION_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/adbc/src/function/adbc_scan.cpp b/adbc/src/function/adbc_scan.cpp new file mode 100644 index 0000000..aebe33f --- /dev/null +++ b/adbc/src/function/adbc_scan.cpp @@ -0,0 +1,148 @@ +#include "function/adbc_scan.h" + +#include + +#include "binder/binder.h" +#include "common/arrow/arrow_converter.h" +#include "common/constants.h" +#include "common/exception/runtime.h" +#include "function/table/bind_input.h" +#include "function/table/table_function.h" +#include + +namespace lbug { +namespace adbc_extension { + +static std::string quoteIdentifier(const std::string& value) { + std::string result = "\""; + for (auto ch : value) { + if (ch == '"') { + result += "\"\""; + } else { + result += ch; + } + } + result += "\""; + return result; +} + +static std::string joinColumns(const std::vector& columnNames) { + std::string result; + bool first = true; + for (auto& columnName : columnNames) { + if (!first) { + result += ", "; + } + result += quoteIdentifier(columnName); + first = false; + } + return result.empty() ? "*" : result; +} + +std::string ADBCScanBindData::getSQL() const { + auto sql = std::format("SELECT {} FROM {}", joinColumns(scanInfo->columnNames), + quoteIdentifier(scanInfo->tableName)); + if (getLimitNum() != common::INVALID_ROW_IDX) { + sql += std::format(" LIMIT {}", getLimitNum()); + } + return sql; +} + +struct ADBCScanFunction { + static constexpr char NAME[] = "adbc_scan"; + + static common::offset_t tableFunc(const function::TableFuncInput& input, + function::TableFuncOutput& output); + static std::unique_ptr bindFunc( + std::shared_ptr scanInfo, main::ClientContext* context, + const function::TableFuncBindInput* input); + static std::unique_ptr initSharedState( + const function::TableFuncInitSharedStateInput& input); + static std::unique_ptr initLocalState( + const function::TableFuncInitLocalStateInput& input); +}; + +struct ADBCScanLocalState final : function::TableFuncLocalState { + ArrowArrayWrapper array; + uint64_t offset = 0; +}; + +std::unique_ptr ADBCScanFunction::initSharedState( + const function::TableFuncInitSharedStateInput& input) { + auto bindData = input.bindData->constPtrCast(); + return std::make_unique(bindData->scanInfo->connector.executeQuery( + bindData->getSQL(), bindData->scanInfo->columnNames, bindData->scanInfo->columnTypes)); +} + +common::offset_t ADBCScanFunction::tableFunc(const function::TableFuncInput& input, + function::TableFuncOutput& output) { + auto sharedState = input.sharedState->ptrCast(); + auto localState = input.localState->ptrCast(); + auto bindData = input.bindData->constPtrCast(); + while (localState->array.release == nullptr || + localState->offset >= static_cast(localState->array.length)) { + localState->array = ArrowArrayWrapper{}; + localState->offset = 0; + { + std::lock_guard lock{sharedState->mtx}; + if (sharedState->queryResult->nextArrayIdx >= sharedState->queryResult->arrays.size()) { + return 0; + } + localState->array = std::move( + sharedState->queryResult->arrays[sharedState->queryResult->nextArrayIdx++]); + } + if (localState->array.release == nullptr || localState->array.length == 0) { + return 0; + } + } + auto count = std::min(common::DEFAULT_VECTOR_CAPACITY, + static_cast(localState->array.length) - localState->offset); + auto& schema = sharedState->queryResult->schema; + for (auto i = 0u; i < bindData->scanInfo->columnNames.size(); i++) { + auto srcOffset = localState->array.children[i]->offset + localState->offset; + common::ArrowNullMaskTree mask(schema.children[i], localState->array.children[i], srcOffset, + count); + common::ArrowConverter::fromArrowArray(schema.children[i], localState->array.children[i], + output.dataChunk.getValueVectorMutable(i), &mask, srcOffset, 0, count); + } + localState->offset += count; + return count; +} + +std::unique_ptr ADBCScanFunction::bindFunc( + std::shared_ptr scanInfo, main::ClientContext* /*context*/, + const function::TableFuncBindInput* input) { + auto columnNames = function::TableFunction::extractYieldVariables(scanInfo->columnNames, + input->yieldVariables); + std::vector columnTypes; + for (auto& columnName : columnNames) { + auto columnIt = + std::find(scanInfo->columnNames.begin(), scanInfo->columnNames.end(), columnName); + columnTypes.push_back( + scanInfo->columnTypes[columnIt - scanInfo->columnNames.begin()].copy()); + } + auto columns = input->binder->createVariables(columnNames, columnTypes); + auto selectedScanInfo = std::make_shared(scanInfo->tableName, + std::move(columnNames), std::move(columnTypes), scanInfo->connector); + return std::make_unique(std::move(selectedScanInfo), std::move(columns)); +} + +std::unique_ptr ADBCScanFunction::initLocalState( + const function::TableFuncInitLocalStateInput&) { + return std::make_unique(); +} + +function::TableFunction getADBCScanFunction(std::shared_ptr scanInfo) { + auto function = + function::TableFunction(ADBCScanFunction::NAME, std::vector{}); + function.tableFunc = ADBCScanFunction::tableFunc; + function.bindFunc = std::bind(ADBCScanFunction::bindFunc, scanInfo, std::placeholders::_1, + std::placeholders::_2); + function.initSharedStateFunc = ADBCScanFunction::initSharedState; + function.initLocalStateFunc = ADBCScanFunction::initLocalState; + function.supportsPushDownFunc = [] { return false; }; + return function; +} + +} // namespace adbc_extension +} // namespace lbug diff --git a/adbc/src/include/catalog/adbc_catalog.h b/adbc/src/include/catalog/adbc_catalog.h new file mode 100644 index 0000000..2402cb8 --- /dev/null +++ b/adbc/src/include/catalog/adbc_catalog.h @@ -0,0 +1,28 @@ +#pragma once + +#include "binder/ddl/property_definition.h" +#include "connector/adbc_connector.h" +#include "extension/catalog_extension.h" + +namespace lbug { +namespace adbc_extension { + +class ADBCCatalog final : public extension::CatalogExtension { +public: + ADBCCatalog(std::string schemaName, main::ClientContext* context, + const ADBCConnector& connector) + : schemaName{std::move(schemaName)}, context{context}, connector{connector} {} + + void init() override; + +private: + void createForeignTable(const std::string& tableName); + +private: + std::string schemaName; + main::ClientContext* context; + const ADBCConnector& connector; +}; + +} // namespace adbc_extension +} // namespace lbug diff --git a/adbc/src/include/catalog/adbc_table_catalog_entry.h b/adbc/src/include/catalog/adbc_table_catalog_entry.h new file mode 100644 index 0000000..066bf17 --- /dev/null +++ b/adbc/src/include/catalog/adbc_table_catalog_entry.h @@ -0,0 +1,30 @@ +#pragma once + +#include "catalog/catalog_entry/table_catalog_entry.h" +#include "function/adbc_scan.h" + +namespace lbug { +namespace catalog { + +class ADBCTableCatalogEntry final : public TableCatalogEntry { +public: + ADBCTableCatalogEntry(std::string name, std::optional scanFunction, + std::shared_ptr scanInfo); + + common::TableType getTableType() const override; + std::optional getScanFunction() const override { return scanFunction; } + std::unique_ptr getBoundScanInfo(main::ClientContext* context, + const std::string& nodeUniqueName = "") override; + std::unique_ptr copy() const override; + +private: + std::unique_ptr getBoundExtraCreateInfo( + transaction::Transaction* transaction) const override; + +private: + std::optional scanFunction; + std::shared_ptr scanInfo; +}; + +} // namespace catalog +} // namespace lbug diff --git a/adbc/src/include/connector/adbc_connector.h b/adbc/src/include/connector/adbc_connector.h new file mode 100644 index 0000000..5117109 --- /dev/null +++ b/adbc/src/include/connector/adbc_connector.h @@ -0,0 +1,79 @@ +#pragma once + +#include +#include + +#include "binder/bound_attach_info.h" +#include "common/arrow/arrow.h" +#include "common/types/types.h" + +#if __has_include() && __has_include() +#include +#include +#else +#error "ADBC C API and driver manager headers not found" +#endif + +#ifndef ARROW_C_STREAM_INTERFACE +#define ARROW_C_STREAM_INTERFACE +struct ArrowArrayStream { + int (*get_schema)(struct ArrowArrayStream*, struct ArrowSchema* out); + int (*get_next)(struct ArrowArrayStream*, struct ArrowArray* out); + const char* (*get_last_error)(struct ArrowArrayStream*); + void (*release)(struct ArrowArrayStream*); + void* private_data; +}; +#endif + +namespace lbug { +namespace adbc_extension { + +struct ADBCQueryResult { + AdbcConnection connection{}; + AdbcStatement statement{}; + ArrowArrayStream stream{}; + ArrowSchemaWrapper schema{}; + std::vector arrays; + uint64_t nextArrayIdx = 0; + bool connectionInitialized = false; + + ADBCQueryResult() = default; + ADBCQueryResult(const ADBCQueryResult&) = delete; + ADBCQueryResult& operator=(const ADBCQueryResult&) = delete; + ~ADBCQueryResult(); +}; + +class ADBCConnector final { +public: + explicit ADBCConnector(const binder::AttachOption& attachOption); + ~ADBCConnector(); + + void connect(const std::string& uri); + + std::vector getTableNames() const; + std::vector> getTableSchema( + const std::string& schemaName, const std::string& tableName) const; + std::unique_ptr executeQuery(const std::string& query, + const std::vector& columnNames, + const std::vector& columnTypes) const; + +private: + std::string getStringOption(const std::string& key, const std::string& defaultValue = "") const; + bool hasOption(const std::string& key) const; + void checkStatus(AdbcStatusCode status, const std::string& operation) const; + +private: + const binder::AttachOption& attachOption; + mutable std::mutex mtx; + mutable AdbcError error{}; + AdbcDatabase database{}; + mutable AdbcConnection connection{}; + mutable std::unordered_map>> + schemaCache; + bool databaseInitialized = false; + bool connectionInitialized = false; +}; + +} // namespace adbc_extension +} // namespace lbug diff --git a/adbc/src/include/function/adbc_scan.h b/adbc/src/include/function/adbc_scan.h new file mode 100644 index 0000000..6a9545a --- /dev/null +++ b/adbc/src/include/function/adbc_scan.h @@ -0,0 +1,48 @@ +#pragma once + +#include "common/arrow/arrow.h" +#include "connector/adbc_connector.h" +#include "function/table/bind_data.h" +#include "function/table/table_function.h" + +namespace lbug { +namespace adbc_extension { + +struct ADBCTableScanInfo { + std::string tableName; + std::vector columnNames; + std::vector columnTypes; + const ADBCConnector& connector; + + ADBCTableScanInfo(std::string tableName, std::vector columnNames, + std::vector columnTypes, const ADBCConnector& connector) + : tableName{std::move(tableName)}, columnNames{std::move(columnNames)}, + columnTypes{std::move(columnTypes)}, connector{connector} {} +}; + +struct ADBCScanBindData final : function::TableFuncBindData { + std::shared_ptr scanInfo; + + ADBCScanBindData(std::shared_ptr scanInfo, binder::expression_vector columns) + : function::TableFuncBindData{std::move(columns), 0}, scanInfo{std::move(scanInfo)} {} + ADBCScanBindData(const ADBCScanBindData& other) + : function::TableFuncBindData{other}, scanInfo{other.scanInfo} {} + + std::string getSQL() const; + std::string getDescription() const override { return getSQL(); } + std::unique_ptr copy() const override { + return std::make_unique(*this); + } +}; + +struct ADBCScanSharedState final : function::TableFuncSharedState { + explicit ADBCScanSharedState(std::unique_ptr queryResult) + : queryResult{std::move(queryResult)} {} + + std::unique_ptr queryResult; +}; + +function::TableFunction getADBCScanFunction(std::shared_ptr scanInfo); + +} // namespace adbc_extension +} // namespace lbug diff --git a/adbc/src/include/main/adbc_extension.h b/adbc/src/include/main/adbc_extension.h new file mode 100644 index 0000000..b3672bb --- /dev/null +++ b/adbc/src/include/main/adbc_extension.h @@ -0,0 +1,17 @@ +#pragma once + +#include "extension/extension.h" + +namespace lbug { +namespace adbc_extension { + +class ADBCExtension final : public extension::Extension { +public: + static constexpr char EXTENSION_NAME[] = "ADBC"; + +public: + static void load(main::ClientContext* context); +}; + +} // namespace adbc_extension +} // namespace lbug diff --git a/adbc/src/include/storage/adbc_storage.h b/adbc/src/include/storage/adbc_storage.h new file mode 100644 index 0000000..a204cbe --- /dev/null +++ b/adbc/src/include/storage/adbc_storage.h @@ -0,0 +1,21 @@ +#pragma once + +#include "main/database.h" +#include "storage/storage_extension.h" + +namespace lbug { +namespace adbc_extension { + +class ADBCStorageExtension final : public storage::StorageExtension { +public: + static constexpr const char* DB_TYPE = "ADBC"; + + static constexpr const char* DEFAULT_SCHEMA_NAME = "main"; + + explicit ADBCStorageExtension(main::Database& database); + + bool canHandleDB(std::string dbType_) const override; +}; + +} // namespace adbc_extension +} // namespace lbug diff --git a/adbc/src/include/storage/attached_adbc_database.h b/adbc/src/include/storage/attached_adbc_database.h new file mode 100644 index 0000000..0182115 --- /dev/null +++ b/adbc/src/include/storage/attached_adbc_database.h @@ -0,0 +1,31 @@ +#pragma once + +#include "connector/adbc_connector.h" +#include "main/attached_database.h" + +namespace lbug { +namespace adbc_extension { + +class AttachedADBCDatabase final : public main::AttachedDatabase { +public: + AttachedADBCDatabase(std::string dbName, std::string dbType, + std::unique_ptr catalog, + std::unique_ptr connector) + : main::AttachedDatabase{std::move(dbName), std::move(dbType), std::move(catalog)}, + connector{std::move(connector)} {} + + std::vector getTableColumnNames(const std::string& tableName) const override { + std::vector result; + for (auto& [name, type] : connector->getTableSchema("", tableName)) { + (void)type; + result.push_back(name); + } + return result; + } + +private: + std::unique_ptr connector; +}; + +} // namespace adbc_extension +} // namespace lbug diff --git a/adbc/src/main/CMakeLists.txt b/adbc/src/main/CMakeLists.txt new file mode 100644 index 0000000..1d66628 --- /dev/null +++ b/adbc/src/main/CMakeLists.txt @@ -0,0 +1,7 @@ +add_library(adbc_extension_main + OBJECT + adbc_extension.cpp) + +set(ADBC_EXTENSION_OBJECT_FILES + ${ADBC_EXTENSION_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/adbc/src/main/adbc_extension.cpp b/adbc/src/main/adbc_extension.cpp new file mode 100644 index 0000000..89dfb14 --- /dev/null +++ b/adbc/src/main/adbc_extension.cpp @@ -0,0 +1,35 @@ +#include "main/adbc_extension.h" + +#include "main/client_context.h" +#include "main/database.h" +#include "storage/adbc_storage.h" + +namespace lbug { +namespace adbc_extension { + +void ADBCExtension::load(main::ClientContext* context) { + auto db = context->getDatabase(); + db->registerStorageExtension(EXTENSION_NAME, std::make_unique(*db)); +} + +} // namespace adbc_extension +} // namespace lbug + +#if defined(BUILD_DYNAMIC_LOAD) +extern "C" { +// Because we link against the static library on windows, we implicitly inherit LBUG_STATIC_DEFINE, +// which cancels out any exporting, so we can't use LBUG_API. +#if defined(_WIN32) +#define INIT_EXPORT __declspec(dllexport) +#else +#define INIT_EXPORT __attribute__((visibility("default"))) +#endif +INIT_EXPORT void init(lbug::main::ClientContext* context) { + lbug::adbc_extension::ADBCExtension::load(context); +} + +INIT_EXPORT const char* name() { + return lbug::adbc_extension::ADBCExtension::EXTENSION_NAME; +} +} +#endif diff --git a/adbc/src/storage/CMakeLists.txt b/adbc/src/storage/CMakeLists.txt new file mode 100644 index 0000000..fcedbe4 --- /dev/null +++ b/adbc/src/storage/CMakeLists.txt @@ -0,0 +1,7 @@ +add_library(lbug_adbc_storage + OBJECT + adbc_storage.cpp) + +set(ADBC_EXTENSION_OBJECT_FILES + ${ADBC_EXTENSION_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/adbc/src/storage/adbc_storage.cpp b/adbc/src/storage/adbc_storage.cpp new file mode 100644 index 0000000..ecbf9ec --- /dev/null +++ b/adbc/src/storage/adbc_storage.cpp @@ -0,0 +1,43 @@ +#include "storage/adbc_storage.h" + +#include "catalog/adbc_catalog.h" +#include "catalog/adbc_table_catalog_entry.h" +#include "common/exception/runtime.h" +#include "common/string_utils.h" +#include "connector/adbc_connector.h" +#include "storage/attached_adbc_database.h" + +namespace lbug { +namespace adbc_extension { + +std::unique_ptr attachADBC(std::string dbName, std::string dbPath, + main::ClientContext* clientContext, const binder::AttachOption& attachOption) { + if (dbName.empty()) { + dbName = "adbc"; + } + std::string schemaName = ADBCStorageExtension::DEFAULT_SCHEMA_NAME; + if (attachOption.options.contains("SCHEMA")) { + auto val = attachOption.options.at("SCHEMA"); + if (val.getDataType().getLogicalTypeID() != common::LogicalTypeID::STRING) { + throw common::RuntimeException{"Invalid option value for SCHEMA"}; + } + schemaName = val.getValue(); + } + auto connector = std::make_unique(attachOption); + connector->connect(dbPath); + auto catalog = std::make_unique(schemaName, clientContext, *connector); + catalog->init(); + return std::make_unique(dbName, ADBCStorageExtension::DB_TYPE, + std::move(catalog), std::move(connector)); +} + +ADBCStorageExtension::ADBCStorageExtension(main::Database& /*database*/) + : StorageExtension{attachADBC} {} + +bool ADBCStorageExtension::canHandleDB(std::string dbType_) const { + common::StringUtils::toUpper(dbType_); + return dbType_ == DB_TYPE; +} + +} // namespace adbc_extension +} // namespace lbug diff --git a/adbc/test/test_files/adbc.test b/adbc/test/test_files/adbc.test new file mode 100644 index 0000000..3324c1b --- /dev/null +++ b/adbc/test/test_files/adbc.test @@ -0,0 +1,17 @@ +-DATASET CSV empty + +-- + +-CASE ADBCDuckDBDriver +-LOAD_DYNAMIC_EXTENSION adbc +-STATEMENT ATTACH '${LBUG_ROOT_DIRECTORY}/extension/adbc/test/adbc.duckdb' AS adbc_duckdb (dbtype adbc, driver='duckdb', tables='games'); +---- 1 +Attached database successfully. +-STATEMENT LOAD FROM adbc_duckdb.games RETURN id, title, score ORDER BY id; +---- 3 +1|Portal|95 +2|Celeste|94 +3|Hades|93 +-STATEMENT DETACH adbc_duckdb; +---- 1 +Detached database successfully. diff --git a/extension_config.cmake b/extension_config.cmake index 65e38d7..b182655 100644 --- a/extension_config.cmake +++ b/extension_config.cmake @@ -1,4 +1,4 @@ -set(EXTENSION_LIST azure delta duckdb fts httpfs iceberg json llm postgres sqlite unity_catalog vector neo4j algo) +set(EXTENSION_LIST adbc azure delta duckdb fts httpfs iceberg json llm postgres sqlite unity_catalog vector neo4j algo) #set(EXTENSION_STATIC_LINK_LIST fts) foreach(extension IN LISTS EXTENSION_STATIC_LINK_LIST)