11open Base
2+
3+ module Lazy = Utils. Lazy
4+
25(* * N-dimensional arrays: a precision-handling wrapper for [Bigarray.Genarray] and its utilities. *)
36
7+ (* External conversion functions for special float types *)
8+ external bfloat16_to_float : int -> float = " arrayjit_bfloat16_to_float"
9+ external float_to_bfloat16 : float -> int = " arrayjit_float_to_bfloat16"
10+ external fp8_to_float : int -> float = " arrayjit_fp8_to_float"
11+ external float_to_fp8 : float -> int = " arrayjit_float_to_fp8"
12+
413let _get_local_debug_runtime = Utils. get_local_debug_runtime
514
615[%% global_debug_log_level 9 ]
@@ -160,15 +169,15 @@ let create_bigarray (type ocaml elt_t) (prec : (ocaml, elt_t) Ops.precision) ~di
160169 init_bigarray_of_prec prec dims ~f: (fun idcs -> Float. of_int @@ indices_to_offset ~dims ~idcs )
161170 | Ops. Half , Standard_uniform ->
162171 init_bigarray_of_prec prec dims ~f: (fun _ -> Rand.Lib. float_range 0.0 1.0 )
163- | Ops. Bfloat16 , Constant_fill { values; strict } -> constant_fill_f Int. of_float values strict (* TODO: proper bfloat16 conversion *)
172+ | Ops. Bfloat16 , Constant_fill { values; strict } -> constant_fill_f float_to_bfloat16 values strict
164173 | Ops. Bfloat16 , Range_over_offsets ->
165- init_bigarray_of_prec prec dims ~f: (fun idcs -> indices_to_offset ~dims ~idcs ) (* TODO: proper bfloat16 conversion * )
166- | Ops. Bfloat16 , Standard_uniform -> init_bigarray_of_prec prec dims ~f: (fun _ -> Random. int 65536 ) (* TODO: proper bfloat16 conversion * )
167- | Ops. Fp8 , Constant_fill { values; strict } -> constant_fill_f (Fn. compose Char. of_int_exn Int. of_float ) values strict (* TODO: proper fp8 conversion *)
174+ init_bigarray_of_prec prec dims ~f: (fun idcs -> float_to_bfloat16 @@ Float. of_int @@ indices_to_offset ~dims ~idcs )
175+ | Ops. Bfloat16 , Standard_uniform -> init_bigarray_of_prec prec dims ~f: (fun _ -> float_to_bfloat16 @@ Rand.Lib. float_range 0.0 1.0 )
176+ | Ops. Fp8 , Constant_fill { values; strict } -> constant_fill_f (Fn. compose Char. of_int_exn float_to_fp8 ) values strict
168177 | Ops. Fp8 , Range_over_offsets ->
169178 init_bigarray_of_prec prec dims ~f: (fun idcs ->
170- Char. of_int_exn @@ indices_to_offset ~dims ~idcs )
171- | Ops. Fp8 , Standard_uniform -> init_bigarray_of_prec prec dims ~f: (fun _ -> Rand.Lib. char () )
179+ Char. of_int_exn @@ float_to_fp8 @@ Float. of_int @@ indices_to_offset ~dims ~idcs )
180+ | Ops. Fp8 , Standard_uniform -> init_bigarray_of_prec prec dims ~f: (fun _ -> Char. of_int_exn @@ float_to_fp8 @@ Rand.Lib. float_range 0.0 1.0 )
172181 | Ops. Single , Constant_fill { values; strict } -> constant_fill_float values strict
173182 | Ops. Single , Range_over_offsets ->
174183 init_bigarray_of_prec prec dims ~f: (fun idcs -> Float. of_int @@ indices_to_offset ~dims ~idcs )
@@ -255,8 +264,8 @@ let set_from_float arr idx v =
255264 | Uint16_nd arr -> A. set arr idx @@ Int. of_float v
256265 | Int32_nd arr -> A. set arr idx @@ Int32. of_float v
257266 | Half_nd arr -> A. set arr idx v
258- | Bfloat16_nd arr -> A. set arr idx @@ Int. of_float v (* TODO: proper bfloat16 conversion *)
259- | Fp8_nd arr -> A. set arr idx @@ Char. of_int_exn @@ Int. of_float v (* TODO: proper fp8 conversion *)
267+ | Bfloat16_nd arr -> A. set arr idx @@ float_to_bfloat16 v
268+ | Fp8_nd arr -> A. set arr idx @@ Char. of_int_exn @@ float_to_fp8 v
260269 | Single_nd arr -> A. set arr idx v
261270 | Double_nd arr -> A. set arr idx v
262271
@@ -266,8 +275,8 @@ let fill_from_float arr v =
266275 | Uint16_nd arr -> A. fill arr @@ Int. of_float v
267276 | Int32_nd arr -> A. fill arr @@ Int32. of_float v
268277 | Half_nd arr -> A. fill arr v
269- | Bfloat16_nd arr -> A. fill arr @@ Int. of_float v (* TODO: proper bfloat16 conversion *)
270- | Fp8_nd arr -> A. fill arr @@ Char. of_int_exn @@ Int. of_float v (* TODO: proper fp8 conversion *)
278+ | Bfloat16_nd arr -> A. fill arr @@ float_to_bfloat16 v
279+ | Fp8_nd arr -> A. fill arr @@ Char. of_int_exn @@ float_to_fp8 v
271280 | Single_nd arr -> A. fill arr v
272281 | Double_nd arr -> A. fill arr v
273282
@@ -319,14 +328,15 @@ let reset_bigarray (init_op : Ops.init_op) (type o b) (prec : (o, b) Ops.precisi
319328 | Ops. Half , Range_over_offsets ->
320329 set_bigarray arr ~f: (fun idcs -> Float. of_int @@ indices_to_offset ~dims ~idcs )
321330 | Ops. Half , Standard_uniform -> set_bigarray arr ~f: (fun _ -> Rand.Lib. float_range 0.0 1.0 )
322- | Ops. Bfloat16 , Constant_fill { values; strict } -> constant_set_f Int. of_float values strict (* TODO: proper bfloat16 conversion *)
331+ | Ops. Bfloat16 , Constant_fill { values; strict } -> constant_set_f float_to_bfloat16 values strict
323332 | Ops. Bfloat16 , Range_over_offsets ->
324- set_bigarray arr ~f: (fun idcs -> indices_to_offset ~dims ~idcs ) (* TODO: proper bfloat16 conversion * )
325- | Ops. Bfloat16 , Standard_uniform -> set_bigarray arr ~f: (fun _ -> Random. int 65536 ) (* TODO: proper bfloat16 conversion * )
326- | Ops. Fp8 , Constant_fill { values; strict } -> constant_set_f (Fn. compose Char. of_int_exn Int. of_float ) values strict (* TODO: proper fp8 conversion *)
333+ set_bigarray arr ~f: (fun idcs -> float_to_bfloat16 @@ Float. of_int @@ indices_to_offset ~dims ~idcs )
334+ | Ops. Bfloat16 , Standard_uniform -> set_bigarray arr ~f: (fun _ -> float_to_bfloat16 @@ Rand.Lib. float_range 0.0 1.0 )
335+ | Ops. Fp8 , Constant_fill { values; strict } -> constant_set_f (Fn. compose Char. of_int_exn float_to_fp8 ) values strict
327336 | Ops. Fp8 , Range_over_offsets ->
328- set_bigarray arr ~f: (fun idcs -> Char. of_int_exn @@ indices_to_offset ~dims ~idcs )
329- | Ops. Fp8 , Standard_uniform -> set_bigarray arr ~f: (fun _ -> Rand.Lib. char () )
337+ set_bigarray arr ~f: (fun idcs ->
338+ Char. of_int_exn @@ float_to_fp8 @@ Float. of_int @@ indices_to_offset ~dims ~idcs )
339+ | Ops. Fp8 , Standard_uniform -> set_bigarray arr ~f: (fun _ -> Char. of_int_exn @@ float_to_fp8 @@ Rand.Lib. float_range 0.0 1.0 )
330340 | Ops. Single , Constant_fill { values; strict } -> constant_set_float values strict
331341 | Ops. Single , Range_over_offsets ->
332342 set_bigarray arr ~f: (fun idcs -> Float. of_int @@ indices_to_offset ~dims ~idcs )
@@ -363,8 +373,8 @@ let fold_as_float ~init ~f arr =
363373 | Uint16_nd arr -> fold_bigarray ~init ~f: (fun accu idx v -> f accu idx @@ Float. of_int v) arr
364374 | Int32_nd arr -> fold_bigarray ~init ~f: (fun accu idx v -> f accu idx @@ Int32. to_float v) arr
365375 | Half_nd arr -> fold_bigarray ~init ~f arr
366- | Bfloat16_nd arr -> fold_bigarray ~init ~f: (fun accu idx v -> f accu idx @@ Float. of_int v) arr (* TODO: proper bfloat16 conversion *)
367- | Fp8_nd arr -> fold_bigarray ~init ~f: (fun accu idx c -> f accu idx @@ Float. of_int @@ Char. to_int c) arr (* TODO: proper fp8 conversion *)
376+ | Bfloat16_nd arr -> fold_bigarray ~init ~f: (fun accu idx v -> f accu idx @@ bfloat16_to_float v) arr
377+ | Fp8_nd arr -> fold_bigarray ~init ~f: (fun accu idx c -> f accu idx @@ fp8_to_float @@ Char. to_int c) arr
368378 | Single_nd arr -> fold_bigarray ~init ~f arr
369379 | Double_nd arr -> fold_bigarray ~init ~f arr
370380
@@ -380,8 +390,8 @@ let get_as_float arr idx =
380390 | Uint16_nd arr -> Float. of_int @@ A. get arr idx
381391 | Int32_nd arr -> Int32. to_float @@ A. get arr idx
382392 | Half_nd arr -> A. get arr idx
383- | Bfloat16_nd arr -> Float. of_int @@ A. get arr idx (* TODO: proper bfloat16 conversion *)
384- | Fp8_nd arr -> Float. of_int @@ Char. to_int @@ A. get arr idx (* TODO: proper fp8 conversion *)
393+ | Bfloat16_nd arr -> bfloat16_to_float @@ A. get arr idx
394+ | Fp8_nd arr -> fp8_to_float @@ Char. to_int @@ A. get arr idx
385395 | Single_nd arr -> A. get arr idx
386396 | Double_nd arr -> A. get arr idx
387397
0 commit comments