Skip to content

Commit

Permalink
rewrite WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
fredo-dedup committed Jan 26, 2016
1 parent d46b52a commit f0ca0c3
Show file tree
Hide file tree
Showing 8 changed files with 212 additions and 142 deletions.
4 changes: 2 additions & 2 deletions src2/base_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ ReverseDiffSource.@deriv_rule reshape(x::AbstractArray, d::Tuple) d 0.
# setindex
@deriv_rule_mut setindex!(x, y, i) x ds[i] = 0.
@deriv_rule setindex!(x, y, i) y ds[i]
@deriv_rule setindex!(x, y::Real, i) y sum(ds[i])
@deriv_rule setindex!(x, y::Real, i::Range) y sum(ds[i])
@deriv_rule setindex!(x, y, i) i 0.

@deriv_rule_mut setindex!(x, y, i1, i2) x ds[i1,i2] = 0.
Expand Down Expand Up @@ -143,7 +143,7 @@ ReverseDiffSource.@deriv_rule .-(x , y::AbstractArray) y -d

# sum()
ReverseDiffSource.@deriv_rule sum(x::Real ) x ds
ReverseDiffSource.@deriv_rule sum(x::AbstractArray) x ones(size(x)).*ds
ReverseDiffSource.@deriv_rule sum(x::AbstractArray) x ones(x).*ds

# dot()
ReverseDiffSource.@deriv_rule dot(x::Real , y::Real ) x y * ds
Expand Down
43 changes: 39 additions & 4 deletions src2/debug2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,21 +68,21 @@ end
# @test fullcycle(:( x = b+6 ), keep_var=[:x]) == :(x=b+6)
# @test fullcycle(:( x = y = b+6 ), keep_var=[:x]) == :(x=b+6)
# @test fullcycle(:( x = y = b+6 ), keep_var=[:y]) == :(y=b+6)
# @test fullcycle(:( x = 3.; y = b+6 ), keep_var=[:x]) == :(y=b+6)


@test fullcycle(:(sin(b); x=3)) == :(3)

@test fullcycle(:(x = b+6; x+=1)) == :((b+6)+1)
@test fullcycle(:(x = 1; x -= b+6)) == :(1 - (b+6))
@test fullcycle(:(x = a; x *= b+6)) == :(a * (b+6))
g = tograph(:(x = a; x *= b+6))
tocode(g)

