Skip to content

Commit

Permalink
Handle mixed activity of literal 0 constant
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed May 17, 2024
1 parent cc8ceb6 commit c73c2c0
Showing 1 changed file with 103 additions and 35 deletions.
138 changes: 103 additions & 35 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1754,92 +1754,160 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err
elseif errtype == API.ET_MixedActivityError
data2 = LLVM.Value(data2)
badval = nothing
gutils = GradientUtils(API.EnzymeGradientUtilsRef(data))
# Ignore mismatched activity if phi/store of ghost
todo = LLVM.Value[data2]
seen = Set{LLVM.Value}()
seen = Dict{LLVM.Value, LLVM.Value}()
illegal = false
while length(todo) != 0
cur = pop!(todo)
if cur in seen
continue
end
push!(seen, cur)
if isa(cur, LLVM.PHIInst)
for v in LLVM.incoming(cur)
push!(todo, cur)
end
continue
created = LLVM.Instruction[]
function make_replacement(cur::LLVM.Value, prevbb)::LLVM.Value
ncur = new_from_original(gutils, cur)
if cur in keys(seen)
return seen[cur]
end

legal, TT = abs_typeof(cur, true)
if legal
world = enzyme_extract_world(LLVM.parent(position(IRBuilder(B))))
if guaranteed_const_nongen(TT, world)
continue
return ncur
end

legal2, obj = absint(cur)

if legal2 && active_reg_inner(TT, (), world) == ActiveState && isa(cur, LLVM.ConstantExpr)
res = emit_allocobj!(prevbb, Base.RefValue{TT})
push!(created, res)
return res
end

badval = if legal2
string(obj)*" of type"*" "*string(TT)
else
"Unknown object of type"*" "*string(TT)
end
illegal = true
break
return ncur
end

if isa(cur, LLVM.PointerNull)
continue
return ncur
end
if isa(cur, LLVM.UndefValue)
continue
return ncur
end
@static if LLVM.version() >= v"12"
if isa(cur, LLVM.PoisonValue)
continue
return ncur
end
end
if isa(cur, LLVM.ConstantAggregateZero)
continue
return ncur
end
if isa(cur, LLVM.ConstantAggregate)
continue
return ncur
end
if isa(cur, LLVM.ConstantDataSequential)
cvals = LLVM.Value[]
changed = false
for v in collect(cur)
push!(todo, v)
tmp = make_replacement(v, prevbb)
if illegal
return cur
end
if v != tmp
changed = true
end
push!(todo, tmp)
end
continue

cur2 = if changed
illegal = true
# TODO replace with correct insertions/splats
ncur
else
ncur
end
return cur2
end
if isa(cur, LLVM.ConstantInt)
if width(value_type(cur)) <= 8
continue
return ncur
end
# if storing a constant int as a non-pointer, presume it is not a GC'd var and is safe
# for activity state to mix
if isa(val, LLVM.StoreInst) operands(val)[1] == cur && !isa(value_type(operands(val)[1]), LLVM.PointerType)
continue
return ncur
end
end

if isa(cur, LLVM.PHIInst)
B = IRBuilder()
position!(B, ncur)
phi2 = phi!(prevbb, value_type(cur), "tempphi"*LLVM.name(cur))
seen[cur] = phi2
changed = false
recsize = length(created)
for (v, bb) in LLVM.incoming(cur)
B2 = IRBuilder()
position!(B2, last(instructions(bb)))
tmp = make_replacement(v, B2)
if illegal
changed = true
break
end
if tmp != v && v != cur
changed = true
break
end
push!(LLVM.incoming(phi2), (tmp, bb))
end
if !changed || illegal
LLVM.API.LLVMInstructionEraseFromParent(phi2)
seen[cur] = ncur
plen = length(created)
for i in changed:plen
u = created[i]
replace_uses!(u, LLVM.UndefValue(value_type(u)))
end
for i in changed:plen
u = created[i]
LLVM.API.LLVMInstructionEraseFromParent(u)
end
for i in changed:plen
pop!(created)
end
return ncur
end
push!(created, phi2)
return phi2
end

illegal = true
break
return ncur
end

if !illegal
return C_NULL
newb = new_from_original(gutils, val)
while isa(newb, LLVM.PHIInst)
newb = LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(newb))
end
b = IRBuilder(B)
replacement = make_replacement(data2, b)

if !illegal
return replacement.ref
end
for u in created
replace_uses!(u, LLVM.UndefValue(value_type(u)))
end
for u in created
LLVM.API.LLVMInstructionEraseFromParent(u)
end
if LLVM.API.LLVMIsAReturnInst(val) != C_NULL
mi, rt = enzyme_custom_extract_mi(LLVM.parent(LLVM.parent(val))::LLVM.Function, #=error=#false)
if mi !== nothing && isghostty(rt)
return C_NULL
end
end

gutils = GradientUtils(API.EnzymeGradientUtilsRef(data))
newb = new_from_original(gutils, val)
while isa(newb, LLVM.PHIInst)
newb = LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(newb))
end
b = IRBuilder(B)
msg2 = sprint() do io
print(io, msg)
println(io)
Expand Down

0 comments on commit c73c2c0

Please sign in to comment.