Skip to content

Commit

Permalink
added test for convGEMMKernel.
Browse files Browse the repository at this point in the history
  • Loading branch information
lruthotto committed Jan 27, 2018
1 parent cd791b2 commit 6dc2e4e
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 12 deletions.
24 changes: 12 additions & 12 deletions src/kernelTypes/convGEMMKernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ function Amv(this::convGEMMKernel{T},theta::Array{T},Y::Array{T}) where {T}
nex = div(numel(Y),prod(nImgIn(this)))
# compute convolution
Y = reshape(Y,nImg[1],nImg[2],this.sK[3],nex);
AY = zeros(eltype(Y),nImg[1]*nImg[2],this.sK[4],nex);
aux = zeros(eltype(Y),nImg[1],nImg[2],this.sK[3]);
AYk = zeros(eltype(Y),nImg[1]*nImg[2],this.sK[4]);
AY = zeros(T,nImg[1]*nImg[2],this.sK[4],nex);
aux = zeros(T,nImg[1],nImg[2],this.sK[3]);
AYk = zeros(T,nImg[1]*nImg[2],this.sK[4]);
### reshape the kernels for gemm!:
K = reshape(theta,tuple(sK...));
KK = Array{Array{T,2}}(sK[1],sK[2]);
Expand All @@ -28,11 +28,11 @@ function Amv(this::convGEMMKernel{T},theta::Array{T},Y::Array{T}) where {T}
end
shiftX = [0;-1;0;0;1;0];
shiftT = [1;0;0;0;0;-1];

for k = 1:nex
AYk = multConv2Dblock(Y,KK, AYk,aux,shiftX,shiftT,k);
@inbounds AY[:,:,k] = AYk;
AYk[:] = 0.0;
AYk[:] = zero(T)
end
AY = reshape(AY,:,nex);
return AY
Expand All @@ -47,7 +47,7 @@ function ATmv(this::convGEMMKernel{T},theta::Array{T},Z::Array{T}) where {T}
aux = zeros(T,nImg[1],nImg[2],sK[4]);
ATZ = zeros(T,nImg[1]*nImg[2],sK[3],nex);
ATZk = zeros(T,nImg[1]*nImg[2],sK[3]);

### reshape the kernels for gemm!:
KK = Array{Array{T,2}}(sK[1],sK[2]);
for k1 = 1:sK[1]
Expand All @@ -62,20 +62,20 @@ function ATmv(this::convGEMMKernel{T},theta::Array{T},Z::Array{T}) where {T}
for k = 1:nex
ATZk = multConv2Dblock(Z,KK, ATZk,aux,shiftX,shiftT,k);
@inbounds ATZ[:,:,k] = ATZk;
ATZk[:] = 0.0;
ATZk[:] = zero(T)
end
ATZ = reshape(ATZ,:,nex);
return ATZ
end

function Jthetamv(this::convGEMMKernel{T},dtheta::Array{T},dummy,Y::Array{T},temp=nothing) where {T}
nex = div(numel(Y),nFeatIn(this));
Z = Amv(this,dtheta,Y);
return Z
end

function JthetaTmv(this::convGEMMKernel{T},Z::Array{T},dummy,Y::Array{T}) where {T}
# derivative of Z*(A(theta)*Y) w.r.t. theta
# derivative of Z*(A(theta)*Y) w.r.t. theta
sK = this.sK;
nImg = this.nImg;
nex = div(numel(Y),prod(nImgIn(this)))
Expand Down Expand Up @@ -156,7 +156,7 @@ for p = 1:2:2*kernelWidth
if it <= nImg1
@inbounds t[it:nImg1,jt,cc] = 0.0;
end
jt+=1;jx+=1;
jt+=1;jx+=1;
end
if jt <= nImg2
@inbounds t[:,jt:nImg2,cc] = 0.0;
Expand Down Expand Up @@ -187,8 +187,8 @@ function transposeTest()
ATZ = ATmv(Kernel2,K,Z);
println(vecdot(Z,AY));
println(vecdot(ATZ,Y));

println(vecdot(Z,Jthetamv(Kernel2,K,[],Y)));
println(vecdot(K,JthetaTmv(Kernel2,Z,[],Y)));

end
35 changes: 35 additions & 0 deletions test/kernel/convGEMMKernelTest.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
using Base.Test
using Meganet
using LinearOperators


nImg = [8,10]
sK = [3,3,4,4]
for TYPE=[Float64,Float32]
K = getConvGEMMKernel(TYPE,nImg,sK)

@testset "adjoint test $TYPE" begin
theta = initTheta(K)
A = getOp(K,theta);
v = randn(TYPE,nFeatIn(K))
w = randn(TYPE,nFeatOut(K))

t1 = dot(w,A*v)
t2 = dot(v,A'*w)
# println("adjointTest t1=$t1\t t2=$t2")
@test norm(t1-t2)/norm(t1) < 1e3*eps(TYPE)
end

@testset "derivative Test" begin
th = initTheta(K);
dth = initTheta(K);
nex = 1;
Y = randn(TYPE,nFeatIn(K),nex)+nex;
Z = randn(TYPE,nFeatOut(K),nex)-nex;

t1 = vec(Z)'*vec(Jthetamv(K,dth,th,Y));
t2 = vec(dth)'*vec(JthetaTmv(K,Z,th,Y));
# println("derivativeTest t1=$t1\t t2=$t2")
@test norm(t1-t2)/norm(t2) < 1e3*eps(TYPE)
end
end
4 changes: 4 additions & 0 deletions test/kernel/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,7 @@ end
@testset "convFFTKernel" begin
include("convFFTKernelTest.jl")
end

@testset "convGEMMKernel" begin
include("convFFTKernelTest.jl")
end

0 comments on commit 6dc2e4e

Please sign in to comment.