Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 33 additions & 27 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,14 @@ A Julia interface to the SQLite library and support for operations on DataFrames

`drop` is pretty self-explanatory. It's really just a convenience wrapper around `query` to execute a DROP TABLE command, while also calling "VACUUM" to clean out freed memory from the database.

* `registerfunc(db::SQLiteDB, nargs::Int, func::Function, isdeterm::Bool=true; name="")`
* `register(db::SQLiteDB, func::Function; nargs::Int=-1, name::AbstractString=string(func), isdeterm::Bool=true)`
* `register(db::SQLiteDB, init, step::Function, final::Function=identity; nargs::Int=-1, name::AbstractString=string(final), isdeterm::Bool=true)`

Register a function `func` (which takes `nargs` number of arguments) with the SQLite database connection `db`. If the keyword argument `name` is given the function is registered with that name, otherwise it is registered with the name of `func`. If the function is stochastic (e.g. uses a random number) `isdeterm` should be set to `false`, see SQLite's [function creation documentation](http://sqlite.org/c3ref/create_function.html) for more information.
Register a scalar (first method) or aggregate (second method) function with a `SQLiteDB`.

* `@scalarfunc function`
`@scalarfunc name function`
* `@register db function`

Define a function which can then be passed to `registerfunc`. In the first usage the function name is infered from the function definition, in the second it is explicitly given as the first parameter. The second form is only recommended when it's use is absolutely necessary, see below.
Automatically define then register `function` with a `SQLiteDB`.

* `sr"..."`

Expand Down Expand Up @@ -188,45 +188,31 @@ The sr"..." currently escapes all special characters in a string but it may be c

##### Custom Scalar Functions

