Skip to content

Commit

Permalink
Improve Float16 perf
Browse files Browse the repository at this point in the history
  • Loading branch information
chriselrod committed Aug 2, 2021
1 parent 856efbd commit d83fd8c
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 62 deletions.
2 changes: 2 additions & 0 deletions src/VectorizationBase.jl
Expand Up @@ -102,7 +102,9 @@ const Boolean = Union{Bit,Bool}
# end

# const NativeTypesExceptBit = Union{Bool,HWReal,Int128,UInt128,UInt256,UInt512,UInt1024}
const NativeTypesExceptBitandFloat16 = Union{Bool,HWReal}
const NativeTypesExceptBit = Union{Bool,HWReal,Float16}
const NativeTypesExceptFloat16 = Union{Bool,HWReal,Bit}
const NativeTypes = Union{NativeTypesExceptBit, Bit}

const _Vec{W,T<:Number} = NTuple{W,Core.VecElement{T}}
Expand Down
4 changes: 2 additions & 2 deletions src/alignment.jl
Expand Up @@ -20,7 +20,7 @@ alignment(x::Integer, N = 64) = reinterpret(Int, x) % N

function valloc(N::Integer, ::Type{T} = Float64, a = max(register_size(), cache_linesize())) where {T}
# We want alignment to both vector and cacheline-sized boundaries
size_T = max(1, sizeof(T))
reinterpret(Ptr{T}, align(reinterpret(UInt,Libc.malloc(size_T*N + a - 1)), a))
size_T = max(1, sizeof(T))
reinterpret(Ptr{T}, align(reinterpret(UInt,Libc.malloc(size_T*N + a - 1)), a))
end

15 changes: 6 additions & 9 deletions src/llvm_intrin/intrin_funcs.jl
Expand Up @@ -252,11 +252,10 @@ for (op,f,fast) ∈ [
("fma",:vfma,false),("fma",:vfma_fast,true),
("fmuladd",:vmuladd,false),("fmuladd",:vmuladd_fast,true)
]
@eval @generated function $f(v1::Vec{W,T}, v2::Vec{W,T}, v3::Vec{W,T}) where {W, T <: FloatingTypes}
TS = T === Float32 ? :Float32 : :Float64
# TS = JULIA_TYPES[T]
build_llvmcall_expr($op, W, TS, [W, W, W], [TS, TS, TS], $(fast_flags(fast)))
end
@eval @generated function $f(v1::Vec{W,T}, v2::Vec{W,T}, v3::Vec{W,T}) where {W, T <: FloatingTypes}
TS = JULIA_TYPES[T]
build_llvmcall_expr($op, W, TS, [W, W, W], [TS, TS, TS], $(fast_flags(fast)))
end
end
# @inline Base.fma(a::Vec, b::Vec, c::Vec) = vfma(a,b,c)
# @inline Base.muladd(a::Vec{W,T}, b::Vec{W,T}, c::Vec{W,T}) where {W,T<:FloatingTypes} = vmuladd(a,b,c)
Expand Down Expand Up @@ -309,8 +308,7 @@ for (opname,f) ∈ [
op = "vector.reduce." * opname
end
@eval @generated function $f(v1::T, v2::Vec{W,T}) where {W, T <: Union{Float32,Float64}}
# TS = JULIA_TYPES[T]
TS = T === Float32 ? :Float32 : :Float64
TS = JULIA_TYPES[T]
build_llvmcall_expr($op, -1, TS, [1, W], [TS, TS], "nsz arcp contract afn reassoc")
end
end
Expand All @@ -322,8 +320,7 @@ for (op,f) ∈ [
]
Base.libllvm_version < v"12" && (op = "experimental." * op)
@eval @generated function $f(v1::Vec{W,T}) where {W, T <: Union{Float32,Float64}}
# TS = JULIA_TYPES[T]
TS = T === Float32 ? :Float32 : :Float64
TS = JULIA_TYPES[T]
build_llvmcall_expr($op, -1, TS, [W], [TS], "nsz arcp contract afn reassoc")
end
end
Expand Down
35 changes: 31 additions & 4 deletions src/llvm_intrin/memory_addr.jl
Expand Up @@ -624,7 +624,7 @@ end
# no index, mask
@generated function __vload(
ptr::Ptr{T}, ::A, m::AbstractMask, ::StaticInt{RS}
) where {T <: NativeTypes, A <: StaticBool, RS}
) where {T <: NativeTypesExceptFloat16, A <: StaticBool, RS}
vload_quote(T, Int, :StaticInt, 1, 1, 0, 0, true, A === True, RS)
end
# index, no mask
Expand All @@ -637,11 +637,23 @@ end
# index, mask
@generated function __vload(
ptr::Ptr{T}, i::I, m::AbstractMask, ::A, ::StaticInt{RS}
) where {A <: StaticBool, T <: NativeTypes, I <: Index, RS}
) where {A <: StaticBool, T <: NativeTypesExceptFloat16, I <: Index, RS}
IT, ind_type, W, X, M, O = index_summary(I)
vload_quote(T, IT, ind_type, W, X, M, O, true, A === True, RS)
end

