Skip to content
This repository has been archived by the owner on Mar 12, 2021. It is now read-only.

Commit

Permalink
Try #181:
Browse files Browse the repository at this point in the history
  • Loading branch information
bors[bot] committed Nov 6, 2018
2 parents 1a5e239 + 2a5c89f commit 2d9011a
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 10 deletions.
18 changes: 9 additions & 9 deletions src/dnn/nnlib.jl
Expand Up @@ -73,58 +73,58 @@ function conv_workspace(bytes)
end

function conv!(y::CuArray{T}, x::CuArray{T}, w::CuArray{T};
pad=0, stride=1, mode=0, alpha=1, dilation=1,
pad=0, stride=1, flipkernel=0, alpha=1, dilation=1,
workspace::Union{CuVector, Nothing}=nothing, algo=0) where T<:CUDNNFloat
if CUDNN_VERSION < 6000
all(x -> x == 1, dilation) || error("Only dilation = 1 is supported in cuDNN version < 6")
end
if workspace === nothing
workspace_size =
cudnnGetConvolutionForwardWorkspaceSize(y, x, w, padding=pad, stride=stride, dilation=dilation,
algo=algo, mode=mode)
algo=algo, mode=flipkernel)
workspace = workspace_size != 0 ? conv_workspace(workspace_size) : workspace
else
workspace_size = length(workspace[])
end
cudnnConvolutionForward(y, x, w, padding=pad, stride=stride, dilation=dilation, mode=mode,
cudnnConvolutionForward(y, x, w, padding=pad, stride=stride, dilation=dilation, mode=flipkernel,
alpha=alpha, algo=algo, workspace=workspace, workspace_size=workspace_size)
end

function ∇conv_filter!(dw::CuArray{T}, dy::CuArray{T}, x::CuArray{T}, w::CuArray{T};
pad=0, stride=1, mode=0, alpha=1, dilation=1,
pad=0, stride=1, flipkernel=0, alpha=1, dilation=1,
workspace::Union{CuVector, Nothing}=nothing, algo=0) where T<:CUDNNFloat
if CUDNN_VERSION < 6000
all(x -> x == 1, dilation) || error("Only dilation = 1 is supported in cuDNN version < 6")
end
if workspace === nothing
workspace_size =
cudnnGetConvolutionBackwardFilterWorkspaceSize(dw, x, w, dy, padding=pad, stride=stride,
dilation=dilation, algo=algo, mode=mode)
dilation=dilation, algo=algo, mode=flipkernel)
workspace = workspace_size != 0 ? conv_workspace(workspace_size) : workspace
else
workspace_size = length(workspace[])
end
cudnnConvolutionBackwardFilter(dw, x, w, dy, padding=pad, stride=stride, dilation=dilation,
mode=mode, alpha=alpha, algo=algo, workspace=workspace,
mode=flipkernel, alpha=alpha, algo=algo, workspace=workspace,
workspace_size=workspace_size)
end

function ∇conv_data!(dx::CuArray{T}, dy::CuArray{T}, x::CuArray{T}, w::CuArray{T};
pad=0, stride=1, mode=0, alpha=1, dilation=1,
pad=0, stride=1, flipkernel=0, alpha=1, dilation=1,
workspace::Union{CuVector, Nothing}=nothing, algo=0) where T<:CUDNNFloat
if CUDNN_VERSION < 6000
all(x -> x == 1, dilation) || error("Only dilation = 1 is supported in cuDNN version < 6")
end
if workspace === nothing
workspace_size =
cudnnGetConvolutionBackwardDataWorkspaceSize(dx, x, w, dy, padding=pad, stride=stride,
dilation=dilation, algo=algo, mode=mode)
dilation=dilation, algo=algo, mode=flipkernel)
workspace = workspace_size != 0 ? conv_workspace(workspace_size) : workspace
else
workspace_size = length(workspace[])
end
cudnnConvolutionBackwardData(dx, x, w, dy, padding=pad, stride=stride, dilation=dilation,
mode=mode, alpha=alpha, algo=algo, workspace=workspace,
mode=flipkernel, alpha=alpha, algo=algo, workspace=workspace,
workspace_size=workspace_size)
end

