Skip to content
This repository has been archived by the owner on Jun 24, 2022. It is now read-only.

Commit

Permalink
fix #22
Browse files Browse the repository at this point in the history
  • Loading branch information
shashi committed Mar 20, 2020
1 parent f8311e0 commit 5cfb1a8
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 6 deletions.
35 changes: 29 additions & 6 deletions src/blas.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,45 @@
# generic implementations

_name(x::Symbol) = x
_name(x::Expr) = (@assert x.head == :(::); x.args[1])
macro reroute(f, g)
fname = f.args[1]
fargs = f.args[2:end]
quote
@inline function Cassette.overdub(ctx::SparsityContext,
f::typeof($(esc(fname))),
$(fargs...))
Cassette.recurse(
ctx,
invoke,
f,
$(esc(:(Tuple{$(g.args[2:end]...)}))),
$(map(_name, fargs)...))
end

@inline function Cassette.overdub(ctx::HessianSparsityContext,
f::typeof($(esc(f))),
f::typeof($(esc(fname))),
args...)
println("rerouted")
Cassette.overdub(
Cassette.recurse(
ctx,
invoke,
$(esc(g.args[1])),
$(esc(:(Tuple{$(g.args[2:end]...)}))),
args...)
$(map(_name, fargs)...))
end
end
end

@reroute LinearAlgebra.BLAS.dot LinearAlgebra.dot(Any, Any)
@reroute LinearAlgebra.BLAS.axpy! LinearAlgebra.axpy!(Any,
@reroute LinearAlgebra.BLAS.dot(x,y) LinearAlgebra.dot(Any, Any)
@reroute LinearAlgebra.BLAS.axpy!(x, y) LinearAlgebra.axpy!(Any,
AbstractArray,
AbstractArray)
@reroute LinearAlgebra.mul!(y::AbstractVector,
A::AbstractVecOrMat,
x::AbstractVector,
α::Number,
β::Number) LinearAlgebra.mul!(AbstractVector,
AbstractVecOrMat,
AbstractVector,
Number,
Number)
12 changes: 12 additions & 0 deletions test/basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,15 @@ let
@test sparsity!(testsparse!, out, x) == sparse([1,2,1,2,3,2,3,4,3,4],
[1,1,2,2,2,3,3,3,4,4], true)
end

@testset "BLAS" begin
function f(out,in)
A = rand(length(in), length(in))
out .= A * in
return nothing
end

x = [1.0:10;]
out = similar(x)
@test all(sparsity!(f, out, x) .== 1)
end

0 comments on commit 5cfb1a8

Please sign in to comment.