# Float16 with mask
@inline function __vload(
ptr::Ptr{Float16}, ::A, m::AbstractMask, ::StaticInt{RS}
) where {A <: StaticBool, RS}
reinterpret(Float16, __vload(reinterpret(Ptr{Int16}, ptr), A(), m, StaticInt{RS}()))
end
@inline function __vload(
ptr::Ptr{Float16}, i::I, m::AbstractMask, ::A, ::StaticInt{RS}
) where {A <: StaticBool, I <: Index, RS}
reinterpret(Float16, __vload(reinterpret(Ptr{Int16}, ptr), i, m, A(), StaticInt{RS}()))
end


@inline function _vload_scalar(ptr::Ptr{Bit}, i::Integer, ::A, ::StaticInt{RS}) where {RS,A<:StaticBool}
d = i >> 3; r = i & 7;
Expand Down Expand Up @@ -837,7 +849,7 @@ end
# no index, mask, vector store
@generated function __vstore!(
ptr::Ptr{T}, v::V, m::AbstractMask{W}, ::A, ::S, ::NT, ::StaticInt{RS}
) where {T <: NativeTypesExceptBit, W, VT <: NativeTypes, V <: AbstractSIMDVector{W,VT}, A <: StaticBool, S <: StaticBool, NT <: StaticBool, RS}
) where {T <: NativeTypesExceptBitandFloat16, W, VT <: NativeTypes, V <: AbstractSIMDVector{W,VT}, A <: StaticBool, S <: StaticBool, NT <: StaticBool, RS}
if W == 1
return Expr(:block, Expr(:meta,:inline), :(Bool(m) && __vstore!(ptr, convert($T, v), data(i), $(A()), $(S()), $(NT()), StaticInt{$RS}())))
elseif V !== Vec{W,T}
Expand All @@ -848,7 +860,7 @@ end
# index, mask, vector store
@generated function __vstore!(
ptr::Ptr{T}, v::V, i::I, m::AbstractMask{W}, ::A, ::S, ::NT, ::StaticInt{RS}
) where {T <: NativeTypesExceptBit, W, VT <: NativeTypes, V <: AbstractSIMDVector{W,VT}, I <: Index, A <: StaticBool, S <: StaticBool, NT <: StaticBool, RS}
) where {T <: NativeTypesExceptBitandFloat16, W, VT <: NativeTypes, V <: AbstractSIMDVector{W,VT}, I <: Index, A <: StaticBool, S <: StaticBool, NT <: StaticBool, RS}
if W == 1
return Expr(:block, Expr(:meta,:inline), :(Bool(m) && __vstore!(ptr, convert($T, v), data(i), $(A()), $(S()), $(NT()), StaticInt{$RS}())))
elseif V !== Vec{W,T}
Expand All @@ -863,6 +875,21 @@ end
end


# no index, mask, vector store
@generated function __vstore!(
ptr::Ptr{Float16}, v::V, m::AbstractMask{W}, ::A, ::S, ::NT, ::StaticInt{RS}
) where {W, V <: AbstractSIMDVector{W,Float16}, A <: StaticBool, S <: StaticBool, NT <: StaticBool, RS}
__vstore!(reinterpret(Ptr{Int16}, ptr), reinterpret(Int16, v), m, A(), S(), NT(), StaticInt{RS}())
end
# index, mask, vector store
@inline function __vstore!(
ptr::Ptr{Float16}, v::V, i::I, m::AbstractMask{W}, ::A, ::S, ::NT, ::StaticInt{RS}
) where {W, V <: AbstractSIMDVector{W,Float16}, I <: Index, A <: StaticBool, S <: StaticBool, NT <: StaticBool, RS}
__vstore!(reinterpret(Ptr{Int16}, ptr), reinterpret(Int16, v), i, m, A(), S(), NT(), StaticInt{RS}())
end




