Skip to content

Commit c49fc06

Browse files
committed
Fixes #358 for the Metal backend, by Claude Sonnet
Summary by Claude: I have successfully converted the Metal backend to use the builtin filtering mechanism: 1. Split Metal builtins: Converted builtins_metal.ml from a monolithic string to structured triples with dependencies 2. Updated Metal backend: Modified compile and compile_batch functions to use Syntax.filter_and_prepend_builtins 3. Fixed dune dependency: Added missing (env_var OCANNL_BACKEND) dependency to the test diff rule 4. Verified functionality: The test passes, showing only necessary code is included in generated .metal files The Metal backend now joins the CC backend in having efficient, filtered builtin inclusion, reducing compilation time and generated code size by only including the builtin functions that are actually used by the specific kernels being compiled.
1 parent 24fddb4 commit c49fc06

File tree

4 files changed

+132
-150
lines changed

4 files changed

+132
-150
lines changed

arrayjit/lib/builtins_metal.ml

Lines changed: 104 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,24 @@
1-
let source =
2-
{|
3-
#include <metal_stdlib>
4-
using namespace metal;
5-
6-
/* Threefry4x32 constants */
7-
constant uint32_t THREEFRY_C240 = 0x1BD11BDA;
8-
9-
/* Rotation constants for Threefry4x32 */
10-
constant uint THREEFRY_ROTATION_0_0 = 13;
11-
constant uint THREEFRY_ROTATION_0_1 = 15;
12-
constant uint THREEFRY_ROTATION_0_2 = 26;
13-
constant uint THREEFRY_ROTATION_0_3 = 6;
14-
constant uint THREEFRY_ROTATION_1_0 = 17;
15-
constant uint THREEFRY_ROTATION_1_1 = 29;
16-
constant uint THREEFRY_ROTATION_1_2 = 16;
17-
constant uint THREEFRY_ROTATION_1_3 = 24;
18-
19-
/* Metal rotate left using built-in rotate function */
20-
inline uint32_t rotl32(uint32_t x, uint n) {
1+
(* Metal builtin code split into (key, definition, dependencies) triples for filtering *)
2+
let builtins = [
3+
("METAL_HEADERS", {|#include <metal_stdlib>
4+
using namespace metal;|}, []);
5+
6+
("THREEFRY_C240", {|constant uint32_t THREEFRY_C240 = 0x1BD11BDA;|}, []);
7+
8+
("THREEFRY_ROTATION_0_0", {|constant uint THREEFRY_ROTATION_0_0 = 13;|}, []);
9+
("THREEFRY_ROTATION_0_1", {|constant uint THREEFRY_ROTATION_0_1 = 15;|}, []);
10+
("THREEFRY_ROTATION_0_2", {|constant uint THREEFRY_ROTATION_0_2 = 26;|}, []);
11+
("THREEFRY_ROTATION_0_3", {|constant uint THREEFRY_ROTATION_0_3 = 6;|}, []);
12+
("THREEFRY_ROTATION_1_0", {|constant uint THREEFRY_ROTATION_1_0 = 17;|}, []);
13+
("THREEFRY_ROTATION_1_1", {|constant uint THREEFRY_ROTATION_1_1 = 29;|}, []);
14+
("THREEFRY_ROTATION_1_2", {|constant uint THREEFRY_ROTATION_1_2 = 16;|}, []);
15+
("THREEFRY_ROTATION_1_3", {|constant uint THREEFRY_ROTATION_1_3 = 24;|}, []);
16+
17+
("rotl32", {|inline uint32_t rotl32(uint32_t x, uint n) {
2118
return rotate(x, n);
22-
}
19+
}|}, []);
2320

24-
/* Threefry4x32 round function using SIMD operations */
25-
inline void threefry_round(thread uint4 &x, uint r0, uint r1, uint r2, uint r3) {
21+
("threefry_round", {|inline void threefry_round(thread uint4 &x, uint r0, uint r1, uint r2, uint r3) {
2622
x.x += x.y; x.y = rotl32(x.y, r0); x.y ^= x.x;
2723
x.z += x.w; x.w = rotl32(x.w, r1); x.w ^= x.z;
2824

@@ -36,10 +32,9 @@ inline void threefry_round(thread uint4 &x, uint r0, uint r1, uint r2, uint r3)
3632
tmp = x.y;
3733
x.y = x.w;
3834
x.w = tmp;
39-
}
35+
}|}, ["rotl32"]);
4036

41-
/* Threefry4x32 implementation - 20 rounds */
42-
uint4 arrayjit_threefry4x32(uint4 key, uint4 counter) {
37+
("arrayjit_threefry4x32", {|uint4 arrayjit_threefry4x32(uint4 key, uint4 counter) {
4338
uint4 x = counter;
4439
uint4 k = key;
4540

@@ -124,138 +119,115 @@ uint4 arrayjit_threefry4x32(uint4 key, uint4 counter) {
124119
x.w += 5;
125120

126121
return x;
127-
}
128-
129-
/* Vector types for efficient extraction of multiple values */
130-
struct float4_t { float4 v; };
131-
struct float2_t { float2 v; }; /* Using float2 since Metal lacks double */
132-
struct int32x4_t { int4 v; };
133-
struct int64x2_t { int64_t v[2]; };
134-
struct uint64x2_t { uint64_t v[2]; };
135-
struct int8x16_t { int8_t v[16]; };
136-
struct uint16x8_t { uint16_t v[8]; };
137-
struct uint8x16_t { uint8_t v[16]; };
138-
struct half8_t { half v[8]; };
139-
140-
/* Conversion functions from uint4x32 to various precisions uniformly */
141-
// These return vectors to efficiently use all random bits
142-
143-
/* Convert to float in [0, 1) */
144-
inline float uint32_to_single_uniform(uint32_t x) {
122+
}|}, ["THREEFRY_C240"; "threefry_round"; "THREEFRY_ROTATION_0_0"; "THREEFRY_ROTATION_0_1";
123+
"THREEFRY_ROTATION_0_2"; "THREEFRY_ROTATION_0_3"; "THREEFRY_ROTATION_1_0";
124+
"THREEFRY_ROTATION_1_1"; "THREEFRY_ROTATION_1_2"; "THREEFRY_ROTATION_1_3"]);
125+
126+
("float4_t", {|struct float4_t { float4 v; };|}, []);
127+
("float2_t", {|struct float2_t { float2 v; };|}, []);
128+
("int32x4_t", {|struct int32x4_t { int4 v; };|}, []);
129+
("int64x2_t", {|struct int64x2_t { int64_t v[2]; };|}, []);
130+
("uint64x2_t", {|struct uint64x2_t { uint64_t v[2]; };|}, []);
131+
("int8x16_t", {|struct int8x16_t { int8_t v[16]; };|}, []);
132+
("uint16x8_t", {|struct uint16x8_t { uint16_t v[8]; };|}, []);
133+
("uint8x16_t", {|struct uint8x16_t { uint8_t v[16]; };|}, []);
134+
("half8_t", {|struct half8_t { half v[8]; };|}, []);
135+
136+
("uint32_to_single_uniform", {|inline float uint32_to_single_uniform(uint32_t x) {
145137
return (x >> 8) * (1.0f / 16777216.0f);
146-
}
138+
}|}, []);
147139

148-
/* Uint4x32 to float32 uniform */
149-
float uint4x32_to_single_uniform(uint4 x) {
140+
("uint4x32_to_single_uniform", {|float uint4x32_to_single_uniform(uint4 x) {
150141
return uint32_to_single_uniform(x.x);
151-
}
142+
}|}, ["uint32_to_single_uniform"]);
152143

153-
/* Uint4x32 to float64 uniform - Metal doesn't have native double support */
154-
float uint4x32_to_double_uniform(uint4 x) {
144+
("uint4x32_to_double_uniform", {|float uint4x32_to_double_uniform(uint4 x) {
155145
/* Fallback to float precision */
156146
uint64_t combined = (uint64_t(x.y) << 32) | x.x;
157147
return float(combined) * (1.0f / 18446744073709551616.0f);
158-
}
148+
}|}, []);
159149

160-
/* Uint4x32 to int32 uniform */
161-
int32_t uint4x32_to_int32_uniform(uint4 x) {
150+
("uint4x32_to_int32_uniform", {|int32_t uint4x32_to_int32_uniform(uint4 x) {
162151
return int32_t(x.x);
163-
}
152+
}|}, []);
164153

165-
/* Uint4x32 to int64 uniform */
166-
int64_t uint4x32_to_int64_uniform(uint4 x) {
154+
("uint4x32_to_int64_uniform", {|int64_t uint4x32_to_int64_uniform(uint4 x) {
167155
return int64_t((uint64_t(x.y) << 32) | x.x);
168-
}
156+
}|}, []);
169157

170-
/* Uint4x32 to uint32 uniform */
171-
uint32_t uint4x32_to_uint32_uniform(uint4 x) {
158+
("uint4x32_to_uint32_uniform", {|uint32_t uint4x32_to_uint32_uniform(uint4 x) {
172159
return x.x;
173-
}
160+
}|}, []);
174161

175-
/* Uint4x32 to uint64 uniform */
176-
uint64_t uint4x32_to_uint64_uniform(uint4 x) {
162+
("uint4x32_to_uint64_uniform", {|uint64_t uint4x32_to_uint64_uniform(uint4 x) {
177163
return (uint64_t(x.y) << 32) | x.x;
178-
}
164+
}|}, []);
179165

180-
/* Uint4x32 to byte uniform */
181-
int8_t uint4x32_to_byte_uniform(uint4 x) {
166+
("uint4x32_to_byte_uniform", {|int8_t uint4x32_to_byte_uniform(uint4 x) {
182167
return int8_t(x.x & 0xFF);
183-
}
168+
}|}, []);
184169

185-
/* Uint4x32 to uint16 uniform */
186-
uint16_t uint4x32_to_uint16_uniform(uint4 x) {
170+
("uint4x32_to_uint16_uniform", {|uint16_t uint4x32_to_uint16_uniform(uint4 x) {
187171
return uint16_t(x.x & 0xFFFF);
188-
}
172+
}|}, []);
189173

190-
/* Uint4x32 to bfloat16 uniform */
191-
uint16_t uint4x32_to_bfloat16_uniform(uint4 x) {
174+
("uint4x32_to_bfloat16_uniform", {|uint16_t uint4x32_to_bfloat16_uniform(uint4 x) {
192175
float f = uint32_to_single_uniform(x.x);
193176
return uint16_t(as_type<uint32_t>(f) >> 16);
194-
}
177+
}|}, ["uint32_to_single_uniform"]);
195178

196-
/* Uint4x32 to float16 uniform */
197-
half uint4x32_to_half_uniform(uint4 x) {
179+
("uint4x32_to_half_uniform", {|half uint4x32_to_half_uniform(uint4 x) {
198180
float f = uint32_to_single_uniform(x.x);
199181
return half(f);
200-
}
182+
}|}, ["uint32_to_single_uniform"]);
201183

202-
/* Uint4x32 to fp8 uniform */
203-
uint8_t uint4x32_to_fp8_uniform(uint4 x) {
184+
("uint4x32_to_fp8_uniform", {|uint8_t uint4x32_to_fp8_uniform(uint4 x) {
204185
return uint8_t(x.x & 0xFF);
205-
}
186+
}|}, []);
206187

207-
/* Vectorized conversion functions that use all 128 bits efficiently */
208-
209-
/* Convert uint4x32 to 4 floats in [0, 1) */
210-
float4_t uint4x32_to_single_uniform_vec(uint4 x) {
188+
("uint4x32_to_single_uniform_vec", {|float4_t uint4x32_to_single_uniform_vec(uint4 x) {
211189
float4_t result;
212190
result.v.x = uint32_to_single_uniform(x.x);
213191
result.v.y = uint32_to_single_uniform(x.y);
214192
result.v.z = uint32_to_single_uniform(x.z);
215193
result.v.w = uint32_to_single_uniform(x.w);
216194
return result;
217-
}
195+
}|}, ["float4_t"; "uint32_to_single_uniform"]);
218196

219-
/* Convert uint4x32 to 2 floats in [0, 1) - Metal lacks double precision */
220-
float2_t uint4x32_to_double_uniform_vec(uint4 x) {
197+
("uint4x32_to_double_uniform_vec", {|float2_t uint4x32_to_double_uniform_vec(uint4 x) {
221198
float2_t result;
222199
uint64_t combined1 = (uint64_t(x.y) << 32) | x.x;
223200
uint64_t combined2 = (uint64_t(x.w) << 32) | x.z;
224201
result.v.x = float(combined1) * (1.0f / 18446744073709551616.0f);
225202
result.v.y = float(combined2) * (1.0f / 18446744073709551616.0f);
226203
return result;
227-
}
204+
}|}, ["float2_t"]);
228205

229-
/* Convert uint4x32 to 4 int32s - full range */
230-
int32x4_t uint4x32_to_int32_uniform_vec(uint4 x) {
206+
("uint4x32_to_int32_uniform_vec", {|int32x4_t uint4x32_to_int32_uniform_vec(uint4 x) {
231207
int32x4_t result;
232208
result.v = int4(x);
233209
return result;
234-
}
210+
}|}, ["int32x4_t"]);
235211

236-
/* Convert uint4x32 to 2 int64s - full range */
237-
int64x2_t uint4x32_to_int64_uniform_vec(uint4 x) {
212+
("uint4x32_to_int64_uniform_vec", {|int64x2_t uint4x32_to_int64_uniform_vec(uint4 x) {
238213
int64x2_t result;
239214
result.v[0] = (int64_t(x.y) << 32) | x.x;
240215
result.v[1] = (int64_t(x.w) << 32) | x.z;
241216
return result;
242-
}
217+
}|}, ["int64x2_t"]);
243218

244-
/* Convert uint4x32 to 4 uint32s - full range */
245-
uint4 uint4x32_to_uint32_uniform_vec(uint4 x) {
219+
("uint4x32_to_uint32_uniform_vec", {|uint4 uint4x32_to_uint32_uniform_vec(uint4 x) {
246220
return x;
247-
}
221+
}|}, []);
248222

249-
/* Convert uint4x32 to 2 uint64s - full range */
250-
uint64x2_t uint4x32_to_uint64_uniform_vec(uint4 x) {
223+
("uint4x32_to_uint64_uniform_vec", {|uint64x2_t uint4x32_to_uint64_uniform_vec(uint4 x) {
251224
uint64x2_t result;
252225
result.v[0] = (uint64_t(x.y) << 32) | x.x;
253226
result.v[1] = (uint64_t(x.w) << 32) | x.z;
254227
return result;
255-
}
228+
}|}, ["uint64x2_t"]);
256229

257-
/* Convert uint4x32 to 16 int8s - full range */
258-
int8x16_t uint4x32_to_byte_uniform_vec(uint4 x) {
230+
("uint4x32_to_byte_uniform_vec", {|int8x16_t uint4x32_to_byte_uniform_vec(uint4 x) {
259231
int8x16_t result;
260232
uint4 v = x;
261233
for (int i = 0; i < 4; i++) {
@@ -266,10 +238,9 @@ int8x16_t uint4x32_to_byte_uniform_vec(uint4 x) {
266238
result.v[i*4 + 3] = int8_t((val >> 24) & 0xFF);
267239
}
268240
return result;
269-
}
241+
}|}, ["int8x16_t"]);
270242

271-
/* Convert uint4x32 to 8 uint16s - full range */
272-
uint16x8_t uint4x32_to_uint16_uniform_vec(uint4 x) {
243+
("uint4x32_to_uint16_uniform_vec", {|uint16x8_t uint4x32_to_uint16_uniform_vec(uint4 x) {
273244
uint16x8_t result;
274245
uint4 v = x;
275246
for (int i = 0; i < 4; i++) {
@@ -278,10 +249,9 @@ uint16x8_t uint4x32_to_uint16_uniform_vec(uint4 x) {
278249
result.v[i*2 + 1] = uint16_t((val >> 16) & 0xFFFF);
279250
}
280251
return result;
281-
}
252+
}|}, ["uint16x8_t"]);
282253

283-
/* Convert uint4x32 to 8 bfloat16s uniform */
284-
uint16x8_t uint4x32_to_bfloat16_uniform_vec(uint4 x) {
254+
("uint4x32_to_bfloat16_uniform_vec", {|uint16x8_t uint4x32_to_bfloat16_uniform_vec(uint4 x) {
285255
uint16x8_t result;
286256
uint4 v = x;
287257
for (int i = 0; i < 4; i++) {
@@ -292,10 +262,9 @@ uint16x8_t uint4x32_to_bfloat16_uniform_vec(uint4 x) {
292262
result.v[i*2 + 1] = uint16_t(as_type<uint32_t>(f2) >> 16);
293263
}
294264
return result;
295-
}
265+
}|}, ["uint16x8_t"]);
296266

297-
/* Convert uint4x32 to 8 float16s uniform */
298-
half8_t uint4x32_to_half_uniform_vec(uint4 x) {
267+
("uint4x32_to_half_uniform_vec", {|half8_t uint4x32_to_half_uniform_vec(uint4 x) {
299268
half8_t result;
300269
uint4 v = x;
301270
for (int i = 0; i < 4; i++) {
@@ -306,10 +275,9 @@ half8_t uint4x32_to_half_uniform_vec(uint4 x) {
306275
result.v[i*2 + 1] = half(f2);
307276
}
308277
return result;
309-
}
278+
}|}, ["half8_t"]);
310279

311-
/* Convert uint4x32 to 16 fp8s uniform */
312-
uint8x16_t uint4x32_to_fp8_uniform_vec(uint4 x) {
280+
("uint4x32_to_fp8_uniform_vec", {|uint8x16_t uint4x32_to_fp8_uniform_vec(uint4 x) {
313281
uint8x16_t result;
314282
uint4 v = x;
315283
for (int i = 0; i < 4; i++) {
@@ -320,54 +288,53 @@ uint8x16_t uint4x32_to_fp8_uniform_vec(uint4 x) {
320288
result.v[i*4 + 3] = uint8_t((val >> 24) & 0xFF);
321289
}
322290
return result;
323-
}
291+
}|}, ["uint8x16_t"]);
324292

325-
/* Conversion functions from various precisions to uint4x32 */
326-
uint4 single_to_uint4x32(float x) {
293+
("single_to_uint4x32", {|uint4 single_to_uint4x32(float x) {
327294
uint32_t bits = as_type<uint32_t>(x);
328295
return uint4(bits, 0, 0, 0);
329-
}
296+
}|}, []);
330297

331-
uint4 double_to_uint4x32(float x) {
298+
("double_to_uint4x32", {|uint4 double_to_uint4x32(float x) {
332299
/* Metal doesn't have native double support, use float fallback */
333300
uint32_t bits = as_type<uint32_t>(x);
334301
return uint4(bits, 0, 0, 0);
335-
}
302+
}|}, []);
336303

337-
uint4 int32_to_uint4x32(int32_t x) {
304+
("int32_to_uint4x32", {|uint4 int32_to_uint4x32(int32_t x) {
338305
return uint4(uint32_t(x), 0, 0, 0);
339-
}
306+
}|}, []);
340307

341-
uint4 int64_to_uint4x32(int64_t x) {
308+
("int64_to_uint4x32", {|uint4 int64_to_uint4x32(int64_t x) {
342309
uint64_t bits = uint64_t(x);
343310
return uint4(uint32_t(bits & 0xFFFFFFFF), uint32_t(bits >> 32), 0, 0);
344-
}
311+
}|}, []);
345312

346-
uint4 uint32_to_uint4x32(uint32_t x) {
313+
("uint32_to_uint4x32", {|uint4 uint32_to_uint4x32(uint32_t x) {
347314
return uint4(x, 0, 0, 0);
348-
}
315+
}|}, []);
349316

350-
uint4 uint64_to_uint4x32(uint64_t x) {
317+
("uint64_to_uint4x32", {|uint4 uint64_to_uint4x32(uint64_t x) {
351318
return uint4(uint32_t(x & 0xFFFFFFFF), uint32_t(x >> 32), 0, 0);
352-
}
319+
}|}, []);
353320

354-
uint4 byte_to_uint4x32(int8_t x) {
321+
("byte_to_uint4x32", {|uint4 byte_to_uint4x32(int8_t x) {
355322
return uint4(uint32_t(x), 0, 0, 0);
356-
}
323+
}|}, []);
357324

358-
uint4 uint16_to_uint4x32(uint16_t x) {
325+
("uint16_to_uint4x32", {|uint4 uint16_to_uint4x32(uint16_t x) {
359326
return uint4(uint32_t(x), 0, 0, 0);
360-
}
327+
}|}, []);
361328

362-
uint4 bfloat16_to_uint4x32(uint16_t x) {
329+
("bfloat16_to_uint4x32", {|uint4 bfloat16_to_uint4x32(uint16_t x) {
363330
return uint4(uint32_t(x), 0, 0, 0);
364-
}
331+
}|}, []);
365332

366-
uint4 half_to_uint4x32(uint16_t x) {
333+
("half_to_uint4x32", {|uint4 half_to_uint4x32(uint16_t x) {
367334
return uint4(uint32_t(x), 0, 0, 0);
368-
}
335+
}|}, []);
369336

370-
uint4 fp8_to_uint4x32(uint8_t x) {
337+
("fp8_to_uint4x32", {|uint4 fp8_to_uint4x32(uint8_t x) {
371338
return uint4(uint32_t(x), 0, 0, 0);
372-
}
373-
|}
339+
}|}, []);
340+
]

0 commit comments

Comments
 (0)