Skip to content

Commit 032facc

Browse files
committed
Untested / broken: Claude's first pass at adding BF16, FP8, uint16, int32
1 parent 95b436d commit 032facc

File tree

7 files changed

+371
-108
lines changed

7 files changed

+371
-108
lines changed

arrayjit/lib/c_syntax.ml

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ struct
175175
match op with
176176
| Ops.Satur01_gate -> (
177177
match prec with
178-
| Ops.Byte_prec _ ->
178+
| Ops.Byte_prec _ | Ops.Uint16_prec _ | Ops.Int32_prec _ | Ops.Fp8_prec _ ->
179179
let open PPrint in
180180
group
181181
(parens
@@ -185,10 +185,25 @@ struct
185185
^^ string " < 1.0f"))
186186
^^ ifflat
187187
(space ^^ string "?" ^^ space ^^ v2 ^^ space ^^ string ":" ^^ space
188-
^^ string "(unsigned char)0")
188+
^^ string "(" ^^ string (typ_of_prec prec) ^^ string ")0")
189189
(nest 2
190190
(break 1 ^^ string "?" ^^ space ^^ v2 ^^ break 1 ^^ string ":" ^^ space
191-
^^ string "(unsigned char)0"))))
191+
^^ string "(" ^^ string (typ_of_prec prec) ^^ string ")0"))))
192+
| Ops.Bfloat16_prec _ ->
193+
(* For CC backend, convert to float for computation *)
194+
let open PPrint in
195+
group
196+
(parens
197+
(group
198+
(parens
199+
(string "(float)" ^^ v1 ^^ string " > 0.0f && (float)" ^^ v1
200+
^^ string " < 1.0f"))
201+
^^ ifflat
202+
(space ^^ string "?" ^^ space ^^ v2 ^^ space ^^ string ":" ^^ space
203+
^^ string "(unsigned short)0")
204+
(nest 2
205+
(break 1 ^^ string "?" ^^ space ^^ v2 ^^ break 1 ^^ string ":" ^^ space
206+
^^ string "(unsigned short)0"))))
192207
| Ops.Half_prec _ ->
193208
let open PPrint in
194209
group

arrayjit/lib/cc_backend.ml

Lines changed: 86 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -83,16 +83,30 @@ let%track7_sexp c_compile_and_load ~f_name =
8383
Stdlib.Gc.finalise finalize result;
8484
result
8585