@test fullcycle(:(x = b')) == :( b')
@test fullcycle(:(x = [1,2])) == :( [1,2])
@test fullcycle(:(x = 4:5 )) == :( 4:5)

@test fullcycle(:(x = b+4+5)) == :((b+4)+5)
@test fullcycle(:(x = b+0)) == :(b)

# @test fullcycle(:(x = b*0)) == :(0)
@test fullcycle(:(x = b*1)) == :(b)
# @test fullcycle(:(x = b*(0.5+0.5))) == :(b)
Expand Down Expand Up @@ -122,6 +122,8 @@ tocode(g)

# @test fullcycle(:( X = copy(B) ; X[1:2] = X[1:2] )) == :( X = copy(B) ; X[1:2] = X[1:2] )
@test fullcycle(:( X = copy(D) ; X[1:2,3] = a )) == :( X = copy(D) ; X[1:2,3] = a )
show(simplify!(tograph(:( X = copy(D) ; X[1:2,3] = a ))))
g = simplify!(tograph(:( X = copy(D) ; X[1:2,3] = a )))
@test fullcycle(:( X = copy(D) ; X[1:2,2] = D[1:2,3] )) == :( X = copy(D) ; X[1:2,2] = D[1:2,3] )

# @test fullcycle(:( B[:] )) == Expr(:block, :( x[1:length(x)] ) )
Expand All @@ -140,7 +142,7 @@ tocode(g)

# @test fullcycle(:( a = b.f[i])) == Expr(:block, :(a = b.f[i]) )
# @test fullcycle(:( a = b[j].f[i])) == Expr(:block, :(a = b[j].f[i]) )

z = Z(0,0).x = a

### test evalconstants, simplify
ex = quote
Expand Down Expand Up @@ -186,6 +188,39 @@ end


################# for loops ################
ex = quote
x = 0.
for i in 1:3
x = x + i
end
x
end
fullcycle(ex)
g = tograph(ex)
show(g)
tocode(g)
simplify!(g)

o = g.block.ops[1]
isfusable(o.asc[1], o.desc[1], Walk(o,g))


keeps2 = intersect([EXIT_SYM;], keys(g.block.symbols))
keep = Set{Loc}([ g.block.symbols[s] for s in keeps2])

prune!(g, keep)
splitnary!(g)

fusecopies!(g)
removerightneutral!(g)
removeleftneutral!(g)
prune!(g, keep)






ex = quote
x = 0.
for i in 1:10
Expand Down
53 changes: 30 additions & 23 deletions src2/debug3.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,39 +87,46 @@ end

g = tograph(ex)
simplify!(g)
# gdiff!(g, g.block.symbols[EXIT_SYM], g.block.symbols[:a])
gdiff!(g, g.block.symbols[EXIT_SYM], g.locs[16])
gdiff!(g, g.block.symbols[EXIT_SYM], g.block.symbols[:a])
simplify!(g)
dex = tocode(g)
show(g)
show(dex)
@eval let a = 1.0; $dex ; end
@eval let a = 1.00001; $dex ; end

ispivot(rest(g.block.ops,1))
ispivot(rest(g.block.ops,2))
ispivot(rest(g.block.ops,3))
ispivot(rest(g.block.ops,4))
ispivot(rest(g.block.ops,5))

ops = g.block.ops[4].ops
ispivot(ops[1],1)
function ispivot(o::Op, line)
# checks if any desc of `o` appears several times afterward
# or if they are modified
o, line = ops[1],1
for l in o.desc # l = o.desc[1]
ct = 0
for o2 in ops[line+1:end]
l in o2.desc && println("1")
ct += l in o2.asc
ct > 1 && println("2")
end
end
# checks if any asc of `o` is modified later
for l in o.asc # l = o.asc[1]
for o2 in ops[line+1:end]
l in o2.desc && l in o2.asc && println("3")
end
end
#
false
ispivot(rest(ops,1))
ispivot(rest(ops,2))
ispivot(rest(ops,3))
ispivot(rest(ops,4))


o, line = ops[4], 4
l = o.desc[1]
writ = false
used = false
for o2 in ops[line+1:end]
l in o2.desc && println("true 1") # result is mutated
writ = any(a -> a in o2.desc, o.asc) # ascendants of results modified
l in o2.asc || continue
used && println("true 2") # result is used at least twice
writ && println("true 3") # result is used after ascendants modification
used = true
end

a = [1,2,3,4]
start(a)
next(a,2)
done(a,4)
done(a,5)

ex = quote
x = 0.
for i in 1:10
Expand Down
30 changes: 19 additions & 11 deletions src2/forblock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,14 @@ getops(bl::ForBlock) = Any[bl.ops, bl.lops]
flatops(bl::ForBlock) = vcat(flatops(bl.ops), bl) # ignore var looping block

function summarize(bl::ForBlock)
# note : only `ops` vector considered, `lops` is only implicit assignements
asc = mapreduce(o -> o.asc, union, Set(), bl.ops)
desc = mapreduce(o -> o.desc, union, Set(), bl.ops)
# # note : only `ops` vector considered, `lops` is only implicit assignements
# asc = mapreduce(o -> o.asc, union, Set(), bl.ops)
# desc = mapreduce(o -> o.desc, union, Set(), bl.ops)
asc = mapreduce(o -> o.asc, union, Set{Loc}(), bl.ops)
asc = mapreduce(o -> o.asc, union, asc, bl.lops)
desc = mapreduce(o -> o.desc, union, Set{Loc}(), bl.ops)
desc = mapreduce(o -> o.desc, union, desc, bl.lops)

# keep var and range in correct positions
asc = vcat(bl.asc[1:2], setdiff(asc, bl.asc[1:2]))
collect(asc), collect(desc)
Expand Down Expand Up @@ -86,11 +91,11 @@ function blockparse!(ex::ExFor, parentops, parentsymbols, g::Graph)
oloc = parentsymbols[k]
dloc = symbols[k]

ns = Snippet(:(copy!(a,b)), [:a, :b])
appendsnippet!(ns, thisblock.lops, Loc[oloc, dloc], g)
# fcop = CLoc(copy!)
# push!(g.locs, fcop)
# push!(thisblock.lops, FOp(fcop, [oloc, dloc], [oloc;]))
# ns = Snippet(:(a=b), [:a, :b])
# appendsnippet!(ns, thisblock.lops, Loc[oloc, dloc], g)
fcop = CLoc(copy)
push!(g.locs, fcop)
push!(thisblock.lops, FOp(fcop, [dloc;], [oloc;]))

# update the parents' symbol map
parentsymbols[k] = dloc
Expand Down Expand Up @@ -129,7 +134,7 @@ function blockcode(bl::ForBlock, locex, g::Graph)
# for each variable rebinding ( != mutated variables) : force creation of
# variable before loop if there isn't one
for lop in bl.lops
li, lo = lop.asc
li, lo = lop.asc, lop.desc

# find symbol
ks = collect(keys(bl.symbols))
Expand All @@ -148,11 +153,14 @@ function blockcode(bl::ForBlock, locex, g::Graph)
end

# for updated and mutated variables : mark as exit for code generation
exits = copy(bl.desc)
append!(exits, Loc[ op.asc[2] for op in bl.lops])
# exits = copy(bl.desc)
# append!(exits, Loc[ op.asc[2] for op in bl.lops])
exits = Loc[ op.asc[1] for op in bl.lops]

# expression for inner code
println("in")
fex = _tocode(bl.ops, exits, bl.symbols, g, locex)
println("out")

push!(out, Expr(:for, Expr(:(=), ixs, rgs), fex))

Expand Down
28 changes: 28 additions & 0 deletions src2/ifblock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,34 @@ function summarize(bl::IfBlock)
collect(asc), collect(desc)
end

function prune!(bl::IfBlock, keep::Set{Loc})
del_list = Int64[]
iop = collect(enumerate(bl.trueops))
for (i, op) in reverse(iop) # i,op = iop[9]
if any(l -> l in op.desc, keep)
isa(op, AbstractBlock) && prune!(op, keep)
union!(keep, op.asc)
else
push!(del_list,i)
end
end
deleteat!(bl.trueops, reverse(del_list))

del_list = Int64[]
iop = collect(enumerate(bl.falseops))
for (i, op) in reverse(iop) # i,op = iop[9]
if any(l -> l in op.desc, keep)
isa(op, AbstractBlock) && prune!(op, keep)
union!(keep, op.asc)
else
push!(del_list,i)
end
end
deleteat!(bl.falseops, reverse(del_list))

bl.asc, bl.desc = summarize(bl)
end


function show(io::IO, bl::IfBlock)
nt = length(bl.trueops)
Expand Down
39 changes: 6 additions & 33 deletions src2/simplify.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ end
#######################################################

flatops(op::Op) = [op;]
flatops(bl::AbstractBlock) = vcat(map(flatops, getops(bl))..., bl)
flatops(bl::AbstractBlock) = vcat(map(flatops, getops(bl))...)
flatops(ops::Vector{Op}) = vcat(map(flatops, ops)...)
flatops(g::Graph) = flatops(g.block)

Expand Down Expand Up @@ -66,45 +66,18 @@ next(w::RevWalk, state) = (state[2][state[1]], (state[1]-1, state[2]))
done(w::RevWalk, state) = state[1] == 0

# removes unecessary elements (as specified by 'keep')
# function prune!(g, keep)
# # find all locs that are relevant to calculate 'keep'
# keep2 = intersect(keep, keys(g.block.symbols))
# lset = Set{Loc}([ g.block.symbols[s] for s in keep2])
#
# for o in RevWalk(g)
# if any(l -> l in o.desc, lset)
# union!(lset, o.asc)
# isa(o, FOp) && push!(lset, o.f)
# end
# end
#
# # filter all locs, symbols, ops unrelated to lset
# filter!(l -> l in lset, g.locs)
#
# for bl in allblocks(g)
# filter!((s,l) -> l in lset, bl.symbols)
# for ops in getops(bl)
# filter!(o -> any(l -> l in o.desc, lset), ops)
# end
# bl.asc, bl.desc = summarize(bl)
# end
#
# g
# end


prune!(g::Graph, keep::Set{Loc}) = prune!(g.block, keep)

function prune!(bl::AbstractBlock, keep::Set{Loc})
del_list = Int64[]
iop = collect(enumerate(bl.ops))
for (i, op) in reverse(iop) # i,op = iop[9]
if any(l -> l in op.desc, keep)
println("keep $i")
# println("keep $i")
isa(op, AbstractBlock) && prune!(op, keep)
union!(keep, op.asc)
else
println("remove $i")
# println("remove $i")
push!(del_list,i)
end
end
Expand Down Expand Up @@ -135,9 +108,9 @@ function splitnary!(g)
end

function isfusable(org::Loc, cpy::Loc, w::Walk)
# org, cpy, g = o.asc[1], o.desc[1], g
# org, cpy, w = o.asc[1], o.desc[1], Walk(o,g)
# if org is external, checks that copy is not mutated
if loctype(org) == :external
if loctype(org) in [:external, :constant]
any(l -> cpy in l.desc, w) && return false
end

Expand Down Expand Up @@ -206,7 +179,7 @@ function removerightneutral!(g)
for bl in allblocks(g)
for ops in getops(bl)
del_list = Int64[]
for (line, o) in enumerate(ops) # line=1 ; o = ops[1]
for (line, o) in enumerate(ops) # line=1 ; o = g.block.ops[1]
isa(o, FOp) || continue
length(o.asc) == 2 || continue
loctype(o.asc[2]) == :constant || continue
Expand Down

0 comments on commit f0ca0c3

Please sign in to comment.