From 7b1660d8d79580a50b33f829402c480d71b8a997 Mon Sep 17 00:00:00 2001 From: Sean Marshallsay Date: Sun, 7 Dec 2014 20:06:44 +0000 Subject: [PATCH 01/15] Initial implementation of aggregate functions. --- src/UDF.jl | 159 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 159 insertions(+) diff --git a/src/UDF.jl b/src/UDF.jl index 95ac2eb..9c2e8c5 100644 --- a/src/UDF.jl +++ b/src/UDF.jl @@ -35,6 +35,141 @@ sqlreturn(context, val::Vector{UInt8}) = sqlite3_result_blob(context, val) sqlreturn(context, val::Bool) = sqlreturn(context, int(val)) sqlreturn(context, val) = sqlreturn(context, sqlserialize(val)) +# a fixed size type to store the aggregate context +#type AggCont +# nbytes::Int +# optr::Ptr{UInt8} +# +# function AggCont(o) +# # TODO: is serialization necessary? maybe just store an array instead +# oarr = sqlserialize(o) +# osize = sizeof(oarr) +# # TODO: can we stop julia from garbage collecting without c_malloc? +# optr = convert(Ptr{UInt8}, c_malloc(osize)) +# unsafe_copy!(optr, pointer(oarr), osize) +# +# return new(osize, optr) +# end +#end +# convert a bytearray to an int arr[1] is 256^0, arr[2] is 256^1... +# TODO: would making this a method of convert needlessly pollute the Base namespace? +function bytestoint(arr::Vector{UInt8}) + l = length(arr) + s = 0 + for (i, v) in enumerate(arr) + s += v * 256^(i - 1) + end + s +end +# TODO: remove this +inttobytes(i::Int) = reinterpret(UInt8, [i]) + +function stepfunc(init, func, fsym=symbol(string(func)*"_step")) + nm = isdefined(Base,fsym) ? :(Base.$fsym) : fsym + return quote + function $(nm)(context::Ptr{Void}, nargs::Cint, values::Ptr{Ptr{Void}}) + args = [sqlvalue(values, i) for i in 1:nargs] + intsize = sizeof(Int) + ptrsize = sizeof(Ptr) + acsize = intsize + ptrsize + acarr = pointer_to_array( + convert(Ptr{UInt8}, sqlite3_aggregate_context(context, acsize)), + acsize, + false, + ) + # acarr will be zeroed-out if this is the first iteration + ret = ccall( + :memcmp, Cint, (Ptr{UInt8}, Ptr{UInt8}, Cuint), + zeros(UInt8, acsize), acarr, acsize, + ) + try + if ret == 0 + acval = $(init) + # TODO: i'm sure there's a better way + valsize = sizeof(sqlserialize(acval)) + valptr = convert(Ptr{UInt8}, c_malloc(valsize)) + else + # retrieve the size of the serialized value (first sizeof(Int) bytes) + sizebuf = zeros(UInt8, intsize) + unsafe_copy!(sizebuf, 1, acarr, 1, intsize) + valsize = bytestoint(sizebuf) + # retrieve the ptr to the serialized value (last sizeof(Ptr) bytes) + ptrbuf = zeros(UInt8, ptrsize) + unsafe_copy!(ptrbuf, 1, acarr, intsize+1, ptrsize) + valptr = reinterpret(Ptr{UInt8}, bytestoint(ptrbuf)) + # deserialize the value pointed to by valptr + acvalbuf = zeros(UInt8, valsize) + unsafe_copy!(pointer(acvalbuf), valptr, valsize) + acval = sqldeserialize(acvalbuf) + end + funcret = sqlserialize($(func)(acval, args...)) + newsize = length(funcret) + # TODO: increase this in a cleverer way? + newsize > valsize && (valptr = convert(Ptr{UInt8}, c_realloc(valptr, newsize))) + # copy serialized return value + unsafe_copy!(valptr, pointer(funcret), newsize) + # copy the size of the serialized value + unsafe_copy!( + acarr, 1, + reinterpret(UInt8, [newsize]), 1, + intsize, + ) + # copy the value of the pointer to the serialized value + # TODO: can we just use ptrbuf here? + unsafe_copy!( + acarr, intsize+1, + reinterpret(UInt8, [valptr]), 1, + ptrsize, + ) + catch + # TODO: this won't catch all memory leaks so add an else clause + if isdefined(:valptr) + c_free(valptr) + end + rethrow() + end + nothing + end + end +end + +# TODO: free valptr on error +function finalfunc(init, func, fsym=symbol(string(func)*"_final")) + nm = isdefined(Base,fsym) ? :(Base.$fsym) : fsym + return quote + function $(nm)(context::Ptr{Void}, nargs::Cint, values::Ptr{Ptr{Void}}) + args = [sqlvalue(context, i) for i in 1:nargs] + acptr = sqlite3_aggregate_context(context, 0) + # step function wasn't run + if acptr === C_NULL + sqlreturn(context, $(init)) + else + intsize = sizeof(Int) + ptrsize = sizeof(Ptr) + acsize = intsize + ptrsize + acarr = pointer_to_array(convert(Ptr{UInt8}, acptr), acsize, false) + # load size + sizebuf = zeros(UInt8, intsize) + unsafe_copy!(sizebuf, 1, acarr, 1, intsize) + valsize = bytestoint(sizebuf) + # load ptr + ptrbuf = zeros(UInt8, ptrsize) + unsafe_copy!(ptrbuf, 1, acarr, intsize+1, ptrsize) + valptr = reinterpret(Ptr{UInt8}, bytestoint(ptrbuf)) + # load value + acvalbuf = zeros(UInt8, valsize) + unsafe_copy!(pointer(acvalbuf), valptr, valsize) + + acval = sqldeserialize(acvalbuf) + ret = $(func)(acval, args...) + c_free(valptr) + sqlreturn(context, ret) + end + nothing + end + end +end + # Internal method for generating an SQLite scalar function from # a Julia function name function scalarfunc(func,fsym=symbol(string(func))) @@ -54,11 +189,13 @@ function scalarfunc(expr::Expr) f = eval(expr) return scalarfunc(f) end + # User-facing macro for convenience in registering a simple function # with no configurations needed macro register(db, func) :(register($(esc(db)), $(esc(func)))) end + # User-facing method for registering a Julia function to be used within SQLite function register(db::SQLiteDB, func::Function; nargs::Int=-1, name::AbstractString=string(func), isdeterm::Bool=true) @assert nargs <= 127 "use -1 if > 127 arguments are needed" @@ -78,6 +215,28 @@ function register(db::SQLiteDB, func::Function; nargs::Int=-1, name::AbstractStr ) end +# as above but for aggregate functions +function register( + db::SQLiteDB, init, step::Function, final::Function; + nargs::Int=-1, name::AbstractString=string(final), isdeterm::Bool=true +) + @assert nargs <= 127 "use -1 if > 127 arguments are needed" + nargs < -1 && (nargs = -1) + @assert sizeof(name) <= 255 "size of function name must be <= 255 chars" + + s = eval(stepfunc(init, step, Base.function_name(step))) + cs = cfunction(s, Nothing, (Ptr{Void}, Cint, Ptr{Ptr{Void}})) + f = eval(finalfunc(init, final, Base.function_name(final))) + cf = cfunction(f, Nothing, (Ptr{Void}, Cint, Ptr{Ptr{Void}})) + + enc = SQLITE_UTF8 + enc = isdeterm ? enc | SQLITE_DETERMINISTIC : enc + + @CHECK db sqlite3_create_function_v2( + db.handle, name, nargs, enc, C_NULL, C_NULL, cs, cf, C_NULL + ) +end + # annotate types because the MethodError makes more sense that way regexp(r::AbstractString, s::AbstractString) = ismatch(Regex(r), s) # macro for preserving the special characters in a string From 868ab16f1ab187c0dfc6461b1d865b6082018955 Mon Sep 17 00:00:00 2001 From: Sean Marshallsay Date: Sun, 7 Dec 2014 20:44:29 +0000 Subject: [PATCH 02/15] Remove unneeded type. --- src/UDF.jl | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/src/UDF.jl b/src/UDF.jl index 9c2e8c5..06f1a83 100644 --- a/src/UDF.jl +++ b/src/UDF.jl @@ -35,22 +35,6 @@ sqlreturn(context, val::Vector{UInt8}) = sqlite3_result_blob(context, val) sqlreturn(context, val::Bool) = sqlreturn(context, int(val)) sqlreturn(context, val) = sqlreturn(context, sqlserialize(val)) -# a fixed size type to store the aggregate context -#type AggCont -# nbytes::Int -# optr::Ptr{UInt8} -# -# function AggCont(o) -# # TODO: is serialization necessary? maybe just store an array instead -# oarr = sqlserialize(o) -# osize = sizeof(oarr) -# # TODO: can we stop julia from garbage collecting without c_malloc? -# optr = convert(Ptr{UInt8}, c_malloc(osize)) -# unsafe_copy!(optr, pointer(oarr), osize) -# -# return new(osize, optr) -# end -#end # convert a bytearray to an int arr[1] is 256^0, arr[2] is 256^1... # TODO: would making this a method of convert needlessly pollute the Base namespace? function bytestoint(arr::Vector{UInt8}) @@ -61,8 +45,6 @@ function bytestoint(arr::Vector{UInt8}) end s end -# TODO: remove this -inttobytes(i::Int) = reinterpret(UInt8, [i]) function stepfunc(init, func, fsym=symbol(string(func)*"_step")) nm = isdefined(Base,fsym) ? :(Base.$fsym) : fsym From c7d205aab140dce1d81b623d454caaff48f47d75 Mon Sep 17 00:00:00 2001 From: Sean Marshallsay Date: Sun, 7 Dec 2014 21:10:06 +0000 Subject: [PATCH 03/15] Add simple test for aggregates. Also put a little TODO reminder. --- src/UDF.jl | 2 ++ test/runtests.jl | 6 ++++++ 2 files changed, 8 insertions(+) diff --git a/src/UDF.jl b/src/UDF.jl index 06f1a83..5a3c7e3 100644 --- a/src/UDF.jl +++ b/src/UDF.jl @@ -120,6 +120,8 @@ function finalfunc(init, func, fsym=symbol(string(func)*"_final")) nm = isdefined(Base,fsym) ? :(Base.$fsym) : fsym return quote function $(nm)(context::Ptr{Void}, nargs::Cint, values::Ptr{Ptr{Void}}) + # TODO: I don't think arguments are ever passed to this function, + # should we leave them in anyway? args = [sqlvalue(context, i) for i in 1:nargs] acptr = sqlite3_aggregate_context(context, 0) # step function wasn't run diff --git a/test/runtests.jl b/test/runtests.jl index 5e8cec6..600dc75 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -207,6 +207,12 @@ SQLite.@register db big r = query(db, "SELECT big(5)") @test r[1][1] == big(5) +doublesum_step(persist, current) = persist + current +doublesum_final(persist) = 2 * persist +register(db, 0, doublesum_step, doublesum_final, name="doublesum") +r = query(db, "SELECT doublesum(UnitPrice) FROM Track") +s = query(db, "SELECT UnitPrice FROM Track") +@test_approx_eq r[1][1] 2*sum(s[1]) db2 = SQLiteDB() query(db2, "CREATE TABLE tab1 (r REAL, s INT)") From bd4c132a79997bb20c52ec5d89ccca2f95f6c4b8 Mon Sep 17 00:00:00 2001 From: Sean Marshallsay Date: Sun, 7 Dec 2014 21:27:22 +0000 Subject: [PATCH 04/15] Fix aggregates on Linux. No arguments are ever passed to the final function so values was a void pointer causing sqlite3_value_type to fail. This simply stopped trying to collect args. Also changed the order of the function definitions to something that made more sense. --- src/UDF.jl | 45 +++++++++++++++++++++------------------------ 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/src/UDF.jl b/src/UDF.jl index 5a3c7e3..c2292a7 100644 --- a/src/UDF.jl +++ b/src/UDF.jl @@ -35,6 +35,26 @@ sqlreturn(context, val::Vector{UInt8}) = sqlite3_result_blob(context, val) sqlreturn(context, val::Bool) = sqlreturn(context, int(val)) sqlreturn(context, val) = sqlreturn(context, sqlserialize(val)) +# Internal method for generating an SQLite scalar function from +# a Julia function name +function scalarfunc(func,fsym=symbol(string(func))) + # check if name defined in Base so we don't clobber Base methods + nm = isdefined(Base,fsym) ? :(Base.$fsym) : fsym + return quote + #nm needs to be a symbol or expr, i.e. :sin or :(Base.sin) + function $(nm)(context::Ptr{Void}, nargs::Cint, values::Ptr{Ptr{Void}}) + args = [SQLite.sqlvalue(values, i) for i in 1:nargs] + ret = $(func)(args...) + SQLite.sqlreturn(context, ret) + nothing + end + end +end +function scalarfunc(expr::Expr) + f = eval(expr) + return scalarfunc(f) +end + # convert a bytearray to an int arr[1] is 256^0, arr[2] is 256^1... # TODO: would making this a method of convert needlessly pollute the Base namespace? function bytestoint(arr::Vector{UInt8}) @@ -120,9 +140,6 @@ function finalfunc(init, func, fsym=symbol(string(func)*"_final")) nm = isdefined(Base,fsym) ? :(Base.$fsym) : fsym return quote function $(nm)(context::Ptr{Void}, nargs::Cint, values::Ptr{Ptr{Void}}) - # TODO: I don't think arguments are ever passed to this function, - # should we leave them in anyway? - args = [sqlvalue(context, i) for i in 1:nargs] acptr = sqlite3_aggregate_context(context, 0) # step function wasn't run if acptr === C_NULL @@ -145,7 +162,7 @@ function finalfunc(init, func, fsym=symbol(string(func)*"_final")) unsafe_copy!(pointer(acvalbuf), valptr, valsize) acval = sqldeserialize(acvalbuf) - ret = $(func)(acval, args...) + ret = $(func)(acval) c_free(valptr) sqlreturn(context, ret) end @@ -154,26 +171,6 @@ function finalfunc(init, func, fsym=symbol(string(func)*"_final")) end end -# Internal method for generating an SQLite scalar function from -# a Julia function name -function scalarfunc(func,fsym=symbol(string(func))) - # check if name defined in Base so we don't clobber Base methods - nm = isdefined(Base,fsym) ? :(Base.$fsym) : fsym - return quote - #nm needs to be a symbol or expr, i.e. :sin or :(Base.sin) - function $(nm)(context::Ptr{Void}, nargs::Cint, values::Ptr{Ptr{Void}}) - args = [SQLite.sqlvalue(values, i) for i in 1:nargs] - ret = $(func)(args...) - SQLite.sqlreturn(context, ret) - nothing - end - end -end -function scalarfunc(expr::Expr) - f = eval(expr) - return scalarfunc(f) -end - # User-facing macro for convenience in registering a simple function # with no configurations needed macro register(db, func) From 4a0bcff839348015d7545f5bf2cb2581f8741022 Mon Sep 17 00:00:00 2001 From: Sean Marshallsay Date: Mon, 8 Dec 2014 10:08:02 +0000 Subject: [PATCH 05/15] Replace copies with loads. No measurable speed-up but slightly simplifies the implementation. --- src/UDF.jl | 71 +++++++++++++++++++++++++----------------------------- 1 file changed, 33 insertions(+), 38 deletions(-) diff --git a/src/UDF.jl b/src/UDF.jl index c2292a7..e956fc4 100644 --- a/src/UDF.jl +++ b/src/UDF.jl @@ -55,15 +55,16 @@ function scalarfunc(expr::Expr) return scalarfunc(f) end -# convert a bytearray to an int arr[1] is 256^0, arr[2] is 256^1... -# TODO: would making this a method of convert needlessly pollute the Base namespace? -function bytestoint(arr::Vector{UInt8}) - l = length(arr) +# convert a byteptr to an int, ptr[start] -> 256^0, ptr[start+1] -> 256^1... +# TODO: this assumes little-endian +function bytestoint(ptr::Ptr{UInt8}, start::Int, len::Int) s = 0 - for (i, v) in enumerate(arr) - s += v * 256^(i - 1) + for i in start:start+len-1 + v = unsafe_load(ptr, i) + s += v * 256^(i - start) end - s + + return s end function stepfunc(init, func, fsym=symbol(string(func)*"_step")) @@ -74,31 +75,25 @@ function stepfunc(init, func, fsym=symbol(string(func)*"_step")) intsize = sizeof(Int) ptrsize = sizeof(Ptr) acsize = intsize + ptrsize - acarr = pointer_to_array( - convert(Ptr{UInt8}, sqlite3_aggregate_context(context, acsize)), - acsize, - false, - ) - # acarr will be zeroed-out if this is the first iteration + acptr = convert(Ptr{UInt8}, sqlite3_aggregate_context(context, acsize)) + # acptr will be zeroed-out if this is the first iteration ret = ccall( :memcmp, Cint, (Ptr{UInt8}, Ptr{UInt8}, Cuint), - zeros(UInt8, acsize), acarr, acsize, + zeros(UInt8, acsize), acptr, acsize, ) try if ret == 0 acval = $(init) - # TODO: i'm sure there's a better way + # TODO: allocate 256 byte valsize = sizeof(sqlserialize(acval)) valptr = convert(Ptr{UInt8}, c_malloc(valsize)) else - # retrieve the size of the serialized value (first sizeof(Int) bytes) - sizebuf = zeros(UInt8, intsize) - unsafe_copy!(sizebuf, 1, acarr, 1, intsize) - valsize = bytestoint(sizebuf) - # retrieve the ptr to the serialized value (last sizeof(Ptr) bytes) - ptrbuf = zeros(UInt8, ptrsize) - unsafe_copy!(ptrbuf, 1, acarr, intsize+1, ptrsize) - valptr = reinterpret(Ptr{UInt8}, bytestoint(ptrbuf)) + # size of serialized value is first sizeof(Int) bytes + valsize = bytestoint(acptr, 1, intsize) + # ptr to serialized value is last sizeof(Ptr) bytes + valptr = reinterpret( + Ptr{UInt8}, bytestoint(acptr, intsize+1, ptrsize) + ) # deserialize the value pointed to by valptr acvalbuf = zeros(UInt8, valsize) unsafe_copy!(pointer(acvalbuf), valptr, valsize) @@ -110,21 +105,24 @@ function stepfunc(init, func, fsym=symbol(string(func)*"_step")) newsize > valsize && (valptr = convert(Ptr{UInt8}, c_realloc(valptr, newsize))) # copy serialized return value unsafe_copy!(valptr, pointer(funcret), newsize) + # following copies are easier with arrays + acptr = pointer_to_array(acptr, acsize, false) # copy the size of the serialized value unsafe_copy!( - acarr, 1, + acptr, 1, reinterpret(UInt8, [newsize]), 1, intsize, ) - # copy the value of the pointer to the serialized value - # TODO: can we just use ptrbuf here? + # copy the address of the pointer to the serialized value unsafe_copy!( - acarr, intsize+1, + acptr, intsize+1, reinterpret(UInt8, [valptr]), 1, ptrsize, ) catch - # TODO: this won't catch all memory leaks so add an else clause + # TODO: + # this won't catch all memory leaks so add an else clause + # alternatively use c-style checking in this function if isdefined(:valptr) c_free(valptr) end @@ -140,28 +138,25 @@ function finalfunc(init, func, fsym=symbol(string(func)*"_final")) nm = isdefined(Base,fsym) ? :(Base.$fsym) : fsym return quote function $(nm)(context::Ptr{Void}, nargs::Cint, values::Ptr{Ptr{Void}}) - acptr = sqlite3_aggregate_context(context, 0) + acptr = convert(Ptr{UInt8}, sqlite3_aggregate_context(context, 0)) # step function wasn't run - if acptr === C_NULL + if acptr == C_NULL sqlreturn(context, $(init)) else intsize = sizeof(Int) ptrsize = sizeof(Ptr) acsize = intsize + ptrsize - acarr = pointer_to_array(convert(Ptr{UInt8}, acptr), acsize, false) # load size - sizebuf = zeros(UInt8, intsize) - unsafe_copy!(sizebuf, 1, acarr, 1, intsize) - valsize = bytestoint(sizebuf) + valsize = bytestoint(acptr, 1, intsize) # load ptr - ptrbuf = zeros(UInt8, ptrsize) - unsafe_copy!(ptrbuf, 1, acarr, intsize+1, ptrsize) - valptr = reinterpret(Ptr{UInt8}, bytestoint(ptrbuf)) + valptr = reinterpret( + Ptr{UInt8}, bytestoint(acptr, intsize+1, ptrsize) + ) # load value acvalbuf = zeros(UInt8, valsize) unsafe_copy!(pointer(acvalbuf), valptr, valsize) - acval = sqldeserialize(acvalbuf) + ret = $(func)(acval) c_free(valptr) sqlreturn(context, ret) From 99ceeb699003607c3af50bcdf72e2499a411c937 Mon Sep 17 00:00:00 2001 From: Sean Marshallsay Date: Mon, 8 Dec 2014 11:30:45 +0000 Subject: [PATCH 06/15] Fix TODOs. Attempt to account for big-endianness. Don't call sqlserialize unnecessarily. Try to avoid memory-leaks. --- src/UDF.jl | 75 ++++++++++++++++++++++++++++++++---------------------- 1 file changed, 44 insertions(+), 31 deletions(-) diff --git a/src/UDF.jl b/src/UDF.jl index e956fc4..3ab0300 100644 --- a/src/UDF.jl +++ b/src/UDF.jl @@ -55,8 +55,7 @@ function scalarfunc(expr::Expr) return scalarfunc(f) end -# convert a byteptr to an int, ptr[start] -> 256^0, ptr[start+1] -> 256^1... -# TODO: this assumes little-endian +# convert a byteptr to an int, assumes little-endian function bytestoint(ptr::Ptr{UInt8}, start::Int, len::Int) s = 0 for i in start:start+len-1 @@ -64,7 +63,9 @@ function bytestoint(ptr::Ptr{UInt8}, start::Int, len::Int) s += v * 256^(i - start) end - return s + # swap byte-order on big-endian machines + # TODO: this desperately needs testing on a big-endian machine!!!!! + return htol(s) end function stepfunc(init, func, fsym=symbol(string(func)*"_step")) @@ -72,39 +73,46 @@ function stepfunc(init, func, fsym=symbol(string(func)*"_step")) return quote function $(nm)(context::Ptr{Void}, nargs::Cint, values::Ptr{Ptr{Void}}) args = [sqlvalue(values, i) for i in 1:nargs] + intsize = sizeof(Int) ptrsize = sizeof(Ptr) acsize = intsize + ptrsize acptr = convert(Ptr{UInt8}, sqlite3_aggregate_context(context, acsize)) + # acptr will be zeroed-out if this is the first iteration ret = ccall( :memcmp, Cint, (Ptr{UInt8}, Ptr{UInt8}, Cuint), zeros(UInt8, acsize), acptr, acsize, ) + if ret == 0 + acval = $(init) + valsize = 256 + # avoid the garbage collector using malloc + valptr = convert(Ptr{UInt8}, c_malloc(valsize)) + else + # size of serialized value is first sizeof(Int) bytes + valsize = bytestoint(acptr, 1, intsize) + # ptr to serialized value is last sizeof(Ptr) bytes + valptr = reinterpret( + Ptr{UInt8}, bytestoint(acptr, intsize+1, ptrsize) + ) + # deserialize the value pointed to by valptr + acvalbuf = zeros(UInt8, valsize) + unsafe_copy!(pointer(acvalbuf), valptr, valsize) + acval = sqldeserialize(acvalbuf) + end + try - if ret == 0 - acval = $(init) - # TODO: allocate 256 byte - valsize = sizeof(sqlserialize(acval)) - valptr = convert(Ptr{UInt8}, c_malloc(valsize)) - else - # size of serialized value is first sizeof(Int) bytes - valsize = bytestoint(acptr, 1, intsize) - # ptr to serialized value is last sizeof(Ptr) bytes - valptr = reinterpret( - Ptr{UInt8}, bytestoint(acptr, intsize+1, ptrsize) - ) - # deserialize the value pointed to by valptr - acvalbuf = zeros(UInt8, valsize) - unsafe_copy!(pointer(acvalbuf), valptr, valsize) - acval = sqldeserialize(acvalbuf) - end funcret = sqlserialize($(func)(acval, args...)) - newsize = length(funcret) - # TODO: increase this in a cleverer way? - newsize > valsize && (valptr = convert(Ptr{UInt8}, c_realloc(valptr, newsize))) + + newsize = sizeof(funcret) + if newsize > valsize + # TODO: increase this in a cleverer way? + valptr = convert(Ptr{UInt8}, c_realloc(valptr, newsize)) + end # copy serialized return value unsafe_copy!(valptr, pointer(funcret), newsize) + # following copies are easier with arrays acptr = pointer_to_array(acptr, acsize, false) # copy the size of the serialized value @@ -133,12 +141,12 @@ function stepfunc(init, func, fsym=symbol(string(func)*"_step")) end end -# TODO: free valptr on error function finalfunc(init, func, fsym=symbol(string(func)*"_final")) nm = isdefined(Base,fsym) ? :(Base.$fsym) : fsym return quote function $(nm)(context::Ptr{Void}, nargs::Cint, values::Ptr{Ptr{Void}}) acptr = convert(Ptr{UInt8}, sqlite3_aggregate_context(context, 0)) + # step function wasn't run if acptr == C_NULL sqlreturn(context, $(init)) @@ -146,20 +154,25 @@ function finalfunc(init, func, fsym=symbol(string(func)*"_final")) intsize = sizeof(Int) ptrsize = sizeof(Ptr) acsize = intsize + ptrsize + # load size valsize = bytestoint(acptr, 1, intsize) # load ptr valptr = reinterpret( Ptr{UInt8}, bytestoint(acptr, intsize+1, ptrsize) ) - # load value - acvalbuf = zeros(UInt8, valsize) - unsafe_copy!(pointer(acvalbuf), valptr, valsize) - acval = sqldeserialize(acvalbuf) - ret = $(func)(acval) - c_free(valptr) - sqlreturn(context, ret) + try + # load value + acvalbuf = zeros(UInt8, valsize) + unsafe_copy!(pointer(acvalbuf), valptr, valsize) + acval = sqldeserialize(acvalbuf) + + ret = $(func)(acval) + sqlreturn(context, ret) + finally + c_free(valptr) + end end nothing end From fb4926523e26f8c4100fd5d681c04b7e33168d48 Mon Sep 17 00:00:00 2001 From: Sean Marshallsay Date: Mon, 8 Dec 2014 12:46:51 +0000 Subject: [PATCH 07/15] Add test and update README. --- test/runtests.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index 600dc75..9200e81 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -214,6 +214,13 @@ r = query(db, "SELECT doublesum(UnitPrice) FROM Track") s = query(db, "SELECT UnitPrice FROM Track") @test_approx_eq r[1][1] 2*sum(s[1]) +mycount(p, c) = p + 1 +mycount(p) = p +register(db, 0, mycount, mycount) +r = query(db, "SELECT mycount(TrackId) FROM PlaylistTrack") +s = query(db, "SELECT count(TrackId) FROM PlaylistTrack") +@test r[1] == s[1] + db2 = SQLiteDB() query(db2, "CREATE TABLE tab1 (r REAL, s INT)") From 0ddb0e807dd7a7c40cb1f6f8bd751a1d16ba1cae Mon Sep 17 00:00:00 2001 From: Sean Marshallsay Date: Thu, 11 Dec 2014 11:52:27 +0000 Subject: [PATCH 08/15] Didn't actually add the README in the last commit. --- README.md | 60 ++++++++++++++++++++++++++++++------------------------- 1 file changed, 33 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index 0acec24..f8ec1c9 100644 --- a/README.md +++ b/README.md @@ -76,14 +76,14 @@ A Julia interface to the SQLite library and support for operations on DataFrames `drop` is pretty self-explanatory. It's really just a convenience wrapper around `query` to execute a DROP TABLE command, while also calling "VACUUM" to clean out freed memory from the database. -* `registerfunc(db::SQLiteDB, nargs::Int, func::Function, isdeterm::Bool=true; name="")` +* `register(db::SQLiteDB, func::Function; nargs::Int=-1, name::AbstractString=string(func), isdeterm::Bool=true)` +* `register(db::SQLiteDB, init, step::Function, final::Function; nargs::Int=-1, name::AbstractString=string(final), isdeterm::Bool=true)` - Register a function `func` (which takes `nargs` number of arguments) with the SQLite database connection `db`. If the keyword argument `name` is given the function is registered with that name, otherwise it is registered with the name of `func`. If the function is stochastic (e.g. uses a random number) `isdeterm` should be set to `false`, see SQLite's [function creation documentation](http://sqlite.org/c3ref/create_function.html) for more information. + Register a scalar (first method) or aggregate (second method) function with a `SQLiteDB`. -* `@scalarfunc function` - `@scalarfunc name function` +* `@register db function` - Define a function which can then be passed to `registerfunc`. In the first usage the function name is infered from the function definition, in the second it is explicitly given as the first parameter. The second form is only recommended when it's use is absolutely necessary, see below. + Automatically define then register `function` with a `SQLiteDB`. * `sr"..."` @@ -188,45 +188,31 @@ The sr"..." currently escapes all special characters in a string but it may be c ##### Custom Scalar Functions -SQLite.jl also provides a way that you can implement your own [Scalar Functions](https://www.sqlite.org/lang_corefunc.html) (though [Aggregate Functions](https://www.sqlite.org/lang_aggfunc.html) are not currently supported). This is done using the `registerfunc` function and `@scalarfunc` macro. +SQLite.jl also provides a way that you can implement your own [Scalar Functions](https://www.sqlite.org/lang_corefunc.html). This is done using the `register` function and macro. -`@scalarfunc` takes an optional function name and a function and defines a new function which can be passed to `registerfunc`. It can be used with block function syntax +`@register` takes a `SQLiteDB` and a function. The function can be in block syntax ```julia -julia> @scalarfunc function add3(x) +julia> @register db function add3(x) x + 3 end -add3 (generic function with 1 method) - -julia> @scalarfunc add5 function irrelevantfuncname(x) - x + 5 - end -add5 (generic function with 1 method) ``` inline function syntax ```julia -julia> @scalarfunc mult3(x) = 3 * x -mult3 (generic function with 1 method) - -julia> @scalarfunc mult5 anotherirrelevantname(x) = 5 * x -mult5 (generic function with 1 method) +julia> @register db mult3(x) = 3 * x ``` -and previously defined functions (note that name inference does not work with this method) +and previously defined functions ```julia -julia> @scalarfunc sin sin -sin (generic function with 1 method) - -julia> @scalarfunc subtract - -subtract (generic function with 1 method) +julia> @register db sin ``` -The function that is defined can then be passed to `registerfunc`. `registerfunc` takes three arguments; the database to which the function should be registered, the number of arguments that the function takes and the function itself. The function is registered to the database connection rather than the database itself so must be registered each time the database opens. Your function can not take more than 127 arguments unless it takes a variable number of arguments, if it does take a variable number of arguments then you must pass -1 as the second argument to `registerfunc`. +The `register` function takes optional arguments; `nargs` which defaults to `-1`, `name` which defaults to the name of the function, `isdeterm` which defaults to `true`. In practice these rarely need to be used. -The `@scalarfunc` macro uses the `sqlreturn` function to return your function's return value to SQLite. By default, `sqlreturn` maps the returned value to a [native SQLite type](http://sqlite.org/c3ref/result_blob.html) or, failing that, serializes the julia value and stores it as a `BLOB`. To change this behaviour simply define a new method for `sqlreturn` which then calls a previously defined method for `sqlreturn`. Methods which map to native SQLite types are +The `register` function uses the `sqlreturn` function to return your function's return value to SQLite. By default, `sqlreturn` maps the returned value to a [native SQLite type](http://sqlite.org/c3ref/result_blob.html) or, failing that, serializes the julia value and stores it as a `BLOB`. To change this behaviour simply define a new method for `sqlreturn` which then calls a previously defined method for `sqlreturn`. Methods which map to native SQLite types are ```julia sqlreturn(context, ::NullType) @@ -251,3 +237,23 @@ sqlreturn(context, val::Bool) = sqlreturn(context, int(val)) ``` Any new method defined for `sqlreturn` must take two arguments and must pass the first argument straight through as the first argument. + +#### Custom Aggregate Functions + +Using the `register` function, you can also define your own aggregate functions with largely the same semantics. + +The `register` function for aggregates must take a `SQLiteDB`, an initial value, a step function and a final function. The first argument to the step function will be the return value of the previous function (or the initial value if it is the first iteration). The final function must take a single argument which will be the return value of the last step function. + +```julia +julia> dsum(prev, cur) = prev + cur + +julia> dsum(prev) = 2 * prev + +julia> register(db, 0, dsum, dsum) +``` + +If no name is given the name of the second (final) function is used (in this case "dsum"). You can also use lambdas, the following does the same as the previous code snippet + +```julia +julia> register(db, 0, (p,c) -> p+c, p -> 2p, name="dsum") +``` From 91189c6d0458b90c64fcd7b5bb45e5e15d7fd22a Mon Sep 17 00:00:00 2001 From: Sean Marshallsay Date: Thu, 11 Dec 2014 12:00:39 +0000 Subject: [PATCH 09/15] Default the final function in register to identity. And therefore default the name to the step functions name. README and tests have also been updated. --- README.md | 4 ++-- src/UDF.jl | 4 ++-- test/runtests.jl | 3 +-- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index f8ec1c9..1a366b3 100644 --- a/README.md +++ b/README.md @@ -77,7 +77,7 @@ A Julia interface to the SQLite library and support for operations on DataFrames `drop` is pretty self-explanatory. It's really just a convenience wrapper around `query` to execute a DROP TABLE command, while also calling "VACUUM" to clean out freed memory from the database. * `register(db::SQLiteDB, func::Function; nargs::Int=-1, name::AbstractString=string(func), isdeterm::Bool=true)` -* `register(db::SQLiteDB, init, step::Function, final::Function; nargs::Int=-1, name::AbstractString=string(final), isdeterm::Bool=true)` +* `register(db::SQLiteDB, init, step::Function, final::Function=identity; nargs::Int=-1, name::AbstractString=string(final), isdeterm::Bool=true)` Register a scalar (first method) or aggregate (second method) function with a `SQLiteDB`. @@ -252,7 +252,7 @@ julia> dsum(prev) = 2 * prev julia> register(db, 0, dsum, dsum) ``` -If no name is given the name of the second (final) function is used (in this case "dsum"). You can also use lambdas, the following does the same as the previous code snippet +If no name is given the name of the first (step) function is used (in this case "dsum"). You can also use lambdas, the following does the same as the previous code snippet ```julia julia> register(db, 0, (p,c) -> p+c, p -> 2p, name="dsum") diff --git a/src/UDF.jl b/src/UDF.jl index 3ab0300..60da3e8 100644 --- a/src/UDF.jl +++ b/src/UDF.jl @@ -206,8 +206,8 @@ end # as above but for aggregate functions function register( - db::SQLiteDB, init, step::Function, final::Function; - nargs::Int=-1, name::AbstractString=string(final), isdeterm::Bool=true + db::SQLiteDB, init, step::Function, final::Function=identity; + nargs::Int=-1, name::AbstractString=string(step), isdeterm::Bool=true ) @assert nargs <= 127 "use -1 if > 127 arguments are needed" nargs < -1 && (nargs = -1) diff --git a/test/runtests.jl b/test/runtests.jl index 9200e81..51f854d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -215,8 +215,7 @@ s = query(db, "SELECT UnitPrice FROM Track") @test_approx_eq r[1][1] 2*sum(s[1]) mycount(p, c) = p + 1 -mycount(p) = p -register(db, 0, mycount, mycount) +register(db, 0, mycount) r = query(db, "SELECT mycount(TrackId) FROM PlaylistTrack") s = query(db, "SELECT count(TrackId) FROM PlaylistTrack") @test r[1] == s[1] From 1d9a7d043f24f8ff69852a3f5691aa11b91e6396 Mon Sep 17 00:00:00 2001 From: Sean Marshallsay Date: Thu, 11 Dec 2014 13:27:56 +0000 Subject: [PATCH 10/15] Test more complex types. --- test/runtests.jl | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 51f854d..26ebf9b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -220,6 +220,29 @@ r = query(db, "SELECT mycount(TrackId) FROM PlaylistTrack") s = query(db, "SELECT count(TrackId) FROM PlaylistTrack") @test r[1] == s[1] +bigsum(p, c) = p + big(c) +register(db, big(0), bigsum) +r = query(db, "SELECT bigsum(TrackId) FROM PlaylistTrack") +s = query(db, "SELECT TrackId FROM PlaylistTrack") +@test r[1][1] == big(sum(s[1])) + +query(db, "CREATE TABLE points (x INT, y INT, z INT)") +query(db, "INSERT INTO points VALUES (?, ?, ?)", [1, 2, 3]) +query(db, "INSERT INTO points VALUES (?, ?, ?)", [4, 5, 6]) +query(db, "INSERT INTO points VALUES (?, ?, ?)", [7, 8, 9]) +type Point3D{T<:Number} + x::T + y::T + z::T +end +==(a::Point3D, b::Point3D) = a.x == b.x && a.y == b.y && a.z == b.z ++(a::Point3D, b::Point3D) = Point3D(a.x + b.x, a.y + b.y, a.z + b.z) +sumpoint(p::Point3D, x, y, z) = p + Point3D(x, y, z) +register(db, Point3D(0, 0, 0), sumpoint) +r = query(db, "SELECT sumpoint(x, y, z) FROM points") +@test r[1][1] == Point3D(12, 15, 18) +drop(db, "points") + db2 = SQLiteDB() query(db2, "CREATE TABLE tab1 (r REAL, s INT)") @@ -237,7 +260,6 @@ drop(db2, "tab2", ifexists=true) close(db2) - @test size(tables(db)) == (11,1) close(db) From 685606d81878ea81632d0e9d0e5ab380f4ac49e2 Mon Sep 17 00:00:00 2001 From: Sean Marshallsay Date: Thu, 11 Dec 2014 13:49:59 +0000 Subject: [PATCH 11/15] Use unsafe store instead of copy. No performance change but reduces bytes allocated by ~5%. --- src/UDF.jl | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/UDF.jl b/src/UDF.jl index 60da3e8..d3a3dcb 100644 --- a/src/UDF.jl +++ b/src/UDF.jl @@ -113,20 +113,16 @@ function stepfunc(init, func, fsym=symbol(string(func)*"_step")) # copy serialized return value unsafe_copy!(valptr, pointer(funcret), newsize) - # following copies are easier with arrays - acptr = pointer_to_array(acptr, acsize, false) # copy the size of the serialized value unsafe_copy!( - acptr, 1, - reinterpret(UInt8, [newsize]), 1, - intsize, + acptr, + pointer(reinterpret(UInt8, [newsize])), + intsize ) # copy the address of the pointer to the serialized value - unsafe_copy!( - acptr, intsize+1, - reinterpret(UInt8, [valptr]), 1, - ptrsize, - ) + for (i, byte) in enumerate(reinterpret(UInt8, [valptr])) + unsafe_store!(acptr, byte, intsize+i) + end catch # TODO: # this won't catch all memory leaks so add an else clause From c130f1d68893a9f2472a14073642575f06a89bc0 Mon Sep 17 00:00:00 2001 From: Sean Marshallsay Date: Thu, 11 Dec 2014 13:58:29 +0000 Subject: [PATCH 12/15] Even fewer bytes allocated. --- src/UDF.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/UDF.jl b/src/UDF.jl index d3a3dcb..c51ee7f 100644 --- a/src/UDF.jl +++ b/src/UDF.jl @@ -120,8 +120,9 @@ function stepfunc(init, func, fsym=symbol(string(func)*"_step")) intsize ) # copy the address of the pointer to the serialized value - for (i, byte) in enumerate(reinterpret(UInt8, [valptr])) - unsafe_store!(acptr, byte, intsize+i) + valarr = reinterpret(UInt8, [valptr]) + for i in 1:length(valarr) + unsafe_store!(acptr, a[i], intsize+i) end catch # TODO: From fa7742d43c715717cc1234f0ea888fec265a5fa3 Mon Sep 17 00:00:00 2001 From: Sean Marshallsay Date: Thu, 11 Dec 2014 18:54:24 +0000 Subject: [PATCH 13/15] Fix undefined variable. --- src/UDF.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/UDF.jl b/src/UDF.jl index c51ee7f..4dfe2f3 100644 --- a/src/UDF.jl +++ b/src/UDF.jl @@ -122,7 +122,7 @@ function stepfunc(init, func, fsym=symbol(string(func)*"_step")) # copy the address of the pointer to the serialized value valarr = reinterpret(UInt8, [valptr]) for i in 1:length(valarr) - unsafe_store!(acptr, a[i], intsize+i) + unsafe_store!(acptr, valarr[i], intsize+i) end catch # TODO: From 57c19cab735cb23602b016cbdf59fd615a252de6 Mon Sep 17 00:00:00 2001 From: Sean Marshallsay Date: Fri, 12 Dec 2014 17:26:10 +0000 Subject: [PATCH 14/15] Reduce scope of try-statements. The only unknown is the users function, any other exceptions are bugs in SQLite.jl and should be fixed not hidden away. --- src/UDF.jl | 61 ++++++++++++++++++++++++++---------------------------- 1 file changed, 29 insertions(+), 32 deletions(-) diff --git a/src/UDF.jl b/src/UDF.jl index 4dfe2f3..28f14c6 100644 --- a/src/UDF.jl +++ b/src/UDF.jl @@ -102,37 +102,33 @@ function stepfunc(init, func, fsym=symbol(string(func)*"_step")) acval = sqldeserialize(acvalbuf) end + local funcret try funcret = sqlserialize($(func)(acval, args...)) - - newsize = sizeof(funcret) - if newsize > valsize - # TODO: increase this in a cleverer way? - valptr = convert(Ptr{UInt8}, c_realloc(valptr, newsize)) - end - # copy serialized return value - unsafe_copy!(valptr, pointer(funcret), newsize) - - # copy the size of the serialized value - unsafe_copy!( - acptr, - pointer(reinterpret(UInt8, [newsize])), - intsize - ) - # copy the address of the pointer to the serialized value - valarr = reinterpret(UInt8, [valptr]) - for i in 1:length(valarr) - unsafe_store!(acptr, valarr[i], intsize+i) - end catch - # TODO: - # this won't catch all memory leaks so add an else clause - # alternatively use c-style checking in this function - if isdefined(:valptr) - c_free(valptr) - end + c_free(valptr) rethrow() end + + newsize = sizeof(funcret) + if newsize > valsize + # TODO: increase this in a cleverer way? + valptr = convert(Ptr{UInt8}, c_realloc(valptr, newsize)) + end + # copy serialized return value + unsafe_copy!(valptr, pointer(funcret), newsize) + + # copy the size of the serialized value + unsafe_copy!( + acptr, + pointer(reinterpret(UInt8, [newsize])), + intsize + ) + # copy the address of the pointer to the serialized value + valarr = reinterpret(UInt8, [valptr]) + for i in 1:length(valarr) + unsafe_store!(acptr, valarr[i], intsize+i) + end nothing end end @@ -159,17 +155,18 @@ function finalfunc(init, func, fsym=symbol(string(func)*"_final")) Ptr{UInt8}, bytestoint(acptr, intsize+1, ptrsize) ) - try - # load value - acvalbuf = zeros(UInt8, valsize) - unsafe_copy!(pointer(acvalbuf), valptr, valsize) - acval = sqldeserialize(acvalbuf) + # load value + acvalbuf = zeros(UInt8, valsize) + unsafe_copy!(pointer(acvalbuf), valptr, valsize) + acval = sqldeserialize(acvalbuf) + local ret + try ret = $(func)(acval) - sqlreturn(context, ret) finally c_free(valptr) end + sqlreturn(context, ret) end nothing end From ef1d1a9e544dbe5fa1be6bfb7d4eff5df166358e Mon Sep 17 00:00:00 2001 From: Sean Marshallsay Date: Sun, 14 Dec 2014 12:41:38 +0000 Subject: [PATCH 15/15] Check malloc and realloc for memory errors. --- src/UDF.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/UDF.jl b/src/UDF.jl index 28f14c6..6be2011 100644 --- a/src/UDF.jl +++ b/src/UDF.jl @@ -89,6 +89,7 @@ function stepfunc(init, func, fsym=symbol(string(func)*"_step")) valsize = 256 # avoid the garbage collector using malloc valptr = convert(Ptr{UInt8}, c_malloc(valsize)) + valptr == C_NULL && throw(SQLiteException("memory error")) else # size of serialized value is first sizeof(Int) bytes valsize = bytestoint(acptr, 1, intsize) @@ -113,7 +114,13 @@ function stepfunc(init, func, fsym=symbol(string(func)*"_step")) newsize = sizeof(funcret) if newsize > valsize # TODO: increase this in a cleverer way? - valptr = convert(Ptr{UInt8}, c_realloc(valptr, newsize)) + tmp = convert(Ptr{UInt8}, c_realloc(valptr, newsize)) + if tmp == C_NULL + c_free(valptr) + throw(SQLiteException("memory error")) + else + valptr = tmp + end end # copy serialized return value unsafe_copy!(valptr, pointer(funcret), newsize)