86-
let%diagn_sexp compile ~(name : string) bindings (lowered : Low_level.optimized) : procedure =
87-
let module Syntax = C_syntax.C_syntax (C_syntax.Pure_C_config (struct
86+
module CC_syntax_config (Procs : sig
87+
val procs : Low_level.optimized array
88+
end) =
89+
struct
90+
include C_syntax.Pure_C_config (struct
8891
type nonrec buffer_ptr = buffer_ptr
8992

9093
let use_host_memory = use_host_memory
91-
let procs = [| lowered |]
94+
let procs = Procs.procs
9295

9396
let full_printf_support =
9497
not @@ Bool.of_string
9598
@@ Utils.get_global_arg ~default:"false" ~arg_name:"prefer_backend_uniformity"
99+
end)
100+
101+
(* Override to add our custom type and conversion support *)
102+
let typ_of_prec = typ_of_prec
103+
let extra_declarations = extra_declarations (* Our bfloat16/fp8 conversion functions *)
104+
let convert_precision = convert_precision
105+
end
106+
107+
let%diagn_sexp compile ~(name : string) bindings (lowered : Low_level.optimized) : procedure =
108+
let module Syntax = C_syntax.C_syntax (CC_syntax_config (struct
109+
let procs = [| lowered |]
96110
end)) in
97111
(* FIXME: do we really want all of them, or only the used ones? *)
98112
let idx_params = Indexing.bound_symbols bindings in
@@ -110,15 +124,8 @@ let%diagn_sexp compile ~(name : string) bindings (lowered : Low_level.optimized)
110124

111125
let%diagn_sexp compile_batch ~names bindings (lowereds : Low_level.optimized option array) :
112126
procedure option array =
113-
let module Syntax = C_syntax.C_syntax (C_syntax.Pure_C_config (struct
114-
type nonrec buffer_ptr = buffer_ptr
115-
116-
let use_host_memory = use_host_memory
127+
let module Syntax = C_syntax.C_syntax (CC_syntax_config (struct
117128
let procs = Array.filter_opt lowereds
118-
119-
let full_printf_support =
120-
not @@ Bool.of_string
121-
@@ Utils.get_global_arg ~default:"false" ~arg_name:"prefer_backend_uniformity"
122129
end)) in
123130
(* FIXME: do we really want all of them, or only the used ones? *)
124131
let idx_params = Indexing.bound_symbols bindings in
@@ -203,3 +210,71 @@ let%track3_sexp link_compiled ~merge_buffer ~runner_label ctx_arrays (code : pro
203210
description = "executes " ^ code.name ^ " on " ^ runner_label;
204211
work;
205212
} )
213+
(*
214+
let typ_of_prec = function
215+
| Ops.Byte_prec _ -> "unsigned char"
216+
| Ops.Uint16_prec _ -> "unsigned short"
217+
| Ops.Int32_prec _ -> "int"
218+
| Ops.Half_prec _ -> "_Float16"
219+
| Ops.Bfloat16_prec _ -> "unsigned short" (* Stored as uint16, emulated as float *)
220+
| Ops.Fp8_prec _ -> "unsigned char" (* Stored as uint8, emulated as float *)
221+
| Ops.Single_prec _ -> "float"
222+
| Ops.Double_prec _ -> "double"
223+
| Ops.Void_prec -> "void"
224+
225+
(* Helper functions for bfloat16 and fp8 conversions *)
226+
let extra_declarations =
227+
[
228+
"/* Emulation functions for special float types */";
229+
"static inline float bfloat16_to_float(unsigned short bf16) {";
230+
" unsigned int f32 = ((unsigned int)bf16) << 16;";
231+
" return *(float*)&f32;";
232+
"}";
233+
"";
234+
"static inline unsigned short float_to_bfloat16(float f) {";
235+
" unsigned int f32 = *(unsigned int*)&f;";
236+
" unsigned int rounded = f32 + 0x7FFF + ((f32 >> 16) & 1);";
237+
" return (unsigned short)(rounded >> 16);";
238+
"}";
239+
"";
240+
"/* Simplified FP8 E5M2 format emulation */";
241+
"static inline float fp8_to_float(unsigned char fp8) {";
242+
" if (fp8 == 0) return 0.0f;";
243+
" unsigned int sign = (fp8 >> 7) & 1;";
244+
" unsigned int exp = (fp8 >> 2) & 0x1F;";
245+
" unsigned int mant = fp8 & 0x3;";
246+
" float result = (1.0f + mant * 0.25f) * powf(2.0f, (float)exp - 15.0f);";
247+
" return sign ? -result : result;";
248+
"}";
249+
"";
250+
"static inline unsigned char float_to_fp8(float f) {";
251+
" if (f == 0.0f) return 0;";
252+
" unsigned int sign = (f < 0) ? 1 : 0;";
253+
" f = fabsf(f);";
254+
" int exp = (int)floorf(log2f(f)) + 15;";
255+
" if (exp < 0) return 0;";
256+
" if (exp > 31) return sign ? 0xFF : 0x7F;";
257+
" float mant = f / powf(2.0f, (float)exp - 15.0f) - 1.0f;";
258+
" unsigned int mant_bits = (unsigned int)(mant * 4.0f + 0.5f);";
259+
" if (mant_bits > 3) mant_bits = 3;";
260+
" return (unsigned char)((sign << 7) | ((exp & 0x1F) << 2) | (mant_bits & 0x3));";
261+
"}";
262+
]
263+
264+
let convert_precision ~from ~to_ =
265+
match (from, to_) with
266+
| p1, p2 when Ops.equal_prec p1 p2 -> ("", "")
267+
| Ops.Bfloat16_prec _, Ops.Single_prec _ -> ("bfloat16_to_float(", ")")
268+
| Ops.Bfloat16_prec _, Ops.Double_prec _ -> ("(double)bfloat16_to_float(", ")")
269+
| Ops.Single_prec _, Ops.Bfloat16_prec _ -> ("float_to_bfloat16(", ")")
270+
| Ops.Double_prec _, Ops.Bfloat16_prec _ -> ("float_to_bfloat16((float)", ")")
271+
| Ops.Fp8_prec _, Ops.Single_prec _ -> ("fp8_to_float(", ")")
272+
| Ops.Fp8_prec _, Ops.Double_prec _ -> ("(double)fp8_to_float(", ")")
273+
| Ops.Single_prec _, Ops.Fp8_prec _ -> ("float_to_fp8(", ")")
274+
| Ops.Double_prec _, Ops.Fp8_prec _ -> ("float_to_fp8((float)", ")")
275+
| Ops.Bfloat16_prec _, _ -> ("(float)bfloat16_to_float(", ")") (* Convert via float *)
276+
| _, Ops.Bfloat16_prec _ -> ("float_to_bfloat16((float)", ")")
277+
| Ops.Fp8_prec _, _ -> ("(float)fp8_to_float(", ")") (* Convert via float *)
278+
| _, Ops.Fp8_prec _ -> ("float_to_fp8((float)", ")")
279+
| _ -> Ops.c_convert_precision ~from ~to_
280+
*)

arrayjit/lib/cuda_backend.ml

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -281,10 +281,14 @@ end) : Ir.Backend_impl.Lowered_backend = struct
281281

282282
let typ_of_prec = function
283283
| Ops.Byte_prec _ -> "unsigned char"
284-
| Half_prec _ -> "__half"
285-
| Single_prec _ -> "float"
286-
| Double_prec _ -> "double"
287-
| Void_prec -> "void"
284+
| Ops.Uint16_prec _ -> "unsigned short"
285+
| Ops.Int32_prec _ -> "int"
286+
| Ops.Half_prec _ -> "__half"
287+
| Ops.Bfloat16_prec _ -> "__nv_bfloat16" (* CUDA bfloat16 type *)
288+
| Ops.Fp8_prec _ -> "__nv_fp8_e5m2" (* CUDA FP8 type (E5M2 format) *)
289+
| Ops.Single_prec _ -> "float"
290+
| Ops.Double_prec _ -> "double"
291+
| Ops.Void_prec -> "void"
288292

289293
let binop_syntax prec v =
290294
(* TODO: consider using binop_syntax inherited from Pure_C_config and overriding only where
@@ -317,9 +321,14 @@ end) : Ir.Backend_impl.Lowered_backend = struct
317321
(string "hexp2(hlog2(" ^^ v1 ^^ string "),"
318322
^^ ifflat (space ^^ v2) (nest 2 (break 1 ^^ v2))
319323
^^ string ")")
320-
| ToPowOf, Byte_prec _ ->
321-
invalid_arg "Cuda_backend.binop_syntax: ToPowOf not supported for byte/integer precisions"
322-
| Relu_gate, Byte_prec _ ->
324+
| ToPowOf, (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Fp8_prec _) ->
325+
invalid_arg "Cuda_backend.binop_syntax: ToPowOf not supported for integer precisions"
326+
| ToPowOf, Bfloat16_prec _ ->
327+
fun v1 v2 ->
328+
group
329+
(string "__float2bfloat16(powf(__bfloat162float(" ^^ v1 ^^ string "), __bfloat162float("
330+
^^ v2 ^^ string ")))")
331+
| Relu_gate, (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Fp8_prec _) ->
323332
fun v1 v2 ->
324333
group
325334
(parens
@@ -330,31 +339,19 @@ end) : Ir.Backend_impl.Lowered_backend = struct
330339
(nest 2
331340
(break 1 ^^ string "?" ^^ space ^^ v2 ^^ break 1 ^^ string ":" ^^ space
332341
^^ string "0"))))
333-
| Relu_gate, Half_prec _ ->
342+
| Relu_gate, Bfloat16_prec _ ->
334343
fun v1 v2 ->
335344
group
336345
(parens
337346
(group
338347
(parens
339-
(string "__hgt(" ^^ v1 ^^ comma
340-
^^ string " __ushort_as_half((unsigned short)0x0000U))"))
341-
^^ ifflat
342-
(space ^^ string "?" ^^ space ^^ v2 ^^ space ^^ string ":" ^^ space
343-
^^ string "__ushort_as_half((unsigned short)0x0000U)")
344-
(nest 2
345-
(break 1 ^^ string "?" ^^ space ^^ v2 ^^ break 1 ^^ string ":" ^^ space
346-
^^ string "__ushort_as_half((unsigned short)0x0000U)"))))
347-
| Relu_gate, _ ->
348-
fun v1 v2 ->
349-
group
350-
(parens
351-
(group (parens (v1 ^^ string " > 0.0"))
348+
(string "__bfloat162float(" ^^ v1 ^^ string ") > 0.0f"))
352349
^^ ifflat
353350
(space ^^ string "?" ^^ space ^^ v2 ^^ space ^^ string ":" ^^ space
354-
^^ string "0.0")
351+
^^ string "__float2bfloat16(0.0f)")
355352
(nest 2
356353
(break 1 ^^ string "?" ^^ space ^^ v2 ^^ break 1 ^^ string ":" ^^ space
357-
^^ string "0.0"))))
354+
^^ string "__float2bfloat16(0.0f)"))))
358355
| Satur01_gate, Byte_prec _ ->
359356
fun v1 v2 ->
360357
group

