-
-
Notifications
You must be signed in to change notification settings - Fork 3
Description
I just found this highly interesting package, and would like to learn from your handling of conditionals to enhance what is essentially a vectorization package of mine.
Background: The package I'm talking about is MonteCarloMeasurements.jl, which aims at propagating sampled distributions through functions. The strategy employed by MCM is to implement a type, Particles, that behaves like a number, but internally contains a vector of samples (particles). Each scalar operation on this type is then translated easily through dispatch to a map over the inner vector. This works wonderfully well, and is much faster than manually calling a function the corresponding number of times, due SPMD properties (some benchmarks here).
The weakness of the type-based approach is that control flow based on the type Particles can currently not be handled. Particles{Bool} appearing in a boolean context should ideally cause each particle to flow through the branching code independently. For this, I imagine a compiler pass to be one viable approach, something that Hydra already seems to have in place. Now, I'm very new to compiler passes, IR and all this, but am willing to spend some time to learn it if it seems like a reasonable approach.
An example of the transformation I would like to do is from
code = Meta.@lower if x > 0
return x^2
else
return -x^2
endto
code2 = Meta.@lower map(x.particles) do x
if x > 0
return x^2
else
return -x^2
end
endwhere the translation should only be done if x is of a special type defined below
struct Particles{T} <: Real
particles::Vector{T}
endWithout this transformation, code like the following fails predictably
Base.:(^)(p::Particles,r) = Particles(p.particles.^r)
Base.:(>)(p::Particles, r) = Particles(map(>, p.particles, r))
p = Particles(randn(10))
function negsquare(x)
if x > 0
return x^2
else
return -x^2
end
end
julia> negsquare(p)
ERROR: TypeError: non-boolean (Particles) used in boolean contextIf the code was translated to
function negsquare(x)
Particles(map(x.particles) do x
if x > 0
return x^2
else
return -x^2
end
end)
end
julia> negsquare(p)
Particles([0.404953, 0.210984, -1.00176, 0.253796, -0.00620389, 0.831144, -0.0240916, -1.90169, 0.875192, 1.4788])I would get the desired result.
Do you think I can learn from the following code in Hydra to make this happen?
function fix_stmt(ssavalue, stmt, ir, old_to_new_ssavalue)
replacement(x::SSAValue) = old_to_new_ssavalue[x]
replacement(x) = x
if stmt.expr isa GotoIfNot
new_stmt = GotoIfNot(old_to_new_ssavalue[stmt.expr.cond], stmt.expr.dest)
push!(ir, new_stmt)
elseif stmt.expr isa GotoNode
push!(ir, stmt)Thankful for any input or insight on this!