Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #63 from XtractOpen/fasterConvFFT
updated convFFT to be more efficient (but still slower than GEMM).
- Loading branch information
Showing
3 changed files
with
195 additions
and
133 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
using MAT | ||
using Meganet | ||
|
||
nImg = vec([32 32]) | ||
sK = vec([3 3 32 32]) | ||
nex = 100; | ||
|
||
K1 = getConvGEMMKernel(Float64,nImg,sK) | ||
K2 = getConvFFTKernel(Float64,nImg,sK) | ||
|
||
theta = randn(nTheta(K1)); | ||
|
||
Y = zeros(tuple([nImg;sK[3];nex]...)) | ||
Y[2:end-1,2:end-1,:] = randn(tuple([nImg-2;sK[3];nex]...)); | ||
|
||
t1 = Amv(K1,theta,Y) | ||
@time t1 = Amv(K1,theta,Y) | ||
|
||
t2 = Amv(K2,theta,Y) | ||
@time t2 = Amv(K2,theta,Y) | ||
|
||
println(norm(t1[:]-t2[:])/norm(t1[:])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,126 +1,169 @@ | ||
export convFFTKernel, getEigs,getConvFFTKernel | ||
## For the functions nImgIn, nImgOut, nFeatIn, nFeatOut, nTheta, getOp, initTheta : see AbstractConvKernel.jl | ||
## All convKernel types are assumed to have fields nImage and sK | ||
mutable struct convFFTKernel{T} <: AbstractConvKernel{T} | ||
nImg :: Array{Int,1} | ||
sK :: Array{Int,1} | ||
S :: Array{Complex{T},2} | ||
end | ||
|
||
function getConvFFTKernel(TYPE::Type,nImg,sK) | ||
S = getEigs(Complex{TYPE},nImg,sK) | ||
return convFFTKernel{TYPE}(nImg,sK,S) | ||
end | ||
|
||
function getEigs(TYPE,nImg,sK) | ||
S = zeros(TYPE,prod(nImg),prod(sK[1:2])); | ||
for k=1:prod(sK[1:2]) | ||
Kk = zeros(sK[1],sK[2]); | ||
Kk[k] = 1; | ||
Ak = getConvMatPeriodic(TYPE,Kk,[nImg[1],nImg[2], 1]); | ||
Akk = full(convert(Array{TYPE},Ak[:,1])); | ||
S[:,k] = vec(fft2(reshape(Akk,nImg[1],nImg[2]) )); | ||
end | ||
return S | ||
end | ||
|
||
export Amv | ||
function Amv(this::convFFTKernel{T},theta::Array{T},Y::Array{T}) where {T<:Number} | ||
|
||
nex = div(numel(Y),prod(nImgIn(this))) | ||
|
||
# compute convolution | ||
AY = zeros(Complex{T},tuple([nImgOut(this); nex]...)); | ||
theta = reshape(theta, tuple([prod(this.sK[1:2]); this.sK[3:4]]...)); | ||
Yh = ifft2(reshape(Y,tuple([nImgIn(this); nex]...))); | ||
#### allocate stuff for the loop | ||
Sk = zeros(Complex{T},tuple(nImgOut(this)...)) | ||
#T = zeros(Complex{eltype(Y)},tuple(nImgOut(this)...)) | ||
nn = nImgOut(this); nn[3] = 1; | ||
sumT = zeros(Complex{T},tuple([nn;nex]...)) | ||
#### | ||
|
||
for k=1:this.sK[4] | ||
Sk = reshape(this.S*theta[:,:,k],tuple(nImgIn(this)...)); | ||
#T = Sk .* Yh; | ||
#sumT = sum(T,3) | ||
sumT = hadamardSum(sumT,Yh,Sk) | ||
AY[:,:,k,:] = sumT[:,:,1,:]; | ||
end | ||
AY = real(fft2(AY)); | ||
Y = reshape(AY,:,nex); | ||
return Y | ||
end | ||
|
||
function ATmv(this::convFFTKernel{T},theta::Array{T},Z::Array{T}) where {T<:Number} | ||
|
||
nex = div(numel(Z),prod(nImgOut(this))); | ||
ATY = zeros(Complex{T},tuple([nImgIn(this); nex]...)); | ||
theta = reshape(theta, prod(this.sK[1:2]),this.sK[3],this.sK[4]); | ||
#### allocate stuff for the loop | ||
Sk = zeros(Complex{T},tuple(nImgOut(this)...)) | ||
#T = zeros(Complex{eltype(Z)},tuple(nImgOut(this)...)) | ||
nn = nImgOut(this); nn[3] = 1; | ||
sumT = zeros(Complex{T},tuple([nn;nex]...)) | ||
#### | ||
|
||
Yh = fft2(reshape(Z,tuple([nImgOut(this); nex]...))); | ||
for k=1:this.sK[3] | ||
tk = theta[:,k,:] | ||
#if size(this.S,2) == 1 | ||
# tk = reshape(tk,1,:); | ||
#end | ||
Sk = reshape(this.S*tk,tuple(nImgOut(this)...)); | ||
#T = Sk.*Yh; | ||
#sumT = sum(T,3) | ||
sumT = hadamardSum(sumT,Yh,Sk) | ||
ATY[:,:,k,:] = sumT[:,:,1,:]; | ||
end | ||
ATY = real(ifft2(ATY)); | ||
ATY = reshape(ATY,:,nex); | ||
return ATY | ||
end | ||
|
||
function Jthetamv(this::convFFTKernel{T},dtheta::Array{T},dummy::Array{T},Y::Array{T},temp=nothing) where {T<:Number} | ||
|
||
nex = div(numel(Y),nFeatIn(this)); | ||
Y = reshape(Y,:,nex); | ||
Z = Amv(this,dtheta,Y); | ||
return Z | ||
end | ||
|
||
function JthetaTmv(this::convFFTKernel{T},Z::Array{T},dummy::Array{T},Y::Array{T}) where {T<:Number} | ||
# derivative of Z*(A(theta)*Y) w.r.t. theta | ||
|
||
nex = div(numel(Y),nFeatIn(this)); | ||
|
||
dth1 = zeros(this.sK[1]*this.sK[2],this.sK[3],this.sK[4]); | ||
Y = permutedims(reshape(Y,tuple([nImgIn(this); nex ]...)),[1 2 4 3]); | ||
Yh = reshape(fft2(Y),prod(this.nImg[1:2]),nex*this.sK[3]); | ||
Zh = permutedims(ifft2(reshape(Z,tuple([nImgOut(this); nex]...))),[1 2 4 3]); | ||
Zh = reshape(Zh,:, this.sK[4]); | ||
|
||
for k=1:prod(this.sK[1:2]) | ||
temp = conj(this.S[:,k]) .* Yh | ||
temp = reshape(temp,:,this.sK[3]) | ||
dth1[k,:,:] = real(conj(temp)'*Zh); | ||
end | ||
|
||
dtheta = reshape(dth1,tuple(this.sK...)); | ||
return dtheta | ||
end | ||
|
||
function hadamardSum(sumT::Array{T},Yh::Array{T},Sk::Array{T}) where {T<:Number} | ||
sumT .= 0.0; | ||
for i4 = 1:size(Yh,4) | ||
for i3 = 1:size(Yh,3) | ||
for i2 = 1:size(Yh,2) | ||
for i1 = 1:size(Yh,1) | ||
@inbounds tt = Sk[i1,i2,i3] | ||
@inbounds sumT[i1,i2,1,i4] += tt * Yh[i1,i2,i3,i4] | ||
end | ||
end | ||
end | ||
end | ||
return sumT | ||
end | ||
export convFFTKernel, getConvFFTKernel | ||
## For the functions nImgIn, nImgOut, nFeatIn, nFeatOut, nTheta, getOp, initTheta : see AbstractConvKernel.jl | ||
## All convKernel types are assumed to have fields nImage and sK | ||
mutable struct convFFTKernel{T} <: AbstractConvKernel{T} | ||
nImg :: Array{Int,1} | ||
sK :: Array{Int,1} | ||
Kp :: Array{T} | ||
I :: Array{Int} | ||
end | ||
|
||
function getConvFFTKernel(TYPE::Type,nImg,sK) | ||
return convFFTKernel{TYPE}(nImg,sK,(TYPE)[],(Int64)[]) | ||
end | ||
|
||
function getKp(this::convFFTKernel{T}) where {T<:Number} | ||
# setup the padded convolution kernel and get the indices of non-zeros | ||
if isempty(this.Kp) | ||
theta = reshape(T.(collect(1:prod(this.sK))),this.sK[1],this.sK[2],prod(this.sK[3:4])) | ||
Kp = zeros(T,this.nImg[1],this.nImg[2],size(theta,3)) | ||
Kp[1:this.sK[1],1:this.sK[2],:] = theta; | ||
center = (this.sK[1:2]+1)./2 | ||
Kp = circshift(Kp,1-center); | ||
I = find(Kp) | ||
this.Kp = zero(T)*Kp; | ||
idp = sortperm(Kp[I]) | ||
this.I = I[idp]; | ||
end | ||
return this.Kp,this.I | ||
end | ||
function getK1(this::convFFTKernel{T},theta::Array{T}) where {T<:Number} | ||
# get first columns of convolultion matrices | ||
Kp,I = getKp(this) | ||
for k=1:length(I) | ||
Kp[I[k]] = theta[k] | ||
end | ||
return Kp | ||
end | ||
|
||
# methods for A*x | ||
function multRed!(Zkh::Array{Complex{T},2},S::Array{Complex{T},4},Yh::Array{Complex{T},3},k::Int) where {T<:Number} | ||
# compute Zkh[i1,i2,k] = S[i1,i2,:,k]'*Yh[i1,i2,:] | ||
Zkh[:]=Complex128(0.0) | ||
for i3=1:size(Yh,3) | ||
for i2=1:size(Zkh,2) | ||
for i1=1:size(Zkh,1) | ||
@inbounds Zkh[i1,i2] += S[i1,i2,i3,k].*Yh[i1,i2,i3] | ||
end | ||
end | ||
end | ||
end | ||
|
||
function Amv!(this::convFFTKernel{T},Z::AbstractArray{T,3},S::Array{Complex{T}},Y::AbstractArray{T,3},Yh::Array{Complex{T},3},Zkh::Array{Complex{T},2}) where {T<:Number} | ||
# 2D convolution for a single example. | ||
Yh[:]=Y; ifft2!(Yh) | ||
for k=1:this.sK[4] | ||
multRed!(Zkh,S,Yh,k) | ||
Z[:,:,k] = real(fft2!(Zkh)) | ||
end | ||
return Z | ||
end | ||
function Amv(this::convFFTKernel{T},theta::Array{T},Y::Array{T}) where {T<:Number} | ||
nex = div(length(Y),prod(nImgIn(this))) | ||
Z = zeros(T,tuple([nImgOut(this); nex]...)) | ||
Amv!(this,Z,theta,Y) | ||
return Z | ||
end | ||
function Amv!(this::convFFTKernel{T},Z::Array{T},theta::Array{T},Y::Array{T}) where {T<:Number} | ||
nex = div(length(Y),prod(nImgIn(this))) | ||
Y = reshape(Y,tuple([nImgIn(this);nex]...)) | ||
# pre-allocation for temps | ||
Ykh = zeros(Complex{T},this.nImg[1],this.nImg[2],this.sK[3]) | ||
Zk = zeros(T, this.nImg[1],this.nImg[2], this.sK[4]) | ||
Zik = zeros(Complex{T}, this.nImg[1],this.nImg[2]) | ||
|
||
# get kernel | ||
S = reshape( fft2(getK1(this,theta)), this.nImg[1],this.nImg[2], this.sK[3], this.sK[4]) | ||
|
||
# compute convolution | ||
for k=1:nex | ||
Amv!(this,view(Z,:,:,:,k),S,view(Y,:,:,:,k),Ykh,Zik) | ||
end | ||
Z = reshape(Z,:,nex) | ||
return Z | ||
end | ||
|
||
# methods for A'*x | ||
function multRedT!(Ykh::Array{Complex{T},2},S::Array{Complex{T},4},Zh::Array{Complex{T},3},j::Int) where {T<:Number} | ||
# compute Ykh[i1,i2,k] = S[i1,i2,j,:]'*Zh[i1,i2,:] | ||
Ykh[:]=Complex{T}(0.0) | ||
for i3=1:size(Zh,3) | ||
for i2=1:size(Ykh,2) | ||
for i1=1:size(Ykh,1) | ||
@inbounds Ykh[i1,i2] += S[i1,i2,j,i3].*Zh[i1,i2,i3] | ||
end | ||
end | ||
end | ||
return Ykh | ||
end | ||
|
||
function ATmv!(this::convFFTKernel{T},Y::AbstractArray{T,3},S::Array{Complex{T}},Z::AbstractArray{T,3},Zh::Array{Complex{T},3},Ykh::Array{Complex{T},2}) where {T<:Number} | ||
# 2D convolution for a single example. | ||
Zh[:]=Z; Zh =fft2!(Zh) | ||
for j=1:this.sK[3] | ||
multRedT!(Ykh,S,Zh,j) | ||
Y[:,:,j] = real(ifft2!(Ykh)) | ||
end | ||
end | ||
function ATmv(this::convFFTKernel{T},theta::Array{T},Z::Array{T}) where {T<:Number} | ||
nex = div(length(Z),prod(nImgOut(this))) | ||
Y = zeros(T,tuple([nImgIn(this); nex]...)) | ||
ATmv!(this,Y,theta,Z) | ||
return Y | ||
end | ||
|
||
function ATmv!(this::convFFTKernel{T},Y::Array{T},theta::Array{T},Z::Array{T}) where {T<:Number} | ||
nex = div(length(Y),prod(nImgIn(this))) | ||
Z = reshape(Z,tuple([nImgOut(this);nex]...)) | ||
|
||
# pre-allocation for temps | ||
Zkh = zeros(Complex{T},this.nImg[1],this.nImg[2],this.sK[4]) | ||
Yk = zeros(T, this.nImg[1],this.nImg[2], this.sK[3]) | ||
Yik = zeros(Complex{T}, this.nImg[1],this.nImg[2]) | ||
|
||
# get kernel | ||
S = reshape( fft2(getK1(this,theta)), this.nImg[1],this.nImg[2], this.sK[3], this.sK[4]) | ||
|
||
# compute convolution | ||
for k=1:nex | ||
ATmv!(this,view(Y,:,:,:,k),S,view(Z,:,:,:,k),Zkh,Yik) | ||
end | ||
Y = reshape(Y,:,nex) | ||
return Y | ||
end | ||
|
||
function Jthetamv(this::convFFTKernel{T},dtheta::Array{T},dummy::Array{T},Y::Array{T},temp=nothing) where {T<:Number} | ||
return Amv(this,dtheta,Y) | ||
end | ||
|
||
function JthetaTmv(this::convFFTKernel{T},Z::Array{T},dummy::Array{T},Y::Array{T}) where {T<:Number} | ||
nex = div(length(Y),nFeatIn(this)) | ||
|
||
Y = reshape(Y,this.nImg[1],this.nImg[2], this.sK[3],nex) | ||
Z = reshape(Z,this.nImg[1],this.nImg[2], this.sK[4],nex) | ||
|
||
# temps | ||
Yh = zeros(Complex{T},this.nImg[1],this.nImg[2],this.sK[3]) | ||
Zh = zeros(Complex{T},this.nImg[1],this.nImg[2],this.sK[4]) | ||
tt = zeros(Complex{T},this.nImg[1],this.nImg[2],this.sK[3],this.sK[4]) | ||
|
||
for k=1:nex | ||
Yh[:] = Y[:,:,:,k] | ||
Zh[:] = Z[:,:,:,k] | ||
|
||
ifft2!(Yh) | ||
fft2!(Zh) | ||
|
||
for i4=1:this.sK[4] | ||
for i3=1:this.sK[3] | ||
for i2=1:this.nImg[2] | ||
for i1=1:this.nImg[1] | ||
@inbounds tt[i1,i2,i3,i4] += Yh[i1,i2,i3].*Zh[i1,i2,i4] | ||
end | ||
end | ||
end | ||
end | ||
end | ||
tt = real(fft2(reshape(tt,this.nImg[1],this.nImg[2],:))) | ||
return tt[this.I] | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters