diff --git a/README.md b/README.md index 0acec24..1a366b3 100644 --- a/README.md +++ b/README.md @@ -76,14 +76,14 @@ A Julia interface to the SQLite library and support for operations on DataFrames `drop` is pretty self-explanatory. It's really just a convenience wrapper around `query` to execute a DROP TABLE command, while also calling "VACUUM" to clean out freed memory from the database. -* `registerfunc(db::SQLiteDB, nargs::Int, func::Function, isdeterm::Bool=true; name="")` +* `register(db::SQLiteDB, func::Function; nargs::Int=-1, name::AbstractString=string(func), isdeterm::Bool=true)` +* `register(db::SQLiteDB, init, step::Function, final::Function=identity; nargs::Int=-1, name::AbstractString=string(final), isdeterm::Bool=true)` - Register a function `func` (which takes `nargs` number of arguments) with the SQLite database connection `db`. If the keyword argument `name` is given the function is registered with that name, otherwise it is registered with the name of `func`. If the function is stochastic (e.g. uses a random number) `isdeterm` should be set to `false`, see SQLite's [function creation documentation](http://sqlite.org/c3ref/create_function.html) for more information. + Register a scalar (first method) or aggregate (second method) function with a `SQLiteDB`. -* `@scalarfunc function` - `@scalarfunc name function` +* `@register db function` - Define a function which can then be passed to `registerfunc`. In the first usage the function name is infered from the function definition, in the second it is explicitly given as the first parameter. The second form is only recommended when it's use is absolutely necessary, see below. + Automatically define then register `function` with a `SQLiteDB`. * `sr"..."` @@ -188,45 +188,31 @@ The sr"..." currently escapes all special characters in a string but it may be c ##### Custom Scalar Functions -SQLite.jl also provides a way that you can implement your own [Scalar Functions](https://www.sqlite.org/lang_corefunc.html) (though [Aggregate Functions](https://www.sqlite.org/lang_aggfunc.html) are not currently supported). This is done using the `registerfunc` function and `@scalarfunc` macro. +SQLite.jl also provides a way that you can implement your own [Scalar Functions](https://www.sqlite.org/lang_corefunc.html). This is done using the `register` function and macro. -`@scalarfunc` takes an optional function name and a function and defines a new function which can be passed to `registerfunc`. It can be used with block function syntax +`@register` takes a `SQLiteDB` and a function. The function can be in block syntax ```julia -julia> @scalarfunc function add3(x) +julia> @register db function add3(x) x + 3 end -add3 (generic function with 1 method) - -julia> @scalarfunc add5 function irrelevantfuncname(x) - x + 5 - end -add5 (generic function with 1 method) ``` inline function syntax ```julia -julia> @scalarfunc mult3(x) = 3 * x -mult3 (generic function with 1 method) - -julia> @scalarfunc mult5 anotherirrelevantname(x) = 5 * x -mult5 (generic function with 1 method) +julia> @register db mult3(x) = 3 * x ``` -and previously defined functions (note that name inference does not work with this method) +and previously defined functions ```julia -julia> @scalarfunc sin sin -sin (generic function with 1 method) - -julia> @scalarfunc subtract - -subtract (generic function with 1 method) +julia> @register db sin ``` -The function that is defined can then be passed to `registerfunc`. `registerfunc` takes three arguments; the database to which the function should be registered, the number of arguments that the function takes and the function itself. The function is registered to the database connection rather than the database itself so must be registered each time the database opens. Your function can not take more than 127 arguments unless it takes a variable number of arguments, if it does take a variable number of arguments then you must pass -1 as the second argument to `registerfunc`. +The `register` function takes optional arguments; `nargs` which defaults to `-1`, `name` which defaults to the name of the function, `isdeterm` which defaults to `true`. In practice these rarely need to be used. -The `@scalarfunc` macro uses the `sqlreturn` function to return your function's return value to SQLite. By default, `sqlreturn` maps the returned value to a [native SQLite type](http://sqlite.org/c3ref/result_blob.html) or, failing that, serializes the julia value and stores it as a `BLOB`. To change this behaviour simply define a new method for `sqlreturn` which then calls a previously defined method for `sqlreturn`. Methods which map to native SQLite types are +The `register` function uses the `sqlreturn` function to return your function's return value to SQLite. By default, `sqlreturn` maps the returned value to a [native SQLite type](http://sqlite.org/c3ref/result_blob.html) or, failing that, serializes the julia value and stores it as a `BLOB`. To change this behaviour simply define a new method for `sqlreturn` which then calls a previously defined method for `sqlreturn`. Methods which map to native SQLite types are ```julia sqlreturn(context, ::NullType) @@ -251,3 +237,23 @@ sqlreturn(context, val::Bool) = sqlreturn(context, int(val)) ``` Any new method defined for `sqlreturn` must take two arguments and must pass the first argument straight through as the first argument. + +#### Custom Aggregate Functions + +Using the `register` function, you can also define your own aggregate functions with largely the same semantics. + +The `register` function for aggregates must take a `SQLiteDB`, an initial value, a step function and a final function. The first argument to the step function will be the return value of the previous function (or the initial value if it is the first iteration). The final function must take a single argument which will be the return value of the last step function. + +```julia +julia> dsum(prev, cur) = prev + cur + +julia> dsum(prev) = 2 * prev + +julia> register(db, 0, dsum, dsum) +``` + +If no name is given the name of the first (step) function is used (in this case "dsum"). You can also use lambdas, the following does the same as the previous code snippet + +```julia +julia> register(db, 0, (p,c) -> p+c, p -> 2p, name="dsum") +``` diff --git a/src/UDF.jl b/src/UDF.jl index 95ac2eb..6be2011 100644 --- a/src/UDF.jl +++ b/src/UDF.jl @@ -54,11 +54,138 @@ function scalarfunc(expr::Expr) f = eval(expr) return scalarfunc(f) end + +# convert a byteptr to an int, assumes little-endian +function bytestoint(ptr::Ptr{UInt8}, start::Int, len::Int) + s = 0 + for i in start:start+len-1 + v = unsafe_load(ptr, i) + s += v * 256^(i - start) + end + + # swap byte-order on big-endian machines + # TODO: this desperately needs testing on a big-endian machine!!!!! + return htol(s) +end + +function stepfunc(init, func, fsym=symbol(string(func)*"_step")) + nm = isdefined(Base,fsym) ? :(Base.$fsym) : fsym + return quote + function $(nm)(context::Ptr{Void}, nargs::Cint, values::Ptr{Ptr{Void}}) + args = [sqlvalue(values, i) for i in 1:nargs] + + intsize = sizeof(Int) + ptrsize = sizeof(Ptr) + acsize = intsize + ptrsize + acptr = convert(Ptr{UInt8}, sqlite3_aggregate_context(context, acsize)) + + # acptr will be zeroed-out if this is the first iteration + ret = ccall( + :memcmp, Cint, (Ptr{UInt8}, Ptr{UInt8}, Cuint), + zeros(UInt8, acsize), acptr, acsize, + ) + if ret == 0 + acval = $(init) + valsize = 256 + # avoid the garbage collector using malloc + valptr = convert(Ptr{UInt8}, c_malloc(valsize)) + valptr == C_NULL && throw(SQLiteException("memory error")) + else + # size of serialized value is first sizeof(Int) bytes + valsize = bytestoint(acptr, 1, intsize) + # ptr to serialized value is last sizeof(Ptr) bytes + valptr = reinterpret( + Ptr{UInt8}, bytestoint(acptr, intsize+1, ptrsize) + ) + # deserialize the value pointed to by valptr + acvalbuf = zeros(UInt8, valsize) + unsafe_copy!(pointer(acvalbuf), valptr, valsize) + acval = sqldeserialize(acvalbuf) + end + + local funcret + try + funcret = sqlserialize($(func)(acval, args...)) + catch + c_free(valptr) + rethrow() + end + + newsize = sizeof(funcret) + if newsize > valsize + # TODO: increase this in a cleverer way? + tmp = convert(Ptr{UInt8}, c_realloc(valptr, newsize)) + if tmp == C_NULL + c_free(valptr) + throw(SQLiteException("memory error")) + else + valptr = tmp + end + end + # copy serialized return value + unsafe_copy!(valptr, pointer(funcret), newsize) + + # copy the size of the serialized value + unsafe_copy!( + acptr, + pointer(reinterpret(UInt8, [newsize])), + intsize + ) + # copy the address of the pointer to the serialized value + valarr = reinterpret(UInt8, [valptr]) + for i in 1:length(valarr) + unsafe_store!(acptr, valarr[i], intsize+i) + end + nothing + end + end +end + +function finalfunc(init, func, fsym=symbol(string(func)*"_final")) + nm = isdefined(Base,fsym) ? :(Base.$fsym) : fsym + return quote + function $(nm)(context::Ptr{Void}, nargs::Cint, values::Ptr{Ptr{Void}}) + acptr = convert(Ptr{UInt8}, sqlite3_aggregate_context(context, 0)) + + # step function wasn't run + if acptr == C_NULL + sqlreturn(context, $(init)) + else + intsize = sizeof(Int) + ptrsize = sizeof(Ptr) + acsize = intsize + ptrsize + + # load size + valsize = bytestoint(acptr, 1, intsize) + # load ptr + valptr = reinterpret( + Ptr{UInt8}, bytestoint(acptr, intsize+1, ptrsize) + ) + + # load value + acvalbuf = zeros(UInt8, valsize) + unsafe_copy!(pointer(acvalbuf), valptr, valsize) + acval = sqldeserialize(acvalbuf) + + local ret + try + ret = $(func)(acval) + finally + c_free(valptr) + end + sqlreturn(context, ret) + end + nothing + end + end +end + # User-facing macro for convenience in registering a simple function # with no configurations needed macro register(db, func) :(register($(esc(db)), $(esc(func)))) end + # User-facing method for registering a Julia function to be used within SQLite function register(db::SQLiteDB, func::Function; nargs::Int=-1, name::AbstractString=string(func), isdeterm::Bool=true) @assert nargs <= 127 "use -1 if > 127 arguments are needed" @@ -78,6 +205,28 @@ function register(db::SQLiteDB, func::Function; nargs::Int=-1, name::AbstractStr ) end +# as above but for aggregate functions +function register( + db::SQLiteDB, init, step::Function, final::Function=identity; + nargs::Int=-1, name::AbstractString=string(step), isdeterm::Bool=true +) + @assert nargs <= 127 "use -1 if > 127 arguments are needed" + nargs < -1 && (nargs = -1) + @assert sizeof(name) <= 255 "size of function name must be <= 255 chars" + + s = eval(stepfunc(init, step, Base.function_name(step))) + cs = cfunction(s, Nothing, (Ptr{Void}, Cint, Ptr{Ptr{Void}})) + f = eval(finalfunc(init, final, Base.function_name(final))) + cf = cfunction(f, Nothing, (Ptr{Void}, Cint, Ptr{Ptr{Void}})) + + enc = SQLITE_UTF8 + enc = isdeterm ? enc | SQLITE_DETERMINISTIC : enc + + @CHECK db sqlite3_create_function_v2( + db.handle, name, nargs, enc, C_NULL, C_NULL, cs, cf, C_NULL + ) +end + # annotate types because the MethodError makes more sense that way regexp(r::AbstractString, s::AbstractString) = ismatch(Regex(r), s) # macro for preserving the special characters in a string diff --git a/test/runtests.jl b/test/runtests.jl index 5e8cec6..26ebf9b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -207,6 +207,41 @@ SQLite.@register db big r = query(db, "SELECT big(5)") @test r[1][1] == big(5) +doublesum_step(persist, current) = persist + current +doublesum_final(persist) = 2 * persist +register(db, 0, doublesum_step, doublesum_final, name="doublesum") +r = query(db, "SELECT doublesum(UnitPrice) FROM Track") +s = query(db, "SELECT UnitPrice FROM Track") +@test_approx_eq r[1][1] 2*sum(s[1]) + +mycount(p, c) = p + 1 +register(db, 0, mycount) +r = query(db, "SELECT mycount(TrackId) FROM PlaylistTrack") +s = query(db, "SELECT count(TrackId) FROM PlaylistTrack") +@test r[1] == s[1] + +bigsum(p, c) = p + big(c) +register(db, big(0), bigsum) +r = query(db, "SELECT bigsum(TrackId) FROM PlaylistTrack") +s = query(db, "SELECT TrackId FROM PlaylistTrack") +@test r[1][1] == big(sum(s[1])) + +query(db, "CREATE TABLE points (x INT, y INT, z INT)") +query(db, "INSERT INTO points VALUES (?, ?, ?)", [1, 2, 3]) +query(db, "INSERT INTO points VALUES (?, ?, ?)", [4, 5, 6]) +query(db, "INSERT INTO points VALUES (?, ?, ?)", [7, 8, 9]) +type Point3D{T<:Number} + x::T + y::T + z::T +end +==(a::Point3D, b::Point3D) = a.x == b.x && a.y == b.y && a.z == b.z ++(a::Point3D, b::Point3D) = Point3D(a.x + b.x, a.y + b.y, a.z + b.z) +sumpoint(p::Point3D, x, y, z) = p + Point3D(x, y, z) +register(db, Point3D(0, 0, 0), sumpoint) +r = query(db, "SELECT sumpoint(x, y, z) FROM points") +@test r[1][1] == Point3D(12, 15, 18) +drop(db, "points") db2 = SQLiteDB() query(db2, "CREATE TABLE tab1 (r REAL, s INT)") @@ -225,7 +260,6 @@ drop(db2, "tab2", ifexists=true) close(db2) - @test size(tables(db)) == (11,1) close(db)