Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 120 additions & 29 deletions src/afs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,9 @@ struct SessionData {
dsa_pointer userName;
dsa_pointer password;
dsa_pointer clientAddress;
dsa_pointer query;
dsa_pointer selectQuery;
dsa_pointer updateQuery;
int64_t nUpdatedRecords;
SharedRingBufferData bufferData;
};

Expand Down Expand Up @@ -529,8 +531,10 @@ class WorkerProcessor : public Processor {
dsa_free(area_, session->userName);
if (DsaPointerIsValid(session->password))
dsa_free(area_, session->password);
if (DsaPointerIsValid(session->query))
dsa_free(area_, session->query);
if (DsaPointerIsValid(session->selectQuery))
dsa_free(area_, session->selectQuery);
if (DsaPointerIsValid(session->updateQuery))
dsa_free(area_, session->updateQuery);
SharedRingBuffer::free_data(&(session->bufferData), area_);
dshash_delete_entry(sessions_, session);
}
Expand Down Expand Up @@ -645,17 +649,20 @@ class Executor : public WorkerProcessor {

void signaled()
{
P("%s: %s: signaled: before: %d", Tag, tag_, session_->query);
P("signaled: before: %d", session_->query);
if (DsaPointerIsValid(session_->query))
P("%s: %s: signaled: before: %d/%d", Tag, tag_, session_->selectQuery, session_->updateQuery);
if (DsaPointerIsValid(session_->selectQuery))
{
execute();
select();
}
else if (DsaPointerIsValid(session_->updateQuery))
{
update();
}
else
{
Processor::signaled();
}
P("%s: %s: signaled: after: %d", Tag, tag_, session_->query);
P("%s: %s: signaled: after: %d/%d", Tag, tag_, session_->selectQuery, session_->updateQuery);
}

private:
Expand Down Expand Up @@ -844,26 +851,26 @@ class Executor : public WorkerProcessor {
return true;
}

void execute()
void select()
{
pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": executing").c_str());
pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": selecting").c_str());

PushActiveSnapshot(GetTransactionSnapshot());

LWLockAcquire(lock_, LW_EXCLUSIVE);
std::string query(
static_cast<const char*>(dsa_get_address(area_, session_->query)));
dsa_free(area_, session_->query);
session_->query = InvalidDsaPointer;
static_cast<const char*>(dsa_get_address(area_, session_->selectQuery)));
dsa_free(area_, session_->selectQuery);
session_->selectQuery = InvalidDsaPointer;
SetCurrentStatementStartTimestamp();
P("%s: %s: execute: %s", Tag, tag_, query.c_str());
P("%s: %s: select: %s", Tag, tag_, query.c_str());
auto result = SPI_execute(query.c_str(), true, 0);
LWLockRelease(lock_);

