Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Illegal type analysis error with gemv call #1020

Closed
gaurav-arya opened this issue Aug 21, 2023 · 1 comment · Fixed by EnzymeAD/Enzyme#1433
Closed

Illegal type analysis error with gemv call #1020

gaurav-arya opened this issue Aug 21, 2023 · 1 comment · Fixed by EnzymeAD/Enzyme#1433

Comments

@gaurav-arya
Copy link
Member

gaurav-arya commented Aug 21, 2023

With a slightly modified version of #1004, where part of the loop body is put into a function, we get an error "Enzyme compilation failed due to illegal type analysis.", on Enzyme.jl master. Note that manually inlining mymul! resolves the error. (Edit: also, using mul! rather than the gemv call directly also resolves the error.) MWE:

using Enzyme, LinearAlgebra

LinearAlgebra.BLAS.set_num_threads(1)
Enzyme.Compiler.bitcode_replacement!(true)

@inline function coupled_springs(K, m, x0, v0, T)
    Ktmp = copy(K)
    xtmp = copy(x0)
    vtmp = copy(v0)
    N = length(m)
    pX = pointer(xtmp)
    pY = pointer(vtmp)
    pA = pointer(Ktmp)
    for j in 1:5000
        mymul!(vtmp, Ktmp, xtmp, pX, pY, pA)
        xtmp .+= vtmp ./ 5000
    end
    return @inbounds xtmp[1]
end

@inline function mymul!(vtmp, Ktmp, xtmp, pX, pY, pA) 
    sX = 1
    sY = 1
    GC.@preserve vtmp Ktmp xtmp ccall((:dgemv_64_, LinearAlgebra.BLAS.libblastrampoline), Cvoid,
    (Ref{UInt8}, Ref{LinearAlgebra.BLAS.BlasInt}, Ref{LinearAlgebra.BLAS.BlasInt}, Ref{Float64},
    Ptr{Float64}, Ref{LinearAlgebra.BLAS.BlasInt}, Ptr{Float64}, Ref{LinearAlgebra.BLAS.BlasInt},
    Ref{Float64}, Ptr{Float64}, Ref{LinearAlgebra.BLAS.BlasInt}, Clong),
    'N', N, N, 1/5000,
    pA, N, pX, sX,
    1.0, pY, sY, 1)
end

function make_args(N)
    K = ones(N, N)#collect(reshape(1:N^2, N, N))
    K[diagind(K)] .= 0
    m = 0.5 .+ 0.5 * rand(N)
    x0 = float.(collect(1:N)) ./ N
    v0 = zeros(N)
    T = 1.0
    return K, m, x0, v0, T
end

function enzyme_inputs(K, m, x0, v0, T)
    dK = zero(K)
    dm = zero(m)
    dx0 = zero(x0)
    dv0 = zero(v0)
    return Duplicated(K, dK), Duplicated(m, dm), Duplicated(x0, dx0), Duplicated(v0, dv0), Const(T)
end

function enzyme_gradient(args...)
    inputs = enzyme_inputs(args...)
    dK = inputs[1].dval
    Enzyme.autodiff(Reverse, Const(coupled_springs), inputs...)
    return dK
end

N = 200
args = make_args(N)
coupled_springs(args...)
enzyme_gradient(args...)

Error:
error.ll.txt

@gaurav-arya gaurav-arya changed the title Illegable type analysis error with gemv call Illegal type analysis error with gemv call Aug 21, 2023
@wsmoses
Copy link
Member

wsmoses commented Aug 23, 2023

I think this should've been fixed by: EnzymeAD/Enzyme#1369

I assume you're using the custom jll here. Can you make sure whatever branch you're based on is rebased on main (past that commit)

cc @ZuseZ4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants