Skip to content

Commit

Permalink
conditional choice of fast modes
Browse files Browse the repository at this point in the history
  • Loading branch information
gartangh committed Jul 27, 2020
1 parent 009143c commit a73a069
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
5 changes: 3 additions & 2 deletions lib/cudnn/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ function ConvDesc(T, N, padding, stride, dilation, mode)
else
cudnnSetConvolutionNdDescriptor(cd[],N,cdsize(padding,N),cdsize(stride,N),cdsize(dilation,N),mode)
end
if version() >= v"7"
cudnnSetConvolutionMathType(cd[], cudnnMathType_t(1))
# enable tensor math mode if our device supports it, and fast math is enabled
if Base.JLOptions().fast_math == 1 && capability(device()) >= v"7.0" && version() >= v"9"
cudnnSetConvolutionMathType(cd[], CUDNN_TENSOR_OP_MATH)
end
this = ConvDesc(cd[])
finalizer(unsafe_free!, this)
Expand Down
8 changes: 7 additions & 1 deletion lib/cudnn/nnlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@ function softmax(xs::CuVecOrMat{T}; dims=1) where T<:CUDNNFloat
end

function softmax!(out::CuVecOrMat{T}, xs::CuVecOrMat{T}; dims=1) where T<:CUDNNFloat
cudnnSoftmaxForward(reshape4D(xs), reshape4D(out), algorithm=CUDNN_SOFTMAX_FAST, mode=cudnnSoftmaxMode_t(dims-1))
# use fast over accurate algorithm if fast math is enabled
if Base.JLOptions().fast_math == 1
algorithm = CUDNN_SOFTMAX_FAST
else
algorithm = CUDNN_SOFTMAX_ACCURATE
end
cudnnSoftmaxForward(reshape4D(xs), reshape4D(out), algorithm=algorithm, mode=cudnnSoftmaxMode_t(dims-1))
return out
end

Expand Down

0 comments on commit a73a069

Please sign in to comment.