Skip to content

Commit b6dec32

Browse files
committed
Complete the refactoring of builtins; fix timeouts (wait longer)
Signed-off-by: Lukasz Stafiniak <lukstafi@gmail.com>
1 parent da369c6 commit b6dec32

File tree

3 files changed

+47
-58
lines changed

3 files changed

+47
-58
lines changed

arrayjit/lib/builtins_cc.ml

Lines changed: 42 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,12 @@
11
let source =
22
{|
3+
#include <stdio.h>
34
#include <math.h>
45
#include <stdint.h>
56
#include <string.h>
67
#include <stdlib.h>
78

8-
/* Windows DLL export/import macros for proper symbol visibility */
9-
#ifdef _WIN32
10-
#ifdef BUILDING_DLL
11-
#define ARRAYJIT_API __declspec(dllexport)
12-
#else
13-
#define ARRAYJIT_API __declspec(dllimport)
14-
#endif
15-
#else
16-
#define ARRAYJIT_API
17-
#endif
18-
19-
/* For static linking within OCaml, we want symbols to be visible */
20-
#ifdef __CYGWIN__
21-
#undef ARRAYJIT_API
22-
#define ARRAYJIT_API __attribute__((visibility("default")))
23-
#endif
9+
/* No longer need export macros since we're using textual prepending */
2410

2511
/* Check for _Float16 support and define macros for zero-overhead abstraction */
2612
#ifdef __FLT16_MAX__
@@ -207,7 +193,7 @@ void threefry_round(uint32_t x[4], unsigned int r0, unsigned int r1, unsigned in
207193
}
208194

209195
/* Threefry4x32 implementation - 20 rounds */
210-
ARRAYJIT_API uint4x32_t arrayjit_threefry4x32(uint4x32_t key, uint4x32_t counter) {
196+
uint4x32_t arrayjit_threefry4x32(uint4x32_t key, uint4x32_t counter) {
211197
uint32_t x[4];
212198
uint32_t ks[5];
213199

@@ -308,59 +294,59 @@ typedef struct { HALF_T v[8]; } half8_t;
308294
// These return vectors to efficiently use all random bits
309295

310296
/* Convert to float in [0, 1) */
311-
ARRAYJIT_API float uint32_to_single_uniform(uint32_t x) {
297+
float uint32_to_single_uniform(uint32_t x) {
312298
/* Use upper 24 bits for float mantissa (23 bits + implicit 1) */
313299
return (x >> 8) * (1.0f / 16777216.0f);
314300
}
315301

316302
/* Convert to double in [0, 1) */
317-
ARRAYJIT_API double uint32_to_double_uniform(uint32_t x) {
303+
double uint32_to_double_uniform(uint32_t x) {
318304
return x * (1.0 / 4294967296.0);
319305
}
320306

321307
/* Uint4x32 to float32 uniform - uses first 32 bits */
322-
ARRAYJIT_API float uint4x32_to_single_uniform(uint4x32_t x) {
308+
float uint4x32_to_single_uniform(uint4x32_t x) {
323309
return uint32_to_single_uniform(x.v[0]);
324310
}
325311

326312
/* Uint4x32 to float64 uniform - uses first 64 bits */
327-
ARRAYJIT_API double uint4x32_to_double_uniform(uint4x32_t x) {
313+
double uint4x32_to_double_uniform(uint4x32_t x) {
328314
uint64_t combined = ((uint64_t)x.v[1] << 32) | x.v[0];
329315
return combined * (1.0 / 18446744073709551616.0);
330316
}
331317

332318
/* Uint4x32 to int32 uniform - full range */
333-
ARRAYJIT_API int32_t uint4x32_to_int32_uniform(uint4x32_t x) {
319+
int32_t uint4x32_to_int32_uniform(uint4x32_t x) {
334320
return (int32_t)x.v[0];
335321
}
336322

337323
/* Uint4x32 to int64 uniform - full range */
338-
ARRAYJIT_API int64_t uint4x32_to_int64_uniform(uint4x32_t x) {
324+
int64_t uint4x32_to_int64_uniform(uint4x32_t x) {
339325
return (int64_t)(((uint64_t)x.v[1] << 32) | x.v[0]);
340326
}
341327

342328
/* Uint4x32 to uint32 uniform - full range */
343-
ARRAYJIT_API uint32_t uint4x32_to_uint32_uniform(uint4x32_t x) {
329+
uint32_t uint4x32_to_uint32_uniform(uint4x32_t x) {
344330
return x.v[0];
345331
}
346332

347333
/* Uint4x32 to uint64 uniform - full range */
348-
ARRAYJIT_API uint64_t uint4x32_to_uint64_uniform(uint4x32_t x) {
334+
uint64_t uint4x32_to_uint64_uniform(uint4x32_t x) {
349335
return ((uint64_t)x.v[1] << 32) | x.v[0];
350336
}
351337

352338
/* Uint4x32 to int8 uniform - full range */
353-
ARRAYJIT_API int8_t uint4x32_to_byte_uniform(uint4x32_t x) {
339+
int8_t uint4x32_to_byte_uniform(uint4x32_t x) {
354340
return (int8_t)(x.v[0] & 0xFF);
355341
}
356342

357343
/* Uint4x32 to uint16 uniform - full range */
358-
ARRAYJIT_API uint16_t uint4x32_to_uint16_uniform(uint4x32_t x) {
344+
uint16_t uint4x32_to_uint16_uniform(uint4x32_t x) {
359345
return (uint16_t)(x.v[0] & 0xFFFF);
360346
}
361347

362348
/* Uint4x32 to bfloat16 uniform - uses first 16 bits */
363-
ARRAYJIT_API uint16_t uint4x32_to_bfloat16_uniform(uint4x32_t x) {
349+
uint16_t uint4x32_to_bfloat16_uniform(uint4x32_t x) {
364350
/* Convert to float first, then to bfloat16 */
365351
float f = uint32_to_single_uniform(x.v[0]);
366352
uint32_t bits;
@@ -372,21 +358,21 @@ ARRAYJIT_API uint16_t uint4x32_to_bfloat16_uniform(uint4x32_t x) {
372358
}
373359

374360
/* Uint4x32 to float16 uniform - uses first 16 bits */
375-
ARRAYJIT_API uint16_t uint4x32_to_half_uniform(uint4x32_t x) {
361+
uint16_t uint4x32_to_half_uniform(uint4x32_t x) {
376362
/* Convert through float for consistent behavior */
377363
float f = (x.v[0] & 0xFFFF) * (1.0f / 65536.0f);
378364
return FLOAT_TO_HALF(f);
379365
}
380366

381367
/* Uint4x32 to fp8 uniform - uses first 8 bits */
382-
ARRAYJIT_API uint8_t uint4x32_to_fp8_uniform(uint4x32_t x) {
368+
uint8_t uint4x32_to_fp8_uniform(uint4x32_t x) {
383369
return (uint8_t)(x.v[0] & 0xFF);
384370
}
385371

386372
/* Vectorized conversion functions that use all 128 bits efficiently */
387373

388374
/* Convert uint4x32 to 4 floats in [0, 1) */
389-
ARRAYJIT_API float4_t uint4x32_to_single_uniform_vec(uint4x32_t x) {
375+
float4_t uint4x32_to_single_uniform_vec(uint4x32_t x) {
390376
float4_t result;
391377
for (int i = 0; i < 4; i++) {
392378
result.v[i] = uint32_to_single_uniform(x.v[i]);
@@ -395,7 +381,7 @@ ARRAYJIT_API float4_t uint4x32_to_single_uniform_vec(uint4x32_t x) {
395381
}
396382

397383
/* Convert uint4x32 to 2 doubles in [0, 1) */
398-
ARRAYJIT_API double2_t uint4x32_to_double_uniform_vec(uint4x32_t x) {
384+
double2_t uint4x32_to_double_uniform_vec(uint4x32_t x) {
399385
double2_t result;
400386
uint64_t combined1 = ((uint64_t)x.v[1] << 32) | x.v[0];
401387
uint64_t combined2 = ((uint64_t)x.v[3] << 32) | x.v[2];
@@ -405,7 +391,7 @@ ARRAYJIT_API double2_t uint4x32_to_double_uniform_vec(uint4x32_t x) {
405391
}
406392

407393
/* Convert uint4x32 to 4 int32s - full range */
408-
ARRAYJIT_API int32x4_t uint4x32_to_int32_uniform_vec(uint4x32_t x) {
394+
int32x4_t uint4x32_to_int32_uniform_vec(uint4x32_t x) {
409395
int32x4_t result;
410396
for (int i = 0; i < 4; i++) {
411397
result.v[i] = (int32_t)x.v[i];
@@ -414,7 +400,7 @@ ARRAYJIT_API int32x4_t uint4x32_to_int32_uniform_vec(uint4x32_t x) {
414400
}
415401

416402
/* Convert uint4x32 to 2 int64s - full range */
417-
ARRAYJIT_API int64x2_t uint4x32_to_int64_uniform_vec(uint4x32_t x) {
403+
int64x2_t uint4x32_to_int64_uniform_vec(uint4x32_t x) {
418404
int64x2_t result;
419405
result.v[0] = (int64_t)(((uint64_t)x.v[1] << 32) | x.v[0]);
420406
result.v[1] = (int64_t)(((uint64_t)x.v[3] << 32) | x.v[2]);
@@ -423,7 +409,7 @@ ARRAYJIT_API int64x2_t uint4x32_to_int64_uniform_vec(uint4x32_t x) {
423409

424410

425411
/* Convert uint4x32 to 16 int8s - full range */
426-
ARRAYJIT_API int8x16_t uint4x32_to_byte_uniform_vec(uint4x32_t x) {
412+
int8x16_t uint4x32_to_byte_uniform_vec(uint4x32_t x) {
427413
int8x16_t result;
428414
for (int i = 0; i < 4; i++) {
429415
result.v[i*4 + 0] = (int8_t)(x.v[i] & 0xFF);
@@ -435,7 +421,7 @@ ARRAYJIT_API int8x16_t uint4x32_to_byte_uniform_vec(uint4x32_t x) {
435421
}
436422

437423
/* Convert uint4x32 to 8 uint16s - full range */
438-
ARRAYJIT_API uint16x8_t uint4x32_to_uint16_uniform_vec(uint4x32_t x) {
424+
uint16x8_t uint4x32_to_uint16_uniform_vec(uint4x32_t x) {
439425
uint16x8_t result;
440426
for (int i = 0; i < 4; i++) {
441427
result.v[i*2 + 0] = (uint16_t)(x.v[i] & 0xFFFF);
@@ -445,7 +431,7 @@ ARRAYJIT_API uint16x8_t uint4x32_to_uint16_uniform_vec(uint4x32_t x) {
445431
}
446432

447433
/* Convert uint4x32 to 8 bfloat16s uniform */
448-
ARRAYJIT_API uint16x8_t uint4x32_to_bfloat16_uniform_vec(uint4x32_t x) {
434+
uint16x8_t uint4x32_to_bfloat16_uniform_vec(uint4x32_t x) {
449435
uint16x8_t result;
450436
for (int i = 0; i < 4; i++) {
451437
// Convert each uint32 to two bfloat16 values
@@ -467,7 +453,7 @@ ARRAYJIT_API uint16x8_t uint4x32_to_bfloat16_uniform_vec(uint4x32_t x) {
467453
}
468454

469455
/* Convert uint4x32 to 8 float16s uniform */
470-
ARRAYJIT_API half8_t uint4x32_to_half_uniform_vec(uint4x32_t x) {
456+
half8_t uint4x32_to_half_uniform_vec(uint4x32_t x) {
471457
half8_t result;
472458
for (int i = 0; i < 4; i++) {
473459
// Extract two 16-bit values and convert to float in [0, 1)
@@ -482,7 +468,7 @@ ARRAYJIT_API half8_t uint4x32_to_half_uniform_vec(uint4x32_t x) {
482468
}
483469

484470
/* Convert uint4x32 to 16 fp8s uniform */
485-
ARRAYJIT_API uint8x16_t uint4x32_to_fp8_uniform_vec(uint4x32_t x) {
471+
uint8x16_t uint4x32_to_fp8_uniform_vec(uint4x32_t x) {
486472
uint8x16_t result;
487473
for (int i = 0; i < 4; i++) {
488474
result.v[i*4 + 0] = (uint8_t)(x.v[i] & 0xFF);
@@ -494,70 +480,70 @@ ARRAYJIT_API uint8x16_t uint4x32_to_fp8_uniform_vec(uint4x32_t x) {
494480
}
495481

496482
/* Conversion functions from various precisions to uint4x32_t */
497-
ARRAYJIT_API uint4x32_t single_to_uint4x32(float x) {
483+
uint4x32_t single_to_uint4x32(float x) {
498484
uint32_t bits;
499485
memcpy(&bits, &x, sizeof(float));
500486
uint4x32_t result = {{bits, 0, 0, 0}};
501487
return result;
502488
}
503489

504-
ARRAYJIT_API uint4x32_t double_to_uint4x32(double x) {
490+
uint4x32_t double_to_uint4x32(double x) {
505491
uint64_t bits;
506492
memcpy(&bits, &x, sizeof(double));
507493
uint4x32_t result = {{(uint32_t)(bits & 0xFFFFFFFF), (uint32_t)(bits >> 32), 0, 0}};
508494
return result;
509495
}
510496

511-
ARRAYJIT_API uint4x32_t int32_to_uint4x32(int32_t x) {
497+
uint4x32_t int32_to_uint4x32(int32_t x) {
512498
uint4x32_t result = {{(uint32_t)x, 0, 0, 0}};
513499
return result;
514500
}
515501

516-
ARRAYJIT_API uint4x32_t int64_to_uint4x32(int64_t x) {
502+
uint4x32_t int64_to_uint4x32(int64_t x) {
517503
uint64_t bits = (uint64_t)x;
518504
uint4x32_t result = {{(uint32_t)(bits & 0xFFFFFFFF), (uint32_t)(bits >> 32), 0, 0}};
519505
return result;
520506
}
521507

522-
ARRAYJIT_API uint4x32_t uint32_to_uint4x32(uint32_t x) {
508+
uint4x32_t uint32_to_uint4x32(uint32_t x) {
523509
uint4x32_t result = {{x, 0, 0, 0}};
524510
return result;
525511
}
526512

527-
ARRAYJIT_API uint4x32_t uint64_to_uint4x32(uint64_t x) {
513+
uint4x32_t uint64_to_uint4x32(uint64_t x) {
528514
uint4x32_t result = {{(uint32_t)(x & 0xFFFFFFFF), (uint32_t)(x >> 32), 0, 0}};
529515
return result;
530516
}
531517

532-
ARRAYJIT_API uint4x32_t byte_to_uint4x32(unsigned char x) {
518+
uint4x32_t byte_to_uint4x32(unsigned char x) {
533519
uint4x32_t result = {{(uint32_t)x, 0, 0, 0}};
534520
return result;
535521
}
536522

537-
ARRAYJIT_API uint4x32_t uint16_to_uint4x32(uint16_t x) {
523+
uint4x32_t uint16_to_uint4x32(uint16_t x) {
538524
uint4x32_t result = {{(uint32_t)x, 0, 0, 0}};
539525
return result;
540526
}
541527

542-
ARRAYJIT_API uint4x32_t bfloat16_to_uint4x32(uint16_t x) {
528+
uint4x32_t bfloat16_to_uint4x32(uint16_t x) {
543529
uint4x32_t result = {{(uint32_t)x, 0, 0, 0}};
544530
return result;
545531
}
546532

547-
ARRAYJIT_API uint4x32_t half_to_uint4x32(uint16_t x) {
533+
uint4x32_t half_to_uint4x32(uint16_t x) {
548534
uint4x32_t result = {{(uint32_t)x, 0, 0, 0}};
549535
return result;
550536
}
551537

552-
ARRAYJIT_API uint4x32_t fp8_to_uint4x32(uint8_t x) {
538+
uint4x32_t fp8_to_uint4x32(uint8_t x) {
553539
uint4x32_t result = {{(uint32_t)x, 0, 0, 0}};
554540
return result;
555541
}
556542

557543
/* Pure C conversion functions for use in C backends */
558544

559545
/* BFloat16 to Float conversion (C function) */
560-
ARRAYJIT_API float bfloat16_to_single(uint16_t bf16)
546+
float bfloat16_to_single(uint16_t bf16)
561547
{
562548
/* BFloat16 format: 1 sign bit, 8 exponent bits, 7 mantissa bits
563549
To convert to float32, we shift left by 16 bits */
@@ -566,7 +552,7 @@ ARRAYJIT_API float bfloat16_to_single(uint16_t bf16)
566552
}
567553

568554
/* Float to BFloat16 conversion (C function) */
569-
ARRAYJIT_API uint16_t single_to_bfloat16(float f)
555+
uint16_t single_to_bfloat16(float f)
570556
{
571557
uint32_t f32 = *((uint32_t *)&f);
572558

@@ -576,22 +562,22 @@ ARRAYJIT_API uint16_t single_to_bfloat16(float f)
576562
}
577563

578564
/* Half (Float16) to Float conversion (C function) */
579-
ARRAYJIT_API float half_to_single(uint16_t h)
565+
float half_to_single(uint16_t h)
580566
{
581567
HALF_T half_val = UINT16_TO_HALF(h);
582568
return HALF_TO_FLOAT(half_val);
583569
}
584570

585571
/* Float to Half (Float16) conversion (C function) */
586-
ARRAYJIT_API uint16_t single_to_half(float f)
572+
uint16_t single_to_half(float f)
587573
{
588574
HALF_T half_val = FLOAT_TO_HALF(f);
589575
return HALF_TO_UINT16(half_val);
590576
}
591577

592578
/* FP8 E5M2 format to Float conversion (C function)
593579
Format: 1 sign bit, 5 exponent bits, 2 mantissa bits */
594-
ARRAYJIT_API float fp8_to_single(uint8_t fp8)
580+
float fp8_to_single(uint8_t fp8)
595581
{
596582
/* Handle zero */
597583
if (fp8 == 0)
@@ -634,7 +620,7 @@ ARRAYJIT_API float fp8_to_single(uint8_t fp8)
634620
}
635621

636622
/* Float to FP8 E5M2 conversion (C function) */
637-
ARRAYJIT_API uint8_t single_to_fp8(float f)
623+
uint8_t single_to_fp8(float f)
638624
{
639625
/* Handle zero */
640626
if (f == 0.0f)

arrayjit/lib/cc_backend.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ let%track7_sexp c_compile_and_load ~f_path =
111111
(* Note: it seems waiting for the file to exist is necessary here and below regardless of needing
112112
the logs. *)
113113
let start_time = Unix.gettimeofday () in
114-
let timeout = 1.0 in
114+
let timeout = Float.of_string @@ Utils.get_global_arg ~default:"5.0" ~arg_name:"cc_backend_post_compile_timeout" in
115115
while rc = 0 && (not @@ (Stdlib.Sys.file_exists libname && Stdlib.Sys.file_exists log_fname)) do
116116
let elapsed = Unix.gettimeofday () -. start_time in
117117
if Float.(elapsed > timeout) then

ocannl_config.example

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,4 +202,7 @@ default_prec=single
202202

203203
# Limit on the allowed size of unrolled constant tensor nodes
204204
# (where the initialization code sets the values).
205-
limit_constant_fill_size=16
205+
limit_constant_fill_size=16
206+
207+
# The timeout for the CC backend to wait for the compilation files to appear.
208+
cc_backend_post_compile_timeout=5.0

0 commit comments

Comments
 (0)