Skip to content

Commit 97f756c

Browse files
committed
Claude's second pass at adding BF16, FP8: conversion functions
Now struggling with a build bug causing arrayjit/test to read its parent ocannl_config.
1 parent 032facc commit 97f756c

File tree

9 files changed

+325
-89
lines changed

9 files changed

+325
-89
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ _opam/
3232

3333
# Local configuration
3434
ocannl_config
35+
!arrayjit/test/ocannl_config
3536
!test/ocannl_config
3637
!test/config/ocannl_config
3738
!test_ppx/ocannl_config

arrayjit/lib/arrayjit_stubs.c

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
#include <caml/alloc.h>
2+
#include <caml/memory.h>
3+
#include <caml/mlvalues.h>
4+
#include <math.h>
5+
#include <stdint.h>
6+
7+
/* BFloat16 to Float conversion */
8+
CAMLprim value arrayjit_bfloat16_to_float(value v_bf16)
9+
{
10+
CAMLparam1(v_bf16);
11+
uint16_t bf16 = (uint16_t)Int_val(v_bf16);
12+
13+
/* BFloat16 format: 1 sign bit, 8 exponent bits, 7 mantissa bits
14+
To convert to float32, we shift left by 16 bits */
15+
uint32_t f32 = ((uint32_t)bf16) << 16;
16+
float result = *((float*)&f32);
17+
18+
CAMLreturn(caml_copy_double((double)result));
19+
}
20+
21+
/* Float to BFloat16 conversion */
22+
CAMLprim value arrayjit_float_to_bfloat16(value v_float)
23+
{
24+
CAMLparam1(v_float);
25+
float f = (float)Double_val(v_float);
26+
uint32_t f32 = *((uint32_t*)&f);
27+
28+
/* Round to nearest even */
29+
uint32_t rounded = f32 + 0x7FFF + ((f32 >> 16) & 1);
30+
uint16_t bf16 = (uint16_t)(rounded >> 16);
31+
32+
CAMLreturn(Val_int(bf16));
33+
}
34+
35+
/* FP8 E5M2 format to Float conversion
36+
Format: 1 sign bit, 5 exponent bits, 2 mantissa bits */
37+
CAMLprim value arrayjit_fp8_to_float(value v_fp8)
38+
{
39+
CAMLparam1(v_fp8);
40+
uint8_t fp8 = (uint8_t)Int_val(v_fp8);
41+
42+
/* Handle zero */
43+
if (fp8 == 0) {
44+
CAMLreturn(caml_copy_double(0.0));
45+
}
46+
47+
uint32_t sign = (fp8 >> 7) & 1;
48+
uint32_t exp = (fp8 >> 2) & 0x1F;
49+
uint32_t mant = fp8 & 0x3;
50+
51+
/* Handle special cases */
52+
if (exp == 0x1F) { /* Infinity or NaN */
53+
if (mant == 0) {
54+
float inf = sign ? -INFINITY : INFINITY;
55+
CAMLreturn(caml_copy_double((double)inf));
56+
} else {
57+
CAMLreturn(caml_copy_double((double)NAN));
58+
}
59+
}
60+
61+
/* Denormalized numbers */
62+
if (exp == 0) {
63+
float result = ldexpf((float)mant / 4.0f, -14);
64+
if (sign) result = -result;
65+
CAMLreturn(caml_copy_double((double)result));
66+
}
67+
68+
/* Normalized numbers */
69+
float result = (1.0f + (float)mant * 0.25f) * ldexpf(1.0f, (int)exp - 15);
70+
if (sign) result = -result;
71+
72+
CAMLreturn(caml_copy_double((double)result));
73+
}
74+
75+
/* Float to FP8 E5M2 conversion */
76+
CAMLprim value arrayjit_float_to_fp8(value v_float)
77+
{
78+
CAMLparam1(v_float);
79+
float f = (float)Double_val(v_float);
80+
81+
/* Handle zero */
82+
if (f == 0.0f) {
83+
CAMLreturn(Val_int(0));
84+
}
85+
86+
uint32_t sign = (f < 0) ? 1 : 0;
87+
f = fabsf(f);
88+
89+
/* Handle special cases */
90+
if (isinf(f)) {
91+
CAMLreturn(Val_int((sign << 7) | 0x7C)); /* Infinity: exp=0x1F, mant=0 */
92+
}
93+
if (isnan(f)) {
94+
CAMLreturn(Val_int((sign << 7) | 0x7F)); /* NaN: exp=0x1F, mant!=0 */
95+
}
96+
97+
/* Get exponent and mantissa */
98+
int exp_val;
99+
float mant_f = frexpf(f, &exp_val);
100+
int exp = exp_val + 14; /* Bias is 15, but frexp gives us mantissa in [0.5, 1) */
101+
102+
/* Clamp to representable range */
103+
if (exp < 0) {
104+
/* Underflow to zero */
105+
CAMLreturn(Val_int(sign << 7));
106+
}
107+
if (exp > 30) {
108+
/* Overflow to infinity */
109+
CAMLreturn(Val_int((sign << 7) | 0x7C));
110+
}
111+
112+
/* Handle denormalized numbers */
113+
if (exp == 0) {
114+
float denorm_mant = f * ldexpf(1.0f, 14) * 4.0f;
115+
uint32_t mant_bits = (uint32_t)(denorm_mant + 0.5f);
116+
if (mant_bits > 3) mant_bits = 3;
117+
CAMLreturn(Val_int((sign << 7) | mant_bits));
118+
}
119+
120+
/* Normalized numbers: convert mantissa from [0.5, 1) to [0, 0.75] */
121+
mant_f = (mant_f - 0.5f) * 4.0f;
122+
uint32_t mant_bits = (uint32_t)(mant_f + 0.5f); /* Round to nearest */
123+
if (mant_bits > 3) mant_bits = 3;
124+
125+
uint8_t result = (uint8_t)((sign << 7) | ((exp & 0x1F) << 2) | (mant_bits & 0x3));
126+
CAMLreturn(Val_int(result));
127+
}