if (result == SPI_OK_SELECT)
{
pgstat_report_activity(STATE_RUNNING,
(std::string(Tag) + ": writing").c_str());
(std::string(Tag) + ": select: writing").c_str());
auto status = write();
if (!status.ok())
{
Expand All @@ -873,7 +880,7 @@ class Executor : public WorkerProcessor {
else
{
set_shared_string(session_->errorMessage,
std::string(Tag) + ": " + tag_ +
std::string(Tag) + ": " + tag_ + ": select" +
": failed to run a query: <" + query +
">: " + SPI_result_code_string(result));
}
Expand All @@ -882,7 +889,53 @@ class Executor : public WorkerProcessor {

if (sharedData_->serverPID != InvalidPid)
{
P("%s: %s: kill server: %d", Tag, tag_, sharedData_->serverPID);
P("%s: %s: select: kill server: %d", Tag, tag_, sharedData_->serverPID);
kill(sharedData_->serverPID, SIGUSR1);
}

pgstat_report_activity(STATE_IDLE, NULL);
}

void update()
{
pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": updating").c_str());

PushActiveSnapshot(GetTransactionSnapshot());

LWLockAcquire(lock_, LW_EXCLUSIVE);
std::string query(
static_cast<const char*>(dsa_get_address(area_, session_->updateQuery)));
dsa_free(area_, session_->updateQuery);
session_->updateQuery = InvalidDsaPointer;
SetCurrentStatementStartTimestamp();
P("%s: %s: update: %s", Tag, tag_, query.c_str());
auto result = SPI_execute(query.c_str(), false, 0);
LWLockRelease(lock_);

switch (result)
{
case SPI_OK_INSERT:
case SPI_OK_DELETE:
case SPI_OK_UPDATE:
session_->nUpdatedRecords = SPI_processed;
break;
default:
set_shared_string(session_->errorMessage,
std::string(Tag) + ": " + tag_ + ": update" +
": failed to run a query: <" + query +
">: " + SPI_result_code_string(result));
break;
}

PopActiveSnapshot();

// TODO: Is this usage correct?
CommitTransactionCommand();
StartTransactionCommand();

if (sharedData_->serverPID != InvalidPid)
{
P("%s: %s: update: kill server: %d", Tag, tag_, sharedData_->serverPID);
kill(sharedData_->serverPID, SIGUSR1);
}

Expand Down Expand Up @@ -1128,48 +1181,75 @@ class Proxy : public WorkerProcessor {
}
}

arrow::Result<std::shared_ptr<arrow::Schema>> execute(uint64_t sessionID,
const std::string& query)
arrow::Result<std::shared_ptr<arrow::Schema>> select(uint64_t sessionID,
const std::string& query)
{
auto session = find_session(sessionID);
SessionReleaser sessionReleaser(sessions_, session);
set_shared_string(session->query, query);
set_shared_string(session->selectQuery, query);
if (session->executorPID != InvalidPid)
{
P("%s: %s: execute: kill executor: %d", Tag, tag_, session->executorPID);
P("%s: %s: select: kill executor: %d", Tag, tag_, session->executorPID);
kill(session->executorPID, SIGUSR1);
}
{
auto buffer = std::move(create_shared_ring_buffer(session));
std::unique_lock<std::mutex> lock(mutex_);
conditionVariable_.wait(lock, [&] {
P("%s: %s: %s: wait: execute", Tag, tag_, AFS_FUNC);
P("%s: %s: %s: wait: select", Tag, tag_, AFS_FUNC);
return DsaPointerIsValid(session->errorMessage) || buffer.size() > 0;
});
}
if (DsaPointerIsValid(session->errorMessage))
{
return report_session_error(session);
}
P("%s: %s: execute: open", Tag, tag_);
P("%s: %s: select: open", Tag, tag_);
auto input = std::make_shared<SharedRingBufferInputStream>(this, session);
// Read schema only stream format data.
ARROW_ASSIGN_OR_RAISE(auto reader,
arrow::ipc::RecordBatchStreamReader::Open(input));
while (true)
{
std::shared_ptr<arrow::RecordBatch> recordBatch;
P("%s: %s: execute: read next", Tag, tag_);
P("%s: %s: select: read next", Tag, tag_);
ARROW_RETURN_NOT_OK(reader->ReadNext(&recordBatch));
if (!recordBatch)
{
break;
}
}
P("%s: %s: execute: schema", Tag, tag_);
P("%s: %s: select: schema", Tag, tag_);
return reader->schema();
}

arrow::Result<int64_t> update(uint64_t sessionID, const std::string& query)
{
auto session = find_session(sessionID);
SessionReleaser sessionReleaser(sessions_, session);
set_shared_string(session->updateQuery, query);
session->nUpdatedRecords = -1;
if (session->executorPID != InvalidPid)
{
P("%s: %s: update: kill executor: %d",
Tag, tag_, session->executorPID);
kill(session->executorPID, SIGUSR1);
}
{
std::unique_lock<std::mutex> lock(mutex_);
conditionVariable_.wait(lock, [&] {
P("%s: %s: %s: wait: update", Tag, tag_, AFS_FUNC);
return DsaPointerIsValid(session->errorMessage) || session->nUpdatedRecords >= 0;
});
}
if (DsaPointerIsValid(session->errorMessage))
{
return report_session_error(session);
}
P("%s: %s: update: done: %ld", Tag, tag_, session->nUpdatedRecords);
return session->nUpdatedRecords;
}

