diff --git a/README.md b/README.md index fb79293..f8c649f 100644 --- a/README.md +++ b/README.md @@ -241,6 +241,75 @@ let newId = query: ``` +### Typed Queries + +Use `query(T):` when you want Ormin to deserialize selected columns directly into a named Nim type instead of returning the default tuple shape. This is useful at module boundaries where a named object, ref object, or scalar domain type is clearer than a tuple. + +For object results, selected column names must match fields on the destination type. Use `as` aliases when the database column name differs from the Nim field name: + +```nim +type + ThreadSummary = object + id: int + title: string + +let threads = query(ThreadSummary): + select thread(id, name as title) + orderby id +``` + +Selecting one column can map directly to a scalar type: + +```nim +let names = query(string): + select thread(name) +``` + +Queries that return a single row, such as a `limit 1` query, return one `T` instead of `seq[T]`: + +```nim +let thread = query(ThreadSummary): + select thread(id, name as title) + where id == ?threadId + limit 1 +``` + +#### `fromQueryHook` Column Hooks + +Typed queries deserialize each selected column through `fromQueryHook`. You can overload this hook for your own field or scalar destination types: + +```nim +import ormin/query_hooks + +type + TitleLength = distinct int + + ThreadTitleSize = object + id: int + title: TitleLength + +proc fromQueryHook*(val: var TitleLength, value: string) = + val = TitleLength(value.len) + +let rows = query(ThreadTitleSize): + select thread(id, name as title) +``` + +If a hook needs to handle SQL `NULL` itself, accept a `DbValue[SourceType]`: + +```nim +type + NullableTitle = distinct string + +proc fromQueryHook*(val: var NullableTitle, value: DbValue[string]) = + if value.isNull: + val = NullableTitle("") + else: + val = NullableTitle(value.value) +``` + +These are column deserialization hooks. In object typed queries, Ormin calls `fromQueryHook` separately for each selected column that maps to a destination field; it does not currently call a hook for the entire row object. For whole-row transformations, query into an intermediate typed result and convert it in regular Nim code. + ### JSON and Raw SQL JSON values can be spliced directly using `%` expressions. The `%` prefix tells Ormin to treat the following Nim expression as a `JsonNode` without conversion: diff --git a/config.nims b/config.nims index a9fa5fd..9eb46ad 100644 --- a/config.nims +++ b/config.nims @@ -16,6 +16,7 @@ task test, "Run all test suite": exec "nim c -f -r tests/tfeature" exec "nim c -f -r tests/tcommon" + exec "nim c -f -r -d:release tests/tquery_types" exec "nim c -f -r tests/tsqlite" exec "nim c -f -r tests/tdb_utils" exec "nim c -f -r tests/timportstatic" @@ -34,6 +35,7 @@ task test_postgres, "Run PostgreSQL test suite": exec "nim c -f -d:nimDebugDlOpen -r -d:postgre tests/tfeature" exec "nim c -f -d:nimDebugDlOpen -r -d:postgre tests/tcommon" + exec "nim c -f -d:nimDebugDlOpen -r -d:release -d:postgre tests/tquery_types" exec "nim c -f -d:nimDebugDlOpen -r -d:postgre tests/tpostgre" task buildexamples, "Build examples: chat and forum": diff --git a/ormin.nimble b/ormin.nimble index d1db178..1de2b6e 100644 --- a/ormin.nimble +++ b/ormin.nimble @@ -1,6 +1,6 @@ # Package -version = "0.8.1" +version = "0.9.0" author = "Araq" description = "Prepared SQL statement generator. A lightweight ORM." license = "MIT" @@ -20,3 +20,4 @@ feature "examples": import std/os when fileExists("config.nims"): include "config.nims" + diff --git a/ormin/ormin_postgre.nim b/ormin/ormin_postgre.nim index bbe75a4..e3f1beb 100644 --- a/ormin/ormin_postgre.nim +++ b/ormin/ormin_postgre.nim @@ -1,6 +1,7 @@ import strutils, db_connector/postgres, json, times import db_connector/db_common +import query_hooks export db_common type @@ -62,6 +63,9 @@ template bindParamUnchecked(db: DbConn; s: PStmt; idx: int; x: untyped; t: untyp pparams[idx-1] = $x parr[idx-1] = cstring(pparams[idx-1]) +template bindNullParam*(db: DbConn; s: PStmt; idx: int) = + parr[idx-1] = cstring(nil) + template bindParamJson*(db: DbConn; s: PStmt; idx: int; xx: JsonNode; t: typedesc) = let x = xx @@ -129,9 +133,15 @@ proc fillString(dest: var string; src: cstring; srcLen: int) {.inline.} = template bindResult*(db: DbConn; s: PStmt; idx: int; dest: var string; t: typedesc; name: string) = - let src = pqgetvalue(queryResult, queryI, idx.cint) - let srcLen = int(pqgetlength(queryResult, queryI, idx.cint)) - fillString(dest, src, srcLen) + if pqgetisnull(queryResult, queryI, idx.cint) != 0: + when defined(nimNoNilSeqs): + setLen(dest, 0) + else: + dest = nil + else: + let src = pqgetvalue(queryResult, queryI, idx.cint) + let srcLen = int(pqgetlength(queryResult, queryI, idx.cint)) + fillString(dest, src, srcLen) template bindResult*(db: DbConn; s: PStmt; idx: int; dest: float64; t: typedesc; name: string) = @@ -162,6 +172,19 @@ template bindResult*(db: DbConn; s: PStmt; idx: int; dest: JsonNode; t: typedesc; name: string) = dest = parseJson($pqgetvalue(queryResult, queryI, idx.cint)) +template bindResult*[T](db: DbConn; s: PStmt; idx: int; dest: var DbValue[T]; + t: typedesc; name: string) = + if pqgetisnull(queryResult, queryI, idx.cint) != 0: + dest.isNull = true + else: + dest.isNull = false + when T is string: + let src = pqgetvalue(queryResult, queryI, idx.cint) + let srcLen = int(pqgetlength(queryResult, queryI, idx.cint)) + fillString(dest.value, src, srcLen) + else: + bindResult(db, s, idx, dest.value, t, name) + template createJObject*(): untyped = newJObject() template createJArray*(): untyped = newJArray() @@ -174,6 +197,9 @@ template bindResultJson*(db: DbConn; s: PStmt; idx: int; obj: JsonNode; else: bindToJson(db, s, idx, x, t, name) +template columnIsNull*(db: DbConn; s: PStmt; idx: int): bool = + pqgetisnull(queryResult, queryI, idx.cint) != 0 + template bindToJson*(db: DbConn; s: PStmt; idx: int; obj: JsonNode; t: typedesc; name: string) = {.error: "invalid type for JSON object".} diff --git a/ormin/ormin_sqlite.nim b/ormin/ormin_sqlite.nim index c343e51..762f18a 100644 --- a/ormin/ormin_sqlite.nim +++ b/ormin/ormin_sqlite.nim @@ -5,6 +5,7 @@ import json, times import db_connector/db_common import db_connector/sqlite3 +import query_hooks export db_common type @@ -45,12 +46,12 @@ template bindParam*(db: DbConn; s: PStmt; idx: int; x, t: untyped) = cast[pointer](nil) else: cast[pointer](unsafeAddr(xs[0])) - if bind_blob(s, idx.cint, blobPtr, xs.len.cint, SQLITE_STATIC) != SQLITE_OK: + if bind_blob(s, idx.cint, blobPtr, xs.len.cint, SQLITE_TRANSIENT) != SQLITE_OK: dbError(db) elif t is int or t is int64 or t is bool: if bind_int64(s, idx.cint, x.int64) != SQLITE_OK: dbError(db) elif t is string: - if bind_text(s, idx.cint, cstring(x), x.len.cint, SQLITE_STATIC) != SQLITE_OK: + if bind_text(s, idx.cint, cstring(x), x.len.cint, SQLITE_TRANSIENT) != SQLITE_OK: dbError(db) elif t is float64: if bind_double(s, idx.cint, x) != SQLITE_OK: @@ -61,15 +62,19 @@ template bindParam*(db: DbConn; s: PStmt; idx: int; x, t: untyped) = else: x.utc().format("yyyy-MM-dd HH:mm:ss") - if bind_text(s, idx.cint, cstring(xx), xx.len.cint, SQLITE_STATIC) != SQLITE_OK: + if bind_text(s, idx.cint, cstring(xx), xx.len.cint, SQLITE_TRANSIENT) != SQLITE_OK: dbError(db) elif t is JsonNode: let xx = $x - if bind_text(s, idx.cint, cstring(xx), xx.len.cint, SQLITE_STATIC) != SQLITE_OK: + if bind_text(s, idx.cint, cstring(xx), xx.len.cint, SQLITE_TRANSIENT) != SQLITE_OK: dbError(db) else: {.error: "type mismatch for query argument at position " & $idx.} +template bindNullParam*(db: DbConn; s: PStmt; idx: int) = + if bind_null(s, idx.cint) != SQLITE_OK: + dbError(db) + template bindParamJson*(db: DbConn; s: PStmt; idx: int; xx: JsonNode; t: typedesc) = let x = xx @@ -86,7 +91,7 @@ template bindFromJson*(db: DbConn; s: PStmt; idx: int; x: JsonNode; t: typedesc[string]) = doAssert x.kind == JString let xs = x.str - if bind_text(s, idx.cint, cstring(xs), xs.len.cint, SQLITE_STATIC) != SQLITE_OK: + if bind_text(s, idx.cint, cstring(xs), xs.len.cint, SQLITE_TRANSIENT) != SQLITE_OK: dbError(db) template bindFromJson*(db: DbConn; s: PStmt; idx: int; x: JsonNode; @@ -118,7 +123,7 @@ template bindFromJson*(db: DbConn; s: PStmt; idx: int; x: JsonNode; dtStr[0 ..< i] else: dtStr - if bind_text(s, idx.cint, cstring(dt), dt.len.cint, SQLITE_STATIC) != SQLITE_OK: + if bind_text(s, idx.cint, cstring(dt), dt.len.cint, SQLITE_TRANSIENT) != SQLITE_OK: dbError(db) template bindResult*(db: DbConn; s: PStmt; idx: int; dest: int; @@ -149,9 +154,15 @@ proc fillBytes(dest: var seq[byte]; src: pointer; srcLen: int) = template bindResult*(db: DbConn; s: PStmt; idx: int; dest: var string; t: typedesc; name: string) = - let srcLen = column_bytes(s, idx.cint) - let src = column_text(s, idx.cint) - fillString(dest, src, srcLen) + if column_type(s, idx.cint) == SQLITE_NULL: + when defined(nimNoNilSeqs): + setLen(dest, 0) + else: + dest = nil + else: + let srcLen = column_bytes(s, idx.cint) + let src = column_text(s, idx.cint) + fillString(dest, src, srcLen) template bindResult*(db: DbConn; s: PStmt; idx: int; dest: var blobType; t: typedesc; name: string) = @@ -182,6 +193,19 @@ template bindResult*(db: DbConn; s: PStmt; idx: int; dest: JsonNode; let src = column_text(s, idx.cint) dest = parseJson($src) +template bindResult*[T](db: DbConn; s: PStmt; idx: int; dest: var DbValue[T]; + t: typedesc; name: string) = + if column_type(s, idx.cint) == SQLITE_NULL: + dest.isNull = true + else: + dest.isNull = false + when T is string: + let srcLen = column_bytes(s, idx.cint) + let src = column_text(s, idx.cint) + fillString(dest.value, src, srcLen) + else: + bindResult(db, s, idx, dest.value, t, name) + template createJObject*(): untyped = newJObject() template createJArray*(): untyped = newJArray() @@ -194,6 +218,9 @@ template bindResultJson*(db: DbConn; s: PStmt; idx: int; obj: JsonNode; else: bindToJson(db, s, idx, x, t, name) +template columnIsNull*(db: DbConn; s: PStmt; idx: int): bool = + column_type(s, idx.cint) == SQLITE_NULL + template bindToJson*(db: DbConn; s: PStmt; idx: int; obj: JsonNode; t: typedesc; name: string) = {.error: "invalid type for JSON object".} diff --git a/ormin/queries.nim b/ormin/queries.nim index 88cdd79..8848126 100644 --- a/ormin/queries.nim +++ b/ormin/queries.nim @@ -10,6 +10,7 @@ import db_connector/db_common from os import parentDir, `/` import db_types +import query_hooks # SQL dialect specific things: const @@ -31,6 +32,8 @@ type sql: string cols: seq[SourceColumn] +proc buildHookedParamBinding(prepStmt: NimNode; idx: int; ex, typ: NimNode; isJson: bool): NimNode + var functions {.compileTime.} = @[ Function(name: "count", arity: 1, typ: dbInt), @@ -809,10 +812,10 @@ proc generateRoutine(name: NimNode, q: QueryBuilder; for p in q.params: if p.isJson: finalParams.add newIdentDefs(p.ex, ident"JsonNode") - body.add newCall(bindSym"bindParamJson", ident"db", prepStmt, newLit(i), p.ex, p.typ) + body.add buildHookedParamBinding(prepStmt, i, p.ex, p.typ, true) else: finalParams.add newIdentDefs(p.ex, p.typ) - body.add newCall(bindSym"bindParam", ident"db", prepStmt, newLit(i), p.ex, p.typ) + body.add buildHookedParamBinding(prepStmt, i, p.ex, p.typ, false) inc i body.add newCall(bindSym"startQuery", ident"db", prepStmt) let yld = newStmtList() @@ -1490,18 +1493,75 @@ proc renderInlineQuery(n: NimNode; params: var Params; result.sql = queryAsString(subq, n) result.typ = DbType(kind: dbSet) -proc newGlobalVar(name, typ: NimNode, value: NimNode): NimNode = - result = newTree(nnkVarSection, - newTree(nnkIdentDefs, newTree(nnkPragmaExpr, name, - newTree(nnkPragma, ident"global")), typ, value) - ) - proc makeSeq(retType: NimNode; singleRow: bool): NimNode = if not singleRow: result = newTree(nnkBracketExpr, bindSym"seq", retType) else: result = retType +proc buildHookedParamBinding(prepStmt: NimNode; idx: int; ex, typ: NimNode; isJson: bool): NimNode = + if isJson: + return newCall(bindSym"bindParamJson", ident"db", prepStmt, newLit(idx), ex, typ) + + result = quote do: + block: + var converted: DbValue[`typ`] + toQueryHook(converted, `ex`) + if converted.isNull: + bindNullParam(db, `prepStmt`, `idx`) + else: + bindParam(db, `prepStmt`, `idx`, converted.value, `typ`) + +proc buildHookedResultAssign(prepStmt, destExpr, destType, sourceType: NimNode; idx: int; colName: string): NimNode = + result = quote do: + var rawValue: DbValue[`sourceType`] + bindResult(db, `prepStmt`, `idx`, rawValue, `sourceType`, `colName`) + `destExpr`.fromQueryHook(rawValue) + +proc buildQueryHookAction(q: QueryBuilder; prepStmt, res, retType: NimNode; singleRow: bool): NimNode = + let selectedCount = newLit(q.retType.len) + + let mapped = genSym(nskVar, "mapped") + let mappedStmt = newStmtList() + mappedStmt.add quote do: + var `mapped` = `retType`() + for idx, name in q.retNames: + let fieldName = ident(name) + let sourceType = q.retType[idx][1] + let destExpr = quote do: + `mapped`.`fieldName` + let hooked = buildHookedResultAssign(prepStmt, destExpr, retType, sourceType, idx, name) + mappedStmt.add quote do: + when compiles(`mapped`.`fieldName`): + `hooked` + if singleRow: + mappedStmt.add quote do: + `res` = `mapped` + else: + mappedStmt.add quote do: + `res`.add(`mapped`) + + let scalarStmt = newStmtList() + let mappedScalar = if singleRow: res else: genSym(nskVar, "mapped") + let sourceType = q.retType[0][1] + if not singleRow: + scalarStmt.add quote do: + var `mappedScalar`: `retType` + scalarStmt.add buildHookedResultAssign(prepStmt, mappedScalar, retType, sourceType, 0, q.retNames[0]) + if not singleRow: + scalarStmt.add quote do: + `res`.add(`mappedScalar`) + + result = quote do: + block: + when `retType` is object or `retType` is ref object: + `mappedStmt` + else: + when `selectedCount` != 1: + {.error: "query(T): scalar mapping expects exactly one selected column".} + else: + `scalarStmt` + proc queryImpl(q: QueryBuilder; body: NimNode; attempt, produceJson: bool): NimNode = expectKind body, nnkStmtList expectMinLen body, 1 @@ -1537,8 +1597,7 @@ proc queryImpl(q: QueryBuilder; body: NimNode; attempt, produceJson: bool): NimN if q.params.len > 0: blk.add newCall(bindSym"startBindings", prepStmt, newLit(q.params.len)) for p in q.params: - let fn = if p.isJson: bindSym"bindParamJson" else: bindSym"bindParam" - blk.add newCall(fn, ident"db", prepStmt, newLit(i), p.ex, p.typ) + blk.add buildHookedParamBinding(prepStmt, i, p.ex, p.typ, p.isJson) inc i blk.add newCall(bindSym"startQuery", ident"db", prepStmt) var body = newStmtList() @@ -1631,17 +1690,113 @@ proc queryImpl(q: QueryBuilder; body: NimNode; attempt, produceJson: bool): NimN if q.retType.len > 0: result.add res -macro query*(body: untyped): untyped = +proc queryHookImpl(q: QueryBuilder; body: NimNode; attempt: bool; retType: NimNode): NimNode = + expectKind body, nnkStmtList + expectMinLen body, 1 + + q.retTypeIsJson = false + applyQueryNode(body, q) + if q.kind notin {qkSelect, qkJoin}: + macros.error "query(T) currently supports select/join queries only", body + if q.retType.len == 0: + macros.error "query(T) requires a query that returns data", body + if q.retTypeIsJson: + macros.error "query(T) does not support 'produce json'", body + + let sql = queryAsString(q, body) + let prepStmt = genSym(nskLet) + let res = genSym(nskVar) + result = newTree( + nnkStmtListExpr, + newLetStmt(prepStmt, newCall(bindSym"prepareStmt", ident"db", newLit sql)) + ) + if q.singleRow: + result.add newTree(nnkVarSection, newIdentDefs(res, copyNimTree(retType), newEmptyNode())) + else: + result.add newTree(nnkVarSection, newIdentDefs(res, + newTree(nnkBracketExpr, bindSym"seq", copyNimTree(retType)), + newTree(nnkPrefix, bindSym"@", newTree(nnkBracket)))) + + let blk = newStmtList() + var i = 1 + if q.params.len > 0: + blk.add newCall(bindSym"startBindings", prepStmt, newLit(q.params.len)) + for p in q.params: + blk.add buildHookedParamBinding(prepStmt, i, p.ex, p.typ, p.isJson) + inc i + blk.add newCall(bindSym"startQuery", ident"db", prepStmt) + + let action = buildQueryHookAction(q, prepStmt, res, retType, q.singleRow) + + if q.singleRow: + if attempt: + blk.add newTree(nnkIfStmt, + newTree(nnkElifBranch, + newCall(bindSym"stepQuery", ident"db", prepStmt, newLit true), + action + ) + ) + blk.add newCall(bindSym"stopQuery", ident"db", prepStmt) + else: + blk.add newTree(nnkIfStmt, + newTree(nnkElifBranch, + newCall(bindSym"stepQuery", ident"db", prepStmt, newLit true), + newStmtList(action, newCall(bindSym"stopQuery", ident"db", prepStmt)) + ), + newTree(nnkElse, + newStmtList( + newCall(bindSym"stopQuery", ident"db", prepStmt), + newCall(bindSym"dbError", ident"db") + ) + ) + ) + else: + blk.add newTree(nnkWhileStmt, + newCall(bindSym"stepQuery", ident"db", prepStmt, newLit true), + action + ) + blk.add newCall(bindSym"stopQuery", ident"db", prepStmt) + + result.add newTree(nnkBlockStmt, newEmptyNode(), blk) + result.add res + +macro query*(args: varargs[untyped]): untyped = + if args.len == 1: + let body = args[0] + var q = newQueryBuilder() + result = queryImpl(q, body, false, false) + when defined(debugOrminDsl): + macros.hint("Ormin Query: " & repr(result), body) + return + + if args.len != 2: + macros.error("query expects either `query: ...` or `query(T): ...`", args) + + let retType = args[0] + let body = args[1] var q = newQueryBuilder() - result = queryImpl(q, body, false, false) + result = queryHookImpl(q, body, false, retType) when defined(debugOrminDsl): - macros.hint("Ormin Query: " & repr(result), body) + macros.hint("Ormin Query(T): " & repr(result), body) + +macro tryQuery*(args: varargs[untyped]): untyped = + if args.len == 1: + let body = args[0] + var q = newQueryBuilder() + result = queryImpl(q, body, true, false) + when defined(debugOrminDsl): + macros.hint("Ormin TryQuery: " & repr(result), body) + return + + if args.len != 2: + macros.error("tryQuery expects either `tryQuery: ...` or `tryQuery(T): ...`", args) -macro tryQuery*(body: untyped): untyped = + let retType = args[0] + let body = args[1] var q = newQueryBuilder() - result = queryImpl(q, body, true, false) + result = queryHookImpl(q, body, true, retType) when defined(debugOrminDsl): - macros.hint("Ormin Query: " & repr(result), body) + macros.hint("Ormin TryQuery(T): " & repr(result), body) # ------------------------- # Transactions DSL diff --git a/ormin/query_hooks.nim b/ormin/query_hooks.nim new file mode 100644 index 0000000..6933b26 --- /dev/null +++ b/ormin/query_hooks.nim @@ -0,0 +1,54 @@ +import options, json + +type + DbValue*[T] = object + isNull*: bool + value*: T + +template fromQueryHook*[T, S](val: var T, x: S) = + ## Default conversion hook used by `query(T): ...`. + ## Users can overload this proc to customize field/type conversions. + val = x + +template toQueryHook*[T, S](val: var T, x: S) = + ## Default conversion hook used for query parameters. + ## Users can overload this proc to customize parameter conversions. + val = x + +proc nullQueryValueError() {.noreturn.} = + raise newException(ValueError, "cannot map NULL query result") + +proc fromQueryHook*[T, S](val: var Option[T], x: var DbValue[S]) = + if x.isNull: + val = none(T) + else: + var converted: T + fromQueryHook(converted, move x.value) + val = some(converted) + +proc fromQueryHook*[T, S](val: var T, x: var DbValue[S]) = + if x.isNull: + when T is string: + val = "" + elif T is JsonNode: + val = newJNull() + else: + nullQueryValueError() + else: + fromQueryHook(val, move x.value) + +proc bindFromQueryHook*[T, S](dest: var T, x: var DbValue[S]) = + fromQueryHook(dest, x) + +proc toQueryHook*[S, T](val: var DbValue[S], x: Option[T]) = + if x.isSome: + val.isNull = false + toQueryHook(val.value, x.get) + else: + val.isNull = true + when compiles(val.value = default(S)): + val.value = default(S) + +proc toQueryHook*[S, T](val: var DbValue[S], x: T) = + val.isNull = false + toQueryHook(val.value, x) diff --git a/tests/tquery_types.nim b/tests/tquery_types.nim new file mode 100644 index 0000000..e782c39 --- /dev/null +++ b/tests/tquery_types.nim @@ -0,0 +1,216 @@ +import unittest, strformat, os, times, std/monotimes +import std/options +import ormin +import ormin/db_utils +import ormin/query_hooks + +when defined(postgre): + when defined(macosx): + {.passL: "-Wl,-rpath,/opt/homebrew/lib/postgresql@14".} + const backend = DbBackend.postgre + importModel(backend, "model_postgre") + const sqlFileName = "model_postgre.sql" + let db {.global.} = open("localhost", "test", "test", "test_ormin") +else: + const backend = DbBackend.sqlite + importModel(backend, "model_sqlite") + const sqlFileName = "model_sqlite.sql" + let db {.global.} = open("test.db", "", "", "") + +let + testDir = currentSourcePath.parentDir() + sqlFile = Path(testDir / sqlFileName) + +type + CompositeRow = object + id: int + message: string + + RefCompositeRow = ref object + id: int + message: string + + BenchmarkCompositeRow = object + pk1: int + message: string + + NullableNoteOptionRow = object + id: int + note: Option[string] + + MessageSize = distinct int + + HookedMessageRow = object + id: int + message: MessageSize + + NullFallbackNote = distinct string + + HookedNullableNoteRow = object + id: int + note: NullFallbackNote + +const + benchmarkRowCount = 256 + benchmarkWarmupIterations = 75 + benchmarkIterations = 250 + benchmarkRounds = 5 + maxTypedQuerySlowdown = 1.20 + +proc fromQueryHook*(val: var MessageSize, value: string) = + val = MessageSize(value.len) + +proc fromQueryHook*(val: var NullFallbackNote, value: DbValue[string]) = + if value.isNull: + val = NullFallbackNote("") + else: + val = NullFallbackNote("note:" & value.value) + +proc loadBenchmarkRows() = + db.dropTable(sqlFile, "tb_composite_pk") + db.createTable(sqlFile, "tb_composite_pk") + for i in 1 .. benchmarkRowCount: + let message = &"message-{i}" + query: + insert tb_composite_pk(pk1 = ?i, pk2 = ?i, message = ?message) + +proc benchmarkCurrentQuery(iterations: int): float = + var checksum = 0 + let started = getMonoTime() + for _ in 0 ..< iterations: + let rows = query: + select tb_composite_pk(pk1, message) + orderby pk1 + checksum += rows.len + rows[^1][0] + rows[^1][1].len + doAssert checksum > 0 + result = (getMonoTime() - started).inNanoseconds.float / 1_000_000_000.0 + +proc benchmarkTypedQuery(iterations: int): float = + var checksum = 0 + let started = getMonoTime() + for _ in 0 ..< iterations: + let rows = query(BenchmarkCompositeRow): + select tb_composite_pk(pk1, message) + orderby pk1 + checksum += rows.len + rows[^1].pk1 + rows[^1].message.len + doAssert checksum > 0 + result = (getMonoTime() - started).inNanoseconds.float / 1_000_000_000.0 + +suite &"query(T) mapping on {backend}": + setup: + db.dropTable(sqlFile, "tb_composite_pk") + db.createTable(sqlFile, "tb_composite_pk") + db.dropTable(sqlFile, "tb_nullable") + db.createTable(sqlFile, "tb_nullable") + + query: + insert tb_composite_pk(pk1 = 1, pk2 = 1, message = "hello") + query: + insert tb_composite_pk(pk1 = 2, pk2 = 2, message = "world") + + query: + insert tb_nullable(id = 1, note = nil) + query: + insert tb_nullable(id = 2, note = "hello") + + test "maps selected rows to objects": + let rows = query(CompositeRow): + select tb_composite_pk(pk1 as id, message) + orderby pk1 + + check rows == @[ + CompositeRow(id: 1, message: "hello"), + CompositeRow(id: 2, message: "world") + ] + + test "maps selected rows to ref objects": + let rows = query(RefCompositeRow): + select tb_composite_pk(pk1 as id, message) + orderby pk1 + + check rows.len == 2 + check rows[0] != nil + check rows[0].id == 1 + check rows[0].message == "hello" + check rows[1] != nil + check rows[1].id == 2 + check rows[1].message == "world" + + test "single-row query(T) returns a single object": + let row = query(CompositeRow): + select tb_composite_pk(pk1 as id, message) + where pk1 == 1 and pk2 == 1 + limit 1 + + check row == CompositeRow(id: 1, message: "hello") + + test "maps nullable column to Option": + let rows = query(NullableNoteOptionRow): + select tb_nullable(id, note) + orderby id + + check rows[0].id == 1 + check rows[0].note.isNone + check rows[1].id == 2 + check rows[1].note.isSome + check rows[1].note.get == "hello" + + test "maps object fields through user fromQueryHook overloads": + let rows = query(HookedMessageRow): + select tb_composite_pk(pk1 as id, message) + orderby pk1 + + check rows.len == 2 + check rows[0].id == 1 + check int(rows[0].message) == "hello".len + check rows[1].id == 2 + check int(rows[1].message) == "world".len + + test "maps scalar query(T) through user fromQueryHook overloads": + let values = query(MessageSize): + select tb_composite_pk(message) + orderby pk1 + + check values.len == 2 + check int(values[0]) == "hello".len + check int(values[1]) == "world".len + + test "allows user fromQueryHook overloads to handle NULL values": + let rows = query(HookedNullableNoteRow): + select tb_nullable(id, note) + orderby id + + check rows.len == 2 + check rows[0].id == 1 + check string(rows[0].note) == "" + check rows[1].id == 2 + check string(rows[1].note) == "note:hello" + + test "sqlite benchmark for query and query(T)": + loadBenchmarkRows() + + let untypedRows = query: + select tb_composite_pk(pk1, message) + orderby pk1 + let typedRows = query(BenchmarkCompositeRow): + select tb_composite_pk(pk1, message) + orderby pk1 + check untypedRows.len == typedRows.len + check typedRows[0] == BenchmarkCompositeRow(pk1: untypedRows[0][0], message: untypedRows[0][1]) + check typedRows[^1] == BenchmarkCompositeRow(pk1: untypedRows[^1][0], message: untypedRows[^1][1]) + + discard benchmarkCurrentQuery(benchmarkWarmupIterations) + discard benchmarkTypedQuery(benchmarkWarmupIterations) + + var currentBest = high(float) + var typedBest = high(float) + for _ in 0 ..< benchmarkRounds: + currentBest = min(currentBest, benchmarkCurrentQuery(benchmarkIterations)) + typedBest = min(typedBest, benchmarkTypedQuery(benchmarkIterations)) + + let ratio = typedBest / currentBest + echo &"sqlite benchmark query={currentBest:.6f}s query(T)={typedBest:.6f}s ratio={ratio:.3f}x; 20% budget={(ratio <= maxTypedQuerySlowdown)}" + check currentBest > 0.0 + check typedBest > 0.0 + when defined(release): + check ratio <= maxTypedQuerySlowdown