From ee532cde46c8481513622bf87042a0eac6a44a39 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Mon, 9 Aug 2021 15:53:59 +0530 Subject: [PATCH 1/3] add basic utils for bf16 --- lib/cudnn/util.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/lib/cudnn/util.jl b/lib/cudnn/util.jl index 8fa8ff7bcc..6400947b70 100644 --- a/lib/cudnn/util.jl +++ b/lib/cudnn/util.jl @@ -2,6 +2,7 @@ cptr(x,a::DenseCuArray{Float64})=Float64[x] cptr(x,a::DenseCuArray{Float32})=Float32[x] cptr(x,a::DenseCuArray{Float16})=Float32[x] +cptr(x,a::DenseCuArray{BFloat16s.BFloat16})=BFloat16s.BFloat16[x] # Conversion between Julia and CUDNN datatypes cudnnDataType(::Type{Float16})=CUDNN_DATA_HALF @@ -10,6 +11,7 @@ cudnnDataType(::Type{Float64})=CUDNN_DATA_DOUBLE cudnnDataType(::Type{Int8}) = CUDNN_DATA_INT8 cudnnDataType(::Type{UInt8}) = CUDNN_DATA_UINT8 cudnnDataType(::Type{Int32}) = CUDNN_DATA_INT32 +cudnnDataType(::Type{BFloat16s.BFloat16}) = CUDNN_DATA_BFLOAT16 # The following are 32-bit elements each composed of 4 8-bit integers, only supported with CUDNN_TENSOR_NCHW_VECT_C # CUDNN_DATA_INT8x4, # CUDNN_DATA_UINT8x4, @@ -19,7 +21,9 @@ juliaDataType(a)=(a==CUDNN_DATA_HALF ? Float16 : a==CUDNN_DATA_DOUBLE ? Float64 : a==CUDNN_DATA_INT8 ? Int8 : a==CUDNN_DATA_UINT8 ? UInt8 : - a==CUDNN_DATA_INT32 ? Int32 : error()) + a==CUDNN_DATA_INT32 ? Int32 : + a==CUDNN_DATA_BFLOAT16 ? BFloat16s.BFloat16 : + error()) tuple_strides(A::Tuple) = _strides((1,), A) _strides(out::Tuple{Int}, A::Tuple{}) = () From 8272c714a593e8f833b6174cd868c0c5c1f6ee35 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Mon, 9 Aug 2021 15:54:10 +0530 Subject: [PATCH 2/3] imports fix --- lib/cudnn/CUDNN.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/cudnn/CUDNN.jl b/lib/cudnn/CUDNN.jl index f7f34399a8..93b0b1fb63 100644 --- a/lib/cudnn/CUDNN.jl +++ b/lib/cudnn/CUDNN.jl @@ -12,6 +12,7 @@ using ..APIUtils using ..CUDA using ..CUDA: CUstream, libraryPropertyType using ..CUDA: libcudnn, @retry_reclaim, isdebug, @context! +using ..CUDA: BFloat16s using CEnum: @cenum From 2e21db389967cd6f7c98e271e4e94184f0ba71b8 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Mon, 9 Aug 2021 16:08:33 +0530 Subject: [PATCH 3/3] precompile fixes --- lib/cudnn/CUDNN.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/cudnn/CUDNN.jl b/lib/cudnn/CUDNN.jl index 93b0b1fb63..3a9bce21c7 100644 --- a/lib/cudnn/CUDNN.jl +++ b/lib/cudnn/CUDNN.jl @@ -12,7 +12,7 @@ using ..APIUtils using ..CUDA using ..CUDA: CUstream, libraryPropertyType using ..CUDA: libcudnn, @retry_reclaim, isdebug, @context! -using ..CUDA: BFloat16s +using BFloat16s using CEnum: @cenum