arrayjit/lib/cc_backend.ml

Lines changed: 1 addition & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -209,72 +209,4 @@ let%track3_sexp link_compiled ~merge_buffer ~runner_label ctx_arrays (code : pro
209209
context_lifetime = (ctx_arrays, code);
210210
description = "executes " ^ code.name ^ " on " ^ runner_label;
211211
work;
212-
} )
213-
(*
214-
let typ_of_prec = function
215-
| Ops.Byte_prec _ -> "unsigned char"
216-
| Ops.Uint16_prec _ -> "unsigned short"
217-
| Ops.Int32_prec _ -> "int"
218-
| Ops.Half_prec _ -> "_Float16"
219-
| Ops.Bfloat16_prec _ -> "unsigned short" (* Stored as uint16, emulated as float *)
220-
| Ops.Fp8_prec _ -> "unsigned char" (* Stored as uint8, emulated as float *)
221-
| Ops.Single_prec _ -> "float"
222-
| Ops.Double_prec _ -> "double"
223-
| Ops.Void_prec -> "void"
224-
225-
(* Helper functions for bfloat16 and fp8 conversions *)
226-
let extra_declarations =
227-
[
228-
"/* Emulation functions for special float types */";
229-
"static inline float bfloat16_to_float(unsigned short bf16) {";
230-
" unsigned int f32 = ((unsigned int)bf16) << 16;";
231-
" return *(float*)&f32;";
232-
"}";
233-
"";
234-
"static inline unsigned short float_to_bfloat16(float f) {";
235-
" unsigned int f32 = *(unsigned int*)&f;";
236-
" unsigned int rounded = f32 + 0x7FFF + ((f32 >> 16) & 1);";
237-
" return (unsigned short)(rounded >> 16);";
238-
"}";
239-
"";
240-
"/* Simplified FP8 E5M2 format emulation */";
241-
"static inline float fp8_to_float(unsigned char fp8) {";
242-
" if (fp8 == 0) return 0.0f;";
243-
" unsigned int sign = (fp8 >> 7) & 1;";
244-
" unsigned int exp = (fp8 >> 2) & 0x1F;";
245-
" unsigned int mant = fp8 & 0x3;";
246-
" float result = (1.0f + mant * 0.25f) * powf(2.0f, (float)exp - 15.0f);";
247-
" return sign ? -result : result;";
248-
"}";
249-
"";
250-
"static inline unsigned char float_to_fp8(float f) {";
251-
" if (f == 0.0f) return 0;";
252-
" unsigned int sign = (f < 0) ? 1 : 0;";
253-
" f = fabsf(f);";
254-
" int exp = (int)floorf(log2f(f)) + 15;";
255-
" if (exp < 0) return 0;";
256-
" if (exp > 31) return sign ? 0xFF : 0x7F;";
257-
" float mant = f / powf(2.0f, (float)exp - 15.0f) - 1.0f;";
258-
" unsigned int mant_bits = (unsigned int)(mant * 4.0f + 0.5f);";
259-
" if (mant_bits > 3) mant_bits = 3;";
260-
" return (unsigned char)((sign << 7) | ((exp & 0x1F) << 2) | (mant_bits & 0x3));";
261-
"}";
262-
]
263-
264-
let convert_precision ~from ~to_ =
265-
match (from, to_) with
266-
| p1, p2 when Ops.equal_prec p1 p2 -> ("", "")
267-
| Ops.Bfloat16_prec _, Ops.Single_prec _ -> ("bfloat16_to_float(", ")")
268-
| Ops.Bfloat16_prec _, Ops.Double_prec _ -> ("(double)bfloat16_to_float(", ")")
269-
| Ops.Single_prec _, Ops.Bfloat16_prec _ -> ("float_to_bfloat16(", ")")
270-
| Ops.Double_prec _, Ops.Bfloat16_prec _ -> ("float_to_bfloat16((float)", ")")
271-
| Ops.Fp8_prec _, Ops.Single_prec _ -> ("fp8_to_float(", ")")
272-
| Ops.Fp8_prec _, Ops.Double_prec _ -> ("(double)fp8_to_float(", ")")
273-
| Ops.Single_prec _, Ops.Fp8_prec _ -> ("float_to_fp8(", ")")
274-
| Ops.Double_prec _, Ops.Fp8_prec _ -> ("float_to_fp8((float)", ")")
275-
| Ops.Bfloat16_prec _, _ -> ("(float)bfloat16_to_float(", ")") (* Convert via float *)
276-
| _, Ops.Bfloat16_prec _ -> ("float_to_bfloat16((float)", ")")
277-
| Ops.Fp8_prec _, _ -> ("(float)fp8_to_float(", ")") (* Convert via float *)
278-
| _, Ops.Fp8_prec _ -> ("float_to_fp8((float)", ")")
279-
| _ -> Ops.c_convert_precision ~from ~to_
280-
*)
212+
} )

