11let source =
22 {|
3+ #include < stdio.h>
34#include < math.h>
45#include < stdint.h>
56#include < string .h>
67#include < stdlib.h>
78
8- /* Windows DLL export/ import macros for proper symbol visibility */
9- #ifdef _WIN32
10- #ifdef BUILDING_DLL
11- #define ARRAYJIT_API __declspec(dllexport)
12- #else
13- #define ARRAYJIT_API __declspec(dllimport)
14- #endif
15- #else
16- #define ARRAYJIT_API
17- #endif
18-
19- /* For static linking within OCaml , we want symbols to be visible */
20- #ifdef __CYGWIN__
21- #undef ARRAYJIT_API
22- #define ARRAYJIT_API __attribute__((visibility(" default" )))
23- #endif
9+ /* No longer need export macros since we're using textual prepending */
2410
2511/* Check for _Float16 support and define macros for zero -overhead abstraction */
2612#ifdef __FLT16_MAX__
@@ -207,7 +193,7 @@ void threefry_round(uint32_t x[4], unsigned int r0, unsigned int r1, unsigned in
207193}
208194
209195/* Threefry4x32 implementation - 20 rounds */
210- ARRAYJIT_API uint4x32_t arrayjit_threefry4x32(uint4x32_t key, uint4x32_t counter) {
196+ uint4x32_t arrayjit_threefry4x32(uint4x32_t key, uint4x32_t counter) {
211197 uint32_t x[4 ];
212198 uint32_t ks[5 ];
213199
@@ -308,59 +294,59 @@ typedef struct { HALF_T v[8]; } half8_t;
308294// These return vectors to efficiently use all random bits
309295
310296/* Convert to float in [0 , 1 ) */
311- ARRAYJIT_API float uint32_to_single_uniform(uint32_t x) {
297+ float uint32_to_single_uniform(uint32_t x) {
312298 /* Use upper 24 bits for float mantissa (23 bits + implicit 1 ) */
313299 return (x >> 8 ) * (1.0 f / 16777216.0 f);
314300}
315301
316302/* Convert to double in [0 , 1 ) */
317- ARRAYJIT_API double uint32_to_double_uniform(uint32_t x) {
303+ double uint32_to_double_uniform(uint32_t x) {
318304 return x * (1.0 / 4294967296.0 );
319305}
320306
321307/* Uint4x32 to float32 uniform - uses first 32 bits */
322- ARRAYJIT_API float uint4x32_to_single_uniform(uint4x32_t x) {
308+ float uint4x32_to_single_uniform(uint4x32_t x) {
323309 return uint32_to_single_uniform(x.v[0 ]);
324310}
325311
326312/* Uint4x32 to float64 uniform - uses first 64 bits */
327- ARRAYJIT_API double uint4x32_to_double_uniform(uint4x32_t x) {
313+ double uint4x32_to_double_uniform(uint4x32_t x) {
328314 uint64_t combined = ((uint64_t)x.v[1 ] << 32 ) | x.v[0 ];
329315 return combined * (1.0 / 18446744073709551616.0 );
330316}
331317
332318/* Uint4x32 to int32 uniform - full range */
333- ARRAYJIT_API int32_t uint4x32_to_int32_uniform(uint4x32_t x) {
319+ int32_t uint4x32_to_int32_uniform(uint4x32_t x) {
334320 return (int32_t)x.v[0 ];
335321}
336322
337323/* Uint4x32 to int64 uniform - full range */
338- ARRAYJIT_API int64_t uint4x32_to_int64_uniform(uint4x32_t x) {
324+ int64_t uint4x32_to_int64_uniform(uint4x32_t x) {
339325 return (int64_t)(((uint64_t)x.v[1 ] << 32 ) | x.v[0 ]);
340326}
341327
342328/* Uint4x32 to uint32 uniform - full range */
343- ARRAYJIT_API uint32_t uint4x32_to_uint32_uniform(uint4x32_t x) {
329+ uint32_t uint4x32_to_uint32_uniform(uint4x32_t x) {
344330 return x.v[0 ];
345331}
346332
347333/* Uint4x32 to uint64 uniform - full range */
348- ARRAYJIT_API uint64_t uint4x32_to_uint64_uniform(uint4x32_t x) {
334+ uint64_t uint4x32_to_uint64_uniform(uint4x32_t x) {
349335 return ((uint64_t)x.v[1 ] << 32 ) | x.v[0 ];
350336}
351337
352338/* Uint4x32 to int8 uniform - full range */
353- ARRAYJIT_API int8_t uint4x32_to_byte_uniform(uint4x32_t x) {
339+ int8_t uint4x32_to_byte_uniform(uint4x32_t x) {
354340 return (int8_t)(x.v[0 ] & 0xFF );
355341}
356342
357343/* Uint4x32 to uint16 uniform - full range */
358- ARRAYJIT_API uint16_t uint4x32_to_uint16_uniform(uint4x32_t x) {
344+ uint16_t uint4x32_to_uint16_uniform(uint4x32_t x) {
359345 return (uint16_t)(x.v[0 ] & 0xFFFF );
360346}
361347
362348/* Uint4x32 to bfloat16 uniform - uses first 16 bits */
363- ARRAYJIT_API uint16_t uint4x32_to_bfloat16_uniform(uint4x32_t x) {
349+ uint16_t uint4x32_to_bfloat16_uniform(uint4x32_t x) {
364350 /* Convert to float first, then to bfloat16 */
365351 float f = uint32_to_single_uniform(x.v[0 ]);
366352 uint32_t bits;
@@ -372,21 +358,21 @@ ARRAYJIT_API uint16_t uint4x32_to_bfloat16_uniform(uint4x32_t x) {
372358}
373359
374360/* Uint4x32 to float16 uniform - uses first 16 bits */
375- ARRAYJIT_API uint16_t uint4x32_to_half_uniform(uint4x32_t x) {
361+ uint16_t uint4x32_to_half_uniform(uint4x32_t x) {
376362 /* Convert through float for consistent behavior */
377363 float f = (x.v[0 ] & 0xFFFF ) * (1.0 f / 65536.0 f);
378364 return FLOAT_TO_HALF (f);
379365}
380366
381367/* Uint4x32 to fp8 uniform - uses first 8 bits */
382- ARRAYJIT_API uint8_t uint4x32_to_fp8_uniform(uint4x32_t x) {
368+ uint8_t uint4x32_to_fp8_uniform(uint4x32_t x) {
383369 return (uint8_t)(x.v[0 ] & 0xFF );
384370}
385371
386372/* Vectorized conversion functions that use all 128 bits efficiently */
387373
388374/* Convert uint4x32 to 4 floats in [0 , 1 ) */
389- ARRAYJIT_API float4_t uint4x32_to_single_uniform_vec(uint4x32_t x) {
375+ float4_t uint4x32_to_single_uniform_vec(uint4x32_t x) {
390376 float4_t result;
391377 for (int i = 0 ; i < 4 ; i++ ) {
392378 result.v[i] = uint32_to_single_uniform(x.v[i]);
@@ -395,7 +381,7 @@ ARRAYJIT_API float4_t uint4x32_to_single_uniform_vec(uint4x32_t x) {
395381}
396382
397383/* Convert uint4x32 to 2 doubles in [0 , 1 ) */
398- ARRAYJIT_API double2_t uint4x32_to_double_uniform_vec(uint4x32_t x) {
384+ double2_t uint4x32_to_double_uniform_vec(uint4x32_t x) {
399385 double2_t result;
400386 uint64_t combined1 = ((uint64_t)x.v[1 ] << 32 ) | x.v[0 ];
401387 uint64_t combined2 = ((uint64_t)x.v[3 ] << 32 ) | x.v[2 ];
@@ -405,7 +391,7 @@ ARRAYJIT_API double2_t uint4x32_to_double_uniform_vec(uint4x32_t x) {
405391}
406392
407393/* Convert uint4x32 to 4 int32s - full range */
408- ARRAYJIT_API int32x4_t uint4x32_to_int32_uniform_vec(uint4x32_t x) {
394+ int32x4_t uint4x32_to_int32_uniform_vec(uint4x32_t x) {
409395 int32x4_t result;
410396 for (int i = 0 ; i < 4 ; i++ ) {
411397 result.v[i] = (int32_t)x.v[i];
@@ -414,7 +400,7 @@ ARRAYJIT_API int32x4_t uint4x32_to_int32_uniform_vec(uint4x32_t x) {
414400}
415401
416402/* Convert uint4x32 to 2 int64s - full range */
417- ARRAYJIT_API int64x2_t uint4x32_to_int64_uniform_vec(uint4x32_t x) {
403+ int64x2_t uint4x32_to_int64_uniform_vec(uint4x32_t x) {
418404 int64x2_t result;
419405 result.v[0 ] = (int64_t)(((uint64_t)x.v[1 ] << 32 ) | x.v[0 ]);
420406 result.v[1 ] = (int64_t)(((uint64_t)x.v[3 ] << 32 ) | x.v[2 ]);
@@ -423,7 +409,7 @@ ARRAYJIT_API int64x2_t uint4x32_to_int64_uniform_vec(uint4x32_t x) {
423409
424410
425411/* Convert uint4x32 to 16 int8s - full range */
426- ARRAYJIT_API int8x16_t uint4x32_to_byte_uniform_vec(uint4x32_t x) {
412+ int8x16_t uint4x32_to_byte_uniform_vec(uint4x32_t x) {
427413 int8x16_t result;
428414 for (int i = 0 ; i < 4 ; i++ ) {
429415 result.v[i* 4 + 0 ] = (int8_t)(x.v[i] & 0xFF );
@@ -435,7 +421,7 @@ ARRAYJIT_API int8x16_t uint4x32_to_byte_uniform_vec(uint4x32_t x) {
435421}
436422
437423/* Convert uint4x32 to 8 uint16s - full range */
438- ARRAYJIT_API uint16x8_t uint4x32_to_uint16_uniform_vec(uint4x32_t x) {
424+ uint16x8_t uint4x32_to_uint16_uniform_vec(uint4x32_t x) {
439425 uint16x8_t result;
440426 for (int i = 0 ; i < 4 ; i++ ) {
441427 result.v[i* 2 + 0 ] = (uint16_t)(x.v[i] & 0xFFFF );
@@ -445,7 +431,7 @@ ARRAYJIT_API uint16x8_t uint4x32_to_uint16_uniform_vec(uint4x32_t x) {
445431}
446432
447433/* Convert uint4x32 to 8 bfloat16s uniform */
448- ARRAYJIT_API uint16x8_t uint4x32_to_bfloat16_uniform_vec(uint4x32_t x) {
434+ uint16x8_t uint4x32_to_bfloat16_uniform_vec(uint4x32_t x) {
449435 uint16x8_t result;
450436 for (int i = 0 ; i < 4 ; i++ ) {
451437 // Convert each uint32 to two bfloat16 values
@@ -467,7 +453,7 @@ ARRAYJIT_API uint16x8_t uint4x32_to_bfloat16_uniform_vec(uint4x32_t x) {
467453}
468454
469455/* Convert uint4x32 to 8 float16s uniform */
470- ARRAYJIT_API half8_t uint4x32_to_half_uniform_vec(uint4x32_t x) {
456+ half8_t uint4x32_to_half_uniform_vec(uint4x32_t x) {
471457 half8_t result;
472458 for (int i = 0 ; i < 4 ; i++ ) {
473459 // Extract two 16 - bit values and convert to float in [0 , 1 )
@@ -482,7 +468,7 @@ ARRAYJIT_API half8_t uint4x32_to_half_uniform_vec(uint4x32_t x) {
482468}
483469
484470/* Convert uint4x32 to 16 fp8s uniform */
485- ARRAYJIT_API uint8x16_t uint4x32_to_fp8_uniform_vec(uint4x32_t x) {
471+ uint8x16_t uint4x32_to_fp8_uniform_vec(uint4x32_t x) {
486472 uint8x16_t result;
487473 for (int i = 0 ; i < 4 ; i++ ) {
488474 result.v[i* 4 + 0 ] = (uint8_t)(x.v[i] & 0xFF );
@@ -494,70 +480,70 @@ ARRAYJIT_API uint8x16_t uint4x32_to_fp8_uniform_vec(uint4x32_t x) {
494480}
495481
496482/* Conversion functions from various precisions to uint4x32_t */
497- ARRAYJIT_API uint4x32_t single_to_uint4x32(float x) {
483+ uint4x32_t single_to_uint4x32(float x) {
498484 uint32_t bits;
499485 memcpy(& bits, & x, sizeof(float ));
500486 uint4x32_t result = {{bits, 0 , 0 , 0 }};
501487 return result;
502488}
503489
504- ARRAYJIT_API uint4x32_t double_to_uint4x32(double x) {
490+ uint4x32_t double_to_uint4x32(double x) {
505491 uint64_t bits;
506492 memcpy(& bits, & x, sizeof(double));
507493 uint4x32_t result = {{(uint32_t)(bits & 0xFFFFFFFF ), (uint32_t)(bits >> 32 ), 0 , 0 }};
508494 return result;
509495}
510496
511- ARRAYJIT_API uint4x32_t int32_to_uint4x32(int32_t x) {
497+ uint4x32_t int32_to_uint4x32(int32_t x) {
512498 uint4x32_t result = {{(uint32_t)x, 0 , 0 , 0 }};
513499 return result;
514500}
515501
516- ARRAYJIT_API uint4x32_t int64_to_uint4x32(int64_t x) {
502+ uint4x32_t int64_to_uint4x32(int64_t x) {
517503 uint64_t bits = (uint64_t)x;
518504 uint4x32_t result = {{(uint32_t)(bits & 0xFFFFFFFF ), (uint32_t)(bits >> 32 ), 0 , 0 }};
519505 return result;
520506}
521507
522- ARRAYJIT_API uint4x32_t uint32_to_uint4x32(uint32_t x) {
508+ uint4x32_t uint32_to_uint4x32(uint32_t x) {
523509 uint4x32_t result = {{x, 0 , 0 , 0 }};
524510 return result;
525511}
526512
527- ARRAYJIT_API uint4x32_t uint64_to_uint4x32(uint64_t x) {
513+ uint4x32_t uint64_to_uint4x32(uint64_t x) {
528514 uint4x32_t result = {{(uint32_t)(x & 0xFFFFFFFF ), (uint32_t)(x >> 32 ), 0 , 0 }};
529515 return result;
530516}
531517
532- ARRAYJIT_API uint4x32_t byte_to_uint4x32(unsigned char x) {
518+ uint4x32_t byte_to_uint4x32(unsigned char x) {
533519 uint4x32_t result = {{(uint32_t)x, 0 , 0 , 0 }};
534520 return result;
535521}
536522
537- ARRAYJIT_API uint4x32_t uint16_to_uint4x32(uint16_t x) {
523+ uint4x32_t uint16_to_uint4x32(uint16_t x) {
538524 uint4x32_t result = {{(uint32_t)x, 0 , 0 , 0 }};
539525 return result;
540526}
541527
542- ARRAYJIT_API uint4x32_t bfloat16_to_uint4x32(uint16_t x) {
528+ uint4x32_t bfloat16_to_uint4x32(uint16_t x) {
543529 uint4x32_t result = {{(uint32_t)x, 0 , 0 , 0 }};
544530 return result;
545531}
546532
547- ARRAYJIT_API uint4x32_t half_to_uint4x32(uint16_t x) {
533+ uint4x32_t half_to_uint4x32(uint16_t x) {
548534 uint4x32_t result = {{(uint32_t)x, 0 , 0 , 0 }};
549535 return result;
550536}
551537
552- ARRAYJIT_API uint4x32_t fp8_to_uint4x32(uint8_t x) {
538+ uint4x32_t fp8_to_uint4x32(uint8_t x) {
553539 uint4x32_t result = {{(uint32_t)x, 0 , 0 , 0 }};
554540 return result;
555541}
556542
557543/* Pure C conversion functions for use in C backends */
558544
559545/* BFloat16 to Float conversion (C function) */
560- ARRAYJIT_API float bfloat16_to_single(uint16_t bf16)
546+ float bfloat16_to_single(uint16_t bf16)
561547{
562548 /* BFloat16 format : 1 sign bit , 8 exponent bits , 7 mantissa bits
563549 To convert to float32 , we shift left by 16 bits */
@@ -566,7 +552,7 @@ ARRAYJIT_API float bfloat16_to_single(uint16_t bf16)
566552}
567553
568554/* Float to BFloat16 conversion (C function) */
569- ARRAYJIT_API uint16_t single_to_bfloat16(float f)
555+ uint16_t single_to_bfloat16(float f)
570556{
571557 uint32_t f32 = * ((uint32_t * )& f);
572558
@@ -576,22 +562,22 @@ ARRAYJIT_API uint16_t single_to_bfloat16(float f)
576562}
577563
578564/* Half (Float16 ) to Float conversion (C function) */
579- ARRAYJIT_API float half_to_single(uint16_t h)
565+ float half_to_single(uint16_t h)
580566{
581567 HALF_T half_val = UINT16_TO_HALF (h);
582568 return HALF_TO_FLOAT (half_val);
583569}
584570
585571/* Float to Half (Float16 ) conversion (C function) */
586- ARRAYJIT_API uint16_t single_to_half(float f)
572+ uint16_t single_to_half(float f)
587573{
588574 HALF_T half_val = FLOAT_TO_HALF (f);
589575 return HALF_TO_UINT16 (half_val);
590576}
591577
592578/* FP8 E5M2 format to Float conversion (C function)
593579 Format : 1 sign bit, 5 exponent bits, 2 mantissa bits */
594- ARRAYJIT_API float fp8_to_single(uint8_t fp8)
580+ float fp8_to_single(uint8_t fp8)
595581{
596582 /* Handle zero */
597583 if (fp8 == 0 )
@@ -634,7 +620,7 @@ ARRAYJIT_API float fp8_to_single(uint8_t fp8)
634620}
635621
636622/* Float to FP8 E5M2 conversion (C function) */
637- ARRAYJIT_API uint8_t single_to_fp8(float f)
623+ uint8_t single_to_fp8(float f)
638624{
639625 /* Handle zero */
640626 if (f == 0.0 f)
0 commit comments