SQLite.jl also provides a way that you can implement your own [Scalar Functions](https://www.sqlite.org/lang_corefunc.html) (though [Aggregate Functions](https://www.sqlite.org/lang_aggfunc.html) are not currently supported). This is done using the `registerfunc` function and `@scalarfunc` macro.
SQLite.jl also provides a way that you can implement your own [Scalar Functions](https://www.sqlite.org/lang_corefunc.html). This is done using the `register` function and macro.

`@scalarfunc` takes an optional function name and a function and defines a new function which can be passed to `registerfunc`. It can be used with block function syntax
`@register` takes a `SQLiteDB` and a function. The function can be in block syntax

```julia
julia> @scalarfunc function add3(x)
julia> @register db function add3(x)
x + 3
end
add3 (generic function with 1 method)

julia> @scalarfunc add5 function irrelevantfuncname(x)
x + 5
end
add5 (generic function with 1 method)
```

inline function syntax

```julia
julia> @scalarfunc mult3(x) = 3 * x
mult3 (generic function with 1 method)

julia> @scalarfunc mult5 anotherirrelevantname(x) = 5 * x
mult5 (generic function with 1 method)
julia> @register db mult3(x) = 3 * x
```

and previously defined functions (note that name inference does not work with this method)
and previously defined functions

```julia
julia> @scalarfunc sin sin
sin (generic function with 1 method)

julia> @scalarfunc subtract -
subtract (generic function with 1 method)
julia> @register db sin
```

The function that is defined can then be passed to `registerfunc`. `registerfunc` takes three arguments; the database to which the function should be registered, the number of arguments that the function takes and the function itself. The function is registered to the database connection rather than the database itself so must be registered each time the database opens. Your function can not take more than 127 arguments unless it takes a variable number of arguments, if it does take a variable number of arguments then you must pass -1 as the second argument to `registerfunc`.
The `register` function takes optional arguments; `nargs` which defaults to `-1`, `name` which defaults to the name of the function, `isdeterm` which defaults to `true`. In practice these rarely need to be used.

The `@scalarfunc` macro uses the `sqlreturn` function to return your function's return value to SQLite. By default, `sqlreturn` maps the returned value to a [native SQLite type](http://sqlite.org/c3ref/result_blob.html) or, failing that, serializes the julia value and stores it as a `BLOB`. To change this behaviour simply define a new method for `sqlreturn` which then calls a previously defined method for `sqlreturn`. Methods which map to native SQLite types are
The `register` function uses the `sqlreturn` function to return your function's return value to SQLite. By default, `sqlreturn` maps the returned value to a [native SQLite type](http://sqlite.org/c3ref/result_blob.html) or, failing that, serializes the julia value and stores it as a `BLOB`. To change this behaviour simply define a new method for `sqlreturn` which then calls a previously defined method for `sqlreturn`. Methods which map to native SQLite types are

```julia
sqlreturn(context, ::NullType)
Expand All @@ -251,3 +237,23 @@ sqlreturn(context, val::Bool) = sqlreturn(context, int(val))
```

Any new method defined for `sqlreturn` must take two arguments and must pass the first argument straight through as the first argument.

#### Custom Aggregate Functions

Using the `register` function, you can also define your own aggregate functions with largely the same semantics.

The `register` function for aggregates must take a `SQLiteDB`, an initial value, a step function and a final function. The first argument to the step function will be the return value of the previous function (or the initial value if it is the first iteration). The final function must take a single argument which will be the return value of the last step function.

```julia
julia> dsum(prev, cur) = prev + cur

julia> dsum(prev) = 2 * prev

julia> register(db, 0, dsum, dsum)
```

If no name is given the name of the first (step) function is used (in this case "dsum"). You can also use lambdas, the following does the same as the previous code snippet

```julia
julia> register(db, 0, (p,c) -> p+c, p -> 2p, name="dsum")
```
149 changes: 149 additions & 0 deletions src/UDF.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,138 @@ function scalarfunc(expr::Expr)
f = eval(expr)
return scalarfunc(f)
end

# convert a byteptr to an int, assumes little-endian
function bytestoint(ptr::Ptr{UInt8}, start::Int, len::Int)
s = 0
for i in start:start+len-1
v = unsafe_load(ptr, i)
s += v * 256^(i - start)
end

# swap byte-order on big-endian machines
# TODO: this desperately needs testing on a big-endian machine!!!!!
return htol(s)
end

function stepfunc(init, func, fsym=symbol(string(func)*"_step"))
nm = isdefined(Base,fsym) ? :(Base.$fsym) : fsym
return quote
function $(nm)(context::Ptr{Void}, nargs::Cint, values::Ptr{Ptr{Void}})
args = [sqlvalue(values, i) for i in 1:nargs]

intsize = sizeof(Int)
ptrsize = sizeof(Ptr)
acsize = intsize + ptrsize
acptr = convert(Ptr{UInt8}, sqlite3_aggregate_context(context, acsize))

# acptr will be zeroed-out if this is the first iteration
ret = ccall(
:memcmp, Cint, (Ptr{UInt8}, Ptr{UInt8}, Cuint),
zeros(UInt8, acsize), acptr, acsize,
)
if ret == 0
acval = $(init)
valsize = 256
# avoid the garbage collector using malloc
valptr = convert(Ptr{UInt8}, c_malloc(valsize))
valptr == C_NULL && throw(SQLiteException("memory error"))
else
# size of serialized value is first sizeof(Int) bytes
valsize = bytestoint(acptr, 1, intsize)
# ptr to serialized value is last sizeof(Ptr) bytes
valptr = reinterpret(
Ptr{UInt8}, bytestoint(acptr, intsize+1, ptrsize)
)
# deserialize the value pointed to by valptr
acvalbuf = zeros(UInt8, valsize)
unsafe_copy!(pointer(acvalbuf), valptr, valsize)
acval = sqldeserialize(acvalbuf)
end

local funcret
try
funcret = sqlserialize($(func)(acval, args...))
catch
c_free(valptr)
rethrow()
end

newsize = sizeof(funcret)
if newsize > valsize
# TODO: increase this in a cleverer way?
tmp = convert(Ptr{UInt8}, c_realloc(valptr, newsize))
if tmp == C_NULL
c_free(valptr)
throw(SQLiteException("memory error"))
else
valptr = tmp
end
end
# copy serialized return value
unsafe_copy!(valptr, pointer(funcret), newsize)

# copy the size of the serialized value
unsafe_copy!(
acptr,
pointer(reinterpret(UInt8, [newsize])),
intsize
)
# copy the address of the pointer to the serialized value
valarr = reinterpret(UInt8, [valptr])
for i in 1:length(valarr)
unsafe_store!(acptr, valarr[i], intsize+i)
end
nothing
end
end
end

function finalfunc(init, func, fsym=symbol(string(func)*"_final"))
nm = isdefined(Base,fsym) ? :(Base.$fsym) : fsym
return quote
function $(nm)(context::Ptr{Void}, nargs::Cint, values::Ptr{Ptr{Void}})
acptr = convert(Ptr{UInt8}, sqlite3_aggregate_context(context, 0))

# step function wasn't run
if acptr == C_NULL
sqlreturn(context, $(init))
else
intsize = sizeof(Int)
ptrsize = sizeof(Ptr)
acsize = intsize + ptrsize

# load size
valsize = bytestoint(acptr, 1, intsize)
# load ptr
valptr = reinterpret(
Ptr{UInt8}, bytestoint(acptr, intsize+1, ptrsize)
)

# load value
acvalbuf = zeros(UInt8, valsize)
unsafe_copy!(pointer(acvalbuf), valptr, valsize)
acval = sqldeserialize(acvalbuf)

local ret
try
ret = $(func)(acval)
finally
c_free(valptr)
end
sqlreturn(context, ret)
end
nothing
end
end
end

# User-facing macro for convenience in registering a simple function
# with no configurations needed
macro register(db, func)
:(register($(esc(db)), $(esc(func))))
end

# User-facing method for registering a Julia function to be used within SQLite
function register(db::SQLiteDB, func::Function; nargs::Int=-1, name::AbstractString=string(func), isdeterm::Bool=true)
@assert nargs <= 127 "use -1 if > 127 arguments are needed"
Expand All @@ -78,6 +205,28 @@ function register(db::SQLiteDB, func::Function; nargs::Int=-1, name::AbstractStr
)
end

# as above but for aggregate functions
function register(
db::SQLiteDB, init, step::Function, final::Function=identity;
nargs::Int=-1, name::AbstractString=string(step), isdeterm::Bool=true
)
@assert nargs <= 127 "use -1 if > 127 arguments are needed"
nargs < -1 && (nargs = -1)
@assert sizeof(name) <= 255 "size of function name must be <= 255 chars"

s = eval(stepfunc(init, step, Base.function_name(step)))
cs = cfunction(s, Nothing, (Ptr{Void}, Cint, Ptr{Ptr{Void}}))
f = eval(finalfunc(init, final, Base.function_name(final)))
cf = cfunction(f, Nothing, (Ptr{Void}, Cint, Ptr{Ptr{Void}}))

enc = SQLITE_UTF8
enc = isdeterm ? enc | SQLITE_DETERMINISTIC : enc

@CHECK db sqlite3_create_function_v2(
db.handle, name, nargs, enc, C_NULL, C_NULL, cs, cf, C_NULL
)
end

# annotate types because the MethodError makes more sense that way
regexp(r::AbstractString, s::AbstractString) = ismatch(Regex(r), s)
# macro for preserving the special characters in a string
Expand Down
36 changes: 35 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,41 @@ SQLite.@register db big
r = query(db, "SELECT big(5)")
@test r[1][1] == big(5)

doublesum_step(persist, current) = persist + current
doublesum_final(persist) = 2 * persist
register(db, 0, doublesum_step, doublesum_final, name="doublesum")
r = query(db, "SELECT doublesum(UnitPrice) FROM Track")
s = query(db, "SELECT UnitPrice FROM Track")
@test_approx_eq r[1][1] 2*sum(s[1])

mycount(p, c) = p + 1
register(db, 0, mycount)
r = query(db, "SELECT mycount(TrackId) FROM PlaylistTrack")
s = query(db, "SELECT count(TrackId) FROM PlaylistTrack")
@test r[1] == s[1]

bigsum(p, c) = p + big(c)
register(db, big(0), bigsum)
r = query(db, "SELECT bigsum(TrackId) FROM PlaylistTrack")
s = query(db, "SELECT TrackId FROM PlaylistTrack")
@test r[1][1] == big(sum(s[1]))

query(db, "CREATE TABLE points (x INT, y INT, z INT)")
query(db, "INSERT INTO points VALUES (?, ?, ?)", [1, 2, 3])
query(db, "INSERT INTO points VALUES (?, ?, ?)", [4, 5, 6])
query(db, "INSERT INTO points VALUES (?, ?, ?)", [7, 8, 9])
type Point3D{T<:Number}
x::T
y::T
z::T
end
==(a::Point3D, b::Point3D) = a.x == b.x && a.y == b.y && a.z == b.z
+(a::Point3D, b::Point3D) = Point3D(a.x + b.x, a.y + b.y, a.z + b.z)
sumpoint(p::Point3D, x, y, z) = p + Point3D(x, y, z)
register(db, Point3D(0, 0, 0), sumpoint)
r = query(db, "SELECT sumpoint(x, y, z) FROM points")
@test r[1][1] == Point3D(12, 15, 18)
drop(db, "points")

db2 = SQLiteDB()
query(db2, "CREATE TABLE tab1 (r REAL, s INT)")
Expand All @@ -225,7 +260,6 @@ drop(db2, "tab2", ifexists=true)

close(db2)


@test size(tables(db)) == (11,1)

close(db)
Expand Down