Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update rocFFT #640

Merged
merged 1 commit into from
May 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 27 additions & 12 deletions src/fft/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@
const ROCFFT_INVERSE = false

# TODO: Real to Complex full not possible atm
# For R2C -> cast array to Complex first

# K is flag for forward/inverse
mutable struct cROCFFTPlan{T,K,inplace,N} <: ROCFFTPlan{T,K,inplace}
mutable struct cROCFFTPlan{T, K, inplace, N} <: ROCFFTPlan{T, K, inplace}
handle::rocfft_plan
stream::HIPStream
workarea::ROCVector{Int8}
execution_info::rocfft_execution_info
sz::NTuple{N,Int} # Julia size of input array
osz::NTuple{N,Int} # Julia size of output array
sz::NTuple{N, Int} # Julia size of input array
osz::NTuple{N, Int} # Julia size of output array
xtype::rocfft_transform_type
region::Any
pinv::ScaledPlan # required by AbstractFFTs API
Expand All @@ -35,18 +38,20 @@
info = info_ref[]

# assign to the current stream
rocfft_execution_info_set_stream(info, AMDGPU.stream())
stream = AMDGPU.stream()
rocfft_execution_info_set_stream(info, stream)
if length(workarea) > 0
rocfft_execution_info_set_work_buffer(info, workarea, length(workarea))
end
p = new(handle, workarea, info, size(X), sizey, xtype, region)
p = new(handle, stream, workarea, info, size(X), sizey, xtype, region)
finalizer(unsafe_free!, p)
p
end
end

mutable struct rROCFFTPlan{T,K,inplace,N} <: ROCFFTPlan{T,K,inplace}
handle::rocfft_plan
stream::HIPStream
workarea::ROCVector{Int8}
execution_info::rocfft_execution_info
sz::NTuple{N,Int} # Julia size of input array
Expand All @@ -63,16 +68,26 @@
rocfft_execution_info_create(info_ref)
info = info_ref[]

rocfft_execution_info_set_stream(info, AMDGPU.stream())
stream = AMDGPU.stream()
rocfft_execution_info_set_stream(info, stream)
if length(workarea) > 0
rocfft_execution_info_set_work_buffer(info, workarea, length(workarea))
end
p = new(handle, workarea, info, size(X), sizey, xtype, region)
p = new(handle, stream, workarea, info, size(X), sizey, xtype, region)
finalizer(unsafe_free!, p)
p
end
end

function update_stream!(plan::ROCFFTPlan)
new_stream = AMDGPU.stream()
if plan.stream != new_stream
plan.stream = new_stream
rocfft_execution_info_set_stream(info, new_stream)

Check warning on line 86 in src/fft/fft.jl

View check run for this annotation

Codecov / codecov/patch

src/fft/fft.jl#L85-L86

Added lines #L85 - L86 were not covered by tests
end
return
end

const xtypenames = (
"complex forward", "complex inverse", "real forward", "real inverse")

Expand Down Expand Up @@ -140,8 +155,7 @@
xtype = rocfft_transform_type_complex_inverse
pp = get_plan(xtype, p.sz, T, inplace, p.region)
ScaledPlan(
cROCFFTPlan{T,ROCFFT_INVERSE,inplace,N}(
pp..., X, p.sz, xtype, p.region),
cROCFFTPlan{T,ROCFFT_INVERSE,inplace,N}(pp..., X, p.sz, xtype, p.region),
normalization(X, p.region))
end

Expand Down Expand Up @@ -198,9 +212,8 @@
end
end

# TODO update stream

function unsafe_execute!(plan::cROCFFTPlan{T,K,true,N}, X::ROCArray{T,N}) where {T,K,N}
update_stream!(plan)
rocfft_execute(plan, [pointer(X),], C_NULL, plan.execution_info)
end

Expand All @@ -209,6 +222,7 @@
) where {T,N,K}
X = copy(X) # since input array can also be modified
# TODO on 1.11 we need to manually cast `pointer(X)` to `Ptr{Cvoid}`.
update_stream!(plan)
rocfft_execute(plan, [pointer(X),], [pointer(Y),], plan.execution_info)
end

Expand All @@ -218,6 +232,7 @@
) where {T<:rocfftReals,N}
@assert plan.xtype == rocfft_transform_type_real_forward
Xcopy = copy(X)
update_stream!(plan)
rocfft_execute(plan, [pointer(Xcopy),], [pointer(Y),], plan.execution_info)
end

