Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
269 changes: 244 additions & 25 deletions src/afs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,20 @@ extern "C"
#include <utils/wait_event.h>
}

#undef Abs

#include <arrow/buffer.h>
#include <arrow/builder.h>
#include <arrow/flight/server_middleware.h>
#include <arrow/flight/sql/server.h>
#include <arrow/io/memory.h>
#include <arrow/ipc/reader.h>
#include <arrow/ipc/writer.h>
#include <arrow/table_builder.h>
#include <arrow/util/base64.h>

#include <condition_variable>
#include <sstream>

extern "C"
{
Expand Down Expand Up @@ -97,13 +108,25 @@ struct ConnectData {
dsa_pointer password;
};

struct Buffer {
dsa_pointer data;
size_t total;
size_t used;
};

struct ExecuteData {
dsa_pointer query;
Buffer buffer;
};

struct SharedData {
dsa_handle handle;
LWLock* lock;
pid_t executorPID;
pid_t serverPID;
pid_t mainPID;
ConnectData connectData;
ExecuteData executeData;
};

class Processor {
Expand Down Expand Up @@ -151,10 +174,17 @@ class Executor : public WorkerProcessor {
BackgroundWorkerInitializeConnection(
static_cast<const char*>(
dsa_get_address(area_, sharedData_->connectData.databaseName)),
nullptr,
static_cast<const char*>(
dsa_get_address(area_, sharedData_->connectData.userName)),
0);
dsa_free(area_, sharedData_->connectData.databaseName);
sharedData_->connectData.databaseName = InvalidDsaPointer;
unsetSharedString(sharedData_->connectData.databaseName);
unsetSharedString(sharedData_->connectData.userName);
unsetSharedString(sharedData_->connectData.password);
// TODO: Customizable.
sharedData_->executeData.buffer.total = 1L * 1024L;
sharedData_->executeData.buffer.data =
dsa_allocate(area_, sharedData_->executeData.buffer.total);
sharedData_->executeData.buffer.used = 0;
LWLockRelease(lock_);
StartTransactionCommand();
SPI_connect();
Expand All @@ -168,29 +198,161 @@ class Executor : public WorkerProcessor {
PopActiveSnapshot();
SPI_finish();
CommitTransactionCommand();
LWLockAcquire(lock_, LW_EXCLUSIVE);
dsa_free(area_, sharedData_->executeData.buffer.data);
sharedData_->executeData.buffer.data = InvalidDsaPointer;
sharedData_->executeData.buffer.total = 0;
sharedData_->executeData.buffer.used = 0;
LWLockRelease(lock_);
pgstat_report_activity(STATE_IDLE, NULL);
}

void execute()
{
pgstat_report_activity(STATE_RUNNING, (std::string(Tag) + ": executing").c_str());
LWLockAcquire(lock_, LW_EXCLUSIVE);
auto query = static_cast<const char*>(
dsa_get_address(area_, sharedData_->executeData.query));
SetCurrentStatementStartTimestamp();
auto result = SPI_execute(query, true, 0);
dsa_free(area_, sharedData_->executeData.query);
sharedData_->executeData.query = InvalidDsaPointer;
if (result == SPI_OK_SELECT)
{
auto bufferResult = write();
if (bufferResult.ok())
{
auto buffer = *bufferResult;
auto output =
dsa_get_address(area_, sharedData_->executeData.buffer.data);
memcpy(output, buffer->data(), buffer->size());
sharedData_->executeData.buffer.used = buffer->size();
}
}
LWLockRelease(lock_);
if (sharedData_->serverPID != InvalidPid)
{
kill(sharedData_->serverPID, SIGUSR1);
}
pgstat_report_activity(STATE_IDLE, NULL);
}

void execute() {}
private:
void unsetSharedString(dsa_pointer& pointer)
{
if (!DsaPointerIsValid(pointer))
{
return;
}
dsa_free(area_, pointer);
pointer = InvalidDsaPointer;
}

arrow::Result<std::shared_ptr<arrow::Buffer>> write()
{
ARROW_ASSIGN_OR_RAISE(auto output, arrow::io::BufferOutputStream::Create());
std::vector<std::shared_ptr<arrow::Field>> fields;
for (int i = 0; i < SPI_tuptable->tupdesc->natts; ++i)
{
auto attribute = TupleDescAttr(SPI_tuptable->tupdesc, i);
std::shared_ptr<arrow::DataType> type;
switch (attribute->atttypid)
{
case INT4OID:
type = arrow::int32();
break;
default:
return arrow::Status::NotImplemented("Unsupported PostgreSQL type: ",
attribute->atttypid);
}
fields.push_back(
arrow::field(NameStr(attribute->attname), type, !attribute->attnotnull));
}
auto schema = arrow::schema(fields);
auto option = arrow::ipc::IpcWriteOptions::Defaults();
option.emit_dictionary_deltas = true;
ARROW_ASSIGN_OR_RAISE(auto writer,
arrow::ipc::MakeStreamWriter(output, schema, option));
ARROW_ASSIGN_OR_RAISE(
auto builder,
arrow::RecordBatchBuilder::Make(schema, arrow::default_memory_pool()));
for (uint64_t iTuple = 0; iTuple < SPI_processed; ++iTuple)
{
for (uint64_t iAttribute = 0; iAttribute < SPI_tuptable->numvals;
++iAttribute)
{
bool isNull;
auto datum = SPI_getbinval(SPI_tuptable->vals[iTuple],
SPI_tuptable->tupdesc,
iAttribute + 1,
&isNull);
if (isNull)
{
auto arrayBuilder = builder->GetField(iAttribute);
ARROW_RETURN_NOT_OK(arrayBuilder->AppendNull());
}
else
{
auto arrayBuilder =
builder->GetFieldAs<arrow::Int32Builder>(iAttribute);
ARROW_RETURN_NOT_OK(arrayBuilder->Append(DatumGetInt32(datum)));
}
}
}
ARROW_ASSIGN_OR_RAISE(auto recordBatch, builder->Flush());
ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*recordBatch));
ARROW_RETURN_NOT_OK(writer->Close());
return output->Finish();
}
};

class Proxy : public WorkerProcessor {
public:
explicit Proxy() : WorkerProcessor("proxy") {}

void connect(const std::string& databaseName)
arrow::Status connect(const std::string& databaseName,
const std::string& userName,
const std::string& password)
{
if (sharedData_->executorPID != InvalidPid)
{
return arrow::Status::OK();
}
LWLockAcquire(lock_, LW_EXCLUSIVE);
sharedData_->connectData.databaseName =
dsa_allocate(area_, databaseName.size() + 1);
memcpy(dsa_get_address(area_, sharedData_->connectData.databaseName),
databaseName.c_str(),
databaseName.size() + 1);
setSharedString(sharedData_->connectData.databaseName, databaseName);
setSharedString(sharedData_->connectData.userName, userName);
setSharedString(sharedData_->connectData.password, password);
LWLockRelease(lock_);
kill(sharedData_->mainPID, SIGUSR1);
std::unique_lock<std::mutex> lock(mutex_);
condition_variable_.wait(lock,
[&] { return sharedData_->executorPID != InvalidPid; });
return arrow::Status::OK();
}

arrow::Result<std::shared_ptr<arrow::RecordBatchReader>> execute(
const std::string& query)
{
LWLockAcquire(lock_, LW_EXCLUSIVE);
setSharedString(sharedData_->executeData.query, query);
LWLockRelease(lock_);
if (sharedData_->executorPID != InvalidPid)
{
kill(sharedData_->executorPID, SIGUSR1);
}
std::unique_lock<std::mutex> lock(mutex_);
condition_variable_.wait(
lock, [&] { return sharedData_->executeData.buffer.used != 0; });
return read();
}

arrow::Result<std::shared_ptr<arrow::RecordBatchReader>> read()
{
auto input = std::make_shared<arrow::io::BufferReader>(
static_cast<const uint8_t*>(
dsa_get_address(area_, sharedData_->executeData.buffer.data)),
sharedData_->executeData.buffer.used);
return arrow::ipc::RecordBatchStreamReader::Open(input);
}

void signaled()
Expand All @@ -200,6 +362,16 @@ class Proxy : public WorkerProcessor {
}

private:
void setSharedString(dsa_pointer& pointer, const std::string& input)
{
if (input.empty())
{
return;
}
pointer = dsa_allocate(area_, input.size() + 1);
memcpy(dsa_get_address(area_, pointer), input.c_str(), input.size() + 1);
}

std::mutex mutex_;
std::condition_variable condition_variable_;
};
Expand All @@ -225,6 +397,9 @@ class MainProcessor : public Processor {
sharedData->connectData.databaseName = InvalidDsaPointer;
sharedData->connectData.userName = InvalidDsaPointer;
sharedData->connectData.password = InvalidDsaPointer;
sharedData->executeData.buffer.data = InvalidDsaPointer;
sharedData->executeData.buffer.total = 0;
sharedData->executeData.buffer.used = 0;
lock_ = &(GetNamedLWLockTranche(LWLockTrancheName)[0].lock);
LWLockRelease(AddinShmemInitLock);
sharedData_ = sharedData;
Expand Down Expand Up @@ -279,26 +454,43 @@ class MainProcessor : public Processor {
}
};

