Skip to content

Integration with NFFT.jl #699

@ptiede

Description

@ptiede

I've found this error when trying to get Reactant to work with NFFT, although I had to customise it a little bit.

using Reactant
using NFFT
using LinearAlgebra
using AbstractFFTs
using Adapt

mutable struct React_NFFTPlan{T,D, arrTc, vecI, vecII, FP, BP, INV, SM} <: AbstractNFFTPlan{T,D,1} 
    N::NTuple{D,Int64}
    NOut::NTuple{1,Int64}
    J::Int64
    k::Matrix{T}::NTuple{D,Int64}
    dims::UnitRange{Int64}
    params::NFFTParams{T}
    forwardFFT::FP
    backwardFFT::BP
    tmpVec::arrTc
    tmpVecHat::arrTc
    deconvolveIdx::vecI
    windowLinInterp::vecII
    windowHatInvLUT::INV
    B::SM
  end
  
  function AbstractNFFTs.plan_nfft(arr::Type{<:Reactant.AnyTracedRArray}, k::AT, N::NTuple{D,Int}, rest...;
    timing::Union{Nothing,TimingStats} = nothing, kargs...) where {D, AT <: AbstractMatrix}
    t = @elapsed begin
        p = React_NFFTPlan(arr, k, N, rest...; kargs...)
    end
    if timing != nothing
      timing.pre = t
    end
    return p
  end
  
  function React_NFFTPlan(arr, k::AbstractArray{T}, N::NTuple{D,Int};
                   fftflags=nothing, kwargs...) where {T,D}
  
    # if dims != 1:D
    #   error("GPU NFFT does not work along directions right now!")
    # end

    dims = 1:D

    params, N, NOut, J, Ñ, dims_ = NFFT.initParams(k, N, dims; kwargs...)
    params.storeDeconvolutionIdx = true # GPU_NFFT only works this way
    params.precompute = NFFT.FULL # GPU_NFFT only works this way
    tmpVec = convert(Reactant.TracedRArray, zeros(Complex{T}, Ñ))

    FP = nothing
    BP = nothing

    windowLinInterp, windowPolyInterp, windowHatInvLUT, deconvolveIdx, B = NFFT.precomputation(k, N[dims_], Ñ[dims_], params)

    U = params.storeDeconvolutionIdx ? N : ntuple(d->0,Val(D))
    tmpVecHat = convert(Reactant.TracedRArray, zeros(Complex{T}, U))
    deconvIdx = convert(Reactant.TracedRArray, Int32.(adapt(arr, (deconvolveIdx))))
    winHatInvLUT = convert(Reactant.TracedRArray, Complex{T}.(adapt(arr, (windowHatInvLUT[1]))))
    B_ = convert(Reactant.TracedRArray, Complex{T}.(adapt(arr, (B)))) # Bit hacky

    React_NFFTPlan{T, D, typeof(tmpVec), typeof(deconvIdx), typeof(windowLinInterp), typeof(FP), typeof(BP), typeof(winHatInvLUT), typeof(B_)}(N, NOut, J, k, Ñ, dims_, params, FP, BP, tmpVec, tmpVecHat, 
        deconvIdx, windowLinInterp, winHatInvLUT, B_)
end
  
AbstractNFFTs.size_in(p::React_NFFTPlan) = p.N
AbstractNFFTs.size_out(p::React_NFFTPlan) = p.NOut
  
  
function AbstractNFFTs.deconvolve!(p::React_NFFTPlan{T,D}, f::Reactant.AnyTracedRArray, g::Reactant.AnyTracedRArray) where {D,T}
    tmp = f .* reshape(p.windowHatInvLUT, size(f))
    @allowscalar @inbounds gv = @view(g[p.deconvolveIdx])
    vtmp = vec(tmp)
    @allowscalar @inbounds g[p.deconvolveIdx] .= vtmp
    return nothing
end

function Base.:*(p::React_NFFTPlan{T}, f::Reactant.AnyTracedRArray; kargs...) where {T}
    fHat = similar(f, Complex{T}, size_out(p))
    mul!(fHat, p, f; kargs...)
    return fHat
end
  
  
  """  in-place NFFT on the GPU"""
  function LinearAlgebra.mul!(fHat::Reactant.AnyTracedRArray, p::React_NFFTPlan{T,D}, f::Reactant.AnyTracedRArray; 
                            verbose=false, timing::Union{Nothing,TimingStats} = nothing) where {T,D} 
      NFFT.consistencyCheck(p, f, fHat)
  
      fill!(p.tmpVec, zero(Complex{T}))
      t1 = @elapsed @inbounds deconvolve!(p, f, p.tmpVec)
      @info "Done Deconvolve"
      @info size(fHat)
      @info size(p.tmpVec)
      @info size(p.tmpVec[1:length(fHat)])
      fHat .= p.tmpVec[1:length(fHat)]
      @info "Done copyto!"
      return fHat
      # @show size(p.tmpVec)
      # tmpVec = fft(p.tmpVec)
      # @info "Done FFT"
      # p.tmpVec .= tmpVec
      # @info "Done copyto!"
      # t3 = @elapsed @inbounds convolve!(p, p.tmpVec, fHat)
      # @info "Done convolve!"
      # # if verbose
      # #     @info "Timing: deconv=$t1 fft=$t2 conv=$t3"
      # # end
      # # if timing != nothing
      # #   timing.conv = t3
      # #   timing.fft = t2
      # #   timing.deconv = t1
      # # end

      # @show fHat
  
      # return fHat
  end