arrow::Result<std::shared_ptr<arrow::RecordBatchReader>> read(uint64_t sessionID)
{
auto session = find_session(sessionID);
Expand Down Expand Up @@ -1210,7 +1290,9 @@ class Proxy : public WorkerProcessor {
set_shared_string(session->userName, userName);
set_shared_string(session->password, password);
set_shared_string(session->clientAddress, clientAddress);
session->query = InvalidDsaPointer;
session->selectQuery = InvalidDsaPointer;
session->updateQuery = InvalidDsaPointer;
session->nUpdatedRecords = -1;
SharedRingBuffer::initialize_data(&(session->bufferData));
LWLockRelease(lock_);
return session;
Expand Down Expand Up @@ -1507,11 +1589,11 @@ class FlightSQLServer : public arrow::flight::sql::FlightSqlServerBase {
arrow::Result<std::unique_ptr<arrow::flight::FlightInfo>> GetFlightInfoStatement(
const arrow::flight::ServerCallContext& context,
const arrow::flight::sql::StatementQuery& command,
const arrow::flight::FlightDescriptor& descriptor)
const arrow::flight::FlightDescriptor& descriptor) override
{
ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context));
const auto& query = command.query;
ARROW_ASSIGN_OR_RAISE(auto schema, proxy_->execute(sessionID, query));
ARROW_ASSIGN_OR_RAISE(auto schema, proxy_->select(sessionID, query));
ARROW_ASSIGN_OR_RAISE(auto ticket,
arrow::flight::sql::CreateStatementQueryTicket(query));
std::vector<arrow::flight::FlightEndpoint> endpoints{
Expand All @@ -1524,13 +1606,22 @@ class FlightSQLServer : public arrow::flight::sql::FlightSqlServerBase {

arrow::Result<std::unique_ptr<arrow::flight::FlightDataStream>> DoGetStatement(
const arrow::flight::ServerCallContext& context,
const arrow::flight::sql::StatementQueryTicket& command)
const arrow::flight::sql::StatementQueryTicket& command) override
{
ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context));
ARROW_ASSIGN_OR_RAISE(auto reader, proxy_->read(sessionID));
return std::make_unique<arrow::flight::RecordBatchStream>(reader);
}

arrow::Result<int64_t> DoPutCommandStatementUpdate(
const arrow::flight::ServerCallContext& context,
const arrow::flight::sql::StatementUpdate& command) override
{
ARROW_ASSIGN_OR_RAISE(auto sessionID, session_id(context));
const auto& query = command.query;
return proxy_->update(sessionID, query);
}

private:
arrow::Result<uint64_t> session_id(const arrow::flight::ServerCallContext& context)
{
Expand Down
23 changes: 23 additions & 0 deletions test/test-flight-sql.rb
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,27 @@ def test_select_from
assert_equal(Arrow::Table.new(value: Arrow::Int32Array.new([1, -2, 3])),
reader.read_all)
end

def test_isnert_int32
unless filght_sql_client.respond_to?(:execute_update)
omit("red-arrow-flight-sql 13.0.0 or later is required")
end

run_sql("CREATE TABLE data (value integer)")

n_changed_records = flight_sql_client.execute_update(
"INSERT INTO data VALUES (1), (-2), (3)",
@options)
assert_equal(3, n_changed_records)
assert_equal([<<-RESULT, ""], run_sql("SELECT * FROM data"))
SELECT * FROM data
value
-------
1
-2
3
(3 rows)

RESULT
end
end