Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 84 additions & 36 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,27 @@ import ..Reactant:
ConcreteRNumber,
TracedRArray,
TracedRNumber,
RArray,
RNumber,
OrderedIdDict,
make_tracer,
TracedToConcrete,
append_path,
TracedType

@inline traced_getfield(@nospecialize(obj), field) = Base.getfield(obj, field)
@inline traced_getfield(
@nospecialize(obj::AbstractArray{<:Union{ConcreteRNumber,ConcreteRArray}}), field
) = Base.getindex(obj, field)
@inline function traced_getfield(@nospecialize(obj::AbstractArray{T}), field) where {T}
(isbitstype(T) || obj isa RArray) && return Base.getfield(obj, field)
return Base.getindex(obj, field)
end

@inline traced_setfield!(@nospecialize(obj), field, val) = Base.setfield!(obj, field, val)
@inline function traced_setfield!(
@nospecialize(obj::AbstractArray{T}), field, val
) where {T}
(isbitstype(T) || obj isa RArray) && return Base.setfield!(obj, field, val)
return Base.setindex!(obj, val, field)
end

function create_result(tocopy::T, path, result_stores) where {T}
if !isstructtype(typeof(tocopy))
Expand Down Expand Up @@ -573,32 +584,32 @@ function codegen_flatten!(linear_args, result_stores)
push!(flatten_code, :($usbuf = $flatcode.data))
push!(flatten_code, :($sbuf = XLA.synced_buffer($usbuf)))

# TODO
respaths = ((p for p in arg.paths if p[1] != :args)...,)
# TODO: unused for the time being
# respaths = ((p for p in arg.paths if p[1] == :result || p[1] == :resargs)...,)

# resarg = false
for respath in respaths
if respath[1] == :result
flatcode = :result
respath = respath[2:end]
result_stores[respath] = usbuf
resarg = true
else
@assert respath[1] == :resargs
if respath[2] != path[2]
continue
end
# flatcode = :(args[$(respath[2])])
path = path[3:end]
end
# for p in path
# flatcode = :(traced_getfield($flatcode, $(Meta.quot(p))))
# end
# resarg = true
# flatcode = :($flatcode.data = $usbuf)
# @show flatcode
# push!(flatten_code, res)
end
# for respath in respaths
# if respath[1] == :result
# flatcode = :result
# respath = respath[2:end]
# result_stores[respath] = usbuf
# resarg = true
# else
# @assert respath[1] == :resargs
# if respath[2] != path[2]
# continue
# end
# # flatcode = :(args[$(respath[2])])
# path = path[3:end]
# end
# # for p in path
# # flatcode = :(traced_getfield($flatcode, $(Meta.quot(p))))
# # end
# # resarg = true
# # flatcode = :($flatcode.data = $usbuf)
# # @show flatcode
# # push!(flatten_code, res)
# end
# if resarg
# push!(resarg_code, :($usbuf = $flatcode.data))
# end
Expand All @@ -620,11 +631,16 @@ function codegen_unflatten!(
concrete_result,
result_stores,
)
unflatten_code = Expr[]
cache_dict = gensym("cache_dict")
unflatten_code = Expr[:(
$cache_dict = $(IdDict{
Union{TracedRArray,TracedRNumber},Union{ConcreteRArray,ConcreteRNumber}
}())
),]

# mutate the result stores to point to the correct concrete results
for (concrete_res_name, result) in zip(concretized_res_names, linear_results)
paths = ((p for p in result.paths if p[1] != :args)...,)
paths = ((p for p in result.paths if p[1] == :result || p[1] == :resargs)...,)
for path in paths
if path[1] == :result
unflatcode = :result
Expand All @@ -635,15 +651,47 @@ function codegen_unflatten!(
@assert path[1] == :resargs
unflatcode = :(args[$(path[2])])
path = path[3:end]
end

# unroll path tree
for p in path
unflatcode = :(traced_getfield($unflatcode, $(Meta.quot(p))))
end
unflatcode = :($unflatcode.data = $concrete_res_name)
for p in path[1:(end - 1)]
unflatcode = :(traced_getfield($unflatcode, $(Meta.quot(p))))
end

push!(unflatten_code, unflatcode)
if length(path) > 0
final_val = gensym("final_val")
clocal = gensym("clocal")
unflatcode = quote
$final_val = traced_getfield($unflatcode, $(Meta.quot(path[end])))
if $final_val isa TracedRArray
$clocal = if haskey($cache_dict, $final_val)
$cache_dict[$final_val]
else
$cache_dict[$final_val] = ConcreteRArray{
eltype($final_val),ndims($final_val)
}(
$concrete_res_name, size($final_val)
)
$cache_dict[$final_val]
end
traced_setfield!($unflatcode, $(Meta.quot(path[end])), $clocal)
elseif $final_val isa TracedRNumber
$clocal = if haskey($cache_dict, $final_val)
$cache_dict[$final_val]
else
$cache_dict[$final_val] = ConcreteRNumber{eltype($final_val)}(
$concrete_res_name
)
$cache_dict[$final_val]
end
traced_setfield!($unflatcode, $(Meta.quot(path[end])), $clocal)
else
traced_setfield!($final_val, :data, $concrete_res_name)
end
end
else
unflatcode = :($unflatcode.data = $concrete_res_name)
end
push!(unflatten_code, unflatcode)
end
end
end

Expand Down
15 changes: 12 additions & 3 deletions src/Interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ function overload_autodiff(
primf = f.val
primargs = ((v.val for v in args)...,)

fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = Reactant.TracedUtils.make_mlir_fn(
fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = TracedUtils.make_mlir_fn(
primf, primargs, (), string(f) * "_autodiff", false
)

Expand Down Expand Up @@ -302,7 +302,7 @@ function overload_autodiff(
cst = MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1)
push!(ad_inputs, cst)
end
else
elseif TracedUtils.has_argidx(a)
idx, path = TracedUtils.get_argidx(a)
if idx == 1 && fnwrap
act = act_from_type(f, reverse, true)
Expand All @@ -322,6 +322,12 @@ function overload_autodiff(
end
TracedUtils.push_val!(ad_inputs, args[idx].dval, path[3:end])
end
else
act = act_from_type(Enzyme.Const, reverse, true)
push!(ret_activity, act)
if act != enzyme_out && act != enzyme_outnoneed
continue
end
end
end

Expand Down Expand Up @@ -385,7 +391,7 @@ function overload_autodiff(
end
residx += 1
end
else
elseif TracedUtils.has_argidx(a)
idx, path = TracedUtils.get_argidx(a)
if idx == 1 && fnwrap
TracedUtils.set!(
Expand All @@ -405,6 +411,9 @@ function overload_autodiff(
)
residx += 1
end
else
TracedUtils.set!(a, (), TracedUtils.transpose_val(MLIR.IR.result(res, residx)))
residx += 1
end
end

Expand Down
12 changes: 12 additions & 0 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,18 @@ function get_argidx(x)
throw(AssertionError("No path found for $x"))
end

function has_argidx(x)
for path in x.paths
if length(path) == 0
continue
end
if path[1] == :args
return true
end
end
return false
end

function set!(x, path, tostore; emptypath=false)
for p in path
x = Reactant.Compiler.traced_getfield(x, p)
Expand Down
Loading
Loading