Skip to content

Commit 4acb6df

Browse files
committed
Formatting
1 parent 148d7ef commit 4acb6df

File tree

6 files changed

+165
-92
lines changed

6 files changed

+165
-92
lines changed

arrayjit/lib/c_syntax.ml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -312,9 +312,10 @@ module C_syntax (B : C_syntax_config) = struct
312312
| For_loop { index = i; from_; to_; body; trace_it = _ } ->
313313
let header =
314314
let idx_type = if Utils.settings.big_models then "uint64_t " else "uint32_t " in
315-
string ("for (" ^ idx_type) ^^ pp_symbol i ^^ string " = " ^^ PPrint.OCaml.int from_ ^^ semi
316-
^^ space ^^ pp_symbol i ^^ string " <= " ^^ PPrint.OCaml.int to_ ^^ semi ^^ space
317-
^^ string "++" ^^ pp_symbol i ^^ string ")"
315+
string ("for (" ^ idx_type)
316+
^^ pp_symbol i ^^ string " = " ^^ PPrint.OCaml.int from_ ^^ semi ^^ space ^^ pp_symbol i
317+
^^ string " <= " ^^ PPrint.OCaml.int to_ ^^ semi ^^ space ^^ string "++" ^^ pp_symbol i
318+
^^ string ")"
318319
in
319320
let body_doc = ref (pp_ll body) in
320321
(if Utils.debug_log_from_routines () then

arrayjit/lib/metal_backend.ml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,9 @@ end) : Ir.Backend_impl.Lowered_backend = struct
433433
let main_kernel_prefix = "kernel"
434434
let buffer_prefix = "device "
435435
let buffer_suffix = fun ~pos -> " [[buffer(" ^ Int.to_string pos ^ ")]]"
436-
let arg_int_prefix = if Utils.settings.big_models then "const uint64_t& " else "const uint32_t& "
436+
437+
let arg_int_prefix =
438+
if Utils.settings.big_models then "const uint64_t& " else "const uint32_t& "
437439

