@@ -230,10 +230,9 @@ end) : Ir.Backend_impl.Lowered_backend = struct
230230 let % diagn2_sexp cuda_to_ptx ~name cu_src =
231231 let name_cu = name ^ " .cu" in
232232 if Utils. settings.output_debug_files_in_build_directory then (
233- let oc = Out_channel. open_text @@ Utils. build_file name_cu in
234- Stdio.Out_channel. output_string oc cu_src;
235- Stdio.Out_channel. flush oc;
236- Stdio.Out_channel. close oc);
233+ let build_file = Utils. open_build_file ~base_name: name ~extension: " .cu" in
234+ Stdio.Out_channel. output_string build_file.oc cu_src; (* Keep direct string output for source *)
235+ build_file.finalize () );
237236 [% log " compiling to PTX" ];
238237 let with_debug =
239238 Utils. settings.output_debug_files_in_build_directory || Utils. settings.log_level > 0
@@ -256,7 +255,7 @@ end) : Ir.Backend_impl.Lowered_backend = struct
256255 Stdio.Out_channel. close oc);
257256 ptx
258257
259- module C_syntax_config (Input : sig
258+ module Cuda_syntax_config (Input : sig
260259 val procs : Low_level .optimized array
261260 end ) =
262261 struct
@@ -283,8 +282,9 @@ end) : Ir.Backend_impl.Lowered_backend = struct
283282 | Void_prec -> " void"
284283
285284 let binop_syntax prec v =
286- let f op_str = C_syntax. binop_adapter (" (" , " " ^ op_str, " )" ) in
287- let func fn = C_syntax. binop_adapter (fn ^ " (" , " ," , " )" ) in
285+ let open PPrint in
286+ let f op_str v1 v2 = group (lparen ^^ v1 ^^ space ^^ string op_str ^^ space ^^ v2 ^^ rparen) in
287+ let func fn v1 v2 = group (string fn ^^ parens (separate comma_sep [ v1; v2 ])) in
288288 match (v, prec) with
289289 | Ops. Arg1 , _ -> invalid_arg " Cuda_backend.binop_syntax: Arg1 is not an operator"
290290 | Arg2 , _ -> invalid_arg " Cuda_backend.binop_syntax: Arg2 is not an operator"
@@ -302,31 +302,17 @@ end) : Ir.Backend_impl.Lowered_backend = struct
302302 | ToPowOf , Half_prec _ -> C_syntax. binop_adapter (" hexp2(hlog2(" , " )," , " )" )
303303 | ToPowOf , Byte_prec _ ->
304304 invalid_arg " Cuda_backend.binop_syntax: ToPowOf not supported for byte/integer precisions"
305- | Relu_gate , Byte_prec _ -> C_syntax. binop_adapter (" (" , " > 0 ?" , " : 0)" )
306- | Relu_gate , Half_prec _ ->
307- C_syntax. binop_adapter
308- ( " (__hgt(" ,
309- " , __ushort_as_half((unsigned short)0x0000U)) ?" ,
310- " : __ushort_as_half((unsigned short)0x0000U))" )
311- | Relu_gate , _ -> C_syntax. binop_adapter (" (" , " > 0.0 ?" , " : 0.0)" )
312- | Satur01_gate , Byte_prec _ ->
313- fun ppf pp1 v1 pp2 v2 ->
314- Stdlib.Format. fprintf ppf
315- " (((float)%a > 0.0f && (float)%a < 1.0f) ? %a : (unsigned char)0)" pp1 v1 pp1 v1 pp2
316- v2
317- | Satur01_gate , Half_prec _ ->
318- fun ppf pp1 v1 pp2 v2 ->
319- Stdlib.Format. fprintf ppf
320- " ((__hgt(%a, __ushort_as_half((unsigned short)0x0000U)) && __hlt(%a, \
321- __ushort_as_half((unsigned short)0x3C00U))) ? %a : __ushort_as_half((unsigned \
322- short)0x0000U))"
323- pp1 v1 pp1 v1 pp2 v2
324- | Satur01_gate , Single_prec _ ->
325- fun ppf pp1 v1 pp2 v2 ->
326- Stdlib.Format. fprintf ppf " ((%a > 0.0f && %a < 1.0f) ? %a : 0.0f)" pp1 v1 pp1 v1 pp2 v2
327- | Satur01_gate , Double_prec _ ->
328- fun ppf pp1 v1 pp2 v2 ->
329- Stdlib.Format. fprintf ppf " ((%a > 0.0 && %a < 1.0) ? %a : 0.0)" pp1 v1 pp1 v1 pp2 v2
305+ | Relu_gate , Byte_prec _ -> fun v1 v2 -> group (parens (v1 ^^ string " > 0" ) ^^ string " ? " ^^ v2 ^^ string " : 0" )
306+ | Relu_gate , Half_prec _ -> fun v1 v2 -> group (parens (string " __hgt(" ^^ v1 ^^ comma ^^ string " __ushort_as_half((unsigned short)0x0000U))" ) ^^ string " ? " ^^ v2 ^^ string " : __ushort_as_half((unsigned short)0x0000U)" )
307+ | Relu_gate , _ -> fun v1 v2 -> group (parens (v1 ^^ string " > 0.0" ) ^^ string " ? " ^^ v2 ^^ string " : 0.0" )
308+ | Satur01_gate , Byte_prec _ -> fun v1 v2 ->
309+ parens (parens (string " (float)" ^^ v1 ^^ string " > 0.0f && (float)" ^^ v1 ^^ string " < 1.0f" ) ^^ string " ? " ^^ v2 ^^ string " : (unsigned char)0" )
310+ | Satur01_gate , Half_prec _ -> fun v1 v2 ->
311+ parens (parens (string " __hgt(" ^^ v1 ^^ comma ^^ string " __ushort_as_half((unsigned short)0x0000U)) && __hlt(" ^^ v1 ^^ comma ^^ string " __ushort_as_half((unsigned short)0x3C00U)))" ) ^^ string " ? " ^^ v2 ^^ string " : __ushort_as_half((unsigned short)0x0000U)" )
312+ | Satur01_gate , Single_prec _ -> fun v1 v2 ->
313+ parens (parens (v1 ^^ string " > 0.0f && " ^^ v1 ^^ string " < 1.0f" ) ^^ string " ? " ^^ v2 ^^ string " : 0.0f" )
314+ | Satur01_gate , Double_prec _ -> fun v1 v2 ->
315+ parens (parens (v1 ^^ string " > 0.0 && " ^^ v1 ^^ string " < 1.0" ) ^^ string " ? " ^^ v2 ^^ string " : 0.0" )
330316 | Max , Byte_prec _ -> func " max"
331317 | Max , Half_prec _ -> func " __hmax"
332318 | Max , Double_prec _ -> func " fmax"
@@ -344,8 +330,9 @@ end) : Ir.Backend_impl.Lowered_backend = struct
344330 | And , _ -> f " &&"
345331
346332 let unop_syntax prec v =
347- let f prefix suffix = C_syntax. unop_adapter (prefix, suffix) in
348- let func fn = C_syntax. unop_adapter (fn ^ " (" , " )" ) in
333+ let open PPrint in
334+ let f prefix suffix expr = group (string prefix ^^ expr ^^ string suffix) in
335+ let func fn expr = group (string fn ^^ parens expr) in
349336 match (v, prec) with
350337 | Ops. Identity , _ -> f " " " "
351338 | Relu , Ops. Single_prec _ -> f " fmaxf(0.0, " " )"
@@ -392,22 +379,22 @@ end) : Ir.Backend_impl.Lowered_backend = struct
392379 | Recip_sqrt , Double_prec _ -> f " (1.0 / sqrt(" " ))"
393380 | Recip_sqrt , _ -> f " (1.0 / sqrtf(" " ))"
394381 | Neg , _ -> f " (-(" " ))"
395- | Tanh_approx , Byte_prec _ ->
396- invalid_arg
397- " Cuda_backend.unop_syntax: Tanh_approx not supported for byte/integer precisions"
382+ | Tanh_approx , Byte_prec _ -> invalid_arg " Cuda_backend.unop_syntax: Tanh_approx not supported for byte/integer precisions"
398383 | Tanh_approx , Half_prec _ -> func " htanh_approx"
399384 | Tanh_approx , Single_prec _ -> func " __tanhf"
400385 | Tanh_approx , _ -> func " tanh"
401386 | Not , _ -> f " (" " == 0.0 ? 1.0 : 0.0)"
402387
403388 let ternop_syntax prec v =
389+ let open PPrint in
404390 match (v, prec) with
405- | Ops. Where , _ -> C_syntax. ternop_adapter ( " ( " , " ?" , " :" , " ) " )
391+ | Ops. Where , _ -> fun v1 v2 v3 -> group (parens v1 ^^ string " ? " ^^ v2 ^^ string " : " ^^ v3 )
406392 | FMA , Ops. Half_prec _ -> C_syntax. ternop_adapter (" __hfma(" , " ," , " ," , " )" )
407393 | FMA , Ops. Single_prec _ -> C_syntax. ternop_adapter (" fmaf(" , " ," , " ," , " )" )
408394 | FMA , _ -> C_syntax. ternop_adapter (" fma(" , " ," , " ," , " )" )
409395
410396 let convert_precision ~from ~to_ =
397+ let open PPrint in
411398 match (from, to_) with
412399 | Ops. Double_prec _, Ops. Double_prec _
413400 | Single_prec _, Single_prec _
@@ -421,43 +408,49 @@ end) : Ir.Backend_impl.Lowered_backend = struct
421408 | _ -> (" (" ^ typ_of_prec to_ ^ " )(" , " )" )
422409 end
423410
424- let compile ~name bindings ({ Low_level. traced_store; _ } as lowered ) =
411+ let % diagn2_sexp compile ~name bindings ({ Low_level. traced_store; _ } as lowered) =
425412 (* TODO: The following link seems to claim it's better to expand into loops than use memset.
426413 https://stackoverflow.com/questions/23712558/how-do-i-best-initialize-a-local-memory-array-to-0 *)
427- let module Syntax = C_syntax. C_syntax (C_syntax_config (struct
414+ let module Syntax = C_syntax. C_syntax (Cuda_syntax_config (struct
428415 let procs = [| lowered |]
429416 end )) in
430417 let idx_params = Indexing. bound_symbols bindings in
431418 let b = Buffer. create 4096 in
432- let ppf = Stdlib.Format. formatter_of_buffer b in
433419 if Utils. debug_log_from_routines () then
434- Stdlib.Format. fprintf ppf " @,__device__ int printf (const char * format, ... );@," ;
435- Syntax. print_declarations ppf;
436- let params = Syntax. compile_proc ~name ppf idx_params lowered in
437- let ptx = cuda_to_ptx ~name @@ Buffer. contents b in
420+ Buffer. add_string b " __device__ int printf (const char * format, ... );\n " ;
421+ let declarations_doc = Syntax. print_declarations () in
422+ let params, proc_doc = Syntax. compile_proc ~name idx_params lowered in
423+ let final_doc = PPrint. (declarations_doc ^^ proc_doc) in
424+ PPrint.ToBuffer. pretty 1.0 110 b final_doc; (* Use ToBuffer *)
425+ let ptx = cuda_to_ptx ~name (Buffer. contents b) in
438426 { traced_store; ptx; params; bindings; name }
439427
440- let compile_batch ~names bindings lowereds =
441- let module Syntax = C_syntax. C_syntax (C_syntax_config (struct
428+ let % diagn2_sexp compile_batch ~names bindings lowereds =
429+ let module Syntax = C_syntax. C_syntax (Cuda_syntax_config (struct
442430 let procs = Array. filter_opt lowereds
443431 end )) in
444432 let idx_params = Indexing. bound_symbols bindings in
445433 let b = Buffer. create 4096 in
446- let ppf = Stdlib.Format. formatter_of_buffer b in
447- Syntax. print_declarations ppf;
448- let params_and_names =
434+ let declarations_doc = Syntax. print_declarations () in
435+ let params_and_docs =
449436 Array. map2_exn names lowereds
450437 ~f:
451438 (Option. map2 ~f: (fun name lowered ->
452- (Syntax. compile_proc ~name ppf idx_params lowered, name)))
439+ let params, doc = Syntax. compile_proc ~name idx_params lowered in
440+ ((params, name), doc)))
453441 in
442+ let all_proc_docs = List. filter_map (Array. to_list params_and_docs) ~f: (Option. map ~f: snd) in
443+ let final_doc = PPrint. (declarations_doc ^^ separate hardline all_proc_docs) in
444+ PPrint.ToBuffer. pretty 1.0 110 b final_doc;
445+
454446 let name : string =
455447 String. (
456448 strip ~drop: (equal_char '_' )
457449 @@ common_prefix (Array. to_list names |> List. concat_map ~f: Option. to_list))
458450 in
459- let ptx = cuda_to_ptx ~name @@ Buffer. contents b in
451+ let ptx = cuda_to_ptx ~name ( Buffer. contents b) in
460452 let traced_stores = Array. map lowereds ~f: (Option. map ~f: (fun l -> l.Low_level. traced_store)) in
453+ let params_and_names = Array. map params_and_docs ~f: (Option. map ~f: fst) in
461454 { traced_stores; ptx; params_and_names; bindings }
462455
463456 let get_global_run_id =
0 commit comments