diff --git a/src/afs.cc b/src/afs.cc index 3cd0009..8f0e778 100644 --- a/src/afs.cc +++ b/src/afs.cc @@ -335,7 +335,9 @@ struct SessionData { dsa_pointer userName; dsa_pointer password; dsa_pointer clientAddress; - dsa_pointer query; + dsa_pointer selectQuery; + dsa_pointer updateQuery; + int64_t nUpdatedRecords; SharedRingBufferData bufferData; }; @@ -529,8 +531,10 @@ class WorkerProcessor : public Processor { dsa_free(area_, session->userName); if (DsaPointerIsValid(session->password)) dsa_free(area_, session->password); - if (DsaPointerIsValid(session->query)) - dsa_free(area_, session->query); + if (DsaPointerIsValid(session->selectQuery)) + dsa_free(area_, session->selectQuery); + if (DsaPointerIsValid(session->updateQuery)) + dsa_free(area_, session->updateQuery); SharedRingBuffer::free_data(&(session->bufferData), area_); dshash_delete_entry(sessions_, session); } @@ -645,17 +649,20 @@ class Executor : public WorkerProcessor { void signaled() { - P("%s: %s: signaled: before: %d", Tag, tag_, session_->query); - P("signaled: before: %d", session_->query); - if (DsaPointerIsValid(session_->query)) + P("%s: %s: signaled: before: %d/%d", Tag, tag_, session_->selectQuery, session_->updateQuery); + if (DsaPointerIsValid(session_->selectQuery)) { - execute(); + select(); + } + else if (DsaPointerIsValid(session_->updateQuery)) + { + update(); } else { Processor::signaled(); } - P("%s: %s: signaled: after: %d", Tag, tag_, session_->query); + P("%s: %s: signaled: after: %d/%d", Tag, tag_, session_->selectQuery, session_->updateQuery); } private: @@ -844,26 +851,26 @@ class Executor : public WorkerProcessor { return true; } - void execute() + void select() { - pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": executing").c_str()); + pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": selecting").c_str()); PushActiveSnapshot(GetTransactionSnapshot()); LWLockAcquire(lock_, LW_EXCLUSIVE); std::string query( - static_cast(dsa_get_address(area_, session_->query))); - dsa_free(area_, session_->query); - session_->query = InvalidDsaPointer; + static_cast(dsa_get_address(area_, session_->selectQuery))); + dsa_free(area_, session_->selectQuery); + session_->selectQuery = InvalidDsaPointer; SetCurrentStatementStartTimestamp(); - P("%s: %s: execute: %s", Tag, tag_, query.c_str()); + P("%s: %s: select: %s", Tag, tag_, query.c_str()); auto result = SPI_execute(query.c_str(), true, 0); LWLockRelease(lock_); if (result == SPI_OK_SELECT) { pgstat_report_activity(STATE_RUNNING, - (std::string(Tag) + ": writing").c_str()); + (std::string(Tag) + ": select: writing").c_str()); auto status = write(); if (!status.ok()) { @@ -873,7 +880,7 @@ class Executor : public WorkerProcessor { else { set_shared_string(session_->errorMessage, - std::string(Tag) + ": " + tag_ + + std::string(Tag) + ": " + tag_ + ": select" + ": failed to run a query: <" + query + ">: " + SPI_result_code_string(result)); } @@ -882,7 +889,53 @@ class Executor : public WorkerProcessor { if (sharedData_->serverPID != InvalidPid) { - P("%s: %s: kill server: %d", Tag, tag_, sharedData_->serverPID); + P("%s: %s: select: kill server: %d", Tag, tag_, sharedData_->serverPID); + kill(sharedData_->serverPID, SIGUSR1); + } + + pgstat_report_activity(STATE_IDLE, NULL); + } + + void update() + { + pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": updating").c_str()); + + PushActiveSnapshot(GetTransactionSnapshot()); + + LWLockAcquire(lock_, LW_EXCLUSIVE); + std::string query( + static_cast(dsa_get_address(area_, session_->updateQuery))); + dsa_free(area_, session_->updateQuery); + session_->updateQuery = InvalidDsaPointer; + SetCurrentStatementStartTimestamp(); + P("%s: %s: update: %s", Tag, tag_, query.c_str()); + auto result = SPI_execute(query.c_str(), false, 0); + LWLockRelease(lock_); + + switch (result) + { + case SPI_OK_INSERT: + case SPI_OK_DELETE: + case SPI_OK_UPDATE: + session_->nUpdatedRecords = SPI_processed; + break; + default: + set_shared_string(session_->errorMessage, + std::string(Tag) + ": " + tag_ + ": update" + + ": failed to run a query: <" + query + + ">: " + SPI_result_code_string(result)); + break; + } + + PopActiveSnapshot(); + + // TODO: Is this usage correct? + CommitTransactionCommand(); + StartTransactionCommand(); + + if (sharedData_->serverPID != InvalidPid) + { + P("%s: %s: update: kill server: %d", Tag, tag_, sharedData_->serverPID); kill(sharedData_->serverPID, SIGUSR1); } @@ -1128,22 +1181,22 @@ class Proxy : public WorkerProcessor { } } - arrow::Result> execute(uint64_t sessionID, - const std::string& query) + arrow::Result> select(uint64_t sessionID, + const std::string& query) { auto session = find_session(sessionID); SessionReleaser sessionReleaser(sessions_, session); - set_shared_string(session->query, query); + set_shared_string(session->selectQuery, query); if (session->executorPID != InvalidPid) { - P("%s: %s: execute: kill executor: %d", Tag, tag_, session->executorPID); + P("%s: %s: select: kill executor: %d", Tag, tag_, session->executorPID); kill(session->executorPID, SIGUSR1); } { auto buffer = std::move(create_shared_ring_buffer(session)); std::unique_lock lock(mutex_); conditionVariable_.wait(lock, [&] { - P("%s: %s: %s: wait: execute", Tag, tag_, AFS_FUNC); + P("%s: %s: %s: wait: select", Tag, tag_, AFS_FUNC); return DsaPointerIsValid(session->errorMessage) || buffer.size() > 0; }); } @@ -1151,7 +1204,7 @@ class Proxy : public WorkerProcessor { { return report_session_error(session); } - P("%s: %s: execute: open", Tag, tag_); + P("%s: %s: select: open", Tag, tag_); auto input = std::make_shared(this, session); // Read schema only stream format data. ARROW_ASSIGN_OR_RAISE(auto reader, @@ -1159,17 +1212,44 @@ class Proxy : public WorkerProcessor { while (true) { std::shared_ptr recordBatch; - P("%s: %s: execute: read next", Tag, tag_); + P("%s: %s: select: read next", Tag, tag_); ARROW_RETURN_NOT_OK(reader->ReadNext(&recordBatch)); if (!recordBatch) { break; } } - P("%s: %s: execute: schema", Tag, tag_); + P("%s: %s: select: schema", Tag, tag_); return reader->schema(); } + arrow::Result update(uint64_t sessionID, const std::string& query) + { + auto session = find_session(sessionID); + SessionReleaser sessionReleaser(sessions_, session); + set_shared_string(session->updateQuery, query); + session->nUpdatedRecords = -1; + if (session->executorPID != InvalidPid) + { + P("%s: %s: update: kill executor: %d", + Tag, tag_, session->executorPID); + kill(session->executorPID, SIGUSR1); + } + { + std::unique_lock lock(mutex_); + conditionVariable_.wait(lock, [&] { + P("%s: %s: %s: wait: update", Tag, tag_, AFS_FUNC); + return DsaPointerIsValid(session->errorMessage) || session->nUpdatedRecords >= 0; + }); + } + if (DsaPointerIsValid(session->errorMessage)) + { + return report_session_error(session); + } + P("%s: %s: update: done: %ld", Tag, tag_, session->nUpdatedRecords); + return session->nUpdatedRecords; + } + arrow::Result> read(uint64_t sessionID) { auto session = find_session(sessionID); @@ -1210,7 +1290,9 @@ class Proxy : public WorkerProcessor { set_shared_string(session->userName, userName); set_shared_string(session->password, password); set_shared_string(session->clientAddress, clientAddress); - session->query = InvalidDsaPointer; + session->selectQuery = InvalidDsaPointer; + session->updateQuery = InvalidDsaPointer; + session->nUpdatedRecords = -1; SharedRingBuffer::initialize_data(&(session->bufferData)); LWLockRelease(lock_); return session; @@ -1507,11 +1589,11 @@ class FlightSQLServer : public arrow::flight::sql::FlightSqlServerBase { arrow::Result> GetFlightInfoStatement( const arrow::flight::ServerCallContext& context, const arrow::flight::sql::StatementQuery& command, - const arrow::flight::FlightDescriptor& descriptor) + const arrow::flight::FlightDescriptor& descriptor) override { ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context)); const auto& query = command.query; - ARROW_ASSIGN_OR_RAISE(auto schema, proxy_->execute(sessionID, query)); + ARROW_ASSIGN_OR_RAISE(auto schema, proxy_->select(sessionID, query)); ARROW_ASSIGN_OR_RAISE(auto ticket, arrow::flight::sql::CreateStatementQueryTicket(query)); std::vector endpoints{ @@ -1524,13 +1606,22 @@ class FlightSQLServer : public arrow::flight::sql::FlightSqlServerBase { arrow::Result> DoGetStatement( const arrow::flight::ServerCallContext& context, - const arrow::flight::sql::StatementQueryTicket& command) + const arrow::flight::sql::StatementQueryTicket& command) override { ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context)); ARROW_ASSIGN_OR_RAISE(auto reader, proxy_->read(sessionID)); return std::make_unique(reader); } + arrow::Result DoPutCommandStatementUpdate( + const arrow::flight::ServerCallContext& context, + const arrow::flight::sql::StatementUpdate& command) override + { + ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context)); + const auto& query = command.query; + return proxy_->update(sessionID, query); + } + private: arrow::Result session_id(const arrow::flight::ServerCallContext& context) { diff --git a/test/test-flight-sql.rb b/test/test-flight-sql.rb index 3e0992d..6997460 100644 --- a/test/test-flight-sql.rb +++ b/test/test-flight-sql.rb @@ -51,4 +51,27 @@ def test_select_from assert_equal(Arrow::Table.new(value: Arrow::Int32Array.new([1, -2, 3])), reader.read_all) end + + def test_isnert_int32 + unless filght_sql_client.respond_to?(:execute_update) + omit("red-arrow-flight-sql 13.0.0 or later is required") + end + + run_sql("CREATE TABLE data (value integer)") + + n_changed_records = flight_sql_client.execute_update( + "INSERT INTO data VALUES (1), (-2), (3)", + @options) + assert_equal(3, n_changed_records) + assert_equal([<<-RESULT, ""], run_sql("SELECT * FROM data")) +SELECT * FROM data + value +------- + 1 + -2 + 3 +(3 rows) + + RESULT + end end