Expand All @@ -227,10 +242,10 @@
) where {T<:rocfftComplexes,N}
@assert plan.xtype == rocfft_transform_type_real_inverse
Xcopy = copy(X)
update_stream!(plan)
rocfft_execute(plan, [pointer(Xcopy),], [pointer(Y),], plan.execution_info)
end


function LinearAlgebra.mul!(y::ROCArray{Ty}, p::ROCFFTPlan{T,K,false}, x::ROCArray{T}) where {T,Ty,K}
assert_applicable(p, x, y)
unsafe_execute!(p, x, y)
Expand Down
86 changes: 63 additions & 23 deletions src/fft/librocfft.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using CEnum

mutable struct rocfft_plan_t end

const rocfft_plan = Ptr{rocfft_plan_t}
Expand All @@ -12,6 +10,14 @@ mutable struct rocfft_execution_info_t end

const rocfft_execution_info = Ptr{rocfft_execution_info_t}

mutable struct rocfft_field_t end

const rocfft_field = Ptr{rocfft_field_t}

mutable struct rocfft_brick_t end

const rocfft_brick = Ptr{rocfft_brick_t}

@cenum rocfft_status_e::UInt32 begin
rocfft_status_success = 0
rocfft_status_failure = 1
Expand All @@ -38,6 +44,7 @@ const rocfft_transform_type = rocfft_transform_type_e
@cenum rocfft_precision_e::UInt32 begin
rocfft_precision_single = 0
rocfft_precision_double = 1
rocfft_precision_half = 2
end

const rocfft_precision = rocfft_precision_e
Expand All @@ -60,100 +67,133 @@ end

const rocfft_array_type = rocfft_array_type_e

# no prototype is found for this function at rocfft.h:124:29, please use with caution
function rocfft_setup()
AMDGPU.prepare_state()
ccall((:rocfft_setup, librocfft), rocfft_status, ()) |> check
@check ccall((:rocfft_setup, librocfft), rocfft_status, ())
end

# no prototype is found for this function at rocfft.h:128:29, please use with caution
function rocfft_cleanup()
AMDGPU.prepare_state()
ccall((:rocfft_cleanup, librocfft), rocfft_status, ()) |> check
@check ccall((:rocfft_cleanup, librocfft), rocfft_status, ())
end

function rocfft_plan_create(plan, placement, transform_type, precision, dimensions, lengths, number_of_transforms, description)
AMDGPU.prepare_state()
ccall((:rocfft_plan_create, librocfft), rocfft_status, (Ptr{rocfft_plan}, rocfft_result_placement, rocfft_transform_type, rocfft_precision, Cint, Ptr{Cint}, Cint, rocfft_plan_description), plan, placement, transform_type, precision, dimensions, lengths, number_of_transforms, description) |> check
@check ccall((:rocfft_plan_create, librocfft), rocfft_status, (Ptr{rocfft_plan}, rocfft_result_placement, rocfft_transform_type, rocfft_precision, Csize_t, Ptr{Csize_t}, Csize_t, rocfft_plan_description), plan, placement, transform_type, precision, dimensions, lengths, number_of_transforms, description)
end

function rocfft_execute(plan, in_buffer, out_buffer, info)
AMDGPU.prepare_state()
ccall((:rocfft_execute, librocfft), rocfft_status, (rocfft_plan, Ptr{Ptr{Cvoid}}, Ptr{Ptr{Cvoid}}, rocfft_execution_info), plan, in_buffer, out_buffer, info) |> check
@check ccall((:rocfft_execute, librocfft), rocfft_status, (rocfft_plan, Ptr{Ptr{Cvoid}}, Ptr{Ptr{Cvoid}}, rocfft_execution_info), plan, in_buffer, out_buffer, info)
end

function rocfft_plan_destroy(plan)
AMDGPU.prepare_state()
ccall((:rocfft_plan_destroy, librocfft), rocfft_status, (rocfft_plan,), plan) |> check
@check ccall((:rocfft_plan_destroy, librocfft), rocfft_status, (rocfft_plan,), plan)
end

function rocfft_plan_description_set_scale_factor(description, scale_factor)
AMDGPU.prepare_state()
ccall((:rocfft_plan_description_set_scale_factor, librocfft), rocfft_status, (rocfft_plan_description, Cdouble), description, scale_factor) |> check
@check ccall((:rocfft_plan_description_set_scale_factor, librocfft), rocfft_status, (rocfft_plan_description, Cdouble), description, scale_factor)
end

