diff --git a/docs/src/index.md b/docs/src/index.md index adeaea04c..f6262c426 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -30,6 +30,18 @@ s = delayed(combine)(p, q, r) @assert collect(s) == 16 ``` + +The above computation can also be written in a more Julia-idiomatic syntax with `@par`: + +```julia +p = @par add1(4) +q = @par add2(p) +r = @par add1(3) +s = @par combine(p, q, r) + +@assert collect(s) == 16 +``` + The connections between nodes `p`, `q`, `r` and `s` is represented by this dependency graph: ![graph](https://user-images.githubusercontent.com/25916/26920104-7b9b5fa4-4c55-11e7-97fb-fe5b9e73cae6.png) diff --git a/src/thunk.jl b/src/thunk.jl index dcd057f06..cafb6ef5e 100644 --- a/src/thunk.jl +++ b/src/thunk.jl @@ -75,6 +75,19 @@ end delayedmap(f, xs...) = map(delayed(f), xs...) +""" + @par f(args...) -> Thunk + +Convenience macro to call `Dagger.delayed` on `f` with arguments `args`. +""" +macro par(ex) + @assert ex.head == :call "@par requires a function call as the argument" + f = ex.args[1] + args = ex.args[2:end] + # TODO: Support kwargs + :(Dagger.delayed(f)($(args...))) +end + persist!(t::Thunk) = (t.persist=true; t) cache_result!(t::Thunk) = (t.cache=true; t) diff --git a/test/runtests.jl b/test/runtests.jl index d9241ce35..12c829be7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,6 +6,7 @@ using Dagger include("fakeproc.jl") +include("thunk.jl") include("domain.jl") include("array.jl") include("scheduler.jl") diff --git a/test/thunk.jl b/test/thunk.jl new file mode 100644 index 000000000..635b0c160 --- /dev/null +++ b/test/thunk.jl @@ -0,0 +1,8 @@ +@testset "@par" begin + x = 2 + a = @par x + x + @test a isa Dagger.Thunk + b = @par sum([x,1,2]) + c = @par a * b + @test collect(c) == 20 +end