Skip to content

Commit 28c157e

Browse files
authored
Merge cb2d055 into 4bd83e9
2 parents 4bd83e9 + cb2d055 commit 28c157e

File tree

7 files changed

+470
-4
lines changed

7 files changed

+470
-4
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@ TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
1414

1515
[weakdeps]
1616
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
17+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
1718
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
1819
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
1920
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
2021
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2122

2223
[extensions]
2324
DynamicExpressionsBumperExt = "Bumper"
25+
DynamicExpressionsCUDAExt = "CUDA"
2426
DynamicExpressionsLoopVectorizationExt = "LoopVectorization"
2527
DynamicExpressionsOptimExt = "Optim"
2628
DynamicExpressionsSymbolicUtilsExt = "SymbolicUtils"
@@ -29,6 +31,7 @@ DynamicExpressionsZygoteExt = "Zygote"
2931
[compat]
3032
Aqua = "0.7"
3133
Bumper = "0.6"
34+
CUDA = "4, 5"
3235
Compat = "3.37, 4"
3336
Enzyme = "^0.11.12"
3437
LoopVectorization = "0.12"
@@ -44,6 +47,7 @@ julia = "1.6"
4447
[extras]
4548
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
4649
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
50+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
4751
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
4852
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
4953
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -57,4 +61,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5761
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
5862

5963
[targets]
60-
test = ["Test", "SafeTestsets", "Aqua", "Bumper", "Enzyme", "ForwardDiff", "LinearAlgebra", "LoopVectorization", "Optim", "SpecialFunctions", "StaticArrays", "SymbolicUtils", "Zygote"]
64+
test = ["Test", "SafeTestsets", "Aqua", "Bumper", "CUDA", "Enzyme", "ForwardDiff", "LinearAlgebra", "LoopVectorization", "Optim", "SpecialFunctions", "StaticArrays", "SymbolicUtils", "Zygote"]