function rocfft_plan_description_set_data_layout(description, in_array_type, out_array_type, in_offsets, out_offsets, in_strides_size, in_strides, in_distance, out_strides_size, out_strides, out_distance)
AMDGPU.prepare_state()
ccall((:rocfft_plan_description_set_data_layout, librocfft), rocfft_status, (rocfft_plan_description, rocfft_array_type, rocfft_array_type, Ptr{Cint}, Ptr{Cint}, Cint, Ptr{Cint}, Cint, Cint, Ptr{Cint}, Cint), description, in_array_type, out_array_type, in_offsets, out_offsets, in_strides_size, in_strides, in_distance, out_strides_size, out_strides, out_distance) |> check
@check ccall((:rocfft_plan_description_set_data_layout, librocfft), rocfft_status, (rocfft_plan_description, rocfft_array_type, rocfft_array_type, Ptr{Csize_t}, Ptr{Csize_t}, Csize_t, Ptr{Csize_t}, Csize_t, Csize_t, Ptr{Csize_t}, Csize_t), description, in_array_type, out_array_type, in_offsets, out_offsets, in_strides_size, in_strides, in_distance, out_strides_size, out_strides, out_distance)
end

function rocfft_field_create(field)
AMDGPU.prepare_state()
@check ccall((:rocfft_field_create, librocfft), rocfft_status, (Ptr{rocfft_field},), field)
end

function rocfft_field_destroy(field)
AMDGPU.prepare_state()
@check ccall((:rocfft_field_destroy, librocfft), rocfft_status, (rocfft_field,), field)
end

function rocfft_get_version_string(buf, len)
AMDGPU.prepare_state()
ccall((:rocfft_get_version_string, librocfft), rocfft_status, (Ptr{Cchar}, Cint), buf, len) |> check
@check ccall((:rocfft_get_version_string, librocfft), rocfft_status, (Ptr{Cchar}, Csize_t), buf, len)
end

function rocfft_brick_create(brick, field_lower, field_upper, brick_stride, dim, deviceID)
AMDGPU.prepare_state()
@check ccall((:rocfft_brick_create, librocfft), rocfft_status, (Ptr{rocfft_brick}, Ptr{Csize_t}, Ptr{Csize_t}, Ptr{Csize_t}, Csize_t, Cint), brick, field_lower, field_upper, brick_stride, dim, deviceID)
end

function rocfft_brick_destroy(brick)
AMDGPU.prepare_state()
@check ccall((:rocfft_brick_destroy, librocfft), rocfft_status, (rocfft_brick,), brick)
end

function rocfft_field_add_brick(field, brick)
AMDGPU.prepare_state()
@check ccall((:rocfft_field_add_brick, librocfft), rocfft_status, (rocfft_field, rocfft_brick), field, brick)
end

function rocfft_plan_description_add_infield(description, field)
AMDGPU.prepare_state()
@check ccall((:rocfft_plan_description_add_infield, librocfft), rocfft_status, (rocfft_plan_description, rocfft_field), description, field)
end

function rocfft_plan_description_add_outfield(description, field)
AMDGPU.prepare_state()
@check ccall((:rocfft_plan_description_add_outfield, librocfft), rocfft_status, (rocfft_plan_description, rocfft_field), description, field)
end

function rocfft_plan_get_work_buffer_size(plan, size_in_bytes)
AMDGPU.prepare_state()
ccall((:rocfft_plan_get_work_buffer_size, librocfft), rocfft_status, (rocfft_plan, Ptr{Cint}), plan, size_in_bytes) |> check
@check ccall((:rocfft_plan_get_work_buffer_size, librocfft), rocfft_status, (rocfft_plan, Ptr{Csize_t}), plan, size_in_bytes)
end

function rocfft_plan_get_print(plan)
AMDGPU.prepare_state()
ccall((:rocfft_plan_get_print, librocfft), rocfft_status, (rocfft_plan,), plan) |> check
@check ccall((:rocfft_plan_get_print, librocfft), rocfft_status, (rocfft_plan,), plan)
end

function rocfft_plan_description_create(description)
AMDGPU.prepare_state()
ccall((:rocfft_plan_description_create, librocfft), rocfft_status, (Ptr{rocfft_plan_description},), description) |> check
@check ccall((:rocfft_plan_description_create, librocfft), rocfft_status, (Ptr{rocfft_plan_description},), description)
end

function rocfft_plan_description_destroy(description)
AMDGPU.prepare_state()
ccall((:rocfft_plan_description_destroy, librocfft), rocfft_status, (rocfft_plan_description,), description) |> check
@check ccall((:rocfft_plan_description_destroy, librocfft), rocfft_status, (rocfft_plan_description,), description)
end