arrayjit/lib/metal_backend.ml

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ end) : Ir.Backend_impl.Lowered_backend = struct
163163
queue_desc / queue itself. *)
164164
let created_q = Me.CommandQueue.on_device_with_descriptor metal_device queue_desc in
165165
(* Store the log_entries_ref for later retrieval, associated with the stream_id which will
166-
be assigned by make_stream shortly. We\'ll add it after make_stream. *)
166+
be assigned by make_stream shortly. We'll add it after make_stream. *)
167167
(created_q, Some log_entries_ref))
168168
else (Me.CommandQueue.on_device metal_device, None)
169169
in
@@ -440,18 +440,25 @@ end) : Ir.Backend_impl.Lowered_backend = struct
440440
let extra_declarations = [ "using namespace metal;" ]
441441

442442
let typ_of_prec = function
443-
| Ops.Byte_prec _ -> "uint8_t"
444-
| Half_prec _ -> "half"
445-
| Single_prec _ -> "float"
446-
| Double_prec _ -> "double"
447-
| Void_prec -> "void"
443+
| Ops.Byte_prec _ -> "uchar"
444+
| Ops.Uint16_prec _ -> "ushort"
445+
| Ops.Int32_prec _ -> "int"
446+
| Ops.Half_prec _ -> "half"
447+
| Ops.Bfloat16_prec _ -> "bfloat" (* Metal supports bfloat16 natively *)
448+
| Ops.Fp8_prec _ -> invalid_arg "Metal backend does not support FP8 precision"
449+
| Ops.Single_prec _ -> "float"
450+
| Ops.Double_prec _ -> "double"
451+
| Ops.Void_prec -> "void"
448452