Expand Down
8 changes: 8 additions & 0 deletions test/dnn.jl
Expand Up @@ -15,6 +15,10 @@
@test testf(∇conv_data, rand(Float64, 8, 8, 4, 1), rand(Float64, 10, 10, 3, 1), rand(Float64, 2, 2, 3, 4); dilation=2)
@test testf(∇conv_filter, rand(Float64, 8, 8, 4, 1), rand(Float64, 10, 10, 3, 1), rand(Float64, 2, 2, 3, 4); dilation=2)

@test testf(NNlib.crosscor, rand(Float64, 10, 10, 3, 1), rand(Float64, 2, 2, 3, 4))
@test testf(∇conv_data, rand(Float64, 9, 9, 4, 1), rand(Float64, 10, 10, 3, 1), rand(Float64, 2, 2, 3, 4); flipkernel=1)
@test testf(∇conv_filter, rand(Float64, 9, 9, 4, 1), rand(Float64, 10, 10, 3, 1), rand(Float64, 2, 2, 3, 4); flipkernel=1)

@test_nowarn NNlib.conv!(cu(zeros(Float64, 9, 9, 3, 1)), cu(rand(Float64, 10, 10, 1, 1)), cu(rand(Float64, 2, 2, 1, 3)), algo=1)
@test_nowarn NNlib.∇conv_data!(cu(zeros(Float64, 10, 10, 1, 1)), cu(ones(Float64, 9, 9, 3, 1)), cu(rand(Float64, 10, 10, 1, 1)), cu(rand(Float64, 2, 2, 1, 3)), algo=1)
@test_nowarn NNlib.∇conv_filter!(cu(zeros(Float64, 2, 2, 1, 3)), cu(ones(Float64, 9, 9, 3, 1)), cu(rand(Float64, 10, 10, 1, 1)), cu(rand(Float64, 2, 2, 1, 3)), algo=1)
Expand All @@ -27,6 +31,10 @@
@test testf(∇conv_data, rand(Float64, 8, 8, 8, 4, 1), rand(Float64, 10, 10, 10, 3, 1), rand(Float64, 2, 2, 2, 3, 4); dilation=2)
@test testf(∇conv_filter, rand(Float64, 8, 8, 8, 4, 1), rand(Float64, 10, 10, 10, 3, 1), rand(Float64, 2, 2, 2, 3, 4); dilation=2)

@test testf(NNlib.crosscor, rand(Float64, 10, 10, 10, 3, 1), rand(Float64, 2, 2, 2, 3, 4))
@test testf(∇conv_data, rand(Float64, 9, 9, 9, 4, 1), rand(Float64, 10, 10, 10, 3, 1), rand(Float64, 2, 2, 2, 3, 4); flipkernel=1)
@test testf(∇conv_filter, rand(Float64, 9, 9, 9, 4, 1), rand(Float64, 10, 10, 10, 3, 1), rand(Float64, 2, 2, 2, 3, 4); flipkernel=1)

@test testf(x -> maxpool(x, (2,2)), rand(Float64, 10, 10, 3, 1))
@test testf(x -> meanpool(x, (2,2)), rand(Float64, 10, 10, 3, 1))
@test testf((x, dy) -> ∇maxpool(dy, maxpool(x, (2,2)), x, (2,2)), rand(Float64, 10, 10, 3, 1), rand(Float64, 5, 5, 3, 1))
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Expand Up @@ -13,7 +13,7 @@ if haskey(ENV, "GITLAB_CI")
end

branch = ENV["CI_COMMIT_REF_NAME"]
for package in ("GPUArrays", "CUDAnative")
for package in ("GPUArrays", "CUDAnative", "NNlib")
match_package(package, branch)
end
end
Expand Down

0 comments on commit 2d9011a

Please sign in to comment.