@@ -164,40 +164,48 @@ struct
164164 let ternop_syntax prec op v1 v2 v3 =
165165 let op_prefix, op_infix1, op_infix2, op_suffix = Ops. ternop_c_syntax prec op in
166166 let open PPrint in
167- group
168- (string op_prefix ^^ v1 ^^ string op_infix1 ^^ space ^^ v2 ^^ string op_infix2 ^^ space ^^ v3
169- ^^ string op_suffix)
167+ group (string op_prefix ^^ v1 ^^ string op_infix1
168+ ^^ ifflat (space ^^ v2) (nest 2 (break 1 ^^ v2))
169+ ^^ string op_infix2
170+ ^^ ifflat (space ^^ v3) (nest 2 (break 1 ^^ v3))
171+ ^^ string op_suffix)
170172
171173 let binop_syntax prec op v1 v2 =
172174 match op with
173175 | Ops. Satur01_gate -> (
174176 match prec with
175177 | Ops. Byte_prec _ ->
176178 let open PPrint in
177- parens
178- (parens
179- (string " (float)" ^^ v1 ^^ string " > 0.0f && (float)" ^^ v1 ^^ string " < 1.0f" )
180- ^^ string " ? " ^^ v2 ^^ string " : (unsigned char)0" )
179+ group (parens
180+ (group (parens
181+ (string " (float)" ^^ v1 ^^ string " > 0.0f && (float)" ^^ v1 ^^ string " < 1.0f" ))
182+ ^^ ifflat (space ^^ string " ?" ^^ space ^^ v2 ^^ space ^^ string " :" ^^ space ^^ string " (unsigned char)0" )
183+ (nest 2 (break 1 ^^ string " ?" ^^ space ^^ v2 ^^ break 1 ^^ string " :" ^^ space ^^ string " (unsigned char)0" ))))
181184 | Ops. Half_prec _ ->
182185 let open PPrint in
183- parens
184- (parens (v1 ^^ string " > 0.0f16 && " ^^ v1 ^^ string " < 1.0f16" )
185- ^^ string " ? " ^^ v2 ^^ string " : 0.0f16" )
186+ group (parens
187+ (group (parens (v1 ^^ string " > 0.0f16 && " ^^ v1 ^^ string " < 1.0f16" ))
188+ ^^ ifflat (space ^^ string " ?" ^^ space ^^ v2 ^^ space ^^ string " :" ^^ space ^^ string " 0.0f16" )
189+ (nest 2 (break 1 ^^ string " ?" ^^ space ^^ v2 ^^ break 1 ^^ string " :" ^^ space ^^ string " 0.0f16" ))))
186190 | Ops. Single_prec _ ->
187191 let open PPrint in
188- parens
189- (parens (v1 ^^ string " > 0.0f && " ^^ v1 ^^ string " < 1.0f" )
190- ^^ string " ? " ^^ v2 ^^ string " : 0.0f" )
192+ group (parens
193+ (group (parens (v1 ^^ string " > 0.0f && " ^^ v1 ^^ string " < 1.0f" ))
194+ ^^ ifflat (space ^^ string " ?" ^^ space ^^ v2 ^^ space ^^ string " :" ^^ space ^^ string " 0.0f" )
195+ (nest 2 (break 1 ^^ string " ?" ^^ space ^^ v2 ^^ break 1 ^^ string " :" ^^ space ^^ string " 0.0f" ))))
191196 | Ops. Double_prec _ ->
192197 let open PPrint in
193- parens
194- (parens (v1 ^^ string " > 0.0 && " ^^ v1 ^^ string " < 1.0" )
195- ^^ string " ? " ^^ v2 ^^ string " : 0.0" )
198+ group (parens
199+ (group (parens (v1 ^^ string " > 0.0 && " ^^ v1 ^^ string " < 1.0" ))
200+ ^^ ifflat (space ^^ string " ?" ^^ space ^^ v2 ^^ space ^^ string " :" ^^ space ^^ string " 0.0" )
201+ (nest 2 (break 1 ^^ string " ?" ^^ space ^^ v2 ^^ break 1 ^^ string " :" ^^ space ^^ string " 0.0" ))))
196202 | Ops. Void_prec -> invalid_arg " Pure_C_config.binop_syntax: Satur01_gate on Void_prec" )
197203 | _ ->
198204 let op_prefix, op_infix, op_suffix = Ops. binop_c_syntax prec op in
199205 let open PPrint in
200- group (string op_prefix ^^ v1 ^^ string op_infix ^^ space ^^ v2 ^^ string op_suffix)
206+ group (string op_prefix ^^ v1 ^^ string op_infix
207+ ^^ ifflat (space ^^ v2) (nest 2 (break 1 ^^ v2))
208+ ^^ string op_suffix)
201209
202210 let unop_syntax prec op v =
203211 let op_prefix, op_suffix = Ops. unop_c_syntax prec op in
@@ -227,12 +235,11 @@ struct
227235 else res
228236 in
229237 log_file_check
230- ^^ group
231- (string " fprintf(log_file, "
232- ^^ dquotes (string base_message_literal)
233- ^^ (if List. is_empty args_docs then empty else comma ^^ space)
234- ^^ separate (PPrint. comma ^^ PPrint. space) args_docs
235- ^^ rparen ^^ semi)
238+ ^^ string " fprintf(log_file, "
239+ ^^ dquotes (string base_message_literal)
240+ ^^ (if List. is_empty args_docs then empty else comma ^^ space)
241+ ^^ separate (comma ^^ space) args_docs
242+ ^^ rparen ^^ semi
236243end
237244
238245module C_syntax (B : C_syntax_config ) = struct
@@ -267,7 +274,7 @@ module C_syntax (B : C_syntax_config) = struct
267274 let open PPrint in
268275 let includes = separate hardline (List. map B. includes ~f: pp_include) in
269276 let extras = separate hardline (List. map B. extra_declarations ~f: string ) in
270- group ( includes ^^ hardline ^^ hardline ^^ extras ^^ hardline ^^ hardline)
277+ includes ^^ hardline ^^ extras ^^ hardline
271278
272279 let rec pp_ll (c : Low_level.t ) : PPrint.document =
273280 let open PPrint in
@@ -279,12 +286,12 @@ module C_syntax (B : C_syntax_config) = struct
279286 (* Avoid extra hardlines if one side is empty *)
280287 if PPrint. is_empty d1 then d2
281288 else if PPrint. is_empty d2 then d1
282- else group ( d1 ^^ hardline ^^ d2)
289+ else d1 ^^ hardline ^^ d2
283290 | For_loop { index = i ; from_; to_; body; trace_it = _ } ->
284291 let header =
285292 string " for (int " ^^ pp_symbol i ^^ string " = " ^^ PPrint.OCaml. int from_ ^^ semi
286- ^^ space ^^ pp_symbol i ^^ string " <= " ^^ PPrint.OCaml. int to_ ^^ semi ^^ string " ++ "
287- ^^ pp_symbol i ^^ string " )"
293+ ^^ space ^^ pp_symbol i ^^ string " <= " ^^ PPrint.OCaml. int to_ ^^ semi
294+ ^^ space ^^ string " ++ " ^^ pp_symbol i ^^ string " )"
288295 in
289296 let body_doc = ref (pp_ll body) in
290297 (if Utils. debug_log_from_routines () then
@@ -296,7 +303,7 @@ module C_syntax (B : C_syntax_config) = struct
296303 ~args_docs: [ pp_symbol i ]
297304 in
298305 body_doc := log_doc ^^ hardline ^^ ! body_doc);
299- surround 2 1 (header ^^ space ^^ lbrace) ! body_doc rbrace
306+ group (header ^^ space ^^ lbrace ^^ nest 2 (hardline ^^ ! body_doc) ^^ hardline ^^ rbrace)
300307 | Zero_out tn ->
301308 pp_ll
302309 (Low_level. loop_over_dims (Lazy. force tn.dims) ~body: (fun idcs ->
@@ -307,11 +314,14 @@ module C_syntax (B : C_syntax_config) = struct
307314 let prec = Lazy. force tn.prec in
308315 let local_defs, val_doc = pp_float prec llv in
309316 let offset_doc = pp_array_offset (idcs, dims) in
310- let assignment = ident_doc ^^ brackets offset_doc ^^ string " = " ^^ val_doc ^^ semi in
317+ let assignment =
318+ group (ident_doc ^^ brackets offset_doc ^^ string " ="
319+ ^^ ifflat (space ^^ val_doc) (nest 4 (hardline ^^ val_doc)) ^^ semi)
320+ in
311321 if Utils. debug_log_from_routines () then
312322 let num_typ = string (B. typ_of_prec prec) in
313323 let new_var = string " new_set_v" in
314- let decl = group ( num_typ ^^ space ^^ new_var ^^ string " = " ^^ val_doc ^^ semi) in
324+ let decl = num_typ ^^ space ^^ new_var ^^ string " = " ^^ val_doc ^^ semi in
315325 let debug_val_doc, debug_args_docs = debug_float prec llv in
316326 let debug_val_str = doc_to_string debug_val_doc in
317327 let pp_args_docs =
@@ -340,14 +350,19 @@ module C_syntax (B : C_syntax_config) = struct
340350 ~base_message_literal: value_base_msg ~args_docs: log_args_for_printf
341351 in
342352 let flush_log =
343- if B. log_involves_file_management then string " fflush(log_file);" ^^ semi else empty
353+ if B. log_involves_file_management then string " fflush(log_file);" else empty
344354 in
345355 comment_log ^^ hardline ^^ value_log ^^ hardline ^^ flush_log
346356 in
347357 let assignment' = ident_doc ^^ brackets offset_doc ^^ string " = " ^^ new_var ^^ semi in
348- let block_content = decl ^^ hardline ^^ log_doc ^^ hardline ^^ assignment' in
349- surround 2 1 lbrace (local_defs ^^ hardline ^^ block_content) rbrace
350- else local_defs ^^ (if PPrint. is_empty local_defs then empty else hardline) ^^ assignment
358+ let block_content =
359+ if PPrint. is_empty local_defs then decl ^^ hardline ^^ log_doc ^^ hardline ^^ assignment'
360+ else local_defs ^^ hardline ^^ decl ^^ hardline ^^ log_doc ^^ hardline ^^ assignment'
361+ in
362+ lbrace ^^ nest 2 (hardline ^^ block_content) ^^ hardline ^^ rbrace
363+ else
364+ if PPrint. is_empty local_defs then assignment
365+ else local_defs ^^ hardline ^^ assignment
351366 | Comment message ->
352367 if Utils. debug_log_from_routines () then
353368 let base_message = " COMMENT: " ^ message ^ " \n " in
@@ -359,22 +374,19 @@ module C_syntax (B : C_syntax_config) = struct
359374 callback ()
360375 | Set_local ({ scope_id; tn = { prec; _ } } , value ) ->
361376 let local_defs, value_doc = pp_float (Lazy. force prec) value in
362- let assignment =
363- group (string (" v" ^ Int. to_string scope_id) ^^ string " = " ^^ value_doc ^^ semi)
364- in
365- local_defs ^^ (if PPrint. is_empty local_defs then empty else hardline) ^^ assignment
377+ let assignment = string (" v" ^ Int. to_string scope_id) ^^ string " = " ^^ value_doc ^^ semi in
378+ if PPrint. is_empty local_defs then assignment
379+ else local_defs ^^ hardline ^^ assignment
366380
367381 and pp_float (prec : Ops.prec ) (vcomp : Low_level.float_t ) : PPrint.document * PPrint.document =
368382 (* Returns (local definitions, value expression) *)
369383 let open PPrint in
370384 match vcomp with
371385 | Local_scope { id = { scope_id; tn = { prec = scope_prec ; _ } } ; body; orig_indices = _ } ->
372386 let num_typ = string (B. typ_of_prec @@ Lazy. force scope_prec) in
373- let decl =
374- group (num_typ ^^ space ^^ string (" v" ^ Int. to_string scope_id) ^^ string " = 0" ^^ semi)
375- in
387+ let decl = num_typ ^^ space ^^ string (" v" ^ Int. to_string scope_id) ^^ string " = 0" ^^ semi in
376388 let body_doc = pp_ll body in
377- let defs = group ( decl ^^ hardline ^^ body_doc) in
389+ let defs = decl ^^ hardline ^^ body_doc in
378390 let prefix, postfix = B. convert_precision ~from: (Lazy. force scope_prec) ~to_: prec in
379391 let expr = string prefix ^^ string (" v" ^ Int. to_string scope_id) ^^ string postfix in
380392 (defs, expr)
@@ -388,23 +400,20 @@ module C_syntax (B : C_syntax_config) = struct
388400 let from_prec = Lazy. force tn.prec in
389401 let prefix, postfix = B. convert_precision ~from: from_prec ~to_: prec in
390402 let offset_doc = pp_array_offset (idcs, Lazy. force tn.dims) in
391- let expr =
392- group (string prefix ^^ string " merge_buffer" ^^ brackets offset_doc ^^ string postfix)
393- in
403+ let expr = string prefix ^^ string " merge_buffer" ^^ brackets offset_doc ^^ string postfix in
394404 (empty, expr)
395405 | Get_global _ -> failwith " C_syntax: Get_global / FFI NOT IMPLEMENTED YET"
396406 | Get (tn , idcs ) ->
397407 let ident_doc = string (get_ident tn) in
398408 let from_prec = Lazy. force tn.prec in
399409 let prefix, postfix = B. convert_precision ~from: from_prec ~to_: prec in
400410 let offset_doc = pp_array_offset (idcs, Lazy. force tn.dims) in
401- let expr = group ( string prefix ^^ ident_doc ^^ brackets offset_doc ^^ string postfix) in
411+ let expr = string prefix ^^ ident_doc ^^ brackets offset_doc ^^ string postfix in
402412 (empty, expr)
403413 | Constant c ->
404414 let from_prec = Ops. double in
405415 let prefix, postfix = B. convert_precision ~from: from_prec ~to_: prec in
406416 let c_str = Printf. sprintf " %.16g" c in
407- (* Use Printf for float formatting *)
408417 let expr =
409418 if String. is_empty prefix && Float. (c < 0.0 ) then
410419 string " (" ^^ string c_str ^^ string " )" ^^ string postfix
@@ -422,24 +431,24 @@ module C_syntax (B : C_syntax_config) = struct
422431 let d1, e1 = pp_float prec v1 in
423432 let d2, e2 = pp_float prec v2 in
424433 let d3, e3 = pp_float prec v3 in
425- let defs =
426- d1
427- ^^ (if PPrint. is_empty d1 then empty else hardline)
428- ^^ d2
429- ^^ (if PPrint. is_empty d2 then empty else hardline)
430- ^^ d3
434+ let defs =
435+ List. filter_map [ d1; d2; d3 ] ~f: (fun d -> if PPrint. is_empty d then None else Some d)
436+ |> separate hardline
431437 in
432- let expr = B. ternop_syntax prec op e1 e2 e3 in
438+ let expr = group ( B. ternop_syntax prec op e1 e2 e3) in
433439 (defs, expr)
434440 | Binop (op , v1 , v2 ) ->
435441 let d1, e1 = pp_float prec v1 in
436442 let d2, e2 = pp_float prec v2 in
437- let defs = d1 ^^ (if PPrint. is_empty d1 then empty else hardline) ^^ d2 in
438- let expr = B. binop_syntax prec op e1 e2 in
443+ let defs =
444+ List. filter_map [ d1; d2 ] ~f: (fun d -> if PPrint. is_empty d then None else Some d)
445+ |> separate hardline
446+ in
447+ let expr = group (B. binop_syntax prec op e1 e2) in
439448 (defs, expr)
440449 | Unop (op , v ) ->
441450 let defs, expr_v = pp_float prec v in
442- let expr = B. unop_syntax prec op expr_v in
451+ let expr = group ( B. unop_syntax prec op expr_v) in
443452 (defs, expr)
444453
445454 and debug_float (prec : Ops.prec ) (value : Low_level.float_t ) :
@@ -463,33 +472,27 @@ module C_syntax (B : C_syntax_config) = struct
463472 let dims = Lazy. force tn.dims in
464473 let prefix, postfix = B. convert_precision ~from: from_prec ~to_: prec in
465474 let offset_doc = pp_array_offset (idcs, dims) in
466- let access_doc =
467- group (string prefix ^^ string " merge_buffer" ^^ brackets offset_doc ^^ string postfix)
468- in
475+ let access_doc = string prefix ^^ string " merge_buffer" ^^ brackets offset_doc ^^ string postfix in
469476 let expr_doc =
470- group
471- (string prefix ^^ string " merge_buffer"
472- ^^ brackets (string " %u" )
473- ^^ string postfix
474- ^^ braces (string (" =" ^ B. float_log_style)))
477+ string prefix ^^ string " merge_buffer"
478+ ^^ brackets (string " %u" )
479+ ^^ string postfix
480+ ^^ braces (string (" =" ^ B. float_log_style))
475481 in
476482 (expr_doc, [ `Accessor (idcs, dims); `Value access_doc ])
477- | Get_global _ -> failwith " Exec_as_cuda : Get_global / FFI NOT IMPLEMENTED YET"
483+ | Get_global _ -> failwith " C_syntax : Get_global / FFI NOT IMPLEMENTED YET"
478484 | Get (tn , idcs ) ->
479485 let ident_doc = string (get_ident tn) in
480486 let from_prec = Lazy. force tn.prec in
481487 let dims = Lazy. force tn.dims in
482488 let prefix, postfix = B. convert_precision ~from: from_prec ~to_: prec in
483489 let offset_doc = pp_array_offset (idcs, dims) in
484- let access_doc =
485- group (string prefix ^^ ident_doc ^^ brackets offset_doc ^^ string postfix)
486- in
490+ let access_doc = string prefix ^^ ident_doc ^^ brackets offset_doc ^^ string postfix in
487491 let expr_doc =
488- group
489- (string prefix ^^ ident_doc
490- ^^ brackets (string " %u" )
491- ^^ string postfix
492- ^^ braces (string (" =" ^ B. float_log_style)))
492+ string prefix ^^ ident_doc
493+ ^^ brackets (string " %u" )
494+ ^^ string postfix
495+ ^^ braces (string (" =" ^ B. float_log_style))
493496 in
494497 (expr_doc, [ `Accessor (idcs, dims); `Value access_doc ])
495498 | Constant c ->
@@ -563,9 +566,8 @@ module C_syntax (B : C_syntax_config) = struct
563566 @ List. map B. extra_args ~f: string
564567 in
565568 let func_header =
566- group
567- (string B. main_kernel_prefix ^^ space ^^ string " void" ^^ space ^^ string name
568- ^^ parens (separate comma_sep args_docs))
569+ string B. main_kernel_prefix ^^ space ^^ string " void" ^^ space ^^ string name
570+ ^^ nest 4 (lparen ^^ hardline ^^ separate (comma ^^ hardline) args_docs ^^ rparen)
569571 in
570572 let body = ref empty in
571573 if not (String. is_empty B. kernel_prep_line) then
@@ -579,12 +581,12 @@ module C_syntax (B : C_syntax_config) = struct
579581 in
580582 body :=
581583 ! body ^^ string " FILE* log_file = NULL;" ^^ hardline
582- ^^ group
583- ( string ( " if ( " ^ log_file_var_name ^ " ) " )
584- ^^ space
585- ^^ braces (nest 2 ( string ( " log_file = fopen( " ^ log_file_var_name ^ " , \" w \" ); " ))))
586- ^^ hardline
587- else body := ! body ^^ hardline;
584+ ^^ string ( " if ( " ^ log_file_var_name ^ " ) " )
585+ ^^ lbrace ^^ nest 2 (hardline
586+ ^^ string ( " log_file = fopen( " ^ log_file_var_name ^ " , \" w \" ); " ))
587+ ^^ hardline ^^ rbrace ^^ hardline
588+ else
589+ body := ! body ^^ hardline;
588590
589591 (if Utils. debug_log_from_routines () then
590592 let debug_init_doc =
@@ -625,14 +627,14 @@ module C_syntax (B : C_syntax_config) = struct
625627 let local_decls =
626628 string " /* Local declarations and initialization. */"
627629 ^^ hardline
628- ^^ separate_map hardline
630+ ^^ separate_map empty
629631 (fun (tn , node ) ->
630632 if not (Tn. is_virtual_force tn 333 || Tn. is_materialized_force tn 336 ) then
631633 let typ_doc = string (B. typ_of_prec @@ Lazy. force tn.prec) in
632634 let ident_doc = string (get_ident tn) in
633635 let size_doc = OCaml. int (Tn. num_elems tn) in
634636 let init_doc = if node.Low_level. zero_initialized then string " = {0}" else empty in
635- group ( typ_doc ^^ space ^^ ident_doc ^^ brackets size_doc ^^ init_doc ^^ semi)
637+ typ_doc ^^ space ^^ ident_doc ^^ brackets size_doc ^^ init_doc ^^ semi ^^ hardline
636638 else empty)
637639 (Hashtbl. to_alist traced_store)
638640 in
@@ -647,6 +649,6 @@ module C_syntax (B : C_syntax_config) = struct
647649 ^^ string " if (log_file) { fclose(log_file); log_file = NULL; }"
648650 ^^ hardline;
649651
650- let func_doc = surround 2 1 ( func_header ^^ space ^^ lbrace) ! body rbrace in
652+ let func_doc = func_header ^^ space ^^ lbrace ^^ nest 2 (hardline ^^ ! body) ^^ hardline ^^ rbrace in
651653 (sorted_params, func_doc)
652654end
0 commit comments