function NFFT.nfft(k::AbstractMatrix, f::Reactant.AnyTracedRArray, args...; kwargs...)
    p = plan_nfft(typeof(f), k, size(f); kwargs...)
    return p*f
end
  

  function NFFT.initParams(k::AbstractMatrix{T}, 
                    N::NTuple{D,Int}, 
                    dims::Union{Integer,UnitRange{Int64}}=1:D; 
                    kargs...) where {D,T}
  # convert dims to a unit range
  dims_ = (typeof(dims) <: Integer) ? (dims:dims) : dims

  params = NFFTParams{T,D}(; kargs...)
  m, σ, reltol = accuracyParams(; kargs...)
  params.m = m
  params.σ = σ
  params.reltol = reltol

  # Taken from NFFT3
  m2K = [1, 3, 7, 9, 14, 17, 20, 23, 24]
  K = m2K[min(m+1,length(m2K))]
  params.LUTSize = 2^(K) * (m) # ensure that LUTSize is dividable by (m)

  if length(dims_) != size(k,1)
      throw(ArgumentError("Nodes x have dimension $(size(k,1)) != $(length(dims_))"))
  end

  doTrafo = ntuple(d->d  dims_, Val(D))

  Ñ = ntuple(d -> doTrafo[d] ?
                      (ceil(Int,params.σ*N[d])÷2)*2 : # ensure that n is an even integer
                        N[d], Val(D))

  params.σ = Ñ[dims_[1]] / N[dims_[1]]

  #params.blockSize = ntuple(d-> Ñ[d] , D) # just one block
  if haskey(kargs, :blockSize)
    params.blockSize = kargs[:blockSize]
  else
    params.blockSize = ntuple(d-> NFFT._blockSize(Ñ,d) , Val(D))
  end

  J = size(k, 2)

  # calculate output size
  NOut = Int[]
  Mtaken = false
  ntuple(Val(D)) do d
    if !doTrafo[d]
      return N[d]
    elseif !Mtaken
      return J
      Mtaken = true
    end
  end
  for d=1:D
    if !doTrafo[d]
      push!(NOut, N[d])
    elseif !Mtaken
      push!(NOut, J)
      Mtaken = true
    end
  end
  # Sort nodes in lexicographic way
  if params.sortNodes
      k .= sortslices(k, dims=2)
  end
  return params, N, Tuple(NOut), J, Ñ, dims_
end

function NFFT.precomputation(k::AbstractVecOrMat, N::NTuple{D,Int}, Ñ, params) where {D}

    m = params.m; σ = params.σ; window=params.window
    LUTSize = params.LUTSize; precompute = params.precompute
  
    win, win_hat = getWindow(window) # highly type instable. But what should be do
    J = size(k, 2)
  
    windowHatInvLUT_ = Vector{Vector{T}}(undef, D)
    precomputeWindowHatInvLUT(windowHatInvLUT_, win_hat, N, Ñ, m, σ, T)
  
    if params.storeDeconvolutionIdx
      windowHatInvLUT = Vector{Vector{T}}(undef, 1)
      windowHatInvLUT[1], deconvolveIdx = precompWindowHatInvLUT(params, N, Ñ, windowHatInvLUT_)
    else
      windowHatInvLUT = windowHatInvLUT_
      deconvolveIdx = Array{Int64,1}(undef, 0)
    end
  
    if precompute == LINEAR
      windowLinInterp = precomputeLinInterp(win, m, σ, LUTSize, T)
      windowPolyInterp = Matrix{T}(undef, 0, 0)
      B = sparse([],[],T[])
    elseif precompute == POLYNOMIAL
      windowLinInterp = Vector{T}(undef, 0)
      windowPolyInterp = precomputePolyInterp(win, m, σ, T)
      B = sparse([],[],T[])
    elseif precompute == FULL
      windowLinInterp = Vector{T}(undef, 0)
      windowPolyInterp = Matrix{T}(undef, 0, 0)
      B = precomputeB(win, k, N, Ñ, m, J, σ, LUTSize, T)
      #windowLinInterp = precomputeLinInterp(win, windowLinInterp, Ñ, m, σ, LUTSize, T) # These versions are for debugging
      #B = precomputeB(windowLinInterp, k, N, Ñ, m, J, σ, LUTSize, T)
    elseif precompute == TENSOR
      windowLinInterp = Vector{T}(undef, 0)
      windowPolyInterp = Matrix{T}(undef, 0, 0)
      B = sparse([],[],T[])
    else
      windowLinInterp = Vector{T}(undef, 0)
      windowPolyInterp = Matrix{T}(undef, 0, 0)
      B = sparse([],[],T[])
      error("precompute = $precompute not supported by NFFT.jl!")
    end
  
    return (windowLinInterp, windowPolyInterp, windowHatInvLUT, deconvolveIdx, B)
  end
  
f = randn(ComplexF64, 64, 64)
fr = Reactant.to_rarray(f)
k = rand(2, 64) .- 0.5

 @compile nfft(k, fr)  

Gives this error.

LLVM ERROR: Incompatible index and shape found while flattening index

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions