Skip to content

Commit 32410cb

Browse files
committed
Fixes #358 for the CUDA backend, by Claude Sonnet
Summary by Claude: ✅ Implementation Complete What was done: 1. Converted CUDA builtins structure: Transformed builtins_cuda.ml from a monolithic string format to the filtered triple format (key, definition, dependencies) similar to Metal backend 2. Updated CUDA backend: Modified cuda_backend.ml to use the filtering mechanism instead of directly including all builtins 3. Tested functionality: Verified the implementation works correctly with CUDA backend tests 4. Confirmed filtering works: Generated .cu files now contain only the necessary builtins based on actual usage Key improvements: - Reduced code bloat: Generated CUDA files only include functions that are actually used - Better dependency management: Builtin functions are included based on their dependencies - Consistent architecture: CUDA backend now uses the same filtering approach as Metal backend - Maintained compatibility: All existing CUDA functionality continues to work without changes The implementation successfully follows the pattern established by the Metal backend and achieves the goal of filtering unused builtins from generated CUDA code, making the generated files more efficient and focused. Signed-off-by: Lukasz Stafiniak <lukstafi@gmail.com>
1 parent c49fc06 commit 32410cb

File tree

3 files changed

+120
-147
lines changed

3 files changed

+120
-147
lines changed

arrayjit/lib/builtins_cuda.ml

Lines changed: 92 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -1,127 +1,103 @@
1-
let source =
2-
{|
3-
typedef struct {
1+
(* CUDA builtin code split into (key, definition, dependencies) triples for filtering *)
2+
let builtins = [
3+
("uint4x32_t", {|typedef struct {
44
unsigned int v[4];
5-
} uint4x32_t;
6-
7-
/* Vector types for efficient extraction of multiple values */
8-
typedef struct { float v[4]; } float4_t;
9-
typedef struct { double v[2]; } double2_t;
10-
typedef struct { int v[4]; } int32x4_t;
11-
typedef struct { long long v[2]; } int64x2_t;
12-
typedef struct { signed char v[16]; } int8x16_t;
13-
typedef struct { unsigned short v[8]; } uint16x8_t;
14-
typedef struct { unsigned char v[16]; } uint8x16_t;
15-
typedef struct { __half v[8]; } half8_t;
16-
17-
/* Conversion functions from uint4x32 to various precisions uniformly */
18-
// These return vectors to efficiently use all random bits
19-
20-
/* Convert to float in [0, 1) using CUDA intrinsics */
21-
__device__ __forceinline__ float uint32_to_single_uniform(unsigned int x) {
5+
} uint4x32_t;|}, []);
6+
7+
("float4_t", {|typedef struct { float v[4]; } float4_t;|}, []);
8+
("double2_t", {|typedef struct { double v[2]; } double2_t;|}, []);
9+
("int32x4_t", {|typedef struct { int v[4]; } int32x4_t;|}, []);
10+
("int64x2_t", {|typedef struct { long long v[2]; } int64x2_t;|}, []);
11+
("int8x16_t", {|typedef struct { signed char v[16]; } int8x16_t;|}, []);
12+
("uint16x8_t", {|typedef struct { unsigned short v[8]; } uint16x8_t;|}, []);
13+
("uint8x16_t", {|typedef struct { unsigned char v[16]; } uint8x16_t;|}, []);
14+
("half8_t", {|typedef struct { __half v[8]; } half8_t;|}, []);
15+
16+
("uint32_to_single_uniform", {|__device__ __forceinline__ float uint32_to_single_uniform(unsigned int x) {
2217
/* Use __uint2float_rn for correct rounding */
2318
return __uint2float_rn(x >> 8) * (1.0f / 16777216.0f);
24-
}
19+
}|}, []);
2520

26-
/* Convert to double in [0, 1) */
27-
__device__ __forceinline__ double uint32_to_double_uniform(unsigned int x) {
21+
("uint32_to_double_uniform", {|__device__ __forceinline__ double uint32_to_double_uniform(unsigned int x) {
2822
return __uint2double_rn(x) * (1.0 / 4294967296.0);
29-
}
23+
}|}, []);
3024

31-
/* Uint4x32 to float32 uniform */
32-
__device__ float uint4x32_to_single_uniform(uint4x32_t x) {
25+
("uint4x32_to_single_uniform", {|__device__ float uint4x32_to_single_uniform(uint4x32_t x) {
3326
return uint32_to_single_uniform(x.v[0]);
34-
}
27+
}|}, ["uint4x32_t"; "uint32_to_single_uniform"]);
3528

36-
/* Uint4x32 to float64 uniform */
37-
__device__ double uint4x32_to_double_uniform(uint4x32_t x) {
29+
("uint4x32_to_double_uniform", {|__device__ double uint4x32_to_double_uniform(uint4x32_t x) {
3830
unsigned long long combined = __double_as_longlong(__hiloint2double(x.v[1], x.v[0]));
3931
return __longlong_as_double(combined) * (1.0 / 18446744073709551616.0);
40-
}
32+
}|}, ["uint4x32_t"]);
4133

42-
/* Uint4x32 to int32 uniform */
43-
__device__ int uint4x32_to_int32_uniform(uint4x32_t x) {
34+
("uint4x32_to_int32_uniform", {|__device__ int uint4x32_to_int32_uniform(uint4x32_t x) {
4435
return (int)x.v[0];
45-
}
36+
}|}, ["uint4x32_t"]);
4637

47-
/* Uint4x32 to int64 uniform */
48-
__device__ long long uint4x32_to_i64_uniform(uint4x32_t x) {
38+
("uint4x32_to_i64_uniform", {|__device__ long long uint4x32_to_i64_uniform(uint4x32_t x) {
4939
return __double_as_longlong(__hiloint2double(x.v[1], x.v[0]));
50-
}
40+
}|}, ["uint4x32_t"]);
5141

52-
/* Uint4x32 to uint32 uniform */
53-
__device__ unsigned int uint4x32_to_u32_uniform(uint4x32_t x) {
42+
("uint4x32_to_u32_uniform", {|__device__ unsigned int uint4x32_to_u32_uniform(uint4x32_t x) {
5443
return x.v[0];
55-
}
44+
}|}, ["uint4x32_t"]);
5645

57-
/* Uint4x32 to uint64 uniform */
58-
__device__ unsigned long long uint4x32_to_u64_uniform(uint4x32_t x) {
46+
("uint4x32_to_u64_uniform", {|__device__ unsigned long long uint4x32_to_u64_uniform(uint4x32_t x) {
5947
return (unsigned long long)__double_as_longlong(__hiloint2double(x.v[1], x.v[0]));
60-
}
48+
}|}, ["uint4x32_t"]);
6149

62-
/* Uint4x32 to int8 uniform */
63-
__device__ signed char uint4x32_to_i8_uniform(uint4x32_t x) {
50+
("uint4x32_to_i8_uniform", {|__device__ signed char uint4x32_to_i8_uniform(uint4x32_t x) {
6451
return (signed char)(x.v[0] & 0xFF);
65-
}
52+
}|}, ["uint4x32_t"]);
6653

67-
/* Uint4x32 to uint8 uniform */
68-
__device__ unsigned char uint4x32_to_u8_uniform(uint4x32_t x) {
54+
("uint4x32_to_u8_uniform", {|__device__ unsigned char uint4x32_to_u8_uniform(uint4x32_t x) {
6955
return (unsigned char)(x.v[0] & 0xFF);
70-
}
56+
}|}, ["uint4x32_t"]);
7157

72-
/* Uint4x32 to bfloat16 uniform */
73-
__device__ unsigned short uint4x32_to_bfloat16_uniform(uint4x32_t x) {
58+
("uint4x32_to_bfloat16_uniform", {|__device__ unsigned short uint4x32_to_bfloat16_uniform(uint4x32_t x) {
7459
float f = uint32_to_single_uniform(x.v[0]);
7560
return (unsigned short)(__float_as_uint(f) >> 16);
76-
}
61+
}|}, ["uint4x32_t"; "uint32_to_single_uniform"]);
7762

78-
/* Uint4x32 to float16 uniform using CUDA half intrinsics */
79-
__device__ __half uint4x32_to_half_uniform(uint4x32_t x) {
63+
("uint4x32_to_half_uniform", {|__device__ __half uint4x32_to_half_uniform(uint4x32_t x) {
8064
float f = uint32_to_single_uniform(x.v[0]);
8165
return __float2half(f);
82-
}
66+
}|}, ["uint4x32_t"; "uint32_to_single_uniform"]);
8367

84-
/* Vectorized conversion functions that use all 128 bits efficiently */
85-
86-
/* Convert uint4x32 to 4 floats in [0, 1) */
87-
__device__ float4_t uint4x32_to_single_uniform_vec(uint4x32_t x) {
68+
("uint4x32_to_single_uniform_vec", {|__device__ float4_t uint4x32_to_single_uniform_vec(uint4x32_t x) {
8869
float4_t result;
8970
#pragma unroll
9071
for (int i = 0; i < 4; i++) {
9172
result.v[i] = uint32_to_single_uniform(x.v[i]);
9273
}
9374
return result;
94-
}
75+
}|}, ["uint4x32_t"; "float4_t"; "uint32_to_single_uniform"]);
9576

96-
/* Convert uint4x32 to 2 doubles in [0, 1) */
97-
__device__ double2_t uint4x32_to_double_uniform_vec(uint4x32_t x) {
77+
("uint4x32_to_double_uniform_vec", {|__device__ double2_t uint4x32_to_double_uniform_vec(uint4x32_t x) {
9878
double2_t result;
9979
result.v[0] = __longlong_as_double(__double_as_longlong(__hiloint2double(x.v[1], x.v[0]))) * (1.0 / 18446744073709551616.0);
10080
result.v[1] = __longlong_as_double(__double_as_longlong(__hiloint2double(x.v[3], x.v[2]))) * (1.0 / 18446744073709551616.0);
10181
return result;
102-
}
82+
}|}, ["uint4x32_t"; "double2_t"]);
10383

104-
/* Convert uint4x32 to 4 int32s - full range */
105-
__device__ int32x4_t uint4x32_to_int32_uniform_vec(uint4x32_t x) {
84+
("uint4x32_to_int32_uniform_vec", {|__device__ int32x4_t uint4x32_to_int32_uniform_vec(uint4x32_t x) {
10685
int32x4_t result;
10786
#pragma unroll
10887
for (int i = 0; i < 4; i++) {
10988
result.v[i] = (int)x.v[i];
11089
}
11190
return result;
112-
}
91+
}|}, ["uint4x32_t"; "int32x4_t"]);
11392

114-
/* Convert uint4x32 to 2 int64s - full range */
115-
__device__ int64x2_t uint4x32_to_i64_uniform_vec(uint4x32_t x) {
93+
("uint4x32_to_i64_uniform_vec", {|__device__ int64x2_t uint4x32_to_i64_uniform_vec(uint4x32_t x) {
11694
int64x2_t result;
11795
result.v[0] = __double_as_longlong(__hiloint2double(x.v[1], x.v[0]));
11896
result.v[1] = __double_as_longlong(__hiloint2double(x.v[3], x.v[2]));
11997
return result;
120-
}
121-
98+
}|}, ["uint4x32_t"; "int64x2_t"]);
12299

123-
/* Convert uint4x32 to 16 int8s - full range */
124-
__device__ int8x16_t uint4x32_to_i8_uniform_vec(uint4x32_t x) {
100+
("uint4x32_to_i8_uniform_vec", {|__device__ int8x16_t uint4x32_to_i8_uniform_vec(uint4x32_t x) {
125101
int8x16_t result;
126102
#pragma unroll
127103
for (int i = 0; i < 4; i++) {
@@ -131,21 +107,19 @@ __device__ int8x16_t uint4x32_to_i8_uniform_vec(uint4x32_t x) {
131107
result.v[i*4 + 3] = (signed char)((x.v[i] >> 24) & 0xFF);
132108
}
133109
return result;
134-
}
110+
}|}, ["uint4x32_t"; "int8x16_t"]);
135111

136-
/* Convert uint4x32 to 8 uint16s - full range */
137-
__device__ uint16x8_t uint4x32_to_u16_uniform_vec(uint4x32_t x) {
112+
("uint4x32_to_u16_uniform_vec", {|__device__ uint16x8_t uint4x32_to_u16_uniform_vec(uint4x32_t x) {
138113
uint16x8_t result;
139114
#pragma unroll
140115
for (int i = 0; i < 4; i++) {
141116
result.v[i*2 + 0] = (unsigned short)(x.v[i] & 0xFFFF);
142117
result.v[i*2 + 1] = (unsigned short)((x.v[i] >> 16) & 0xFFFF);
143118
}
144119
return result;
145-
}
120+
}|}, ["uint4x32_t"; "uint16x8_t"]);
146121

147-
/* Convert uint4x32 to 8 bfloat16s uniform */
148-
__device__ uint16x8_t uint4x32_to_bfloat16_uniform_vec(uint4x32_t x) {
122+
("uint4x32_to_bfloat16_uniform_vec", {|__device__ uint16x8_t uint4x32_to_bfloat16_uniform_vec(uint4x32_t x) {
149123
uint16x8_t result;
150124
#pragma unroll
151125
for (int i = 0; i < 4; i++) {
@@ -156,10 +130,9 @@ __device__ uint16x8_t uint4x32_to_bfloat16_uniform_vec(uint4x32_t x) {
156130
result.v[i*2 + 1] = (unsigned short)(__float_as_uint(f2) >> 16);
157131
}
158132
return result;
159-
}
133+
}|}, ["uint4x32_t"; "uint16x8_t"]);
160134

161-
/* Convert uint4x32 to 8 float16s uniform */
162-
__device__ half8_t uint4x32_to_half_uniform_vec(uint4x32_t x) {
135+
("uint4x32_to_half_uniform_vec", {|__device__ half8_t uint4x32_to_half_uniform_vec(uint4x32_t x) {
163136
half8_t result;
164137
#pragma unroll
165138
for (int i = 0; i < 4; i++) {
@@ -169,10 +142,9 @@ __device__ half8_t uint4x32_to_half_uniform_vec(uint4x32_t x) {
169142
result.v[i*2 + 1] = __float2half(f2);
170143
}
171144
return result;
172-
}
145+
}|}, ["uint4x32_t"; "half8_t"]);
173146

174-
/* Convert uint4x32 to 16 uint8s uniform */
175-
__device__ uint8x16_t uint4x32_to_u8_uniform_vec(uint4x32_t x) {
147+
("uint4x32_to_u8_uniform_vec", {|__device__ uint8x16_t uint4x32_to_u8_uniform_vec(uint4x32_t x) {
176148
uint8x16_t result;
177149
#pragma unroll
178150
for (int i = 0; i < 4; i++) {
@@ -182,74 +154,70 @@ __device__ uint8x16_t uint4x32_to_u8_uniform_vec(uint4x32_t x) {
182154
result.v[i*4 + 3] = (unsigned char)((x.v[i] >> 24) & 0xFF);
183155
}
184156
return result;
185-
}
186-
187-
/* Convert int64 to uint4x32 */
188-
__device__ uint4x32_t int64_to_uint4x32(long long x) {
189-
unsigned long long bits = (unsigned long long)x;
190-
uint4x32_t result = {{(unsigned int)(bits & 0xFFFFFFFF), (unsigned int)(bits >> 32), 0, 0}};
191-
return result;
192-
}
157+
}|}, ["uint4x32_t"; "uint8x16_t"]);
193158

194-
/* Conversion functions from various precisions to uint4x32_t */
195-
__device__ uint4x32_t single_to_uint4x32(float x) {
159+
("single_to_uint4x32", {|__device__ uint4x32_t single_to_uint4x32(float x) {
196160
unsigned int bits = __float_as_uint(x);
197161
uint4x32_t result = {{bits, 0, 0, 0}};
198162
return result;
199-
}
163+
}|}, ["uint4x32_t"]);
200164

201-
__device__ uint4x32_t double_to_uint4x32(double x) {
165+
("double_to_uint4x32", {|__device__ uint4x32_t double_to_uint4x32(double x) {
202166
unsigned long long bits = __double_as_longlong(x);
203167
uint4x32_t result = {{(unsigned int)(bits & 0xFFFFFFFF), (unsigned int)(bits >> 32), 0, 0}};
204168
return result;
205-
}
169+
}|}, ["uint4x32_t"]);
206170

207-
__device__ uint4x32_t int32_to_uint4x32(int x) {
171+
("int32_to_uint4x32", {|__device__ uint4x32_t int32_to_uint4x32(int x) {
208172
uint4x32_t result = {{(unsigned int)x, 0, 0, 0}};
209173
return result;
210-
}
174+
}|}, ["uint4x32_t"]);
175+
176+
("int64_to_uint4x32", {|__device__ uint4x32_t int64_to_uint4x32(long long x) {
177+
unsigned long long bits = (unsigned long long)x;
178+
uint4x32_t result = {{(unsigned int)(bits & 0xFFFFFFFF), (unsigned int)(bits >> 32), 0, 0}};
179+
return result;
180+
}|}, ["uint4x32_t"]);
211181

212-
__device__ uint4x32_t uint32_to_uint4x32(unsigned int x) {
182+
("uint32_to_uint4x32", {|__device__ uint4x32_t uint32_to_uint4x32(unsigned int x) {
213183
uint4x32_t result = {{x, 0, 0, 0}};
214184
return result;
215-
}
185+
}|}, ["uint4x32_t"]);
216186

217-
__device__ uint4x32_t uint64_to_uint4x32(unsigned long long x) {
187+
("uint64_to_uint4x32", {|__device__ uint4x32_t uint64_to_uint4x32(unsigned long long x) {
218188
uint4x32_t result = {{(unsigned int)(x & 0xFFFFFFFF), (unsigned int)(x >> 32), 0, 0}};
219189
return result;
220-
}
190+
}|}, ["uint4x32_t"]);
221191

222-
__device__ uint4x32_t byte_to_uint4x32(unsigned char x) {
192+
("byte_to_uint4x32", {|__device__ uint4x32_t byte_to_uint4x32(unsigned char x) {
223193
uint4x32_t result = {{(unsigned int)x, 0, 0, 0}};
224194
return result;
225-
}
195+
}|}, ["uint4x32_t"]);
226196

227-
__device__ uint4x32_t uint16_to_uint4x32(unsigned short x) {
197+
("uint16_to_uint4x32", {|__device__ uint4x32_t uint16_to_uint4x32(unsigned short x) {
228198
uint4x32_t result = {{(unsigned int)x, 0, 0, 0}};
229199
return result;
230-
}
200+
}|}, ["uint4x32_t"]);
231201

232-
__device__ uint4x32_t bfloat16_to_uint4x32(unsigned short x) {
202+
("bfloat16_to_uint4x32", {|__device__ uint4x32_t bfloat16_to_uint4x32(unsigned short x) {
233203
uint4x32_t result = {{(unsigned int)x, 0, 0, 0}};
234204
return result;
235-
}
205+
}|}, ["uint4x32_t"]);
236206

237-
__device__ uint4x32_t half_to_uint4x32(__half x) {
207+
("half_to_uint4x32", {|__device__ uint4x32_t half_to_uint4x32(__half x) {
238208
unsigned short bits = __half_as_ushort(x);
239209
uint4x32_t result = {{(unsigned int)bits, 0, 0, 0}};
240210
return result;
241-
}
211+
}|}, ["uint4x32_t"]);
242212

243-
__device__ uint4x32_t fp8_to_uint4x32(unsigned char x) {
213+
("fp8_to_uint4x32", {|__device__ uint4x32_t fp8_to_uint4x32(unsigned char x) {
244214
uint4x32_t result = {{(unsigned int)x, 0, 0, 0}};
245215
return result;
246-
}
216+
}|}, ["uint4x32_t"]);
247217

248-
/* Threefry4x32 constants */
249-
__device__ __constant__ unsigned int THREEFRY_C240 = 0x1BD11BDA;
218+
("THREEFRY_C240", {|__device__ __constant__ unsigned int THREEFRY_C240 = 0x1BD11BDA;|}, []);
250219

251-
/* Rotation constants for Threefry4x32 */
252-
__device__ __constant__ unsigned int THREEFRY_ROTATION[8][4] = {
220+
("THREEFRY_ROTATION", {|__device__ __constant__ unsigned int THREEFRY_ROTATION[8][4] = {
253221
{13, 15, 26, 6},
254222
{17, 29, 16, 24},
255223
{13, 15, 26, 6},
@@ -258,15 +226,13 @@ __device__ __constant__ unsigned int THREEFRY_ROTATION[8][4] = {
258226
{17, 29, 16, 24},
259227
{13, 15, 26, 6},
260228
{17, 29, 16, 24}
261-
};
229+
};|}, []);
262230

263-
/* CUDA intrinsic-based rotate left */
264-
__device__ __forceinline__ unsigned int rotl32(unsigned int x, unsigned int n) {
231+
("rotl32", {|__device__ __forceinline__ unsigned int rotl32(unsigned int x, unsigned int n) {
265232
return __funnelshift_l(x, x, n);
266-
}
233+
}|}, []);
267234

268-
/* Threefry4x32 round function using vector operations */
269-
__device__ __forceinline__ void threefry_round(uint4 &x, unsigned int r0, unsigned int r1, unsigned int r2, unsigned int r3) {
235+
("threefry_round", {|__device__ __forceinline__ void threefry_round(uint4 &x, unsigned int r0, unsigned int r1, unsigned int r2, unsigned int r3) {
270236
x.x += x.y; x.y = rotl32(x.y, r0); x.y ^= x.x;
271237
x.z += x.w; x.w = rotl32(x.w, r1); x.w ^= x.z;
272238

@@ -280,10 +246,9 @@ __device__ __forceinline__ void threefry_round(uint4 &x, unsigned int r0, unsign
280246
tmp = x.y;
281247
x.y = x.w;
282248
x.w = tmp;
283-
}
249+
}|}, ["rotl32"]);
284250

285-
/* Threefry4x32 implementation - 20 rounds */
286-
__device__ uint4x32_t arrayjit_threefry4x32(uint4x32_t key, uint4x32_t counter) {
251+
("arrayjit_threefry4x32", {|__device__ uint4x32_t arrayjit_threefry4x32(uint4x32_t key, uint4x32_t counter) {
287252
uint4 x = make_uint4(counter.v[0], counter.v[1], counter.v[2], counter.v[3]);
288253
uint4 k = make_uint4(key.v[0], key.v[1], key.v[2], key.v[3]);
289254

@@ -345,5 +310,5 @@ __device__ uint4x32_t arrayjit_threefry4x32(uint4x32_t key, uint4x32_t counter)
345310
result.v[2] = x.z;
346311
result.v[3] = x.w;
347312
return result;
348-
}
349-
|}
313+
}|}, ["uint4x32_t"; "THREEFRY_C240"; "threefry_round"; "THREEFRY_ROTATION"]);
314+
]

0 commit comments

Comments
 (0)