Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closure with kwarg constructor #1541

Closed
michel2323 opened this issue Jun 17, 2024 · 1 comment
Closed

Closure with kwarg constructor #1541

michel2323 opened this issue Jun 17, 2024 · 1 comment

Comments

@michel2323
Copy link
Collaborator

  args = LLVM.Value[LLVM.AddrSpaceCastInst(%70 = addrspacecast [1 x double] addrspace(10)* %69 to [1 x double] addrspace(11)*, !dbg !46), LLVM.ConstantInt(0x000000000815b670), LLVM.AddrSpaceCastInst(%46 = addrspacecast [2 x [1 x {} addrspace(10)*]] addrspace(10)* %45 to [2 x [1 x {} addrspace(10)*]] addrspace(11)*, !dbg !46), LLVM.AddrSpaceCastInst(%60 = addrspacecast [1 x double] addrspace(10)* %59 to [1 x double] addrspace(11)*, !dbg !46)]
  i = 1
  args[i] = LLVM.AddrSpaceCastInst(%70 = addrspacecast [1 x double] addrspace(10)* %69 to [1 x double] addrspace(11)*, !dbg !46)
  party = LLVM.IntegerType(i64)
  ctype = LLVM.PointerType([1 x double] addrspace(11)*)
  tape = LLVM.IntegerType(i64)
  val =   %70 = addrspacecast [1 x double] addrspace(10)* %69 to [1 x double] addrspace(11)*, !dbg !46
  prev = i64 undef
  lidxs = UInt32[]
  ridxs = UInt32[]
  tape_type(tape) = UInt64
  convert(LLVMType, tape_type(tape)) = LLVM.IntegerType(i64)
using Enzyme
using Enzyme.EnzymeCore
using Enzyme.EnzymeCore.EnzymeRules
using Test

# struct Closure
struct Closure
    v::Vector{Float64}
    Closure(;v::Vector{Float64}) = new(v)
end

function (cl::Closure)(x)
    val = cl.v[1] * x
    cl.v[1] = 0.0
    return val
end


function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{Closure},
    ::Type{<:Active}, args::Vararg{Active,N}; kwargs...) where {N}
    @show args
    @show kwargs
    vec = copy(func.val.v)
    pval = func.val(args[1].val)
    primal = if EnzymeRules.needs_primal(config)
        pval
    else
        nothing
    end
    return AugmentedReturn(primal, nothing, vec)
end

function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{Closure},
    dret::Active, tape, args::Vararg{Active,N}; kwargs...) where {N}
    @show kwargs
    dargs = ntuple(Val(N)) do i
        7 * args[1].val * dret.val + tape[1] * 1000
    end
    return dargs
end

function driver(x, y)
    cl = Closure(; v=[y])
    cl(x)
end

@testset "Closure rule" begin
    driver(2.7, 3.14)
    res = autodiff(Reverse, driver, Active, Active(2.7), Active(3.14))[1][1]
    @show res

    @test res  7 * 2.7 + 3.14 * 1000
    @test cl.v[1]  0.0
end
@wsmoses
Copy link
Member

wsmoses commented Jun 18, 2024

Fixed by #1543

@wsmoses wsmoses closed this as completed Jun 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants