-
Couldn't load subscription status.
- Fork 33
Closed
Description
I've found this error when trying to get Reactant to work with NFFT, although I had to customise it a little bit.
using Reactant
using NFFT
using LinearAlgebra
using AbstractFFTs
using Adapt
mutable struct React_NFFTPlan{T,D, arrTc, vecI, vecII, FP, BP, INV, SM} <: AbstractNFFTPlan{T,D,1}
N::NTuple{D,Int64}
NOut::NTuple{1,Int64}
J::Int64
k::Matrix{T}
Ñ::NTuple{D,Int64}
dims::UnitRange{Int64}
params::NFFTParams{T}
forwardFFT::FP
backwardFFT::BP
tmpVec::arrTc
tmpVecHat::arrTc
deconvolveIdx::vecI
windowLinInterp::vecII
windowHatInvLUT::INV
B::SM
end
function AbstractNFFTs.plan_nfft(arr::Type{<:Reactant.AnyTracedRArray}, k::AT, N::NTuple{D,Int}, rest...;
timing::Union{Nothing,TimingStats} = nothing, kargs...) where {D, AT <: AbstractMatrix}
t = @elapsed begin
p = React_NFFTPlan(arr, k, N, rest...; kargs...)
end
if timing != nothing
timing.pre = t
end
return p
end
function React_NFFTPlan(arr, k::AbstractArray{T}, N::NTuple{D,Int};
fftflags=nothing, kwargs...) where {T,D}
# if dims != 1:D
# error("GPU NFFT does not work along directions right now!")
# end
dims = 1:D
params, N, NOut, J, Ñ, dims_ = NFFT.initParams(k, N, dims; kwargs...)
params.storeDeconvolutionIdx = true # GPU_NFFT only works this way
params.precompute = NFFT.FULL # GPU_NFFT only works this way
tmpVec = convert(Reactant.TracedRArray, zeros(Complex{T}, Ñ))
FP = nothing
BP = nothing
windowLinInterp, windowPolyInterp, windowHatInvLUT, deconvolveIdx, B = NFFT.precomputation(k, N[dims_], Ñ[dims_], params)
U = params.storeDeconvolutionIdx ? N : ntuple(d->0,Val(D))
tmpVecHat = convert(Reactant.TracedRArray, zeros(Complex{T}, U))
deconvIdx = convert(Reactant.TracedRArray, Int32.(adapt(arr, (deconvolveIdx))))
winHatInvLUT = convert(Reactant.TracedRArray, Complex{T}.(adapt(arr, (windowHatInvLUT[1]))))
B_ = convert(Reactant.TracedRArray, Complex{T}.(adapt(arr, (B)))) # Bit hacky
React_NFFTPlan{T, D, typeof(tmpVec), typeof(deconvIdx), typeof(windowLinInterp), typeof(FP), typeof(BP), typeof(winHatInvLUT), typeof(B_)}(N, NOut, J, k, Ñ, dims_, params, FP, BP, tmpVec, tmpVecHat,
deconvIdx, windowLinInterp, winHatInvLUT, B_)
end
AbstractNFFTs.size_in(p::React_NFFTPlan) = p.N
AbstractNFFTs.size_out(p::React_NFFTPlan) = p.NOut
function AbstractNFFTs.deconvolve!(p::React_NFFTPlan{T,D}, f::Reactant.AnyTracedRArray, g::Reactant.AnyTracedRArray) where {D,T}
tmp = f .* reshape(p.windowHatInvLUT, size(f))
@allowscalar @inbounds gv = @view(g[p.deconvolveIdx])
vtmp = vec(tmp)
@allowscalar @inbounds g[p.deconvolveIdx] .= vtmp
return nothing
end
function Base.:*(p::React_NFFTPlan{T}, f::Reactant.AnyTracedRArray; kargs...) where {T}
fHat = similar(f, Complex{T}, size_out(p))
mul!(fHat, p, f; kargs...)
return fHat
end
""" in-place NFFT on the GPU"""
function LinearAlgebra.mul!(fHat::Reactant.AnyTracedRArray, p::React_NFFTPlan{T,D}, f::Reactant.AnyTracedRArray;
verbose=false, timing::Union{Nothing,TimingStats} = nothing) where {T,D}
NFFT.consistencyCheck(p, f, fHat)
fill!(p.tmpVec, zero(Complex{T}))
t1 = @elapsed @inbounds deconvolve!(p, f, p.tmpVec)
@info "Done Deconvolve"
@info size(fHat)
@info size(p.tmpVec)
@info size(p.tmpVec[1:length(fHat)])
fHat .= p.tmpVec[1:length(fHat)]
@info "Done copyto!"
return fHat
# @show size(p.tmpVec)
# tmpVec = fft(p.tmpVec)
# @info "Done FFT"
# p.tmpVec .= tmpVec
# @info "Done copyto!"
# t3 = @elapsed @inbounds convolve!(p, p.tmpVec, fHat)
# @info "Done convolve!"
# # if verbose
# # @info "Timing: deconv=$t1 fft=$t2 conv=$t3"
# # end
# # if timing != nothing
# # timing.conv = t3
# # timing.fft = t2
# # timing.deconv = t1
# # end
# @show fHat
# return fHat
end
function NFFT.nfft(k::AbstractMatrix, f::Reactant.AnyTracedRArray, args...; kwargs...)
p = plan_nfft(typeof(f), k, size(f); kwargs...)
return p*f
end
function NFFT.initParams(k::AbstractMatrix{T},
N::NTuple{D,Int},
dims::Union{Integer,UnitRange{Int64}}=1:D;
kargs...) where {D,T}
# convert dims to a unit range
dims_ = (typeof(dims) <: Integer) ? (dims:dims) : dims
params = NFFTParams{T,D}(; kargs...)
m, σ, reltol = accuracyParams(; kargs...)
params.m = m
params.σ = σ
params.reltol = reltol
# Taken from NFFT3
m2K = [1, 3, 7, 9, 14, 17, 20, 23, 24]
K = m2K[min(m+1,length(m2K))]
params.LUTSize = 2^(K) * (m) # ensure that LUTSize is dividable by (m)
if length(dims_) != size(k,1)
throw(ArgumentError("Nodes x have dimension $(size(k,1)) != $(length(dims_))"))
end
doTrafo = ntuple(d->d ∈ dims_, Val(D))
Ñ = ntuple(d -> doTrafo[d] ?
(ceil(Int,params.σ*N[d])÷2)*2 : # ensure that n is an even integer
N[d], Val(D))
params.σ = Ñ[dims_[1]] / N[dims_[1]]
#params.blockSize = ntuple(d-> Ñ[d] , D) # just one block
if haskey(kargs, :blockSize)
params.blockSize = kargs[:blockSize]
else
params.blockSize = ntuple(d-> NFFT._blockSize(Ñ,d) , Val(D))
end
J = size(k, 2)
# calculate output size
NOut = Int[]
Mtaken = false
ntuple(Val(D)) do d
if !doTrafo[d]
return N[d]
elseif !Mtaken
return J
Mtaken = true
end
end
for d=1:D
if !doTrafo[d]
push!(NOut, N[d])
elseif !Mtaken
push!(NOut, J)
Mtaken = true
end
end
# Sort nodes in lexicographic way
if params.sortNodes
k .= sortslices(k, dims=2)
end
return params, N, Tuple(NOut), J, Ñ, dims_
end
function NFFT.precomputation(k::AbstractVecOrMat, N::NTuple{D,Int}, Ñ, params) where {D}
m = params.m; σ = params.σ; window=params.window
LUTSize = params.LUTSize; precompute = params.precompute
win, win_hat = getWindow(window) # highly type instable. But what should be do
J = size(k, 2)
windowHatInvLUT_ = Vector{Vector{T}}(undef, D)
precomputeWindowHatInvLUT(windowHatInvLUT_, win_hat, N, Ñ, m, σ, T)
if params.storeDeconvolutionIdx
windowHatInvLUT = Vector{Vector{T}}(undef, 1)
windowHatInvLUT[1], deconvolveIdx = precompWindowHatInvLUT(params, N, Ñ, windowHatInvLUT_)
else
windowHatInvLUT = windowHatInvLUT_
deconvolveIdx = Array{Int64,1}(undef, 0)
end
if precompute == LINEAR
windowLinInterp = precomputeLinInterp(win, m, σ, LUTSize, T)
windowPolyInterp = Matrix{T}(undef, 0, 0)
B = sparse([],[],T[])
elseif precompute == POLYNOMIAL
windowLinInterp = Vector{T}(undef, 0)
windowPolyInterp = precomputePolyInterp(win, m, σ, T)
B = sparse([],[],T[])
elseif precompute == FULL
windowLinInterp = Vector{T}(undef, 0)
windowPolyInterp = Matrix{T}(undef, 0, 0)
B = precomputeB(win, k, N, Ñ, m, J, σ, LUTSize, T)
#windowLinInterp = precomputeLinInterp(win, windowLinInterp, Ñ, m, σ, LUTSize, T) # These versions are for debugging
#B = precomputeB(windowLinInterp, k, N, Ñ, m, J, σ, LUTSize, T)
elseif precompute == TENSOR
windowLinInterp = Vector{T}(undef, 0)
windowPolyInterp = Matrix{T}(undef, 0, 0)
B = sparse([],[],T[])
else
windowLinInterp = Vector{T}(undef, 0)
windowPolyInterp = Matrix{T}(undef, 0, 0)
B = sparse([],[],T[])
error("precompute = $precompute not supported by NFFT.jl!")
end
return (windowLinInterp, windowPolyInterp, windowHatInvLUT, deconvolveIdx, B)
end
f = randn(ComplexF64, 64, 64)
fr = Reactant.to_rarray(f)
k = rand(2, 64) .- 0.5
@compile nfft(k, fr) Gives this error.
LLVM ERROR: Incompatible index and shape found while flattening indexMetadata
Metadata
Assignees
Labels
No labels