1- let source =
2- {|
3- #include < metal_stdlib>
4- using namespace metal;
5-
6- /* Threefry4x32 constants */
7- constant uint32_t THREEFRY_C240 = 0x1BD11BDA ;
8-
9- /* Rotation constants for Threefry4x32 */
10- constant uint THREEFRY_ROTATION_0_0 = 13 ;
11- constant uint THREEFRY_ROTATION_0_1 = 15 ;
12- constant uint THREEFRY_ROTATION_0_2 = 26 ;
13- constant uint THREEFRY_ROTATION_0_3 = 6 ;
14- constant uint THREEFRY_ROTATION_1_0 = 17 ;
15- constant uint THREEFRY_ROTATION_1_1 = 29 ;
16- constant uint THREEFRY_ROTATION_1_2 = 16 ;
17- constant uint THREEFRY_ROTATION_1_3 = 24 ;
18-
19- /* Metal rotate left using built- in rotate function */
20- inline uint32_t rotl32(uint32_t x, uint n) {
1+ (* Metal builtin code split into (key, definition, dependencies) triples for filtering *)
2+ let builtins = [
3+ (" METAL_HEADERS" , {| #include < metal_stdlib>
4+ using namespace metal;| }, [] );
5+
6+ (" THREEFRY_C240" , {| constant uint32_t THREEFRY_C240 = 0x1BD11BDA ;| }, [] );
7+
8+ (" THREEFRY_ROTATION_0_0" , {| constant uint THREEFRY_ROTATION_0_0 = 13 ;| }, [] );
9+ (" THREEFRY_ROTATION_0_1" , {| constant uint THREEFRY_ROTATION_0_1 = 15 ;| }, [] );
10+ (" THREEFRY_ROTATION_0_2" , {| constant uint THREEFRY_ROTATION_0_2 = 26 ;| }, [] );
11+ (" THREEFRY_ROTATION_0_3" , {| constant uint THREEFRY_ROTATION_0_3 = 6 ;| }, [] );
12+ (" THREEFRY_ROTATION_1_0" , {| constant uint THREEFRY_ROTATION_1_0 = 17 ;| }, [] );
13+ (" THREEFRY_ROTATION_1_1" , {| constant uint THREEFRY_ROTATION_1_1 = 29 ;| }, [] );
14+ (" THREEFRY_ROTATION_1_2" , {| constant uint THREEFRY_ROTATION_1_2 = 16 ;| }, [] );
15+ (" THREEFRY_ROTATION_1_3" , {| constant uint THREEFRY_ROTATION_1_3 = 24 ;| }, [] );
16+
17+ (" rotl32" , {| inline uint32_t rotl32(uint32_t x, uint n) {
2118 return rotate(x, n);
22- }
19+ }| }, [] );
2320
24- /* Threefry4x32 round function using SIMD operations */
25- inline void threefry_round(thread uint4 & x, uint r0, uint r1, uint r2, uint r3) {
21+ (" threefry_round" , {| inline void threefry_round(thread uint4 & x, uint r0, uint r1, uint r2, uint r3) {
2622 x.x += x.y; x.y = rotl32(x.y, r0); x.y ^= x.x;
2723 x.z += x.w; x.w = rotl32(x.w, r1); x.w ^= x.z;
2824
@@ -36,10 +32,9 @@ inline void threefry_round(thread uint4 &x, uint r0, uint r1, uint r2, uint r3)
3632 tmp = x.y;
3733 x.y = x.w;
3834 x.w = tmp;
39- }
35+ }| }, [ " rotl32 " ]);
4036
41- /* Threefry4x32 implementation - 20 rounds */
42- uint4 arrayjit_threefry4x32(uint4 key, uint4 counter) {
37+ (" arrayjit_threefry4x32" , {| uint4 arrayjit_threefry4x32(uint4 key, uint4 counter) {
4338 uint4 x = counter;
4439 uint4 k = key;
4540
@@ -124,138 +119,115 @@ uint4 arrayjit_threefry4x32(uint4 key, uint4 counter) {
124119 x.w += 5 ;
125120
126121 return x;
127- }
128-
129- /* Vector types for efficient extraction of multiple values */
130- struct float4_t { float4 v; };
131- struct float2_t { float2 v; }; /* Using float2 since Metal lacks double */
132- struct int32x4_t { int4 v; };
133- struct int64x2_t { int64_t v[2 ]; };
134- struct uint64x2_t { uint64_t v[2 ]; };
135- struct int8x16_t { int8_t v[16 ]; };
136- struct uint16x8_t { uint16_t v[8 ]; };
137- struct uint8x16_t { uint8_t v[16 ]; };
138- struct half8_t { half v[8 ]; };
139-
140- /* Conversion functions from uint4x32 to various precisions uniformly */
141- // These return vectors to efficiently use all random bits
142-
143- /* Convert to float in [0 , 1 ) */
144- inline float uint32_to_single_uniform(uint32_t x) {
122+ }| }, [" THREEFRY_C240" ; " threefry_round" ; " THREEFRY_ROTATION_0_0" ; " THREEFRY_ROTATION_0_1" ;
123+ " THREEFRY_ROTATION_0_2" ; " THREEFRY_ROTATION_0_3" ; " THREEFRY_ROTATION_1_0" ;
124+ " THREEFRY_ROTATION_1_1" ; " THREEFRY_ROTATION_1_2" ; " THREEFRY_ROTATION_1_3" ]);
125+
126+ (" float4_t" , {| struct float4_t { float4 v; };| }, [] );
127+ (" float2_t" , {| struct float2_t { float2 v; };| }, [] );
128+ (" int32x4_t" , {| struct int32x4_t { int4 v; };| }, [] );
129+ (" int64x2_t" , {| struct int64x2_t { int64_t v[2 ]; };| }, [] );
130+ (" uint64x2_t" , {| struct uint64x2_t { uint64_t v[2 ]; };| }, [] );
131+ (" int8x16_t" , {| struct int8x16_t { int8_t v[16 ]; };| }, [] );
132+ (" uint16x8_t" , {| struct uint16x8_t { uint16_t v[8 ]; };| }, [] );
133+ (" uint8x16_t" , {| struct uint8x16_t { uint8_t v[16 ]; };| }, [] );
134+ (" half8_t" , {| struct half8_t { half v[8 ]; };| }, [] );
135+
136+ (" uint32_to_single_uniform" , {| inline float uint32_to_single_uniform(uint32_t x) {
145137 return (x >> 8 ) * (1.0 f / 16777216.0 f);
146- }
138+ }| }, [] );
147139
148- /* Uint4x32 to float32 uniform */
149- float uint4x32_to_single_uniform(uint4 x) {
140+ (" uint4x32_to_single_uniform" , {| float uint4x32_to_single_uniform(uint4 x) {
150141 return uint32_to_single_uniform(x.x);
151- }
142+ }| }, [ " uint32_to_single_uniform " ]);
152143
153- /* Uint4x32 to float64 uniform - Metal doesn't have native double support */
154- float uint4x32_to_double_uniform(uint4 x) {
144+ (" uint4x32_to_double_uniform" , {| float uint4x32_to_double_uniform(uint4 x) {
155145 /* Fallback to float precision */
156146 uint64_t combined = (uint64_t(x.y) << 32 ) | x.x;
157147 return float (combined) * (1.0 f / 18446744073709551616.0 f);
158- }
148+ }| }, [] );
159149
160- /* Uint4x32 to int32 uniform */
161- int32_t uint4x32_to_int32_uniform(uint4 x) {
150+ (" uint4x32_to_int32_uniform" , {| int32_t uint4x32_to_int32_uniform(uint4 x) {
162151 return int32_t(x.x);
163- }
152+ }| }, [] );
164153
165- /* Uint4x32 to int64 uniform */
166- int64_t uint4x32_to_int64_uniform(uint4 x) {
154+ (" uint4x32_to_int64_uniform" , {| int64_t uint4x32_to_int64_uniform(uint4 x) {
167155 return int64_t((uint64_t(x.y) << 32 ) | x.x);
168- }
156+ }| }, [] );
169157
170- /* Uint4x32 to uint32 uniform */
171- uint32_t uint4x32_to_uint32_uniform(uint4 x) {
158+ (" uint4x32_to_uint32_uniform" , {| uint32_t uint4x32_to_uint32_uniform(uint4 x) {
172159 return x.x;
173- }
160+ }| }, [] );
174161
175- /* Uint4x32 to uint64 uniform */
176- uint64_t uint4x32_to_uint64_uniform(uint4 x) {
162+ (" uint4x32_to_uint64_uniform" , {| uint64_t uint4x32_to_uint64_uniform(uint4 x) {
177163 return (uint64_t(x.y) << 32 ) | x.x;
178- }
164+ }| }, [] );
179165
180- /* Uint4x32 to byte uniform */
181- int8_t uint4x32_to_byte_uniform(uint4 x) {
166+ (" uint4x32_to_byte_uniform" , {| int8_t uint4x32_to_byte_uniform(uint4 x) {
182167 return int8_t(x.x & 0xFF );
183- }
168+ }| }, [] );
184169
185- /* Uint4x32 to uint16 uniform */
186- uint16_t uint4x32_to_uint16_uniform(uint4 x) {
170+ (" uint4x32_to_uint16_uniform" , {| uint16_t uint4x32_to_uint16_uniform(uint4 x) {
187171 return uint16_t(x.x & 0xFFFF );
188- }
172+ }| }, [] );
189173
190- /* Uint4x32 to bfloat16 uniform */
191- uint16_t uint4x32_to_bfloat16_uniform(uint4 x) {
174+ (" uint4x32_to_bfloat16_uniform" , {| uint16_t uint4x32_to_bfloat16_uniform(uint4 x) {
192175 float f = uint32_to_single_uniform(x.x);
193176 return uint16_t(as_type< uint32_t> (f) >> 16 );
194- }
177+ }| }, [ " uint32_to_single_uniform " ]);
195178
196- /* Uint4x32 to float16 uniform */
197- half uint4x32_to_half_uniform(uint4 x) {
179+ (" uint4x32_to_half_uniform" , {| half uint4x32_to_half_uniform(uint4 x) {
198180 float f = uint32_to_single_uniform(x.x);
199181 return half(f);
200- }
182+ }| }, [ " uint32_to_single_uniform " ]);
201183
202- /* Uint4x32 to fp8 uniform */
203- uint8_t uint4x32_to_fp8_uniform(uint4 x) {
184+ (" uint4x32_to_fp8_uniform" , {| uint8_t uint4x32_to_fp8_uniform(uint4 x) {
204185 return uint8_t(x.x & 0xFF );
205- }
186+ }| }, [] );
206187
207- /* Vectorized conversion functions that use all 128 bits efficiently */
208-
209- /* Convert uint4x32 to 4 floats in [0 , 1 ) */
210- float4_t uint4x32_to_single_uniform_vec(uint4 x) {
188+ (" uint4x32_to_single_uniform_vec" , {| float4_t uint4x32_to_single_uniform_vec(uint4 x) {
211189 float4_t result;
212190 result.v.x = uint32_to_single_uniform(x.x);
213191 result.v.y = uint32_to_single_uniform(x.y);
214192 result.v.z = uint32_to_single_uniform(x.z);
215193 result.v.w = uint32_to_single_uniform(x.w);
216194 return result;
217- }
195+ }| }, [ " float4_t " ; " uint32_to_single_uniform " ]);
218196
219- /* Convert uint4x32 to 2 floats in [0 , 1 ) - Metal lacks double precision */
220- float2_t uint4x32_to_double_uniform_vec(uint4 x) {
197+ (" uint4x32_to_double_uniform_vec" , {| float2_t uint4x32_to_double_uniform_vec(uint4 x) {
221198 float2_t result;
222199 uint64_t combined1 = (uint64_t(x.y) << 32 ) | x.x;
223200 uint64_t combined2 = (uint64_t(x.w) << 32 ) | x.z;
224201 result.v.x = float (combined1) * (1.0 f / 18446744073709551616.0 f);
225202 result.v.y = float (combined2) * (1.0 f / 18446744073709551616.0 f);
226203 return result;
227- }
204+ }| }, [ " float2_t " ]);
228205
229- /* Convert uint4x32 to 4 int32s - full range */
230- int32x4_t uint4x32_to_int32_uniform_vec(uint4 x) {
206+ (" uint4x32_to_int32_uniform_vec" , {| int32x4_t uint4x32_to_int32_uniform_vec(uint4 x) {
231207 int32x4_t result;
232208 result.v = int4(x);
233209 return result;
234- }
210+ }| }, [ " int32x4_t " ]);
235211
236- /* Convert uint4x32 to 2 int64s - full range */
237- int64x2_t uint4x32_to_int64_uniform_vec(uint4 x) {
212+ (" uint4x32_to_int64_uniform_vec" , {| int64x2_t uint4x32_to_int64_uniform_vec(uint4 x) {
238213 int64x2_t result;
239214 result.v[0 ] = (int64_t(x.y) << 32 ) | x.x;
240215 result.v[1 ] = (int64_t(x.w) << 32 ) | x.z;
241216 return result;
242- }
217+ }| }, [ " int64x2_t " ]);
243218
244- /* Convert uint4x32 to 4 uint32s - full range */
245- uint4 uint4x32_to_uint32_uniform_vec(uint4 x) {
219+ (" uint4x32_to_uint32_uniform_vec" , {| uint4 uint4x32_to_uint32_uniform_vec(uint4 x) {
246220 return x;
247- }
221+ }| }, [] );
248222
249- /* Convert uint4x32 to 2 uint64s - full range */
250- uint64x2_t uint4x32_to_uint64_uniform_vec(uint4 x) {
223+ (" uint4x32_to_uint64_uniform_vec" , {| uint64x2_t uint4x32_to_uint64_uniform_vec(uint4 x) {
251224 uint64x2_t result;
252225 result.v[0 ] = (uint64_t(x.y) << 32 ) | x.x;
253226 result.v[1 ] = (uint64_t(x.w) << 32 ) | x.z;
254227 return result;
255- }
228+ }| }, [ " uint64x2_t " ]);
256229
257- /* Convert uint4x32 to 16 int8s - full range */
258- int8x16_t uint4x32_to_byte_uniform_vec(uint4 x) {
230+ (" uint4x32_to_byte_uniform_vec" , {| int8x16_t uint4x32_to_byte_uniform_vec(uint4 x) {
259231 int8x16_t result;
260232 uint4 v = x;
261233 for (int i = 0 ; i < 4 ; i++ ) {
@@ -266,10 +238,9 @@ int8x16_t uint4x32_to_byte_uniform_vec(uint4 x) {
266238 result.v[i* 4 + 3 ] = int8_t((val >> 24 ) & 0xFF );
267239 }
268240 return result;
269- }
241+ }| }, [ " int8x16_t " ]);
270242
271- /* Convert uint4x32 to 8 uint16s - full range */
272- uint16x8_t uint4x32_to_uint16_uniform_vec(uint4 x) {
243+ (" uint4x32_to_uint16_uniform_vec" , {| uint16x8_t uint4x32_to_uint16_uniform_vec(uint4 x) {
273244 uint16x8_t result;
274245 uint4 v = x;
275246 for (int i = 0 ; i < 4 ; i++ ) {
@@ -278,10 +249,9 @@ uint16x8_t uint4x32_to_uint16_uniform_vec(uint4 x) {
278249 result.v[i* 2 + 1 ] = uint16_t((val >> 16 ) & 0xFFFF );
279250 }
280251 return result;
281- }
252+ }| }, [ " uint16x8_t " ]);
282253
283- /* Convert uint4x32 to 8 bfloat16s uniform */
284- uint16x8_t uint4x32_to_bfloat16_uniform_vec(uint4 x) {
254+ (" uint4x32_to_bfloat16_uniform_vec" , {| uint16x8_t uint4x32_to_bfloat16_uniform_vec(uint4 x) {
285255 uint16x8_t result;
286256 uint4 v = x;
287257 for (int i = 0 ; i < 4 ; i++ ) {
@@ -292,10 +262,9 @@ uint16x8_t uint4x32_to_bfloat16_uniform_vec(uint4 x) {
292262 result.v[i* 2 + 1 ] = uint16_t(as_type< uint32_t> (f2) >> 16 );
293263 }
294264 return result;
295- }
265+ }| }, [ " uint16x8_t " ]);
296266
297- /* Convert uint4x32 to 8 float16s uniform */
298- half8_t uint4x32_to_half_uniform_vec(uint4 x) {
267+ (" uint4x32_to_half_uniform_vec" , {| half8_t uint4x32_to_half_uniform_vec(uint4 x) {
299268 half8_t result;
300269 uint4 v = x;
301270 for (int i = 0 ; i < 4 ; i++ ) {
@@ -306,10 +275,9 @@ half8_t uint4x32_to_half_uniform_vec(uint4 x) {
306275 result.v[i* 2 + 1 ] = half(f2);
307276 }
308277 return result;
309- }
278+ }| }, [ " half8_t " ]);
310279
311- /* Convert uint4x32 to 16 fp8s uniform */
312- uint8x16_t uint4x32_to_fp8_uniform_vec(uint4 x) {
280+ (" uint4x32_to_fp8_uniform_vec" , {| uint8x16_t uint4x32_to_fp8_uniform_vec(uint4 x) {
313281 uint8x16_t result;
314282 uint4 v = x;
315283 for (int i = 0 ; i < 4 ; i++ ) {
@@ -320,54 +288,53 @@ uint8x16_t uint4x32_to_fp8_uniform_vec(uint4 x) {
320288 result.v[i* 4 + 3 ] = uint8_t((val >> 24 ) & 0xFF );
321289 }
322290 return result;
323- }
291+ }| }, [ " uint8x16_t " ]);
324292
325- /* Conversion functions from various precisions to uint4x32 */
326- uint4 single_to_uint4x32(float x) {
293+ (" single_to_uint4x32" , {| uint4 single_to_uint4x32(float x) {
327294 uint32_t bits = as_type< uint32_t> (x);
328295 return uint4(bits, 0 , 0 , 0 );
329- }
296+ }| }, [] );
330297
331- uint4 double_to_uint4x32(float x) {
298+ ( " double_to_uint4x32 " , { | uint4 double_to_uint4x32(float x) {
332299 /* Metal doesn't have native double support, use float fallback */
333300 uint32_t bits = as_type< uint32_t> (x);
334301 return uint4(bits, 0 , 0 , 0 );
335- }
302+ }| }, [] );
336303
337- uint4 int32_to_uint4x32(int32_t x) {
304+ ( " int32_to_uint4x32 " , { | uint4 int32_to_uint4x32(int32_t x) {
338305 return uint4(uint32_t(x), 0 , 0 , 0 );
339- }
306+ }| }, [] );
340307
341- uint4 int64_to_uint4x32(int64_t x) {
308+ ( " int64_to_uint4x32 " , { | uint4 int64_to_uint4x32(int64_t x) {
342309 uint64_t bits = uint64_t(x);
343310 return uint4(uint32_t(bits & 0xFFFFFFFF ), uint32_t(bits >> 32 ), 0 , 0 );
344- }
311+ }| }, [] );
345312
346- uint4 uint32_to_uint4x32(uint32_t x) {
313+ ( " uint32_to_uint4x32 " , { | uint4 uint32_to_uint4x32(uint32_t x) {
347314 return uint4(x, 0 , 0 , 0 );
348- }
315+ }| }, [] );
349316
350- uint4 uint64_to_uint4x32(uint64_t x) {
317+ ( " uint64_to_uint4x32 " , { | uint4 uint64_to_uint4x32(uint64_t x) {
351318 return uint4(uint32_t(x & 0xFFFFFFFF ), uint32_t(x >> 32 ), 0 , 0 );
352- }
319+ }| }, [] );
353320
354- uint4 byte_to_uint4x32(int8_t x) {
321+ ( " byte_to_uint4x32 " , { | uint4 byte_to_uint4x32(int8_t x) {
355322 return uint4(uint32_t(x), 0 , 0 , 0 );
356- }
323+ }| }, [] );
357324
358- uint4 uint16_to_uint4x32(uint16_t x) {
325+ ( " uint16_to_uint4x32 " , { | uint4 uint16_to_uint4x32(uint16_t x) {
359326 return uint4(uint32_t(x), 0 , 0 , 0 );
360- }
327+ }| }, [] );
361328
362- uint4 bfloat16_to_uint4x32(uint16_t x) {
329+ ( " bfloat16_to_uint4x32 " , { | uint4 bfloat16_to_uint4x32(uint16_t x) {
363330 return uint4(uint32_t(x), 0 , 0 , 0 );
364- }
331+ }| }, [] );
365332
366- uint4 half_to_uint4x32(uint16_t x) {
333+ ( " half_to_uint4x32 " , { | uint4 half_to_uint4x32(uint16_t x) {
367334 return uint4(uint32_t(x), 0 , 0 , 0 );
368- }
335+ }| }, [] );
369336
370- uint4 fp8_to_uint4x32(uint8_t x) {
337+ ( " fp8_to_uint4x32 " , { | uint4 fp8_to_uint4x32(uint8_t x) {
371338 return uint4(uint32_t(x), 0 , 0 , 0 );
372- }
373- | }
339+ }| }, [] );
340+ ]
0 commit comments