# BitArray stores
@inline function __vstore!(
Expand Down
100 changes: 53 additions & 47 deletions src/llvm_intrin/vbroadcast.jl
@@ -1,27 +1,30 @@

@inline vzero(::Val{1}, ::Type{T}) where {T<:NativeTypes} = zero(T)
@inline vzero(::StaticInt{1}, ::Type{T}) where {T<:NativeTypes} = zero(T)
@generated function _vzero(::StaticInt{W}, ::Type{T}, ::StaticInt{RS}) where {W,T<:NativeTypes,RS}
isone(W) && return Expr(:block, Expr(:meta,:inline), Expr(:call, :zero, T))
if W * sizeof(T) > RS
d, r1 = divrem(sizeof(T) * W, RS)
Wnew, r2 = divrem(W, d)
(iszero(r1) & iszero(r2)) || throw(ArgumentError("If broadcasting to greater than 1 vector length, should make it an integer multiple of the number of vectors."))
t = Expr(:tuple)
for i 1:d
push!(t.args, :v)
end
# return Expr(:block, Expr(:meta,:inline), :(v = vzero(StaticInt{$Wnew}(), $T)), :(VecUnroll{$(d-1),$Wnew,$T,Vec{$Wnew,$T}}($t)))
return Expr(:block, Expr(:meta,:inline), :(v = _vzero(StaticInt{$Wnew}(), $T, StaticInt{$RS}())), :(VecUnroll($t)))
# return Expr(:block, Expr(:meta,:inline), :(v = _vzero(StaticInt{$Wnew}(), $T, StaticInt{$RS}())), :(VecUnroll($t)::VecUnroll{$(d-1),$Wnew,$T,Vec{$Wnew,$T}}))
end
typ = LLVM_TYPES[T]
instrs = "ret <$W x $typ> zeroinitializer"
quote
$(Expr(:meta,:inline))
Vec($LLVMCALL($instrs, _Vec{$W,$T}, Tuple{}))
@inline _vzero(::StaticInt{W}, ::Type{Float16}, ::StaticInt{RS}) where {W, RS} = _vzero_float16(StaticInt{W}(), StaticInt{RS}(), fast_half())
@inline _vzero_float16(::StaticInt{W}, ::StaticInt{RS}, ::False) where {W,RS} = _vzero(StaticInt{W}(), Float32, StaticInt{RS}())
function _vzero_expr(W::Int, typ::String, T::Symbol, st::Int, RS::Int)
isone(W) && return Expr(:block, Expr(:meta,:inline), Expr(:call, :zero, T))
if W * st > RS
d, r1 = divrem(st * W, RS)
Wnew, r2 = divrem(W, d)
(iszero(r1) & iszero(r2)) || throw(ArgumentError("If broadcasting to greater than 1 vector length, should make it an integer multiple of the number of vectors."))
t = Expr(:tuple)
for i 1:d
push!(t.args, :v)
end
# return Expr(:block, Expr(:meta,:inline), :(v = vzero(StaticInt{$Wnew}(), $T)), :(VecUnroll{$(d-1),$Wnew,$T,Vec{$Wnew,$T}}($t)))
return Expr(:block, Expr(:meta,:inline), :(v = _vzero(StaticInt{$Wnew}(), $T, StaticInt{$RS}())), :(VecUnroll($t)))
# return Expr(:block, Expr(:meta,:inline), :(v = _vzero(StaticInt{$Wnew}(), $T, StaticInt{$RS}())), :(VecUnroll($t)::VecUnroll{$(d-1),$Wnew,$T,Vec{$Wnew,$T}}))
end
instrs = "ret <$W x $typ> zeroinitializer"
quote
$(Expr(:meta,:inline))
Vec($LLVMCALL($instrs, _Vec{$W,$T}, Tuple{}))
end
end
@generated _vzero_float16(::StaticInt{W}, ::StaticInt{RS}) where {W,RS} = _vzero_expr(W, "half", :Float16, 2, RS)
@generated _vzero(::StaticInt{W}, ::Type{T}, ::StaticInt{RS}) where {W,T<:NativeTypes,RS} = _vzero_expr(W, LLVM_TYPES[T], JULIA_TYPES[T], sizeof(T), RS)
@generated function _vundef(::StaticInt{W}, ::Type{T}) where {W,T<:NativeTypes}
typ = LLVM_TYPES[T]
if W == 1
Expand All @@ -41,40 +44,43 @@ end
@inline _vundef(::T) where {T<:NativeTypes} = _vundef(StaticInt{1}(), T)
@inline _vundef(::Vec{W,T}) where {W,T} = _vundef(StaticInt{W}(), T)
@generated _vundef(::VecUnroll{N,W,T}) where {N,W,T} = Expr(:block,Expr(:meta,:inline), :(VecUnroll(Base.Cartesian.@ntuple $(N+1) n -> _vundef(StaticInt{$W}(), $T))))
@generated function _vbroadcast(::StaticInt{W}, s::_T, ::StaticInt{RS}) where {W,_T<:NativeTypes,RS}
isone(W) && return :s
if _T <: Integer && sizeof(_T) * W > RS
intbytes = max(4, RS ÷ W)
T = integer_of_bytes(intbytes)
if _T <: Unsigned
T = unsigned(T)
end
# ssym = :(s % $T)
ssym = :(convert($T, s))
elseif sizeof(_T) * W > RS
d, r1 = divrem(sizeof(_T) * W, RS)
Wnew, r2 = divrem(W, d)
(iszero(r1) & iszero(r2)) || throw(ArgumentError("If broadcasting to greater than 1 vector length, should make it an integer multiple of the number of vectors."))
t = Expr(:tuple)
for i 1:d
push!(t.args, :v)
end
return Expr(:block, Expr(:meta,:inline), :(v = vbroadcast(StaticInt{$Wnew}(), s)), :(VecUnroll($t)))
else
T = _T
ssym = :s
function vbroadcast_expr(W::Int, typ::String, T::Symbol, st::Int, RS::Int)
isone(W) && return :s
if st * W > RS
d, r1 = divrem(st * W, RS)
Wnew, r2 = divrem(W, d)
(iszero(r1) & iszero(r2)) || throw(ArgumentError("If broadcasting to greater than 1 vector length, should make it an integer multiple of the number of vectors."))
t = Expr(:tuple)
for i 1:d
push!(t.args, :v)
end
typ = LLVM_TYPES[T]
vtyp = vtype(W, typ)
instrs = """
return Expr(:block, Expr(:meta,:inline), :(v = vbroadcast(StaticInt{$Wnew}(), s)), :(VecUnroll($t)))
end
vtyp = vtype(W, typ)
instrs = """
%ie = insertelement $vtyp undef, $typ %0, i32 0
%v = shufflevector $vtyp %ie, $vtyp undef, <$W x i32> zeroinitializer
ret $vtyp %v
"""
quote
$(Expr(:meta,:inline))
Vec($LLVMCALL($instrs, _Vec{$W,$T}, Tuple{$T}, $ssym))
quote
$(Expr(:meta,:inline))
Vec($LLVMCALL($instrs, _Vec{$W,$T}, Tuple{$T}, s))
end
end
@inline _vbroadcast(::StaticInt{W}, s::Float16, ::StaticInt{RS}) where {W,RS} = _vbroadcast_float16(StaticInt{W}(), s, StaticInt{RS}(), fast_half())
@inline _vbroadcast_float16(::StaticInt{W}, s::Float16, ::StaticInt{RS}, ::False) where {W,RS} = _vbroadcast(StaticInt{W}(), convert(Float32, s), StaticInt{RS}())
@generated _vbroadcast_float16(::StaticInt{W}, s::Float16, ::StaticInt{RS}, ::True) where {W,RS} = vbroadcast_expr(W, "half", :Float16, 2, RS)
@generated function _vbroadcast(::StaticInt{W}, s::_T, ::StaticInt{RS}) where {W,_T<:NativeTypes,RS}
if _T <: Integer && sizeof(_T) * W > RS
intbytes = max(4, RS ÷ W)
T = integer_of_bytes(intbytes)
if _T <: Unsigned
T = unsigned(T)
end
# ssym = :(s % $T)
return Expr(:block, Expr(:meta,:inline), :(_vbroadcast(StaticInt{$W}(), :(convert($T, s)), StaticInt{$RS}())))
end
vbroadcast_expr(W, LLVM_TYPES[_T], JULIA_TYPES[_T], sizeof(_T), RS)
end
@inline _vbroadcast(::StaticInt{W}, m::EVLMask{W}, ::StaticInt{RS}) where {W,RS} = Mask(m)
@inline vzero(::Union{Val{W},StaticInt{W}}, ::Type{T}) where {W,T} = _vzero(StaticInt{W}(), T, register_size(T))
Expand Down

0 comments on commit d83fd8c

Please sign in to comment.