ext/DynamicExpressionsCUDAExt.jl

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
module DynamicExpressionsCUDAExt
2+
3+
# TODO: Switch to KernelAbstractions.jl (once they hit v1.0)
4+
using CUDA: @cuda, CuArray, blockDim, blockIdx, threadIdx
5+
using DynamicExpressions: OperatorEnum, AbstractExpressionNode
6+
using DynamicExpressions.EvaluateEquationModule: get_nbin, get_nuna
7+
using DynamicExpressions.AsArrayModule: as_array
8+
9+
import DynamicExpressions.EvaluateEquationModule: eval_tree_array
10+
11+
# array type for exclusively testing purposes
12+
struct FakeCuArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
13+
a::A
14+
end
15+
Base.similar(x::FakeCuArray, dims::Integer...) = FakeCuArray(similar(x.a, dims...))
16+
Base.getindex(x::FakeCuArray, i::Int...) = getindex(x.a, i...)
17+
Base.setindex!(x::FakeCuArray, v, i::Int...) = setindex!(x.a, v, i...)
18+
Base.size(x::FakeCuArray) = size(x.a)
19+
20+
const MaybeCuArray{T,N} = Union{CuArray{T,N},FakeCuArray{T,N}}
21+
22+
to_device(a, ::CuArray) = CuArray(a)
23+
to_device(a, ::FakeCuArray) = FakeCuArray(a)
24+
25+
function eval_tree_array(
26+
tree::AbstractExpressionNode{T}, gcX::MaybeCuArray{T,2}, operators::OperatorEnum; kws...
27+
) where {T<:Number}
28+
(outs, is_good) = eval_tree_array((tree,), gcX, operators; kws...)
29+
return (only(outs), only(is_good))
30+
end
31+
32+
function eval_tree_array(
33+
trees::Union{Tuple{N,Vararg{N}},AbstractVector{N}},
34+
gcX::MaybeCuArray{T,2},
35+
operators::OperatorEnum;
36+
buffer=nothing,
37+
gpu_workspace=nothing,
38+
gpu_buffer=nothing,
39+
roots=nothing,
40+
num_nodes=nothing,
41+
num_launches=nothing,
42+
update_buffers::Val{_update_buffers}=Val(true),
43+
kws...,
44+
) where {T<:Number,N<:AbstractExpressionNode{T},_update_buffers}
45+
if _update_buffers
46+
(; val, roots, buffer, num_nodes, num_launches) = as_array(Int32, trees; buffer)
47+
end
48+
num_elem = size(gcX, 2)
49+
50+
## The following array is our "workspace" for
51+
## the GPU kernel, with size equal to the number of rows
52+
## in the input data by the number of nodes in the tree.
53+
## It has one extra row to store the constant values.
54+
gworkspace = if gpu_workspace === nothing
55+
similar(gcX, num_elem + 1, num_nodes)
56+
else
57+
gpu_workspace
58+
end
59+
gval = @view gworkspace[end, :]
60+
if _update_buffers
61+
copyto!(gval, val)
62+
end
63+
64+
## Index arrays (much faster to have `@view` here)
65+
gbuffer = if !_update_buffers
66+
gpu_buffer
67+
elseif gpu_buffer === nothing
68+
to_device(buffer, gcX)
69+
else
70+
copyto!(gpu_buffer, buffer)
71+
end
72+
gdegree = @view gbuffer[1, :]
73+
gfeature = @view gbuffer[2, :]
74+
gop = @view gbuffer[3, :]
75+
gexecution_order = @view gbuffer[4, :]
76+
gidx_self = @view gbuffer[5, :]
77+
gidx_l = @view gbuffer[6, :]
78+
gidx_r = @view gbuffer[7, :]
79+
gconstant = @view gbuffer[8, :]
80+
81+
num_threads = 256
82+
num_blocks = nextpow(2, ceil(Int, num_elem * num_nodes / num_threads))
83+
84+
#! format: off
85+
_launch_gpu_kernel!(
86+
num_threads, num_blocks, num_launches, gworkspace,
87+
# Thread info:
88+
num_elem, num_nodes, gexecution_order,
89+
# Input data and tree
90+
operators, gcX, gidx_self, gidx_l, gidx_r,
91+
gdegree, gconstant, gval, gfeature, gop,
92+
)
93+
#! format: on
94+
95+
out = (r -> @view(gworkspace[begin:(end - 1), r])).(roots)
96+
is_good = (_ -> true).(trees)
97+
98+
return (out, is_good)
99+
end
100+
101+
#! format: off
102+
function _launch_gpu_kernel!(
103+
num_threads, num_blocks, num_launches::Integer, buffer::AbstractArray{T,2},
104+
# Thread info:
105+
num_elem::Integer, num_nodes::Integer, execution_order::AbstractArray{I},
106+
# Input data and tree
107+
operators::OperatorEnum, cX::AbstractArray{T,2}, idx_self::AbstractArray, idx_l::AbstractArray, idx_r::AbstractArray,
108+
degree::AbstractArray, constant::AbstractArray, val::AbstractArray{T,1}, feature::AbstractArray, op::AbstractArray,
109+
) where {I,T}
110+
#! format: on
111+
nuna = get_nuna(typeof(operators))
112+
nbin = get_nbin(typeof(operators))
113+
(nuna > 10 || nbin > 10) &&
114+
error("Too many operators. Kernels are only compiled up to 10.")
115+
gpu_kernel! = create_gpu_kernel(operators, Val(nuna), Val(nbin))
116+
for launch in one(I):I(num_launches)
117+
#! format: off
118+
if buffer isa CuArray
119+
@cuda threads=num_threads blocks=num_blocks gpu_kernel!(
120+
buffer,
121+
launch, num_elem, num_nodes, execution_order,
122+
cX, idx_self, idx_l, idx_r,
123+
degree, constant, val, feature, op
124+
)
125+
else
126+
Threads.@threads for i in 1:(num_threads * num_blocks)
127+
gpu_kernel!(
128+
buffer,
129+
launch, num_elem, num_nodes, execution_order,
130+
cX, idx_self, idx_l, idx_r,
131+
degree, constant, val, feature, op,
132+
i
133+
)
134+
end
135+
end
136+
#! format: on
137+
end
138+
return nothing
139+
end
140+
141+
# Need to pre-compute the GPU kernels with an `@eval` for each number of operators
142+
# 1. We need to use an `@nif` over operators, as GPU kernels
143+
# can't index into arrays of operators.
144+
# 2. `@nif` is evaluated at parse time and needs to know the number of
145+
# ifs to generate at that time, so we can't simply use specialization.
146+
# 3. We can't use `@generated` because we can't create closures in those.
147+
for nuna in 0:10, nbin in 0:10
148+
@eval function create_gpu_kernel(operators::OperatorEnum, ::Val{$nuna}, ::Val{$nbin})
149+
#! format: off
150+
function (
151+
# Storage:
152+
buffer,
153+
# Thread info:
154+
launch::Integer, num_elem::Integer, num_nodes::Integer, execution_order::AbstractArray,
155+
# Input data and tree
156+
cX::AbstractArray, idx_self::AbstractArray, idx_l::AbstractArray, idx_r::AbstractArray,
157+
degree::AbstractArray, constant::AbstractArray, val::AbstractArray, feature::AbstractArray, op::AbstractArray,
158+
# Override for unittesting:
159+
i=nothing,
160+
)
161+
#! format: on
162+
i = i === nothing ? (blockIdx().x - 1) * blockDim().x + threadIdx().x : i
163+
if i > num_elem * num_nodes
164+
return nothing
165+
end
166+
167+
node = (i - 1) % num_nodes + 1
168+
elem = (i - node) ÷ num_nodes + 1
169+
170+
if execution_order[node] != launch
171+
return nothing
172+
end
173+
174+
cur_degree = degree[node]
175+
cur_idx = idx_self[node]
176+
if cur_degree == 0
177+
if constant[node] == 1
178+
cur_val = val[node]
179+
buffer[elem, cur_idx] = cur_val
180+
else
181+
cur_feature = feature[node]
182+
buffer[elem, cur_idx] = cX[cur_feature, elem]
183+
end
184+
else
185+
if cur_degree == 1 && $nuna > 0
186+
cur_op = op[node]
187+
l_idx = idx_l[node]
188+
Base.Cartesian.@nif(
189+
$nuna,
190+
i -> i == cur_op,
191+
i -> let op = operators.unaops[i]
192+
buffer[elem, cur_idx] = op(buffer[elem, l_idx])
193+
end
194+
)
195+
elseif $nbin > 0 # Note this check is to avoid type inference issues when binops is empty
196+
cur_op = op[node]
197+
l_idx = idx_l[node]
198+
r_idx = idx_r[node]
199+
Base.Cartesian.@nif(
200+
$nbin,
201+
i -> i == cur_op,
202+
i -> let op = operators.binops[i]
203+
buffer[elem, cur_idx] = op(buffer[elem, l_idx], buffer[elem, r_idx])
204+
end
205+
)
206+
end
207+
end
208+
return nothing
209+
end
210+
end
211+
end
212+
213+
end