arrayjit/lib/dune

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@
3333
saturn_lockfree
3434
utils
3535
ppx_minidebug.runtime)
36+
(foreign_stubs
37+
(language c)
38+
(names arrayjit_stubs))
3639
(preprocess
3740
(pps
3841
ppx_compare

arrayjit/lib/ndarray.ml

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
open Base
2+
3+
module Lazy = Utils.Lazy
4+
25
(** N-dimensional arrays: a precision-handling wrapper for [Bigarray.Genarray] and its utilities. *)
36

7+
(* External conversion functions for special float types *)
8+
external bfloat16_to_float : int -> float = "arrayjit_bfloat16_to_float"
9+
external float_to_bfloat16 : float -> int = "arrayjit_float_to_bfloat16"
10+
external fp8_to_float : int -> float = "arrayjit_fp8_to_float"
11+
external float_to_fp8 : float -> int = "arrayjit_float_to_fp8"
12+
413
let _get_local_debug_runtime = Utils.get_local_debug_runtime
514

615
[%%global_debug_log_level 9]
@@ -160,15 +169,15 @@ let create_bigarray (type ocaml elt_t) (prec : (ocaml, elt_t) Ops.precision) ~di
160169
init_bigarray_of_prec prec dims ~f:(fun idcs -> Float.of_int @@ indices_to_offset ~dims ~idcs)
161170
| Ops.Half, Standard_uniform ->
162171
init_bigarray_of_prec prec dims ~f:(fun _ -> Rand.Lib.float_range 0.0 1.0)
163-
| Ops.Bfloat16, Constant_fill { values; strict } -> constant_fill_f Int.of_float values strict (* TODO: proper bfloat16 conversion *)
172+
| Ops.Bfloat16, Constant_fill { values; strict } -> constant_fill_f float_to_bfloat16 values strict
164173
| Ops.Bfloat16, Range_over_offsets ->
165-
init_bigarray_of_prec prec dims ~f:(fun idcs -> indices_to_offset ~dims ~idcs) (* TODO: proper bfloat16 conversion *)
166-
| Ops.Bfloat16, Standard_uniform -> init_bigarray_of_prec prec dims ~f:(fun _ -> Random.int 65536) (* TODO: proper bfloat16 conversion *)
167-
| Ops.Fp8, Constant_fill { values; strict } -> constant_fill_f (Fn.compose Char.of_int_exn Int.of_float) values strict (* TODO: proper fp8 conversion *)
174+
init_bigarray_of_prec prec dims ~f:(fun idcs -> float_to_bfloat16 @@ Float.of_int @@ indices_to_offset ~dims ~idcs)
175+
| Ops.Bfloat16, Standard_uniform -> init_bigarray_of_prec prec dims ~f:(fun _ -> float_to_bfloat16 @@ Rand.Lib.float_range 0.0 1.0)
176+
| Ops.Fp8, Constant_fill { values; strict } -> constant_fill_f (Fn.compose Char.of_int_exn float_to_fp8) values strict
168177
| Ops.Fp8, Range_over_offsets ->
169178
init_bigarray_of_prec prec dims ~f:(fun idcs ->
170-
Char.of_int_exn @@ indices_to_offset ~dims ~idcs)
171-
| Ops.Fp8, Standard_uniform -> init_bigarray_of_prec prec dims ~f:(fun _ -> Rand.Lib.char ())
179+
Char.of_int_exn @@ float_to_fp8 @@ Float.of_int @@ indices_to_offset ~dims ~idcs)
180+
| Ops.Fp8, Standard_uniform -> init_bigarray_of_prec prec dims ~f:(fun _ -> Char.of_int_exn @@ float_to_fp8 @@ Rand.Lib.float_range 0.0 1.0)
172181
| Ops.Single, Constant_fill { values; strict } -> constant_fill_float values strict
173182
| Ops.Single, Range_over_offsets ->
174183
init_bigarray_of_prec prec dims ~f:(fun idcs -> Float.of_int @@ indices_to_offset ~dims ~idcs)
@@ -255,8 +264,8 @@ let set_from_float arr idx v =
255264
| Uint16_nd arr -> A.set arr idx @@ Int.of_float v
256265
| Int32_nd arr -> A.set arr idx @@ Int32.of_float v
257266
| Half_nd arr -> A.set arr idx v
258-
| Bfloat16_nd arr -> A.set arr idx @@ Int.of_float v (* TODO: proper bfloat16 conversion *)
259-
| Fp8_nd arr -> A.set arr idx @@ Char.of_int_exn @@ Int.of_float v (* TODO: proper fp8 conversion *)
267+
| Bfloat16_nd arr -> A.set arr idx @@ float_to_bfloat16 v
268+
| Fp8_nd arr -> A.set arr idx @@ Char.of_int_exn @@ float_to_fp8 v
260269
| Single_nd arr -> A.set arr idx v
261270
| Double_nd arr -> A.set arr idx v
262271

@@ -266,8 +275,8 @@ let fill_from_float arr v =
266275
| Uint16_nd arr -> A.fill arr @@ Int.of_float v
267276
| Int32_nd arr -> A.fill arr @@ Int32.of_float v
268277
| Half_nd arr -> A.fill arr v
269-
| Bfloat16_nd arr -> A.fill arr @@ Int.of_float v (* TODO: proper bfloat16 conversion *)
270-
| Fp8_nd arr -> A.fill arr @@ Char.of_int_exn @@ Int.of_float v (* TODO: proper fp8 conversion *)
278+
| Bfloat16_nd arr -> A.fill arr @@ float_to_bfloat16 v
279+
| Fp8_nd arr -> A.fill arr @@ Char.of_int_exn @@ float_to_fp8 v
271280
| Single_nd arr -> A.fill arr v
272281
| Double_nd arr -> A.fill arr v
273282

@@ -319,14 +328,15 @@ let reset_bigarray (init_op : Ops.init_op) (type o b) (prec : (o, b) Ops.precisi
319328
| Ops.Half, Range_over_offsets ->
320329
set_bigarray arr ~f:(fun idcs -> Float.of_int @@ indices_to_offset ~dims ~idcs)
321330
| Ops.Half, Standard_uniform -> set_bigarray arr ~f:(fun _ -> Rand.Lib.float_range 0.0 1.0)
322-
| Ops.Bfloat16, Constant_fill { values; strict } -> constant_set_f Int.of_float values strict (* TODO: proper bfloat16 conversion *)
331+
| Ops.Bfloat16, Constant_fill { values; strict } -> constant_set_f float_to_bfloat16 values strict
323332
| Ops.Bfloat16, Range_over_offsets ->
324-
set_bigarray arr ~f:(fun idcs -> indices_to_offset ~dims ~idcs) (* TODO: proper bfloat16 conversion *)
325-
| Ops.Bfloat16, Standard_uniform -> set_bigarray arr ~f:(fun _ -> Random.int 65536) (* TODO: proper bfloat16 conversion *)
326-
| Ops.Fp8, Constant_fill { values; strict } -> constant_set_f (Fn.compose Char.of_int_exn Int.of_float) values strict (* TODO: proper fp8 conversion *)
333+
set_bigarray arr ~f:(fun idcs -> float_to_bfloat16 @@ Float.of_int @@ indices_to_offset ~dims ~idcs)
334+
| Ops.Bfloat16, Standard_uniform -> set_bigarray arr ~f:(fun _ -> float_to_bfloat16 @@ Rand.Lib.float_range 0.0 1.0)
335+
| Ops.Fp8, Constant_fill { values; strict } -> constant_set_f (Fn.compose Char.of_int_exn float_to_fp8) values strict
327336
| Ops.Fp8, Range_over_offsets ->
328-
set_bigarray arr ~f:(fun idcs -> Char.of_int_exn @@ indices_to_offset ~dims ~idcs)
329-
| Ops.Fp8, Standard_uniform -> set_bigarray arr ~f:(fun _ -> Rand.Lib.char ())
337+
set_bigarray arr ~f:(fun idcs ->
338+
Char.of_int_exn @@ float_to_fp8 @@ Float.of_int @@ indices_to_offset ~dims ~idcs)
339+
| Ops.Fp8, Standard_uniform -> set_bigarray arr ~f:(fun _ -> Char.of_int_exn @@ float_to_fp8 @@ Rand.Lib.float_range 0.0 1.0)
330340
| Ops.Single, Constant_fill { values; strict } -> constant_set_float values strict
331341
| Ops.Single, Range_over_offsets ->
332342
set_bigarray arr ~f:(fun idcs -> Float.of_int @@ indices_to_offset ~dims ~idcs)
@@ -363,8 +373,8 @@ let fold_as_float ~init ~f arr =
363373
| Uint16_nd arr -> fold_bigarray ~init ~f:(fun accu idx v -> f accu idx @@ Float.of_int v) arr
364374
| Int32_nd arr -> fold_bigarray ~init ~f:(fun accu idx v -> f accu idx @@ Int32.to_float v) arr
365375
| Half_nd arr -> fold_bigarray ~init ~f arr
366-
| Bfloat16_nd arr -> fold_bigarray ~init ~f:(fun accu idx v -> f accu idx @@ Float.of_int v) arr (* TODO: proper bfloat16 conversion *)
367-
| Fp8_nd arr -> fold_bigarray ~init ~f:(fun accu idx c -> f accu idx @@ Float.of_int @@ Char.to_int c) arr (* TODO: proper fp8 conversion *)
376+
| Bfloat16_nd arr -> fold_bigarray ~init ~f:(fun accu idx v -> f accu idx @@ bfloat16_to_float v) arr
377+
| Fp8_nd arr -> fold_bigarray ~init ~f:(fun accu idx c -> f accu idx @@ fp8_to_float @@ Char.to_int c) arr
368378
| Single_nd arr -> fold_bigarray ~init ~f arr
369379
| Double_nd arr -> fold_bigarray ~init ~f arr
370380

@@ -380,8 +390,8 @@ let get_as_float arr idx =
380390
| Uint16_nd arr -> Float.of_int @@ A.get arr idx
381391
| Int32_nd arr -> Int32.to_float @@ A.get arr idx
382392
| Half_nd arr -> A.get arr idx
383-
| Bfloat16_nd arr -> Float.of_int @@ A.get arr idx (* TODO: proper bfloat16 conversion *)
384-
| Fp8_nd arr -> Float.of_int @@ Char.to_int @@ A.get arr idx (* TODO: proper fp8 conversion *)
393+
| Bfloat16_nd arr -> bfloat16_to_float @@ A.get arr idx
394+
| Fp8_nd arr -> fp8_to_float @@ Char.to_int @@ A.get arr idx
385395
| Single_nd arr -> A.get arr idx
386396
| Double_nd arr -> A.get arr idx
387397

arrayjit/test/dune

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
(executable
2+
(name test_numerical_types)
3+
(modules test_numerical_types)
4+
(libraries base stdio arrayjit.ir)
5+
(preprocess
6+
(pps ppx_jane)))
7+
8+
(rule
9+
(target test_numerical_types.output)
10+
(deps test_numerical_types.exe)
11+
(action
12+
(with-stdout-to %{target}
13+
(run %{deps}))))
14+
15+
(rule
16+
(alias runtest)
17+
(action
18+
(diff test_numerical_types.expected test_numerical_types.output)))

arrayjit/test/ocannl_config

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
randomness_lib=for_tests
2+
log_main_domain_to_stdout=true
3+
backend=multicore_cc
4+
log_level=0
5+
print_decimals_precision=2
6+
prefer_backend_uniformity=true

0 commit comments

Comments
 (0)