Skip to content

Commit

Permalink
Fix #15.
Browse files Browse the repository at this point in the history
  • Loading branch information
chriselrod committed Jan 11, 2020
1 parent da4139d commit 8ed082c
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 44 deletions.
81 changes: 59 additions & 22 deletions src/graphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,22 @@ isstaticloop(ls::LoopSet, s::Symbol) = ls.loops[s].hintexact
looprangehint(ls::LoopSet, s::Symbol) = ls.loops[s].rangehint
looprangesym(ls::LoopSet, s::Symbol) = ls.loops[s].rangesym
# itersyms(ls::LoopSet) = keys(ls.loops)
getop(ls::LoopSet, s::Symbol) = ls.opdict[s]
function getop(ls::LoopSet, var::Symbol, elementbytes::Int = 8)
get!(ls.opdict, var) do
# might add constant
op = add_constant!(ls, var, elementbytes)
pushpreamble!(ls, Expr(:(=), mangledvar(op), var))
op
end
end
function getop(ls::LoopSet, var::Symbol, deps, elementbytes::Int = 8)
get!(ls.opdict, var) do
# might add constant
op = add_constant!(ls, var, deps, gensym(:constant), elementbytes)
pushpreamble!(ls, Expr(:(=), mangledvar(op), var))
op
end
end
getop(ls::LoopSet, i::Int) = ls.operations[i + 1]