449453
let metal_prec_suffix_float = function
450-
(* Suffix for float literals like 0.0, 1.0 *)
454+
| Ops.Byte_prec _ -> ""
455+
| Ops.Uint16_prec _ -> ""
456+
| Ops.Int32_prec _ -> ""
451457
| Ops.Half_prec _ -> "h"
458+
| Ops.Bfloat16_prec _ -> "bf" (* TODO: Verify actual Metal suffix for bfloat16 *)
459+
| Ops.Fp8_prec _ -> invalid_arg "Metal backend does not support FP8 precision"
452460
| Ops.Single_prec _ -> "f"
453-
| Ops.Double_prec _ -> "" (* No suffix for double literals *)
454-
| Ops.Byte_prec _ -> ""
461+
| Ops.Double_prec _ -> ""
455462
| Ops.Void_prec -> ""
456463

457464
let ternop_syntax _prec op =
@@ -661,7 +668,7 @@ end) : Ir.Backend_impl.Lowered_backend = struct
661668

662669
let work () : unit =
663670
[%log3_result "Launching", func_name, "on", runner_label];
664-
(* Unlike CUDA, we don\'t use Utils.add_log_processor here. Logs are captured by the LogState
671+
(* Unlike CUDA, we don't use Utils.add_log_processor here. Logs are captured by the LogState
665672
handler installed on the CommandQueue. They will be processed by Utils.log_trace_tree in
666673
`await`. *)
667674
try

0 commit comments

Comments
 (0)