|
| 1 | +module DynamicExpressionsCUDAExt |
| 2 | + |
| 3 | +# TODO: Switch to KernelAbstractions.jl (once they hit v1.0) |
| 4 | +using CUDA: @cuda, CuArray, blockDim, blockIdx, threadIdx |
| 5 | +using DynamicExpressions: OperatorEnum, AbstractExpressionNode |
| 6 | +using DynamicExpressions.EvaluateEquationModule: get_nbin, get_nuna |
| 7 | +using DynamicExpressions.AsArrayModule: as_array |
| 8 | + |
| 9 | +import DynamicExpressions.EvaluateEquationModule: eval_tree_array |
| 10 | + |
| 11 | +# array type for exclusively testing purposes |
| 12 | +struct FakeCuArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N} |
| 13 | + a::A |
| 14 | +end |
| 15 | +Base.similar(x::FakeCuArray, dims::Integer...) = FakeCuArray(similar(x.a, dims...)) |
| 16 | +Base.getindex(x::FakeCuArray, i::Int...) = getindex(x.a, i...) |
| 17 | +Base.setindex!(x::FakeCuArray, v, i::Int...) = setindex!(x.a, v, i...) |
| 18 | +Base.size(x::FakeCuArray) = size(x.a) |
| 19 | + |
| 20 | +const MaybeCuArray{T,N} = Union{CuArray{T,N},FakeCuArray{T,N}} |
| 21 | + |
| 22 | +to_device(a, ::CuArray) = CuArray(a) |
| 23 | +to_device(a, ::FakeCuArray) = FakeCuArray(a) |
| 24 | + |
| 25 | +function eval_tree_array( |
| 26 | + tree::AbstractExpressionNode{T}, gcX::MaybeCuArray{T,2}, operators::OperatorEnum; kws... |
| 27 | +) where {T<:Number} |
| 28 | + (outs, is_good) = eval_tree_array((tree,), gcX, operators; kws...) |
| 29 | + return (only(outs), only(is_good)) |
| 30 | +end |
| 31 | + |
| 32 | +function eval_tree_array( |
| 33 | + trees::Union{Tuple{N,Vararg{N}},AbstractVector{N}}, |
| 34 | + gcX::MaybeCuArray{T,2}, |
| 35 | + operators::OperatorEnum; |
| 36 | + buffer=nothing, |
| 37 | + gpu_workspace=nothing, |
| 38 | + gpu_buffer=nothing, |
| 39 | + roots=nothing, |
| 40 | + num_nodes=nothing, |
| 41 | + num_launches=nothing, |
| 42 | + update_buffers::Val{_update_buffers}=Val(true), |
| 43 | + kws..., |
| 44 | +) where {T<:Number,N<:AbstractExpressionNode{T},_update_buffers} |
| 45 | + if _update_buffers |
| 46 | + (; val, roots, buffer, num_nodes, num_launches) = as_array(Int32, trees; buffer) |
| 47 | + end |
| 48 | + num_elem = size(gcX, 2) |
| 49 | + |
| 50 | + ## The following array is our "workspace" for |
| 51 | + ## the GPU kernel, with size equal to the number of rows |
| 52 | + ## in the input data by the number of nodes in the tree. |
| 53 | + ## It has one extra row to store the constant values. |
| 54 | + gworkspace = if gpu_workspace === nothing |
| 55 | + similar(gcX, num_elem + 1, num_nodes) |
| 56 | + else |
| 57 | + gpu_workspace |
| 58 | + end |
| 59 | + gval = @view gworkspace[end, :] |
| 60 | + if _update_buffers |
| 61 | + copyto!(gval, val) |
| 62 | + end |
| 63 | + |
| 64 | + ## Index arrays (much faster to have `@view` here) |
| 65 | + gbuffer = if !_update_buffers |
| 66 | + gpu_buffer |
| 67 | + elseif gpu_buffer === nothing |
| 68 | + to_device(buffer, gcX) |
| 69 | + else |
| 70 | + copyto!(gpu_buffer, buffer) |
| 71 | + end |
| 72 | + gdegree = @view gbuffer[1, :] |
| 73 | + gfeature = @view gbuffer[2, :] |
| 74 | + gop = @view gbuffer[3, :] |
| 75 | + gexecution_order = @view gbuffer[4, :] |
| 76 | + gidx_self = @view gbuffer[5, :] |
| 77 | + gidx_l = @view gbuffer[6, :] |
| 78 | + gidx_r = @view gbuffer[7, :] |
| 79 | + gconstant = @view gbuffer[8, :] |
| 80 | + |
| 81 | + num_threads = 256 |
| 82 | + num_blocks = nextpow(2, ceil(Int, num_elem * num_nodes / num_threads)) |
| 83 | + |
| 84 | + #! format: off |
| 85 | + _launch_gpu_kernel!( |
| 86 | + num_threads, num_blocks, num_launches, gworkspace, |
| 87 | + # Thread info: |
| 88 | + num_elem, num_nodes, gexecution_order, |
| 89 | + # Input data and tree |
| 90 | + operators, gcX, gidx_self, gidx_l, gidx_r, |
| 91 | + gdegree, gconstant, gval, gfeature, gop, |
| 92 | + ) |
| 93 | + #! format: on |
| 94 | + |
| 95 | + out = (r -> @view(gworkspace[begin:(end - 1), r])).(roots) |
| 96 | + is_good = (_ -> true).(trees) |
| 97 | + |
| 98 | + return (out, is_good) |
| 99 | +end |
| 100 | + |
| 101 | +#! format: off |
| 102 | +function _launch_gpu_kernel!( |
| 103 | + num_threads, num_blocks, num_launches::Integer, buffer::AbstractArray{T,2}, |
| 104 | + # Thread info: |
| 105 | + num_elem::Integer, num_nodes::Integer, execution_order::AbstractArray{I}, |
| 106 | + # Input data and tree |
| 107 | + operators::OperatorEnum, cX::AbstractArray{T,2}, idx_self::AbstractArray, idx_l::AbstractArray, idx_r::AbstractArray, |
| 108 | + degree::AbstractArray, constant::AbstractArray, val::AbstractArray{T,1}, feature::AbstractArray, op::AbstractArray, |
| 109 | +) where {I,T} |
| 110 | + #! format: on |
| 111 | + nuna = get_nuna(typeof(operators)) |
| 112 | + nbin = get_nbin(typeof(operators)) |
| 113 | + (nuna > 10 || nbin > 10) && |
| 114 | + error("Too many operators. Kernels are only compiled up to 10.") |
| 115 | + gpu_kernel! = create_gpu_kernel(operators, Val(nuna), Val(nbin)) |
| 116 | + for launch in one(I):I(num_launches) |
| 117 | + #! format: off |
| 118 | + if buffer isa CuArray |
| 119 | + @cuda threads=num_threads blocks=num_blocks gpu_kernel!( |
| 120 | + buffer, |
| 121 | + launch, num_elem, num_nodes, execution_order, |
| 122 | + cX, idx_self, idx_l, idx_r, |
| 123 | + degree, constant, val, feature, op |
| 124 | + ) |
| 125 | + else |
| 126 | + Threads.@threads for i in 1:(num_threads * num_blocks) |
| 127 | + gpu_kernel!( |
| 128 | + buffer, |
| 129 | + launch, num_elem, num_nodes, execution_order, |
| 130 | + cX, idx_self, idx_l, idx_r, |
| 131 | + degree, constant, val, feature, op, |
| 132 | + i |
| 133 | + ) |
| 134 | + end |
| 135 | + end |
| 136 | + #! format: on |
| 137 | + end |
| 138 | + return nothing |
| 139 | +end |
| 140 | + |
| 141 | +# Need to pre-compute the GPU kernels with an `@eval` for each number of operators |
| 142 | +# 1. We need to use an `@nif` over operators, as GPU kernels |
| 143 | +# can't index into arrays of operators. |
| 144 | +# 2. `@nif` is evaluated at parse time and needs to know the number of |
| 145 | +# ifs to generate at that time, so we can't simply use specialization. |
| 146 | +# 3. We can't use `@generated` because we can't create closures in those. |
| 147 | +for nuna in 0:10, nbin in 0:10 |
| 148 | + @eval function create_gpu_kernel(operators::OperatorEnum, ::Val{$nuna}, ::Val{$nbin}) |
| 149 | + #! format: off |
| 150 | + function ( |
| 151 | + # Storage: |
| 152 | + buffer, |
| 153 | + # Thread info: |
| 154 | + launch::Integer, num_elem::Integer, num_nodes::Integer, execution_order::AbstractArray, |
| 155 | + # Input data and tree |
| 156 | + cX::AbstractArray, idx_self::AbstractArray, idx_l::AbstractArray, idx_r::AbstractArray, |
| 157 | + degree::AbstractArray, constant::AbstractArray, val::AbstractArray, feature::AbstractArray, op::AbstractArray, |
| 158 | + # Override for unittesting: |
| 159 | + i=nothing, |
| 160 | + ) |
| 161 | + #! format: on |
| 162 | + i = i === nothing ? (blockIdx().x - 1) * blockDim().x + threadIdx().x : i |
| 163 | + if i > num_elem * num_nodes |
| 164 | + return nothing |
| 165 | + end |
| 166 | + |
| 167 | + node = (i - 1) % num_nodes + 1 |
| 168 | + elem = (i - node) ÷ num_nodes + 1 |
| 169 | + |
| 170 | + if execution_order[node] != launch |
| 171 | + return nothing |
| 172 | + end |
| 173 | + |
| 174 | + cur_degree = degree[node] |
| 175 | + cur_idx = idx_self[node] |
| 176 | + if cur_degree == 0 |
| 177 | + if constant[node] == 1 |
| 178 | + cur_val = val[node] |
| 179 | + buffer[elem, cur_idx] = cur_val |
| 180 | + else |
| 181 | + cur_feature = feature[node] |
| 182 | + buffer[elem, cur_idx] = cX[cur_feature, elem] |
| 183 | + end |
| 184 | + else |
| 185 | + if cur_degree == 1 && $nuna > 0 |
| 186 | + cur_op = op[node] |
| 187 | + l_idx = idx_l[node] |
| 188 | + Base.Cartesian.@nif( |
| 189 | + $nuna, |
| 190 | + i -> i == cur_op, |
| 191 | + i -> let op = operators.unaops[i] |
| 192 | + buffer[elem, cur_idx] = op(buffer[elem, l_idx]) |
| 193 | + end |
| 194 | + ) |
| 195 | + elseif $nbin > 0 # Note this check is to avoid type inference issues when binops is empty |
| 196 | + cur_op = op[node] |
| 197 | + l_idx = idx_l[node] |
| 198 | + r_idx = idx_r[node] |
| 199 | + Base.Cartesian.@nif( |
| 200 | + $nbin, |
| 201 | + i -> i == cur_op, |
| 202 | + i -> let op = operators.binops[i] |
| 203 | + buffer[elem, cur_idx] = op(buffer[elem, l_idx], buffer[elem, r_idx]) |
| 204 | + end |
| 205 | + ) |
| 206 | + end |
| 207 | + end |
| 208 | + return nothing |
| 209 | + end |
| 210 | + end |
| 211 | +end |
| 212 | + |
| 213 | +end |
0 commit comments