function rocfft_execution_info_create(info)
AMDGPU.prepare_state()
ccall((:rocfft_execution_info_create, librocfft), rocfft_status, (Ptr{rocfft_execution_info},), info) |> check
@check ccall((:rocfft_execution_info_create, librocfft), rocfft_status, (Ptr{rocfft_execution_info},), info)
end

function rocfft_execution_info_destroy(info)
AMDGPU.prepare_state()
ccall((:rocfft_execution_info_destroy, librocfft), rocfft_status, (rocfft_execution_info,), info) |> check
@check ccall((:rocfft_execution_info_destroy, librocfft), rocfft_status, (rocfft_execution_info,), info)
end

function rocfft_execution_info_set_work_buffer(info, work_buffer, size_in_bytes)
AMDGPU.prepare_state()
ccall((:rocfft_execution_info_set_work_buffer, librocfft), rocfft_status, (rocfft_execution_info, Ptr{Cvoid}, Cint), info, work_buffer, size_in_bytes) |> check
@check ccall((:rocfft_execution_info_set_work_buffer, librocfft), rocfft_status, (rocfft_execution_info, Ptr{Cvoid}, Csize_t), info, work_buffer, size_in_bytes)
end

function rocfft_execution_info_set_stream(info, stream)
AMDGPU.prepare_state()
ccall((:rocfft_execution_info_set_stream, librocfft), rocfft_status, (rocfft_execution_info, hipStream_t), info, stream) |> check
@check ccall((:rocfft_execution_info_set_stream, librocfft), rocfft_status, (rocfft_execution_info, Ptr{Cvoid}), info, stream)
end

function rocfft_execution_info_set_load_callback(info, cb_functions, cb_data, shared_mem_bytes)
AMDGPU.prepare_state()
ccall((:rocfft_execution_info_set_load_callback, librocfft), rocfft_status, (rocfft_execution_info, Ptr{Ptr{Cvoid}}, Ptr{Ptr{Cvoid}}, Cint), info, cb_functions, cb_data, shared_mem_bytes) |> check
@check ccall((:rocfft_execution_info_set_load_callback, librocfft), rocfft_status, (rocfft_execution_info, Ptr{Ptr{Cvoid}}, Ptr{Ptr{Cvoid}}, Csize_t), info, cb_functions, cb_data, shared_mem_bytes)
end

function rocfft_execution_info_set_store_callback(info, cb_functions, cb_data, shared_mem_bytes)
AMDGPU.prepare_state()
ccall((:rocfft_execution_info_set_store_callback, librocfft), rocfft_status, (rocfft_execution_info, Ptr{Ptr{Cvoid}}, Ptr{Ptr{Cvoid}}, Cint), info, cb_functions, cb_data, shared_mem_bytes) |> check
@check ccall((:rocfft_execution_info_set_store_callback, librocfft), rocfft_status, (rocfft_execution_info, Ptr{Ptr{Cvoid}}, Ptr{Ptr{Cvoid}}, Csize_t), info, cb_functions, cb_data, shared_mem_bytes)
end

const rocfft_version_major = 1

const rocfft_version_minor = 0

const rocfft_version_patch = 21
const rocfft_version_patch = 27
16 changes: 8 additions & 8 deletions src/fft/rocFFT.jl
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
module rocFFT
export ROCFFTError

import AbstractFFTs: complexfloat, realfloat
import AbstractFFTs: plan_fft, plan_fft!, plan_bfft, plan_bfft!
import AbstractFFTs: plan_rfft, plan_brfft, plan_inv, normalization
import AbstractFFTs: fft, bfft, ifft, rfft, Plan, ScaledPlan
using CEnum
using LinearAlgebra

# TODO
# @reexport using AbstractFFTs

using LinearAlgebra
import AbstractFFTs: complexfloat, realfloat
import AbstractFFTs: plan_fft, plan_fft!, plan_bfft, plan_bfft!
import AbstractFFTs: plan_rfft, plan_brfft, plan_inv, normalization
import AbstractFFTs: fft, bfft, ifft, rfft, Plan, ScaledPlan

import ..AMDGPU
import .AMDGPU: ROCArray, ROCVector, HandleCache, HIP, unsafe_free!, check
import .AMDGPU: ROCArray, ROCVector, HandleCache, HIP, unsafe_free!, check, @check
import AMDGPU: librocfft
import .HIP: hipStream_t, HIPContext, HIPStream

using CEnum

include("librocfft.jl")
include("error.jl")
include("util.jl")
include("wrappers.jl")
include("fft.jl")

version() = VersionNumber(
Expand Down
Loading