@inline extract_val(::Val{N}) where {N} = N
Expand Down Expand Up @@ -284,7 +299,7 @@ function add_loop!(ls::LoopSet, q::Expr, elementbytes::Int = 8)
if body.head === :block
add_block!(ls, body, elementbytes)
else
Base.push!(ls, q, elementbytes)
push!(ls, q, elementbytes)
end
end
function add_loop!(ls::LoopSet, loop::Loop)
Expand Down Expand Up @@ -316,12 +331,18 @@ function add_load!(
ls::LoopSet, var::Symbol, ref::ArrayReference, elementbytes::Int = 8
)
if ref.loaded[] == true
op = getop(ls, var)
op = getop(ls, var, elementbytes)
@assert var === op.variable
return op
end
push!(ls.syms_aliasing_refs, var)
push!(ls.refs_aliasing_syms, ref)
id = findfirst(r -> r == ref, ls.refs_aliasing_syms)
if id === nothing
push!(ls.syms_aliasing_refs, var)
push!(ls.refs_aliasing_syms, ref)
else
opp = getop(ls, ls.syms_aliasing_refs[id], elementbytes)
return isstore(opp) ? getop(ls, first(parents(opp))) : opp
end
ref.loaded[] = true
# ls.sym_to_ref_aliases[ var ] = ref
# ls.ref_to_sym_aliases[ ref ] = var
Expand Down Expand Up @@ -427,7 +448,7 @@ function maybe_cse_load!(ls::LoopSet, expr::Expr, elementbytes::Int = 8)
if id === nothing
add_load!( ls, gensym(:temporary), ref, elementbytes )
else
getop(ls, ls.syms_aliasing_refs[id])
getop(ls, ls.syms_aliasing_refs[id], elementbytes)
end
# id = includesarray(ls, array)
# if id > 0
Expand All @@ -440,12 +461,7 @@ function add_parent!(
parents::Vector{Operation}, deps::Vector{Symbol}, reduceddeps::Vector{Symbol}, ls::LoopSet, var, elementbytes::Int = 8
)
parent = if var isa Symbol
get!(ls.opdict, var) do
# might add constant
op = add_constant!(ls, var, elementbytes)
pushpreamble!(ls, Expr(:(=), mangledvar(op), var))
op
end
getop(ls, var, elementbytes)
elseif var isa Expr #CSE candidate
maybe_cse_load!(ls, var, elementbytes)
else # assumed constant
Expand All @@ -465,7 +481,7 @@ function add_reduction_update_parent!(
parents::Vector{Operation}, deps::Vector{Symbol}, reduceddeps::Vector{Symbol}, ls::LoopSet,
var::Symbol, instr::Symbol, elementbytes::Int = 8
)
parent = getop(ls, var)
parent = getop(ls, var, elementbytes)
setdiffv!(reduceddeps, deps, loopdependencies(parent))
pushparent!(parents, deps, reduceddeps, parent) # deps and reduced deps will not be disjoint
op = Operation(length(operations(ls)), var, elementbytes, instr, compute, deps, reduceddeps, parents)
Expand Down Expand Up @@ -502,23 +518,33 @@ end
function add_store!(
ls::LoopSet, var::Symbol, ref::ArrayReference, elementbytes::Int = 8
)
parent = getop(ls, var)
op = Operation( length(operations(ls)), ref.array, elementbytes, :setindex!, memstore, loopdependencies(ref), reduceddependencies(parent), [parent], ref )
# @show loopdependencies(ref)
# @show ls.operations
ldref = loopdependencies(ref)
parent = getop(ls, var, ldref, elementbytes)
pvar = parent.variable
if pvar ls.syms_aliasing_refs
push!(ls.syms_aliasing_refs, pvar)
push!(ls.refs_aliasing_syms, ref)
end
op = Operation( length(operations(ls)), ref.array, elementbytes, :setindex!, memstore, ldref, reduceddependencies(parent), [parent], ref )
# @show loopdependencies(op) op
add_vptr!(ls, ref.array, identifier(op), ref.ptr)
pushop!(ls, op, ref.array)
end
function add_store_ref!(ls::LoopSet, var::Symbol, ex::Expr, elementbytes::Int = 8)
ref = ref_from_ref(ex)
ref = ref_from_ref(ex)::ArrayReference
add_store!(ls, var, ref, elementbytes)
end
function add_store_setindex!(ls::LoopSet, ex::Expr, elementbytes::Int = 8)
ref = ref_from_setindex(ex)
add_store!(ls, var, ref, elementbytes)
ref = ref_from_setindex(ex)::ArrayReference
add_store!(ls, (ex.args[2])::Symbol, ref, elementbytes)
end
# add operation assigns X to var
function add_operation!(
ls::LoopSet, LHS::Symbol, RHS::Expr, elementbytes::Int = 8
)
# @show LHS, RHS
if RHS.head === :ref
add_load_ref!(ls, LHS, RHS, elementbytes)
elseif RHS.head === :call
Expand All @@ -539,11 +565,17 @@ end
function add_operation!(
ls::LoopSet, LHS_sym::Symbol, RHS::Expr, LHS_ref::ArrayReference, elementbytes::Int = 8
)
# @show LHS_sym, RHS
if RHS.head === :ref# || (RHS.head === :call && first(RHS.args) === :getindex)
add_load!(ls, LHS_sym, LHS_ref, elementbytes)
elseif RHS.head === :call
if first(RHS.args) === :getindex
f = first(RHS.args)
if f === :getindex
add_load!(ls, LHS_sym, LHS_ref, elementbytes)
elseif f === :zero || f === :one
c = gensym(:constant)
pushpreamble!(ls, Expr(:(=), c, RHS))
add_constant!(ls, c, [keys(ls.loops)...], LHS_sym, elementbytes)
else
add_compute!(ls, LHS_sym, RHS, elementbytes, LHS_ref)
end
Expand All @@ -552,6 +584,7 @@ function add_operation!(
end
end
function Base.push!(ls::LoopSet, ex::Expr, elementbytes::Int = 8)
# @show ex
if ex.head === :call
finex = first(ex.args)::Symbol
if finex === :setindex!
Expand All @@ -566,21 +599,25 @@ function Base.push!(ls::LoopSet, ex::Expr, elementbytes::Int = 8)
if RHS isa Expr
add_operation!(ls, LHS, RHS, elementbytes)
else
# @show [keys(ls.loops)...]
add_constant!(ls, RHS, [keys(ls.loops)...], LHS, elementbytes)
end
elseif LHS isa Expr
@assert LHS.head === :ref
local lrhs::Symbol
# @show LHS, RHS
if RHS isa Symbol
lrhs = RHS
elseif RHS isa Expr
# need to check of LHS appears in RHS
# assign RHS to lrhs
ref = ArrayReference(LHS)
id = findfirst(r -> r == ref, ls.refs_aliasing_syms)
lrhs = id === nothing ? gensym(:RHS) : ls.syms_aliasing_refs[id]
# we pass ref, so it can compare references within RHS, and realize
# they equal lrhs
lrhs = if id === nothing
gensym(:RHS)
else
ls.syms_aliasing_refs[id]
end
add_operation!(ls, lrhs, RHS, ref, elementbytes)
end
add_store_ref!(ls, lrhs, LHS, elementbytes)
Expand Down
2 changes: 1 addition & 1 deletion src/lowering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ function reduce_unroll!(q, op, U, unrolled)
return U, isunrolled
end
unrolled reduceddependencies(op) || return U
var = pvariable_name(op, suffix)
var = mangledvar(op)
instr = first(parents(op)).instruction
reduce_expr!(q, var, instr, U) # assigns reduction to storevar
1, isunrolled
Expand Down
22 changes: 8 additions & 14 deletions src/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@ function Base.hash(x::ArrayReference, h::UInt)
end
hash(x.array, h)
end
loopdependencies(ref::ArrayReference) = filter(i -> i isa Symbol, ref.ref)
function loopdependencies(ref::ArrayReference)
ld = Symbol[]
for r ref.ref
r isa Symbol && push!(ld, r)
end
ld
end
function Base.isequal(x::ArrayReference, y::ArrayReference)
x.array === y.array || return false
nrefs = length(x.ref)
Expand Down Expand Up @@ -74,30 +80,18 @@ end

# TODO: can some computations be cached in the operations?
"""
if ooperation_type == memstore || operation_type == memstore# || operation_type == compute_new || operation_type == compute_update
symbolic metadata contains info on direct dependencies / placement within loop.
if isload(op) -> Symbol(:vptr_, first(op.reduced_deps))
if istore(op) -> Symbol(:vptr_, op.variable)
is how we access the memory.
is the stride for loop index
symbolic_metadata[i]
"""
struct Operation
identifier::Int
variable::Symbol
elementbytes::Int
instruction::Instruction
node_type::OperationType
dependencies::Vector{Symbol}#::Vector{Symbol}
dependencies::Vector{Symbol}
reduced_deps::Vector{Symbol}
parents::Vector{Operation}
ref::ArrayReference
mangledvariable::Symbol
# children::Vector{Operation}
# numerical_metadata::Vector{Int} # stride of -1 indicates dynamic
# symbolic_metadata::Vector{Symbol}
function Operation(
identifier::Int,
variable,
Expand Down
51 changes: 44 additions & 7 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,33 @@ using LinearAlgebra
@test logsumexp!(r, x) 102.35216846104409

@testset "GEMM" begin
AmulBq = :(for m 1:size(A,1), n 1:size(B,2)
using LoopVectorization, Test
U, T = LoopVectorization.VectorizationBase.REGISTER_COUNT == 16 ? (3, 4) : (4, 4)
AmulBq1 = :(for m 1:size(A,1), n 1:size(B,2)
C[m,n] = zeroB
for k 1:size(A,2)
C[m,n] += A[m,k] * B[k,n]
end
end)
lsAmulB1 = LoopVectorization.LoopSet(AmulBq1);
@test LoopVectorization.choose_order(lsAmulB1) == (Symbol[:n,:m,:k], :m, U, T)
AmulBq2 = :(for m 1:M, n 1:N
C[m,n] = zero(eltype(B))
for k 1:K
C[m,n] += A[m,k] * B[k,n]
end
end)
lsAmulB2 = LoopVectorization.LoopSet(AmulBq2);
@test LoopVectorization.choose_order(lsAmulB2) == (Symbol[:n,:m,:k], :m, U, T)
AmulBq3 = :(for m 1:size(A,1), n 1:size(B,2)
ΔCₘₙ = zero(eltype(C))
for k 1:size(A,2)
ΔCₘₙ += A[m,k] * B[k,n]
end
C[m,n] += ΔCₘₙ
end)

lsAmulB = LoopVectorization.LoopSet(AmulBq);
U, T = LoopVectorization.VectorizationBase.REGISTER_COUNT == 16 ? (3, 4) : (4, 4)
@test LoopVectorization.choose_order(lsAmulB) == (Symbol[:n,:m,:k], :m, U, T)
lsAmulB3 = LoopVectorization.LoopSet(AmulBq3);
@test LoopVectorization.choose_order(lsAmulB3) == (Symbol[:n,:m,:k], :m, U, T)

function AmulB!(C, A, B)
C .= 0
Expand All @@ -53,7 +69,7 @@ using LinearAlgebra
end
end
end
function AmulBavx!(C, A, B)
function AmulBavx1!(C, A, B)
@avx for m 1:size(A,1), n 1:size(B,2)
Cₘₙ = zero(eltype(C))
for k 1:size(A,2)
Expand All @@ -62,6 +78,23 @@ using LinearAlgebra
C[m,n] = Cₘₙ
end
end
function AmulBavx2!(C, A, B)
z = zero(eltype(C))
@avx for m 1:size(A,1), n 1:size(B,2)
C[m,n] = z
for k 1:size(A,2)
C[m,n] += A[m,k] * B[k,n]
end
end
end
function AmulBavx3!(C, A, B)
@avx for m 1:size(A,1), n 1:size(B,2)
C[m,n] = zero(eltype(C))
for k 1:size(A,2)
C[m,n] += A[m,k] * B[k,n]
end
end
end
function AmuladdBavx!(C, A, B, factor = 1)
@avx for m 1:size(A,1), n 1:size(B,2)
ΔCₘₙ = zero(eltype(C))
Expand Down Expand Up @@ -171,8 +204,12 @@ using LinearAlgebra
C = Matrix{TC}(undef, M, N);
A = rand(R, M, K); B = rand(R, K, N);
C2 = similar(C);
AmulBavx!(C, A, B)
AmulB!(C2, A, B)
AmulBavx1!(C, A, B)
@test C C2
fill!(C, 999.99); AmulBavx2!(C, A, B)
@test C C2
fill!(C, 999.99); AmulBavx3!(C, A, B)
@test C C2
fill!(C, 0.0); AmuladdBavx!(C, A, B)
@test C C2
Expand Down

0 comments on commit 8ed082c

Please sign in to comment.