src/AsArray.jl

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
module AsArrayModule
2+
3+
using ..EquationModule: AbstractExpressionNode, tree_mapreduce, count_nodes
4+
5+
function as_array(
6+
::Type{I},
7+
trees::Union{N,Tuple{N,Vararg{N}},AbstractVector{N}};
8+
buffer::Union{AbstractArray,Nothing}=nothing,
9+
) where {T,N<:AbstractExpressionNode{T},I}
10+
if trees isa N
11+
return as_array(I, (trees,); buffer=buffer)
12+
end
13+
each_num_nodes = (t -> count_nodes(t; break_sharing=Val(true))).(trees)
14+
num_nodes = sum(each_num_nodes)
15+
16+
# Want `roots` to be tuple if `trees` is tuple and similar for vector
17+
roots = cumsum(
18+
if each_num_nodes isa Tuple
19+
tuple(one(I), each_num_nodes[1:(end - 1)]...)
20+
else
21+
vcat(one(I), each_num_nodes[1:(end - 1)])
22+
end,
23+
)
24+
25+
val = Array{T}(undef, num_nodes)
26+
27+
## Views of the same matrix:
28+
buffer = buffer === nothing ? Array{I}(undef, 8, num_nodes) : buffer
29+
degree = @view buffer[1, :]
30+
feature = @view buffer[2, :]
31+
op = @view buffer[3, :]
32+
execution_order = @view buffer[4, :]
33+
idx_self = @view buffer[5, :]
34+
idx_l = @view buffer[6, :]
35+
idx_r = @view buffer[7, :]
36+
constant = @view buffer[8, :]
37+
38+
cursor = Ref(zero(I))
39+
num_launches = zero(I)
40+
for (root, tree) in zip(roots, trees)
41+
@assert root == cursor[] + 1
42+
tree_mapreduce(
43+
leaf -> begin
44+
self = (cursor[] += one(I))
45+
idx_self[self] = self
46+
degree[self] = 0
47+
execution_order[self] = one(I)
48+
constant[self] = leaf.constant
49+
if leaf.constant
50+
val[self] = leaf.val::T
51+
else
52+
feature[self] = leaf.feature
53+
end
54+
55+
(id=self, order=one(I))
56+
end,
57+
branch -> begin
58+
self = (cursor[] += one(I))
59+
idx_self[self] = self
60+
op[self] = branch.op
61+
degree[self] = branch.degree
62+
63+
(id=self, order=one(I)) # this order is unused
64+
end,
65+
((parent, children::Vararg{Any,C}) where {C}) -> begin
66+
idx_l[parent.id] = children[1].id
67+
if C == 2
68+
idx_r[parent.id] = children[2].id
69+
end
70+
parent_execution_order = if C == 1
71+
children[1].order + one(I)
72+
else
73+
max(children[1].order, children[2].order) + one(I)
74+
end
75+
execution_order[parent.id] = parent_execution_order
76+
77+
# Global number of launches equal to maximum execution order
78+
if parent_execution_order > num_launches
79+
num_launches = parent_execution_order
80+
end
81+
82+
(id=parent.id, order=parent_execution_order)
83+
end,
84+
tree;
85+
break_sharing=Val(true),
86+
)
87+
end
88+
89+
return (;
90+
degree,
91+
constant,
92+
val,
93+
feature,
94+
op,
95+
execution_order,
96+
num_launches,
97+
idx_self,
98+
idx_l,
99+
idx_r,
100+
roots,
101+
buffer,
102+
num_nodes,
103+
)
104+
end
105+
106+
end

src/DynamicExpressions.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ include("EvaluationHelpers.jl")
1212
include("SimplifyEquation.jl")
1313
include("OperatorEnumConstruction.jl")
1414
include("Random.jl")
15+
include("AsArray.jl")
1516

1617
import PackageExtensionCompat: @require_extensions
1718
import Reexport: @reexport
@@ -49,6 +50,7 @@ import .EquationModule: constructorof, preserve_sharing
4950
@reexport import .EvaluationHelpersModule
5051
@reexport import .ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
5152
@reexport import .RandomModule: NodeSampler
53+
@reexport import .AsArrayModule: as_array
5254

5355
function __init__()
5456
@require_extensions

0 commit comments

Comments
 (0)