Skip to content

Commit 16b612b

Browse files
authored
Merge pull request #375 from ahrefs/feature/heterogeneous-precision
Support heterogeneous precision for primitive operations
2 parents 03c80e7 + 181e6be commit 16b612b

14 files changed

+594
-183
lines changed

arrayjit/lib/assignments.ml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,14 @@ let%track4_sexp to_low_level code =
300300
| Ops.Void_prec -> failwith "Cannot use vector operation with void precision")
301301
in
302302
Set_from_vec
303-
{ tn = lhs; idcs = lhs_idcs; length; vec_unop = op; arg = rhs_ll; debug = "" }
303+
{
304+
tn = lhs;
305+
idcs = lhs_idcs;
306+
length;
307+
vec_unop = op;
308+
arg = (rhs_ll, Low_level.scalar_precision rhs_ll);
309+
debug = "";
310+
}
304311
in
305312
let rec for_loop rev_iters = function
306313
| [] -> basecase rev_iters

arrayjit/lib/c_syntax.ml

Lines changed: 103 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -398,12 +398,16 @@ module C_syntax (B : C_syntax_config) = struct
398398
~args_docs:[]
399399
else string "/* " ^^ string message ^^ string " */"
400400
| Staged_compilation callback -> callback ()
401-
| Set_from_vec { tn; idcs; length; vec_unop; arg; debug } ->
401+
| Set_from_vec { tn; idcs; length; vec_unop; arg = arg, arg_prec; debug } ->
402402
let ident_doc = string (get_ident tn) in
403403
let dims = Lazy.force tn.dims in
404404
let prec = Lazy.force tn.prec in
405-
(* FIXME: this precision is hardcoded, bad, bad practice. *)
406-
let arg_prec = Ops.uint4x32 in
405+
(* Determine argument precision based on operation homogeneity *)
406+
let arg_prec =
407+
if Ops.is_homogeneous_prec_vec_unop vec_unop then prec
408+
(* Homogeneous: argument uses result precision *)
409+
else arg_prec
410+
in
407411
let local_defs, arg_doc = pp_scalar arg_prec arg in
408412
let local_defs = pp_local_defs local_defs in
409413
(* Generate the function call *)
@@ -564,30 +568,62 @@ module C_syntax (B : C_syntax_config) = struct
564568
let idx_doc = if PPrint.is_empty idx_doc then string "0" else idx_doc in
565569
let expr = string prefix ^^ idx_doc ^^ string postfix in
566570
([], expr)
567-
| Binop (Arg1, v1, _v2) -> pp_scalar prec v1
568-
| Binop (Arg2, _v1, v2) -> pp_scalar prec v2
569-
| Ternop (op, v1, v2, v3) ->
570-
let d1, e1 = pp_scalar prec v1 in
571-
let d2, e2 = pp_scalar prec v2 in
572-
let d3, e3 = pp_scalar prec v3 in
571+
| Binop (Arg1, (v1, _), _v2) -> pp_scalar prec v1
572+
| Binop (Arg2, _v1, (v2, _)) -> pp_scalar prec v2
573+
| Ternop (op, (v1, v1_prec), (v2, v2_prec), (v3, v3_prec)) ->
574+
let d1, e1, d2, e2, d3, e3 =
575+
if Ops.is_homogeneous_prec_ternop op then
576+
(* Homogeneous: all arguments use result precision *)
577+
let d1, e1 = pp_scalar prec v1 in
578+
let d2, e2 = pp_scalar prec v2 in
579+
let d3, e3 = pp_scalar prec v3 in
580+
(d1, e1, d2, e2, d3, e3)
581+
else
582+
(* Heterogeneous: arguments keep their natural precision *)
583+
match op with
584+
| Ops.Where ->
585+
(* For Where: condition keeps its precision, then/else use result precision *)
586+
(* Note: we evaluate condition without precision conversion, but then/else
587+
need to match the result precision for the final assignment *)
588+
let d1, e1 = pp_scalar v1_prec v1 in
589+
(* condition: no conversion *)
590+
let d2, e2 = pp_scalar prec v2 in
591+
(* then: result precision *)
592+
let d3, e3 = pp_scalar prec v3 in
593+
(* else: result precision *)
594+
(d1, e1, d2, e2, d3, e3)
595+
| _ ->
596+
(* Other heterogeneous ternary ops would go here *)
597+
let d1, e1 = pp_scalar v1_prec v1 in
598+
let d2, e2 = pp_scalar v2_prec v2 in
599+
let d3, e3 = pp_scalar v3_prec v3 in
600+
(d1, e1, d2, e2, d3, e3)
601+
in
573602
let defs = List.concat [ d1; d2; d3 ] in
574603
let expr = group (B.ternop_syntax prec op e1 e2 e3) in
575604
(defs, expr)
576-
| Binop (op, v1, v2) ->
577-
let d1, e1 = pp_scalar prec v1 in
578-
let d2, e2 = pp_scalar prec v2 in
605+
| Binop (op, (v1, v1_prec), (v2, v2_prec)) ->
606+
let d1, e1, d2, e2 =
607+
if Ops.is_homogeneous_prec_binop op then
608+
(* Homogeneous: both arguments use result precision *)
609+
let d1, e1 = pp_scalar prec v1 in
610+
let d2, e2 = pp_scalar prec v2 in
611+
(d1, e1, d2, e2)
612+
else
613+
(* Heterogeneous: arguments keep their natural precision *)
614+
(* Currently all binops are homogeneous, but this is here for future extension *)
615+
let d1, e1 = pp_scalar v1_prec v1 in
616+
let d2, e2 = pp_scalar v2_prec v2 in
617+
(d1, e1, d2, e2)
618+
in
579619
let defs = List.concat [ d1; d2 ] in
580620
let expr = group (B.binop_syntax prec op e1 e2) in
581621
(defs, expr)
582-
| Unop (op, v) ->
622+
| Unop (op, (v, v_prec)) ->
583623
let arg_prec =
584-
match op with
585-
| Ops.Uint4x32_to_prec_uniform1 ->
586-
(* The argument to Uint4x32_to_prec_uniform1 must be evaluated with uint4x32
587-
precision, regardless of the target precision. This handles the case where the
588-
operation is inlined as part of a scalar expression. *)
589-
Ops.uint4x32
590-
| _ -> prec
624+
if Ops.is_homogeneous_prec_unop op then prec
625+
(* Homogeneous: argument uses result precision *)
626+
else v_prec
591627
in
592628
let defs, expr_v = pp_scalar arg_prec v in
593629
let expr = group (B.unop_syntax prec op expr_v) in
@@ -651,19 +687,55 @@ module C_syntax (B : C_syntax_config) = struct
651687
| Embed_index idx ->
652688
let idx_doc = pp_axis_index idx in
653689
((if PPrint.is_empty idx_doc then string "0" else idx_doc), [])
654-
| Binop (Arg1, v1, _v2) -> debug_float prec v1
655-
| Binop (Arg2, _v1, v2) -> debug_float prec v2
656-
| Ternop (op, v1, v2, v3) ->
657-
let v1_doc, idcs1 = debug_float prec v1 in
658-
let v2_doc, idcs2 = debug_float prec v2 in
659-
let v3_doc, idcs3 = debug_float prec v3 in
690+
| Binop (Arg1, (v1, _), _v2) -> debug_float prec v1
691+
| Binop (Arg2, _v1, (v2, _)) -> debug_float prec v2
692+
| Ternop (op, (v1, v1_prec), (v2, v2_prec), (v3, v3_prec)) ->
693+
let v1_doc, idcs1, v2_doc, idcs2, v3_doc, idcs3 =
694+
if Ops.is_homogeneous_prec_ternop op then
695+
(* Homogeneous: all arguments use result precision *)
696+
let v1_doc, idcs1 = debug_float prec v1 in
697+
let v2_doc, idcs2 = debug_float prec v2 in
698+
let v3_doc, idcs3 = debug_float prec v3 in
699+
(v1_doc, idcs1, v2_doc, idcs2, v3_doc, idcs3)
700+
else
701+
(* Heterogeneous: handle based on operation *)
702+
match op with
703+
| Ops.Where ->
704+
let v1_doc, idcs1 = debug_float v1_prec v1 in
705+
(* condition: no conversion *)
706+
let v2_doc, idcs2 = debug_float prec v2 in
707+
(* then: result precision *)
708+
let v3_doc, idcs3 = debug_float prec v3 in
709+
(* else: result precision *)
710+
(v1_doc, idcs1, v2_doc, idcs2, v3_doc, idcs3)
711+
| _ ->
712+
let v1_doc, idcs1 = debug_float v1_prec v1 in
713+
let v2_doc, idcs2 = debug_float v2_prec v2 in
714+
let v3_doc, idcs3 = debug_float v3_prec v3 in
715+
(v1_doc, idcs1, v2_doc, idcs2, v3_doc, idcs3)
716+
in
660717
(B.ternop_syntax prec op v1_doc v2_doc v3_doc, idcs1 @ idcs2 @ idcs3)
661-
| Binop (op, v1, v2) ->
662-
let v1_doc, idcs1 = debug_float prec v1 in
663-
let v2_doc, idcs2 = debug_float prec v2 in
718+
| Binop (op, (v1, v1_prec), (v2, v2_prec)) ->
719+
let v1_doc, idcs1, v2_doc, idcs2 =
720+
if Ops.is_homogeneous_prec_binop op then
721+
(* Homogeneous: both arguments use result precision *)
722+
let v1_doc, idcs1 = debug_float prec v1 in
723+
let v2_doc, idcs2 = debug_float prec v2 in
724+
(v1_doc, idcs1, v2_doc, idcs2)
725+
else
726+
(* Heterogeneous: arguments keep their natural precision *)
727+
let v1_doc, idcs1 = debug_float v1_prec v1 in
728+
let v2_doc, idcs2 = debug_float v2_prec v2 in
729+
(v1_doc, idcs1, v2_doc, idcs2)
730+
in
664731
(B.binop_syntax prec op v1_doc v2_doc, idcs1 @ idcs2)
665-
| Unop (op, v) ->
666-
let v_doc, idcs = debug_float prec v in
732+
| Unop (op, (v, v_prec)) ->
733+
let arg_prec =
734+
if Ops.is_homogeneous_prec_unop op then prec
735+
(* Homogeneous: argument uses result precision *)
736+
else v_prec
737+
in
738+
let v_doc, idcs = debug_float arg_prec v in
667739
(B.unop_syntax prec op v_doc, idcs)
668740

669741
let compile_main llc : PPrint.document = pp_ll llc

0 commit comments

Comments
 (0)