diff --git a/CMakeLists.txt b/CMakeLists.txt index 2be2a8b..dbe1bf5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -244,8 +244,10 @@ add_library(flapi-lib STATIC src/request_handler.cpp src/request_validator.cpp src/rate_limit_middleware.cpp + src/prepared_template_rewriter.cpp src/route_translator.cpp src/security_auditor.cpp + src/sql_parameter_classifier.cpp src/sql_template_processor.cpp src/sql_utils.cpp src/mcp_server.cpp diff --git a/src/include/prepared_template_rewriter.hpp b/src/include/prepared_template_rewriter.hpp new file mode 100644 index 0000000..80efb21 --- /dev/null +++ b/src/include/prepared_template_rewriter.hpp @@ -0,0 +1,47 @@ +#pragma once + +#include +#include +#include + +#include "sql_parameter_classifier.hpp" + +namespace flapi { + +struct RequestFieldConfig; + +// W3.1 (PR A): one entry in the binding plan produced by the rewriter. +// `position` is the 0-indexed `?` slot in the rewritten template; the +// caller binds values in that order via DuckDB's prepared-statement API. +struct PreparedBindingSpec { + std::string field_name; + SqlParameterType type = SqlParameterType::Varchar; + std::size_t position = 0; +}; + +struct PreparedRewriteResult { + std::string rewritten_template; + std::vector bindings; +}; + +// Pure helper. Scans a Mustache template and rewrites occurrences of +// `{{ params.X }}` to `?` for parameters X that the classifier marks +// bindable, recording the binding order. Untouched: +// - triple-brace `{{{ params.X }}}` (raw substitution; operators +// migrate these separately) +// - any `{{ params.X }}` inside a Mustache section block +// (`{{#X}}...{{/X}}` or `{{^X}}...{{/X}}`) +// - params whose `RequestFieldConfig` has no typed validator +// - non-`params.*` references like `{{ conn.X }}` or `{{ env.X }}` +// +// The result is suitable for: render-the-rewritten-template-as-Mustache +// (only structural variation remains), then `duckdb_prepare` + bind +// per the plan. +class PreparedTemplateRewriter { +public: + PreparedRewriteResult rewrite(const std::string& template_text, + const std::vector& request_fields, + const SqlParameterClassifier& classifier) const; +}; + +} // namespace flapi diff --git a/src/include/sql_parameter_classifier.hpp b/src/include/sql_parameter_classifier.hpp new file mode 100644 index 0000000..bb03582 --- /dev/null +++ b/src/include/sql_parameter_classifier.hpp @@ -0,0 +1,40 @@ +#pragma once + +#include + +namespace flapi { + +struct RequestFieldConfig; + +// W3.1 (PR A): how a typed parameter should be bound into a DuckDB +// prepared statement. The classifier maps a `RequestFieldConfig` to one +// of these — `PreparedTemplateRewriter` then uses the result to decide +// whether a `{{ params.X }}` occurrence should become a `?` placeholder. +enum class SqlParameterType { + Integer, // duckdb_bind_int64 + Double, // duckdb_bind_double + Boolean, // duckdb_bind_boolean + Date, // duckdb_bind_date (caller parses yyyy-mm-dd) + Time, // duckdb_bind_time (caller parses hh:mm:ss) + Varchar, // duckdb_bind_varchar (default for string-shaped values) +}; + +struct Bindability { + bool bindable = false; + SqlParameterType type = SqlParameterType::Varchar; +}; + +// Pure helper. Stateless; safe to construct on demand at the call site. +// +// The classification rule is intentionally conservative: a field without +// any typed validator is NOT considered bindable, so the Mustache path +// remains the safe default during migration. Unknown validator type +// names also fall back to non-bindable for forward compatibility — a +// new validator added in a later version doesn't accidentally route +// through the prepared path before its binder has been written. +class SqlParameterClassifier { +public: + Bindability classify(const RequestFieldConfig& field) const; +}; + +} // namespace flapi diff --git a/src/prepared_template_rewriter.cpp b/src/prepared_template_rewriter.cpp new file mode 100644 index 0000000..adab240 --- /dev/null +++ b/src/prepared_template_rewriter.cpp @@ -0,0 +1,193 @@ +#include "prepared_template_rewriter.hpp" + +#include + +#include "config_manager.hpp" + +namespace flapi { + +namespace { + +// Tag types we recognise while scanning the template. `Sentinel` keeps +// the lexer code small without sprinkling magic chars. +enum class TagKind { + OpenSection, // {{#name}} + OpenInvertedSection, // {{^name}} + CloseSection, // {{/name}} + TripleBrace, // {{{ ... }}} + DoubleBrace, // {{ ... }} + NoTag, +}; + +struct TagScan { + TagKind kind = TagKind::NoTag; + std::size_t start = 0; // index of the first `{` + std::size_t end = 0; // one past the last `}` + std::string inner; // trimmed content between braces +}; + +void appendRange(std::string& out, const std::string& src, std::size_t a, std::size_t b) { + out.append(src, a, b - a); +} + +std::string trim(std::string s) { + auto begin = std::find_if_not(s.begin(), s.end(), [](char c) { return c == ' ' || c == '\t'; }); + s.erase(s.begin(), begin); + auto rend = std::find_if_not(s.rbegin(), s.rend(), [](char c) { return c == ' ' || c == '\t'; }); + s.erase(rend.base(), s.end()); + return s; +} + +bool startsWith(const std::string& s, std::size_t from, const char* prefix) { + std::size_t n = 0; + while (prefix[n] != '\0') ++n; + if (s.size() < from + n) { + return false; + } + return s.compare(from, n, prefix) == 0; +} + +// Find the next Mustache-ish tag starting at `from`. Returns NoTag when +// no further `{{` appears in the template. +TagScan nextTag(const std::string& s, std::size_t from) { + TagScan out; + const std::size_t open = s.find("{{", from); + if (open == std::string::npos) { + return out; + } + out.start = open; + + // Triple-brace? + if (startsWith(s, open, "{{{")) { + const std::size_t close = s.find("}}}", open + 3); + if (close == std::string::npos) { + return out; // unterminated; treat as no-tag + } + out.kind = TagKind::TripleBrace; + out.end = close + 3; + out.inner = trim(s.substr(open + 3, close - (open + 3))); + return out; + } + + const std::size_t close = s.find("}}", open + 2); + if (close == std::string::npos) { + return out; + } + out.end = close + 2; + std::string raw = s.substr(open + 2, close - (open + 2)); + if (!raw.empty() && raw.front() == '#') { + out.kind = TagKind::OpenSection; + out.inner = trim(raw.substr(1)); + } else if (!raw.empty() && raw.front() == '^') { + out.kind = TagKind::OpenInvertedSection; + out.inner = trim(raw.substr(1)); + } else if (!raw.empty() && raw.front() == '/') { + out.kind = TagKind::CloseSection; + out.inner = trim(raw.substr(1)); + } else { + out.kind = TagKind::DoubleBrace; + out.inner = trim(raw); + } + return out; +} + +// Extract X from "params.X". Returns empty string when the expression +// doesn't have the expected shape. +std::string paramName(const std::string& inner) { + static const std::string kPrefix = "params."; + if (inner.compare(0, kPrefix.size(), kPrefix) != 0) { + return {}; + } + return inner.substr(kPrefix.size()); +} + +const RequestFieldConfig* findField(const std::string& name, + const std::vector& fields) { + for (const auto& f : fields) { + if (f.fieldName == name) { + return &f; + } + } + return nullptr; +} + +} // namespace + +PreparedRewriteResult PreparedTemplateRewriter::rewrite( + const std::string& template_text, + const std::vector& request_fields, + const SqlParameterClassifier& classifier) const { + + PreparedRewriteResult result; + result.rewritten_template.reserve(template_text.size()); + + std::size_t cursor = 0; + int section_depth = 0; + + while (cursor < template_text.size()) { + const TagScan tag = nextTag(template_text, cursor); + if (tag.kind == TagKind::NoTag) { + appendRange(result.rewritten_template, template_text, cursor, template_text.size()); + break; + } + + // Copy untouched text up to the tag. + appendRange(result.rewritten_template, template_text, cursor, tag.start); + + switch (tag.kind) { + case TagKind::OpenSection: + case TagKind::OpenInvertedSection: + ++section_depth; + appendRange(result.rewritten_template, template_text, tag.start, tag.end); + break; + case TagKind::CloseSection: + if (section_depth > 0) { + --section_depth; + } + appendRange(result.rewritten_template, template_text, tag.start, tag.end); + break; + case TagKind::TripleBrace: + // Out of scope for PR A — operators migrate triple-brace + // sites manually (drop surrounding quotes, switch to + // double-brace) before they become bindable. + appendRange(result.rewritten_template, template_text, tag.start, tag.end); + break; + case TagKind::DoubleBrace: { + if (section_depth > 0) { + appendRange(result.rewritten_template, template_text, tag.start, tag.end); + break; + } + const std::string name = paramName(tag.inner); + if (name.empty()) { + appendRange(result.rewritten_template, template_text, tag.start, tag.end); + break; + } + const RequestFieldConfig* field = findField(name, request_fields); + if (field == nullptr) { + appendRange(result.rewritten_template, template_text, tag.start, tag.end); + break; + } + const auto bindability = classifier.classify(*field); + if (!bindability.bindable) { + appendRange(result.rewritten_template, template_text, tag.start, tag.end); + break; + } + result.rewritten_template += '?'; + PreparedBindingSpec spec; + spec.field_name = name; + spec.type = bindability.type; + spec.position = result.bindings.size(); + result.bindings.push_back(std::move(spec)); + break; + } + case TagKind::NoTag: + break; // unreachable; satisfied by the early break above + } + + cursor = tag.end; + } + + return result; +} + +} // namespace flapi diff --git a/src/sql_parameter_classifier.cpp b/src/sql_parameter_classifier.cpp new file mode 100644 index 0000000..c8c7554 --- /dev/null +++ b/src/sql_parameter_classifier.cpp @@ -0,0 +1,56 @@ +#include "sql_parameter_classifier.hpp" + +#include "config_manager.hpp" + +namespace flapi { + +namespace { + +// Returns `true` and sets `out` when `type_name` maps to a bindable +// SqlParameterType; otherwise returns `false` and leaves `out` untouched. +// Comparison is case-sensitive on purpose — see the test rationale. +bool tryMapType(const std::string& type_name, SqlParameterType& out) { + if (type_name == "int" || type_name == "integer") { + out = SqlParameterType::Integer; + return true; + } + if (type_name == "number" || type_name == "float" || type_name == "double") { + out = SqlParameterType::Double; + return true; + } + if (type_name == "boolean" || type_name == "bool") { + out = SqlParameterType::Boolean; + return true; + } + if (type_name == "date") { + out = SqlParameterType::Date; + return true; + } + if (type_name == "time") { + out = SqlParameterType::Time; + return true; + } + if (type_name == "uuid" || type_name == "string" || type_name == "email" || + type_name == "enum") { + out = SqlParameterType::Varchar; + return true; + } + return false; +} + +} // namespace + +Bindability SqlParameterClassifier::classify(const RequestFieldConfig& field) const { + Bindability result; + for (const auto& validator : field.validators) { + SqlParameterType mapped = SqlParameterType::Varchar; + if (tryMapType(validator.type, mapped)) { + result.bindable = true; + result.type = mapped; + return result; // first known type wins, for determinism + } + } + return result; // bindable stays false +} + +} // namespace flapi diff --git a/test/cpp/CMakeLists.txt b/test/cpp/CMakeLists.txt index a16ef28..d928956 100644 --- a/test/cpp/CMakeLists.txt +++ b/test/cpp/CMakeLists.txt @@ -23,8 +23,10 @@ add_executable(flapi_tests query_executor_test.cpp rate_limit_middleware_test.cpp request_handler_test.cpp + prepared_template_rewriter_test.cpp request_validator_test.cpp security_auditor_test.cpp + sql_parameter_classifier_test.cpp sql_template_processor_test.cpp sql_utils_test.cpp test_duckdb_raii.cpp diff --git a/test/cpp/prepared_template_rewriter_test.cpp b/test/cpp/prepared_template_rewriter_test.cpp new file mode 100644 index 0000000..166e01e --- /dev/null +++ b/test/cpp/prepared_template_rewriter_test.cpp @@ -0,0 +1,404 @@ +#include + +#include "config_manager.hpp" +#include "prepared_template_rewriter.hpp" + +namespace flapi { +namespace test { + +namespace { + +RequestFieldConfig typedField(const std::string& name, const std::string& validator_type) { + RequestFieldConfig f; + f.fieldName = name; + f.fieldIn = "query"; + ValidatorConfig v; + v.type = validator_type; + f.validators.push_back(v); + return f; +} + +RequestFieldConfig bareField(const std::string& name) { + RequestFieldConfig f; + f.fieldName = name; + f.fieldIn = "query"; + return f; +} + +} // namespace + +TEST_CASE("PreparedTemplateRewriter: empty template yields empty result with no bindings", + "[security][prepared][rewriter]") { + PreparedTemplateRewriter rewriter; + SqlParameterClassifier classifier; + auto r = rewriter.rewrite("", {}, classifier); + REQUIRE(r.rewritten_template.empty()); + REQUIRE(r.bindings.empty()); +} + +TEST_CASE("PreparedTemplateRewriter: template with no params is left alone", + "[security][prepared][rewriter]") { + PreparedTemplateRewriter rewriter; + SqlParameterClassifier classifier; + auto r = rewriter.rewrite("SELECT 1", {}, classifier); + REQUIRE(r.rewritten_template == "SELECT 1"); + REQUIRE(r.bindings.empty()); +} + +TEST_CASE("PreparedTemplateRewriter: simple {{ params.X }} with int field is rewritten to ?", + "[security][prepared][rewriter]") { + PreparedTemplateRewriter rewriter; + SqlParameterClassifier classifier; + auto r = rewriter.rewrite( + "SELECT * FROM t WHERE id = {{ params.id }}", + {typedField("id", "int")}, + classifier); + + REQUIRE(r.rewritten_template == "SELECT * FROM t WHERE id = ?"); + REQUIRE(r.bindings.size() == 1); + REQUIRE(r.bindings[0].field_name == "id"); + REQUIRE(r.bindings[0].type == SqlParameterType::Integer); + REQUIRE(r.bindings[0].position == 0); +} + +TEST_CASE("PreparedTemplateRewriter: triple-brace {{{ params.X }}} is never rewritten", + "[security][prepared][rewriter]") { + // Triple-brace is the existing convention for raw-substituted string + // values inside SQL quotes. Migrating those is a separate step + // (operators must drop the surrounding quotes); v1 leaves them + // untouched. + PreparedTemplateRewriter rewriter; + SqlParameterClassifier classifier; + auto r = rewriter.rewrite( + "SELECT '{{{ params.name }}}'", + {typedField("name", "string")}, + classifier); + + REQUIRE(r.rewritten_template == "SELECT '{{{ params.name }}}'"); + REQUIRE(r.bindings.empty()); +} + +TEST_CASE("PreparedTemplateRewriter: param with no validator is left on the Mustache path", + "[security][prepared][rewriter]") { + PreparedTemplateRewriter rewriter; + SqlParameterClassifier classifier; + auto r = rewriter.rewrite( + "SELECT * FROM t WHERE id = {{ params.id }}", + {bareField("id")}, + classifier); + + REQUIRE(r.rewritten_template == "SELECT * FROM t WHERE id = {{ params.id }}"); + REQUIRE(r.bindings.empty()); +} + +TEST_CASE("PreparedTemplateRewriter: param missing from request_fields is left alone", + "[security][prepared][rewriter]") { + // A template that references an undeclared parameter is a config + // error, but the rewriter must not crash or invent bindings — + // existing flapi validation handles the diagnostic separately. + PreparedTemplateRewriter rewriter; + SqlParameterClassifier classifier; + auto r = rewriter.rewrite( + "SELECT {{ params.mystery }}", + {typedField("known", "int")}, + classifier); + + REQUIRE(r.rewritten_template == "SELECT {{ params.mystery }}"); + REQUIRE(r.bindings.empty()); +} + +TEST_CASE("PreparedTemplateRewriter: param inside {{#section}}...{{/section}} is left alone", + "[security][prepared][rewriter]") { + // A bindable param inside a Mustache conditional section is risky: + // if the section evaluates falsy at render time the `?` would + // disappear but the binding would still be queued, breaking position + // counts. v1 keeps the entire section on the Mustache path. + PreparedTemplateRewriter rewriter; + SqlParameterClassifier classifier; + const std::string in = + "SELECT * FROM t WHERE 1=1 " + "{{#params.id}}AND id = {{ params.id }}{{/params.id}}"; + auto r = rewriter.rewrite(in, {typedField("id", "int")}, classifier); + + REQUIRE(r.rewritten_template == in); + REQUIRE(r.bindings.empty()); +} + +TEST_CASE("PreparedTemplateRewriter: param after a closed section is rewritten", + "[security][prepared][rewriter]") { + // Section depth must return to 0 after the matching closing tag, + // so subsequent top-level params are eligible again. + PreparedTemplateRewriter rewriter; + SqlParameterClassifier classifier; + const std::string in = + "WITH cte AS ({{#params.flag}}SELECT 1{{/params.flag}}) " + "SELECT id FROM t WHERE id = {{ params.id }}"; + auto r = rewriter.rewrite(in, {typedField("id", "int"), typedField("flag", "int")}, classifier); + + REQUIRE(r.rewritten_template.find("?") != std::string::npos); + REQUIRE(r.bindings.size() == 1); + REQUIRE(r.bindings[0].field_name == "id"); +} + +TEST_CASE("PreparedTemplateRewriter: multiple top-level params produce ordered bindings", + "[security][prepared][rewriter]") { + PreparedTemplateRewriter rewriter; + SqlParameterClassifier classifier; + auto r = rewriter.rewrite( + "SELECT * FROM t WHERE id = {{ params.id }} AND status = {{ params.status }}", + {typedField("id", "int"), typedField("status", "string")}, + classifier); + + REQUIRE(r.bindings.size() == 2); + REQUIRE(r.bindings[0].field_name == "id"); + REQUIRE(r.bindings[0].position == 0); + REQUIRE(r.bindings[0].type == SqlParameterType::Integer); + REQUIRE(r.bindings[1].field_name == "status"); + REQUIRE(r.bindings[1].position == 1); + REQUIRE(r.bindings[1].type == SqlParameterType::Varchar); +} + +TEST_CASE("PreparedTemplateRewriter: repeated occurrence of a single param produces two bindings", + "[security][prepared][rewriter]") { + // DuckDB's prepared API uses positional `?` — repeating the same + // logical param twice means two physical bindings, not one. + PreparedTemplateRewriter rewriter; + SqlParameterClassifier classifier; + auto r = rewriter.rewrite( + "SELECT {{ params.id }} AS a, {{ params.id }} AS b", + {typedField("id", "int")}, + classifier); + + REQUIRE(r.bindings.size() == 2); + REQUIRE(r.bindings[0].field_name == "id"); + REQUIRE(r.bindings[1].field_name == "id"); + REQUIRE(r.bindings[0].position == 0); + REQUIRE(r.bindings[1].position == 1); +} + +TEST_CASE("PreparedTemplateRewriter: whitespace inside braces is tolerated", + "[security][prepared][rewriter]") { + PreparedTemplateRewriter rewriter; + SqlParameterClassifier classifier; + auto r = rewriter.rewrite( + "SELECT {{ params.id }} FROM t", + {typedField("id", "int")}, + classifier); + REQUIRE(r.bindings.size() == 1); +} + +TEST_CASE("PreparedTemplateRewriter: inverted section ({{^X}}) also pauses rewriting", + "[security][prepared][rewriter]") { + // Mustache inverted sections (`{{^X}}...{{/X}}`) have the same + // depth semantics as positive ones; bindable params inside them + // must NOT be rewritten. + PreparedTemplateRewriter rewriter; + SqlParameterClassifier classifier; + const std::string in = + "SELECT * FROM t {{^params.skip}}WHERE id = {{ params.id }}{{/params.skip}}"; + auto r = rewriter.rewrite(in, {typedField("id", "int"), typedField("skip", "int")}, classifier); + + REQUIRE(r.rewritten_template == in); + REQUIRE(r.bindings.empty()); +} + +TEST_CASE("PreparedTemplateRewriter: unterminated double-brace tag is left alone", + "[security][prepared][rewriter][edge]") { + // Defensive: a malformed template must never crash or produce an + // invalid binding plan. The whole tail of the template flows + // through as untouched text. + PreparedTemplateRewriter rewriter; + SqlParameterClassifier classifier; + auto r = rewriter.rewrite( + "SELECT * FROM t WHERE id = {{ params.id no closing brace", + {typedField("id", "int")}, + classifier); + REQUIRE(r.bindings.empty()); + REQUIRE(r.rewritten_template.find("{{ params.id") != std::string::npos); +} + +TEST_CASE("PreparedTemplateRewriter: unterminated triple-brace is left alone", + "[security][prepared][rewriter][edge]") { + PreparedTemplateRewriter rewriter; + SqlParameterClassifier classifier; + auto r = rewriter.rewrite( + "SELECT '{{{ params.name and then nothing", + {typedField("name", "string")}, + classifier); + REQUIRE(r.bindings.empty()); +} + +TEST_CASE("PreparedTemplateRewriter: newlines and tabs inside tags are tolerated", + "[security][prepared][rewriter][edge]") { + PreparedTemplateRewriter rewriter; + SqlParameterClassifier classifier; + auto r = rewriter.rewrite( + "SELECT * FROM t WHERE id = {{\tparams.id\t}}", + {typedField("id", "int")}, + classifier); + REQUIRE(r.bindings.size() == 1); + REQUIRE(r.rewritten_template.find('?') != std::string::npos); +} + +TEST_CASE("PreparedTemplateRewriter: no-whitespace form {{params.X}} also matches", + "[security][prepared][rewriter][edge]") { + PreparedTemplateRewriter rewriter; + SqlParameterClassifier classifier; + auto r = rewriter.rewrite("WHERE id = {{params.id}}", + {typedField("id", "int")}, classifier); + REQUIRE(r.bindings.size() == 1); + REQUIRE(r.rewritten_template == "WHERE id = ?"); +} + +TEST_CASE("PreparedTemplateRewriter: nested sections still suppress rewriting", + "[security][prepared][rewriter][edge]") { + // Depth must increment for every opening tag and decrement for + // every close, so a bindable param nested two sections deep is + // still left for Mustache. + PreparedTemplateRewriter rewriter; + SqlParameterClassifier classifier; + const std::string in = + "{{#params.a}}{{#params.b}}WHERE id = {{ params.id }}{{/params.b}}{{/params.a}}"; + auto r = rewriter.rewrite(in, + {typedField("a", "int"), typedField("b", "int"), + typedField("id", "int")}, + classifier); + REQUIRE(r.bindings.empty()); + REQUIRE(r.rewritten_template == in); +} + +TEST_CASE("PreparedTemplateRewriter: stray closing section without an open is harmless", + "[security][prepared][rewriter][edge]") { + // Bug-bait input. The depth counter clamps at zero so a stray + // `{{/X}}` doesn't underflow, and subsequent params remain bindable. + PreparedTemplateRewriter rewriter; + SqlParameterClassifier classifier; + auto r = rewriter.rewrite("{{/orphan}} SELECT {{ params.id }} FROM t", + {typedField("id", "int")}, classifier); + REQUIRE(r.bindings.size() == 1); +} + +TEST_CASE("PreparedTemplateRewriter: section open without matching close keeps depth nonzero", + "[security][prepared][rewriter][edge]") { + // If the operator forgets to close a section, every subsequent + // bindable param stays on the Mustache path. This is a strictly + // safer-than-alternative behaviour: better an unbound param that + // Mustache renders than a bound `?` whose surrounding section + // disappears at render time. + PreparedTemplateRewriter rewriter; + SqlParameterClassifier classifier; + auto r = rewriter.rewrite( + "{{#params.a}}WHERE id = {{ params.id }} no-close-tag", + {typedField("a", "int"), typedField("id", "int")}, + classifier); + REQUIRE(r.bindings.empty()); +} + +TEST_CASE("PreparedTemplateRewriter: empty params.name (just 'params.') is left alone", + "[security][prepared][rewriter][edge]") { + PreparedTemplateRewriter rewriter; + SqlParameterClassifier classifier; + auto r = rewriter.rewrite("SELECT {{ params. }}", + {typedField("id", "int")}, classifier); + REQUIRE(r.bindings.empty()); +} + +TEST_CASE("PreparedTemplateRewriter: bare {{ params }} without a dot is left alone", + "[security][prepared][rewriter][edge]") { + PreparedTemplateRewriter rewriter; + SqlParameterClassifier classifier; + auto r = rewriter.rewrite("SELECT {{ params }}", + {typedField("id", "int")}, classifier); + REQUIRE(r.bindings.empty()); +} + +TEST_CASE("PreparedTemplateRewriter: multi-line template binds across line boundaries", + "[security][prepared][rewriter][edge]") { + PreparedTemplateRewriter rewriter; + SqlParameterClassifier classifier; + const std::string in = + "SELECT *\n" + "FROM customers\n" + "WHERE id = {{ params.id }}\n" + " AND region = {{ params.region }}\n" + "ORDER BY id"; + auto r = rewriter.rewrite(in, + {typedField("id", "int"), typedField("region", "string")}, + classifier); + REQUIRE(r.bindings.size() == 2); + REQUIRE(r.bindings[0].field_name == "id"); + REQUIRE(r.bindings[1].field_name == "region"); +} + +TEST_CASE("PreparedTemplateRewriter: many distinct bindings preserve order strictly", + "[security][prepared][rewriter][edge]") { + // Smoke-test the binding-position bookkeeping against a non-trivial + // count. DuckDB binds by position; an off-by-one would corrupt every + // query in production. + PreparedTemplateRewriter rewriter; + SqlParameterClassifier classifier; + std::vector fields; + std::string in = "SELECT 1"; + for (int i = 0; i < 25; ++i) { + const std::string name = "p" + std::to_string(i); + fields.push_back(typedField(name, "int")); + in += " AND col = {{ params." + name + " }}"; + } + auto r = rewriter.rewrite(in, fields, classifier); + REQUIRE(r.bindings.size() == 25); + for (std::size_t i = 0; i < r.bindings.size(); ++i) { + REQUIRE(r.bindings[i].field_name == "p" + std::to_string(i)); + REQUIRE(r.bindings[i].position == i); + } +} + +TEST_CASE("PreparedTemplateRewriter: classic SQL-injection payload as a value is bound, not injected", + "[security][prepared][rewriter][edge]") { + // The rewriter itself doesn't see values — only the template. But + // the proof of the security guarantee is: the template references + // `params.id` once and produces one `?` and one binding. The + // VALUE of params.id at runtime (whatever it is, including a + // classic injection payload) lands in DuckDB's bind buffer, never + // as syntax. This test pins the binding count so any future change + // that "expands" the value into the template would fail loudly. + PreparedTemplateRewriter rewriter; + SqlParameterClassifier classifier; + auto r = rewriter.rewrite( + "SELECT * FROM users WHERE id = {{ params.id }}", + {typedField("id", "int")}, + classifier); + REQUIRE(r.rewritten_template == "SELECT * FROM users WHERE id = ?"); + REQUIRE(r.bindings.size() == 1); +} + +TEST_CASE("PreparedTemplateRewriter: rewriting is idempotent under repeated invocation", + "[security][prepared][rewriter][edge]") { + // Sanity: running the rewriter on its own output must be a no-op + // (no `{{ params.X }}` left to rewrite the second time). + PreparedTemplateRewriter rewriter; + SqlParameterClassifier classifier; + auto first = rewriter.rewrite("WHERE id = {{ params.id }}", + {typedField("id", "int")}, classifier); + auto second = rewriter.rewrite(first.rewritten_template, + {typedField("id", "int")}, classifier); + REQUIRE(second.rewritten_template == first.rewritten_template); + REQUIRE(second.bindings.empty()); +} + +TEST_CASE("PreparedTemplateRewriter: param names not starting with 'params.' are ignored", + "[security][prepared][rewriter]") { + // The rewriter only owns `params.*` — `conn.*` / `env.*` / `auth.*` + // are connection/env/auth variables that the existing Mustache path + // populates with operator-controlled values, not user input. + PreparedTemplateRewriter rewriter; + SqlParameterClassifier classifier; + auto r = rewriter.rewrite( + "SELECT * FROM read_parquet('{{{ conn.path }}}') WHERE x = {{ conn.x }}", + {typedField("x", "int")}, // x is a request field, but `conn.x` is not `params.x` + classifier); + + REQUIRE(r.bindings.empty()); +} + +} // namespace test +} // namespace flapi diff --git a/test/cpp/sql_parameter_classifier_test.cpp b/test/cpp/sql_parameter_classifier_test.cpp new file mode 100644 index 0000000..c48cb88 --- /dev/null +++ b/test/cpp/sql_parameter_classifier_test.cpp @@ -0,0 +1,229 @@ +#include + +#include "config_manager.hpp" +#include "sql_parameter_classifier.hpp" + +namespace flapi { +namespace test { + +namespace { + +RequestFieldConfig fieldWith(const std::string& name, const std::string& validator_type) { + RequestFieldConfig f; + f.fieldName = name; + f.fieldIn = "query"; + if (!validator_type.empty()) { + ValidatorConfig v; + v.type = validator_type; + f.validators.push_back(v); + } + return f; +} + +RequestFieldConfig bareField(const std::string& name) { + RequestFieldConfig f; + f.fieldName = name; + f.fieldIn = "query"; + return f; +} + +} // namespace + +TEST_CASE("SqlParameterClassifier: field with no validators is not bindable", + "[security][prepared][classifier]") { + // Without a validator we have no shape information for the value. + // The conservative call is to leave it on the Mustache path; an + // operator can opt in to binding by attaching a typed validator. + SqlParameterClassifier c; + auto r = c.classify(bareField("x")); + REQUIRE_FALSE(r.bindable); +} + +TEST_CASE("SqlParameterClassifier: empty validator type is not bindable", + "[security][prepared][classifier]") { + SqlParameterClassifier c; + RequestFieldConfig f = bareField("x"); + f.validators.push_back(ValidatorConfig{}); // type == "" + auto r = c.classify(f); + REQUIRE_FALSE(r.bindable); +} + +TEST_CASE("SqlParameterClassifier: int validator binds as Integer", + "[security][prepared][classifier]") { + SqlParameterClassifier c; + auto r = c.classify(fieldWith("id", "int")); + REQUIRE(r.bindable); + REQUIRE(r.type == SqlParameterType::Integer); +} + +TEST_CASE("SqlParameterClassifier: integer alias also maps to Integer", + "[security][prepared][classifier]") { + SqlParameterClassifier c; + auto r = c.classify(fieldWith("id", "integer")); + REQUIRE(r.bindable); + REQUIRE(r.type == SqlParameterType::Integer); +} + +TEST_CASE("SqlParameterClassifier: number validator binds as Double", + "[security][prepared][classifier]") { + SqlParameterClassifier c; + auto r = c.classify(fieldWith("amount", "number")); + REQUIRE(r.bindable); + REQUIRE(r.type == SqlParameterType::Double); +} + +TEST_CASE("SqlParameterClassifier: float / double aliases map to Double", + "[security][prepared][classifier]") { + SqlParameterClassifier c; + REQUIRE(c.classify(fieldWith("a", "float")).type == SqlParameterType::Double); + REQUIRE(c.classify(fieldWith("b", "double")).type == SqlParameterType::Double); +} + +TEST_CASE("SqlParameterClassifier: boolean validator binds as Boolean", + "[security][prepared][classifier]") { + SqlParameterClassifier c; + auto r = c.classify(fieldWith("active", "boolean")); + REQUIRE(r.bindable); + REQUIRE(r.type == SqlParameterType::Boolean); +} + +TEST_CASE("SqlParameterClassifier: date validator binds as Date", + "[security][prepared][classifier]") { + SqlParameterClassifier c; + auto r = c.classify(fieldWith("created", "date")); + REQUIRE(r.bindable); + REQUIRE(r.type == SqlParameterType::Date); +} + +TEST_CASE("SqlParameterClassifier: time validator binds as Time", + "[security][prepared][classifier]") { + SqlParameterClassifier c; + auto r = c.classify(fieldWith("at", "time")); + REQUIRE(r.bindable); + REQUIRE(r.type == SqlParameterType::Time); +} + +TEST_CASE("SqlParameterClassifier: uuid validator binds as Varchar", + "[security][prepared][classifier]") { + // DuckDB will parse the UUID from a VARCHAR binding; treating it as + // Varchar keeps the binder simple and correct. + SqlParameterClassifier c; + auto r = c.classify(fieldWith("id", "uuid")); + REQUIRE(r.bindable); + REQUIRE(r.type == SqlParameterType::Varchar); +} + +TEST_CASE("SqlParameterClassifier: string validator binds as Varchar", + "[security][prepared][classifier]") { + SqlParameterClassifier c; + auto r = c.classify(fieldWith("name", "string")); + REQUIRE(r.bindable); + REQUIRE(r.type == SqlParameterType::Varchar); +} + +TEST_CASE("SqlParameterClassifier: email / enum validators bind as Varchar", + "[security][prepared][classifier]") { + SqlParameterClassifier c; + REQUIRE(c.classify(fieldWith("e", "email")).type == SqlParameterType::Varchar); + REQUIRE(c.classify(fieldWith("s", "enum")).type == SqlParameterType::Varchar); +} + +TEST_CASE("SqlParameterClassifier: unknown validator type is not bindable", + "[security][prepared][classifier]") { + // Forward-compat: a future validator we don't know about must NOT + // accidentally land on the prepared path. Falling back to Mustache + // is the safe default. + SqlParameterClassifier c; + auto r = c.classify(fieldWith("x", "future-type-from-2030")); + REQUIRE_FALSE(r.bindable); +} + +TEST_CASE("SqlParameterClassifier: multiple validators take the first known type", + "[security][prepared][classifier]") { + // A field can carry several validators (range AND pattern, say). + // We pick the first one that resolves to a typed binding so behaviour + // is deterministic. + SqlParameterClassifier c; + RequestFieldConfig f = bareField("x"); + ValidatorConfig v1; v1.type = "int"; + ValidatorConfig v2; v2.type = "string"; + f.validators.push_back(v1); + f.validators.push_back(v2); + auto r = c.classify(f); + REQUIRE(r.bindable); + REQUIRE(r.type == SqlParameterType::Integer); +} + +TEST_CASE("SqlParameterClassifier: multiple validators, unknown first then known, picks the known", + "[security][prepared][classifier][edge]") { + // Forward-compat: a validator we don't recognise must NOT block a + // later recognised one from making the field bindable. + SqlParameterClassifier c; + RequestFieldConfig f = bareField("x"); + ValidatorConfig v1; v1.type = "future-type-from-2030"; + ValidatorConfig v2; v2.type = "int"; + f.validators.push_back(v1); + f.validators.push_back(v2); + auto r = c.classify(f); + REQUIRE(r.bindable); + REQUIRE(r.type == SqlParameterType::Integer); +} + +TEST_CASE("SqlParameterClassifier: validator type with stray whitespace is not bindable", + "[security][prepared][classifier][edge]") { + // YAML doesn't trim values by default; an operator who writes + // `type: " int"` must not silently land on the prepared path. + SqlParameterClassifier c; + REQUIRE_FALSE(c.classify(fieldWith("x", " int")).bindable); + REQUIRE_FALSE(c.classify(fieldWith("x", "int ")).bindable); +} + +TEST_CASE("SqlParameterClassifier: empty field with empty validators is not bindable", + "[security][prepared][classifier][edge]") { + SqlParameterClassifier c; + RequestFieldConfig empty; + REQUIRE_FALSE(c.classify(empty).bindable); +} + +TEST_CASE("SqlParameterClassifier: very large validator list still picks the first known", + "[security][prepared][classifier][edge]") { + // Bookkeeping smoke test — make sure iteration ordering is stable + // and doesn't depend on validator count. + SqlParameterClassifier c; + RequestFieldConfig f = bareField("x"); + for (int i = 0; i < 50; ++i) { + ValidatorConfig v; v.type = "unknown-" + std::to_string(i); + f.validators.push_back(v); + } + ValidatorConfig known; known.type = "boolean"; + f.validators.push_back(known); + auto r = c.classify(f); + REQUIRE(r.bindable); + REQUIRE(r.type == SqlParameterType::Boolean); +} + +TEST_CASE("SqlParameterClassifier: bindability decision is deterministic across calls", + "[security][prepared][classifier][edge]") { + SqlParameterClassifier c; + auto f = fieldWith("x", "int"); + auto a = c.classify(f); + auto b = c.classify(f); + auto cc = c.classify(f); + REQUIRE(a.bindable == b.bindable); + REQUIRE(b.bindable == cc.bindable); + REQUIRE(a.type == b.type); + REQUIRE(b.type == cc.type); +} + +TEST_CASE("SqlParameterClassifier: validator type comparison is case-sensitive", + "[security][prepared][classifier]") { + // YAML config conventions in flapi are all lowercase; tolerating + // case differences would let `Int` quietly bypass the typed path on + // some platforms. Explicit type names only. + SqlParameterClassifier c; + REQUIRE_FALSE(c.classify(fieldWith("x", "Int")).bindable); + REQUIRE_FALSE(c.classify(fieldWith("x", "STRING")).bindable); +} + +} // namespace test +} // namespace flapi