diff --git a/lib/cudnn/CUDNN.jl b/lib/cudnn/CUDNN.jl index f7f34399a8..3a9bce21c7 100644 --- a/lib/cudnn/CUDNN.jl +++ b/lib/cudnn/CUDNN.jl @@ -12,6 +12,7 @@ using ..APIUtils using ..CUDA using ..CUDA: CUstream, libraryPropertyType using ..CUDA: libcudnn, @retry_reclaim, isdebug, @context! +using BFloat16s using CEnum: @cenum diff --git a/lib/cudnn/util.jl b/lib/cudnn/util.jl index 8fa8ff7bcc..6400947b70 100644 --- a/lib/cudnn/util.jl +++ b/lib/cudnn/util.jl @@ -2,6 +2,7 @@ cptr(x,a::DenseCuArray{Float64})=Float64[x] cptr(x,a::DenseCuArray{Float32})=Float32[x] cptr(x,a::DenseCuArray{Float16})=Float32[x] +cptr(x,a::DenseCuArray{BFloat16s.BFloat16})=BFloat16s.BFloat16[x] # Conversion between Julia and CUDNN datatypes cudnnDataType(::Type{Float16})=CUDNN_DATA_HALF @@ -10,6 +11,7 @@ cudnnDataType(::Type{Float64})=CUDNN_DATA_DOUBLE cudnnDataType(::Type{Int8}) = CUDNN_DATA_INT8 cudnnDataType(::Type{UInt8}) = CUDNN_DATA_UINT8 cudnnDataType(::Type{Int32}) = CUDNN_DATA_INT32 +cudnnDataType(::Type{BFloat16s.BFloat16}) = CUDNN_DATA_BFLOAT16 # The following are 32-bit elements each composed of 4 8-bit integers, only supported with CUDNN_TENSOR_NCHW_VECT_C # CUDNN_DATA_INT8x4, # CUDNN_DATA_UINT8x4, @@ -19,7 +21,9 @@ juliaDataType(a)=(a==CUDNN_DATA_HALF ? Float16 : a==CUDNN_DATA_DOUBLE ? Float64 : a==CUDNN_DATA_INT8 ? Int8 : a==CUDNN_DATA_UINT8 ? UInt8 : - a==CUDNN_DATA_INT32 ? Int32 : error()) + a==CUDNN_DATA_INT32 ? Int32 : + a==CUDNN_DATA_BFLOAT16 ? BFloat16s.BFloat16 : + error()) tuple_strides(A::Tuple) = _strides((1,), A) _strides(out::Tuple{Int}, A::Tuple{}) = ()