class AuthHandler : public arrow::flight::ServerAuthHandler {
class HeaderAuthServerMiddlewareFactory : public arrow::flight::ServerMiddlewareFactory {
public:
explicit AuthHandler(Proxy* proxy) : arrow::flight::ServerAuthHandler(), proxy_(proxy)
explicit HeaderAuthServerMiddlewareFactory(Proxy* proxy)
: arrow::flight::ServerMiddlewareFactory(), proxy_(proxy)
{
}

~AuthHandler() override {}

arrow::Status Authenticate(arrow::flight::ServerAuthSender* outgoing,
arrow::flight::ServerAuthReader* incoming) override
arrow::Status StartCall(const arrow::flight::CallInfo& info,
const arrow::flight::CallHeaders& incoming_headers,
std::shared_ptr<arrow::flight::ServerMiddleware>* middleware)
{
std::string databaseName("postgres");
proxy_->connect(databaseName);
return arrow::Status::OK();
}

arrow::Status IsValid(const std::string& token, std::string* peer_identity) override
{
*peer_identity = "postgres";
return arrow::Status::OK();
auto databaseHeader = incoming_headers.find("x-flight-sql-database");
if (databaseHeader != incoming_headers.end())
{
databaseName = databaseHeader->second;
}
std::string userName("");
std::string password("");
auto authorizationHeader = incoming_headers.find("authorization");
if (authorizationHeader != incoming_headers.end())
{
std::stringstream decodedStream(
arrow::util::base64_decode(authorizationHeader->second));
std::getline(decodedStream, userName, ':');
std::getline(decodedStream, password);
}
auto status = proxy_->connect(databaseName, userName, password);
if (status.ok())
{
return status;
}
else
{
return arrow::flight::MakeFlightError(
arrow::flight::FlightStatusCode::Unauthenticated, status.ToString());
}
}

private:
Expand All @@ -314,6 +506,32 @@ class FlightSQLServer : public arrow::flight::sql::FlightSqlServerBase {

~FlightSQLServer() override {}

arrow::Result<std::unique_ptr<arrow::flight::FlightInfo>> GetFlightInfoStatement(
const arrow::flight::ServerCallContext& context,
const arrow::flight::sql::StatementQuery& command,
const arrow::flight::FlightDescriptor& descriptor)
{
const auto& query = command.query;
ARROW_ASSIGN_OR_RAISE(auto reader, proxy_->execute(query));
auto schema = reader->schema();
ARROW_ASSIGN_OR_RAISE(auto ticket,
arrow::flight::sql::CreateStatementQueryTicket(query));
std::vector<arrow::flight::FlightEndpoint> endpoints{
arrow::flight::FlightEndpoint{std::move(ticket), {}}};
ARROW_ASSIGN_OR_RAISE(
auto result,
arrow::flight::FlightInfo::Make(*schema, descriptor, endpoints, -1, -1));
return std::make_unique<arrow::flight::FlightInfo>(result);
}

arrow::Result<std::unique_ptr<arrow::flight::FlightDataStream>> DoGetStatement(
const arrow::flight::ServerCallContext& context,
const arrow::flight::sql::StatementQueryTicket& command)
{
ARROW_ASSIGN_OR_RAISE(auto reader, proxy_->read());
return std::make_unique<arrow::flight::RecordBatchStream>(reader);
}

private:
Proxy* proxy_;
};
Expand All @@ -323,7 +541,8 @@ afs_server_internal(Proxy* proxy)
{
ARROW_ASSIGN_OR_RAISE(auto location, arrow::flight::Location::Parse(URI));
arrow::flight::FlightServerOptions options(location);
options.auth_handler = std::make_shared<AuthHandler>(proxy);
options.middleware.push_back(
{"header-auth", std::make_shared<HeaderAuthServerMiddlewareFactory>(proxy)});
FlightSQLServer flightSQLServer(proxy);
ARROW_RETURN_NOT_OK(flightSQLServer.Init(options));

Expand Down
4 changes: 4 additions & 0 deletions test/run.rb
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,8 @@

require_relative "helper/sandbox"

if File.exist?("build.ninja")
system("ninja", "install") or exit(false)
end

exit(Test::Unit::AutoRunner.run(true, __dir__))
Loading