438440
let extra_args =
439441
[

arrayjit/lib/ops.ml

Lines changed: 92 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@ type ('ocaml, 'impl) precision =
1717
| Byte : (char, uint8_elt) precision
1818
| Uint16 : (int, uint16_elt) precision
1919
| Int32 : (int32, int32_elt) precision
20-
| Uint32 : (int32, int32_elt) precision (** Using int32_elt representation but treating as unsigned *)
20+
| Uint32 : (int32, int32_elt) precision
21+
(** Using int32_elt representation but treating as unsigned *)
2122
| Int64 : (int64, int64_elt) precision
22-
| Uint64 : (int64, int64_elt) precision (** Using int64_elt representation but treating as unsigned *)
23+
| Uint64 : (int64, int64_elt) precision
24+
(** Using int64_elt representation but treating as unsigned *)
2325
| Uint4x32 : (Stdlib.Complex.t, Bigarray.complex64_elt) precision
2426
(** A 128-bit value that corresponds to e.g. CUDA's uint4 type. Luckily, the OCaml Bigarray
2527
library supports complex64_elt which is a 128-bit value, so we avoid dims conversions. *)
@@ -563,21 +565,32 @@ let binop_c_syntax prec v =
563565
| Mul, _ -> ("(", " *", ")")
564566
| Div, _ -> ("(", " /", ")")
565567
| ToPowOf, Double_prec _ -> ("pow(", ",", ")")
566-
| ToPowOf, (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _) ->
568+
| ( ToPowOf,
569+
( Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _
570+
| Fp8_prec _ ) ) ->
567571
invalid_arg "Ops.binop_c_syntax: ToPowOf not supported for integer precisions"
568572
| ToPowOf, _ -> ("powf(", ",", ")")
569-
| Relu_gate, (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _) -> ("(", " > 0 ?", " : 0)")
573+
| ( Relu_gate,
574+
( Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _
575+
| Fp8_prec _ ) ) ->
576+
("(", " > 0 ?", " : 0)")
570577
| Relu_gate, _ -> ("(", " > 0.0 ?", " : 0.0)")
571-
| Satur01_gate, (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _) ->
578+
| ( Satur01_gate,
579+
( Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _
580+
| Fp8_prec _ ) ) ->
572581
("(abs(", " ) > 0 ? 0 : (", "))")
573582
| Satur01_gate, Single_prec _ ->
574583
(* This disagrees at 0 with the semantics. *)
575584
("(fabsf(floorf(", ")) > 0.0 ? 0.0 : (", "))")
576585
| Satur01_gate, _ -> ("(fabs(floor(", ")) > 0.0 ? 0.0 : (", "))")
577-
| Max, (Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _) ->
586+
| ( Max,
587+
( Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _
588+
| Uint64_prec _ | Fp8_prec _ ) ) ->
578589
("fmax(", ",", ")")
579590
| Max, _ -> ("fmaxf(", ",", ")")
580-
| Min, (Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _) ->
591+
| ( Min,
592+
( Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _
593+
| Uint64_prec _ | Fp8_prec _ ) ) ->
581594
("fmin(", ",", ")")
582595
| Min, _ -> ("fminf(", ",", ")")
583596
| Mod, _ -> ("(", " %", ")")
@@ -654,43 +667,80 @@ let unop_c_syntax prec op =
654667
let fmax () =
655668
(* See: https://en.cppreference.com/w/c/numeric/math/fmax option (4) *)
656669
match prec with
657-
| Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _ -> "fmax"
670+
| Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _
671+
| Uint64_prec _ | Fp8_prec _ ->
672+
"fmax"
658673
| _ -> "fmaxf"
659674
in
660675
let fmin () =
661676
match prec with
662-
| Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _ -> "fmin"
677+
| Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _
678+
| Uint64_prec _ | Fp8_prec _ ->
679+
"fmin"
663680
| _ -> "fminf"
664681
in
665682
match (op, prec) with
666683
| Identity, _ -> ("", "")
667-
| Relu, (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _) -> ("fmax(0, ", ")")
684+
| ( Relu,
685+
( Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _
686+
| Fp8_prec _ ) ) ->
687+
("fmax(0, ", ")")
668688
| Relu, _ -> (fmax () ^ "(0.0, ", ")")
669-
| Satur01, (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _) -> ("fmax(0, fmin(1, ", "))")
689+
| ( Satur01,
690+
( Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _
691+
| Fp8_prec _ ) ) ->
692+
("fmax(0, fmin(1, ", "))")
670693
| Satur01, _ -> (fmax () ^ "(0.0, " ^ fmin () ^ "(1.0, ", "))")
671-
| Exp, (Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _) -> ("exp(", ")")
694+
| ( Exp,
695+
( Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _
696+
| Uint64_prec _ | Fp8_prec _ ) ) ->
697+
("exp(", ")")
672698
| Exp, _ -> ("expf(", ")")
673-
| Log, (Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _) -> ("log(", ")")
699+
| ( Log,
700+
( Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _
701+
| Uint64_prec _ | Fp8_prec _ ) ) ->
702+
("log(", ")")
674703
| Log, _ -> ("logf(", ")")
675-
| Exp2, (Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _) -> ("exp2(", ")")
704+
| ( Exp2,
705+
( Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _
706+
| Uint64_prec _ | Fp8_prec _ ) ) ->
707+
("exp2(", ")")
676708
| Exp2, _ -> ("exp2f(", ")")
677-
| Log2, (Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _) -> ("log2(", ")")
709+
| ( Log2,
710+
( Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _
711+
| Uint64_prec _ | Fp8_prec _ ) ) ->
712+
("log2(", ")")
678713
| Log2, _ -> ("log2f(", ")")
679-
| Sin, (Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _) -> ("sin(", ")")
714+
| ( Sin,
715+
( Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _
716+
| Uint64_prec _ | Fp8_prec _ ) ) ->
717+
("sin(", ")")
680718
| Sin, _ -> ("sinf(", ")")
681-
| Cos, (Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _) -> ("cos(", ")")
719+
| ( Cos,
720+
( Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _
721+
| Uint64_prec _ | Fp8_prec _ ) ) ->
722+
("cos(", ")")
682723
| Cos, _ -> ("cosf(", ")")
683-
| Sqrt, (Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _) -> ("sqrt(", ")")
724+
| ( Sqrt,
725+
( Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _
726+
| Uint64_prec _ | Fp8_prec _ ) ) ->
727+
("sqrt(", ")")
684728
| Sqrt, _ -> ("sqrtf(", ")")
685-
| Recip, (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _) ->
729+
| ( Recip,
730+
( Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _
731+
| Fp8_prec _ ) ) ->
686732
invalid_arg "Ops.unop_c_syntax: Recip not supported for integer precisions"
687733
| Recip, _ -> ("(1.0 / (", "))")
688-
| Recip_sqrt, (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _) ->
734+
| ( Recip_sqrt,
735+
( Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _
736+
| Fp8_prec _ ) ) ->
689737
invalid_arg "Ops.unop_c_syntax: Recip_sqrt not supported for integer precisions"
690738
| Recip_sqrt, Double_prec _ -> ("(1.0 / sqrt(", "))")
691739
| Recip_sqrt, _ -> ("(1.0 / sqrtf(", "))")
692740
| Neg, _ -> ("(-(", "))")
693-
| Tanh_approx, (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _) ->
741+
| ( Tanh_approx,
742+
( Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _
743+
| Fp8_prec _ ) ) ->
694744
invalid_arg "Ops.unop_c_syntax: Tanh_approx not supported for integer precisions"
695745
| Tanh_approx, _ -> ("tanhf(", ")")
696746
| Not, _ -> ("(", " == 0.0 ? 1.0 : 0.0)")
@@ -709,10 +759,14 @@ let ternop_cd_syntax = function Where -> "where" | FMA -> "fma"
709759

710760
let ternop_c_syntax prec op =
711761
match (op, prec) with
712-
| Where, (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _) ->
762+
| ( Where,
763+
( Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _
764+
| Fp8_prec _ ) ) ->
713765
("((", ") != 0 ? (", ") : (", "))")
714766
| Where, _ -> ("((", ") != 0.0 ? (", ") : (", "))")
715-
| FMA, (Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _) ->
767+
| ( FMA,
768+
( Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _
769+
| Uint64_prec _ | Fp8_prec _ ) ) ->
716770
("fma(", ",", ",", ")")
717771
| FMA, _ -> ("fmaf(", ",", ",", ")")
718772

@@ -745,16 +799,22 @@ let c_convert_precision ~from ~to_ =
745799
(* Conversions involving BFloat16 and other types *)
746800
| Bfloat16_prec _, Half_prec _ -> ("FLOAT_TO_HALF(bfloat16_to_single(", "))")
747801
| Half_prec _, Bfloat16_prec _ -> ("single_to_bfloat16(HALF_TO_FLOAT(", "))")
748-
| Bfloat16_prec _, (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _) ->
802+
| ( Bfloat16_prec _,
803+
(Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _) )
804+
->
749805
("(" ^ c_typ_of_prec to_ ^ ")bfloat16_to_single(", ")")
750-
| (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _), Bfloat16_prec _ ->
806+
| ( (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _),
807+
Bfloat16_prec _ ) ->
751808
("single_to_bfloat16((float)", ")")
752809
(* Conversions involving FP8 and other types *)
753810
| Fp8_prec _, Half_prec _ -> ("FLOAT_TO_HALF(fp8_to_single(", "))")
754811
| Half_prec _, Fp8_prec _ -> ("single_to_fp8(HALF_TO_FLOAT(", "))")
755-
| Fp8_prec _, (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _) ->
812+
| ( Fp8_prec _,
813+
(Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _) )
814+
->
756815
("(" ^ c_typ_of_prec to_ ^ ")fp8_to_single(", ")")
757-
| (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _), Fp8_prec _ ->
816+
| ( (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _),
817+
Fp8_prec _ ) ->
758818
("single_to_fp8((float)", ")")
759819
(* BFloat16 <-> FP8 conversions *)
760820
| Bfloat16_prec _, Fp8_prec _ -> ("single_to_fp8(bfloat16_to_single(", "))")
@@ -764,9 +824,12 @@ let c_convert_precision ~from ~to_ =
764824
| Single_prec _, Half_prec _ -> ("FLOAT_TO_HALF(", ")")
765825
| Half_prec _, Double_prec _ -> ("(double)HALF_TO_FLOAT(", ")")
766826
| Double_prec _, Half_prec _ -> ("FLOAT_TO_HALF((float)", ")")
767-
| Half_prec _, (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _) ->
827+
| ( Half_prec _,
828+
(Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _) )
829+
->
768830
("(" ^ c_typ_of_prec to_ ^ ")HALF_TO_FLOAT(", ")")
769-
| (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _), Half_prec _ ->
831+
| ( (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _),
832+
Half_prec _ ) ->
770833
("FLOAT_TO_HALF((float)", ")")
771834
(* Uint4x32 conversions - special handling *)
772835
| Uint4x32_prec _, _ -> ("uint4x32_to_" ^ prec_string to_ ^ "(", ")")

lib/ppx_op.ml

Lines changed: 62 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@ open Ppx_arrayjit.Ppx_helper
44
open Ppx_shared
55

66
let make_p ~opt_label ~loc ?value ?values ?param_init ~extra_args name =
7-
let more_label = match opt_label with
7+
let more_label =
8+
match opt_label with
89
| Some (_label_name, label_pat) -> [%expr Some [%e pat2expr label_pat]]
9-
| None -> [%expr None] in
10+
| None -> [%expr None]
11+
in
1012
let value = match value with Some c -> [%expr Some [%e c]] | None -> [%expr None] in
1113
let values = match values with Some c -> [%expr Some [%e c]] | None -> [%expr None] in
1214
let param_init =
@@ -264,7 +266,7 @@ let rec translate ~num_configs ~is_toplevel ~opt_label ?label expr =
264266
(* Check if there's a unit parameter or a labeled parameter with label "label" *)
265267
let rec find_unit_pos idx = function
266268
| [] -> None
267-
| { pparam_desc = Pparam_val (Nolabel, _, pat); _ } :: _
269+
| { pparam_desc = Pparam_val (Nolabel, _, pat); _ } :: _
268270
when match pat.ppat_desc with
269271
| Ppat_construct ({ txt = Lident "()"; _ }, None) -> true
270272
| _ -> false ->
@@ -280,9 +282,11 @@ let rec translate ~num_configs ~is_toplevel ~opt_label ?label expr =
280282
| Some unit_idx ->
281283
(* Split args at unit parameter *)
282284
let before_unit, unit_and_after = List.split_n args unit_idx in
283-
let unit_param, after_unit = match unit_and_after with
285+
let unit_param, after_unit =
286+
match unit_and_after with
284287
| unit :: rest -> (unit, rest)
285-
| [] -> failwith "Internal error: unit_and_after should not be empty" in
288+
| [] -> failwith "Internal error: unit_and_after should not be empty"
289+
in
286290
let opt_label = find_label_param before_unit in
287291
let vbs, inner_body =
288292
translate ~num_configs ~is_toplevel:false ~opt_label ?label
@@ -293,54 +297,61 @@ let rec translate ~num_configs ~is_toplevel ~opt_label ?label expr =
293297
let new_body = inner_body in
294298
( no_vbs,
295299
if List.is_empty before_unit then
296-
{ expr with pexp_desc = Pexp_function ([unit_param], constr, Pfunction_body new_body) }
300+
{
301+
expr with
302+
pexp_desc = Pexp_function ([ unit_param ], constr, Pfunction_body new_body);
303+
}
297304
else
298-
{ expr with pexp_desc = Pexp_function (before_unit @ [unit_param], constr, Pfunction_body new_body) } )
305+
{
306+
expr with
307+
pexp_desc =
308+
Pexp_function (before_unit @ [ unit_param ], constr, Pfunction_body new_body);
309+
} )
299310
| None ->
300-
(* No unit parameter, normal processing *)
301-
let labels =
302-
Option.to_list label
303-
@ List.filter_map args ~f:(function
304-
| { pparam_desc = Pparam_val (_, _, pat); _ } ->
305-
let loc = pat.ppat_loc in
306-
Some [%expr [%e pat2expr pat].Tensor.value.Ir.Tnode.label]
307-
| _ -> None)
308-
in
309-
let label_locs = List.map labels ~f:(fun label -> label.pexp_loc) in
310-
let label_starts = List.map label_locs ~f:(fun l -> l.loc_start) in
311-
let label_ends = List.map label_locs ~f:(fun l -> l.loc_end) in
312-
let label_loc =
313-
if List.is_empty labels then loc
314-
else
315-
Location.
316-
{
317-
loc_start = List.reduce_exn label_starts ~f:min_pos;
318-
loc_end = List.reduce_exn label_ends ~f:max_pos;
319-
loc_ghost = false;
320-
}
321-
in
322-
let label =
323-
let loc = label_loc in
324-
[%expr List.concat [%e Ast_builder.Default.elist ~loc labels]]
325-
in
326-
let vbs, body =
327-
match body with
328-
| Pfunction_body body ->
329-
let vbs, body = loop ~label body in
330-
(vbs, Pfunction_body body)
331-
| Pfunction_cases (cases, loc, attrs) ->
332-
let vbs, cases =
333-
List.unzip
334-
@@ List.map cases ~f:(fun ({ pc_rhs; _ } as c) ->
335-
let vbs, pc_rhs = loop ~label pc_rhs in
336-
(vbs, { c with pc_rhs }))
337-
in
338-
( List.fold vbs
339-
~init:(Map.empty (module String))
340-
~f:(fun acc vbs -> Map.merge_disjoint_exn acc vbs),
341-
Pfunction_cases (cases, loc, attrs) )
342-
in
343-
(vbs, { expr with pexp_desc = Pexp_function (args, constr, body) }) )
311+
(* No unit parameter, normal processing *)
312+
let labels =
313+
Option.to_list label
314+
@ List.filter_map args ~f:(function
315+
| { pparam_desc = Pparam_val (_, _, pat); _ } ->
316+
let loc = pat.ppat_loc in
317+
Some [%expr [%e pat2expr pat].Tensor.value.Ir.Tnode.label]
318+
| _ -> None)
319+
in
320+
let label_locs = List.map labels ~f:(fun label -> label.pexp_loc) in
321+
let label_starts = List.map label_locs ~f:(fun l -> l.loc_start) in
322+
let label_ends = List.map label_locs ~f:(fun l -> l.loc_end) in
323+
let label_loc =
324+
if List.is_empty labels then loc
325+
else
326+
Location.
327+
{
328+
loc_start = List.reduce_exn label_starts ~f:min_pos;
329+
loc_end = List.reduce_exn label_ends ~f:max_pos;
330+
loc_ghost = false;
331+
}
332+
in
333+
let label =
334+
let loc = label_loc in
335+
[%expr List.concat [%e Ast_builder.Default.elist ~loc labels]]
336+
in
337+
let vbs, body =
338+
match body with
339+
| Pfunction_body body ->
340+
let vbs, body = loop ~label body in
341+
(vbs, Pfunction_body body)
342+
| Pfunction_cases (cases, loc, attrs) ->
343+
let vbs, cases =
344+
List.unzip
345+
@@ List.map cases ~f:(fun ({ pc_rhs; _ } as c) ->
346+
let vbs, pc_rhs = loop ~label pc_rhs in
347+
(vbs, { c with pc_rhs }))
348+
in
349+
( List.fold vbs
350+
~init:(Map.empty (module String))
351+
~f:(fun acc vbs -> Map.merge_disjoint_exn acc vbs),
352+
Pfunction_cases (cases, loc, attrs) )
353+
in
354+
(vbs, { expr with pexp_desc = Pexp_function (args, constr, body) }))
344355
| { pexp_desc = Pexp_function (args, constr, body); _ } ->
345356
let vbs, body =
346357
match body with

0 commit comments

Comments
 (0)