@@ -398,12 +398,16 @@ module C_syntax (B : C_syntax_config) = struct
398398 ~args_docs: []
399399 else string " /* " ^^ string message ^^ string " */"
400400 | Staged_compilation callback -> callback ()
401- | Set_from_vec { tn; idcs; length; vec_unop; arg; debug } ->
401+ | Set_from_vec { tn; idcs; length; vec_unop; arg = arg , arg_prec ; debug } ->
402402 let ident_doc = string (get_ident tn) in
403403 let dims = Lazy. force tn.dims in
404404 let prec = Lazy. force tn.prec in
405- (* FIXME: this precision is hardcoded, bad, bad practice. *)
406- let arg_prec = Ops. uint4x32 in
405+ (* Determine argument precision based on operation homogeneity *)
406+ let arg_prec =
407+ if Ops. is_homogeneous_prec_vec_unop vec_unop then prec
408+ (* Homogeneous: argument uses result precision *)
409+ else arg_prec
410+ in
407411 let local_defs, arg_doc = pp_scalar arg_prec arg in
408412 let local_defs = pp_local_defs local_defs in
409413 (* Generate the function call *)
@@ -564,30 +568,62 @@ module C_syntax (B : C_syntax_config) = struct
564568 let idx_doc = if PPrint. is_empty idx_doc then string " 0" else idx_doc in
565569 let expr = string prefix ^^ idx_doc ^^ string postfix in
566570 ([] , expr)
567- | Binop (Arg1, v1 , _v2 ) -> pp_scalar prec v1
568- | Binop (Arg2, _v1 , v2 ) -> pp_scalar prec v2
569- | Ternop (op , v1 , v2 , v3 ) ->
570- let d1, e1 = pp_scalar prec v1 in
571- let d2, e2 = pp_scalar prec v2 in
572- let d3, e3 = pp_scalar prec v3 in
571+ | Binop (Arg1, (v1 , _ ), _v2 ) -> pp_scalar prec v1
572+ | Binop (Arg2, _v1 , (v2 , _ )) -> pp_scalar prec v2
573+ | Ternop (op , (v1 , v1_prec ), (v2 , v2_prec ), (v3 , v3_prec )) ->
574+ let d1, e1, d2, e2, d3, e3 =
575+ if Ops. is_homogeneous_prec_ternop op then
576+ (* Homogeneous: all arguments use result precision *)
577+ let d1, e1 = pp_scalar prec v1 in
578+ let d2, e2 = pp_scalar prec v2 in
579+ let d3, e3 = pp_scalar prec v3 in
580+ (d1, e1, d2, e2, d3, e3)
581+ else
582+ (* Heterogeneous: arguments keep their natural precision *)
583+ match op with
584+ | Ops. Where ->
585+ (* For Where: condition keeps its precision, then/else use result precision *)
586+ (* Note: we evaluate condition without precision conversion, but then/else
587+ need to match the result precision for the final assignment *)
588+ let d1, e1 = pp_scalar v1_prec v1 in
589+ (* condition: no conversion *)
590+ let d2, e2 = pp_scalar prec v2 in
591+ (* then: result precision *)
592+ let d3, e3 = pp_scalar prec v3 in
593+ (* else: result precision *)
594+ (d1, e1, d2, e2, d3, e3)
595+ | _ ->
596+ (* Other heterogeneous ternary ops would go here *)
597+ let d1, e1 = pp_scalar v1_prec v1 in
598+ let d2, e2 = pp_scalar v2_prec v2 in
599+ let d3, e3 = pp_scalar v3_prec v3 in
600+ (d1, e1, d2, e2, d3, e3)
601+ in
573602 let defs = List. concat [ d1; d2; d3 ] in
574603 let expr = group (B. ternop_syntax prec op e1 e2 e3) in
575604 (defs, expr)
576- | Binop (op , v1 , v2 ) ->
577- let d1, e1 = pp_scalar prec v1 in
578- let d2, e2 = pp_scalar prec v2 in
605+ | Binop (op , (v1 , v1_prec ), (v2 , v2_prec )) ->
606+ let d1, e1, d2, e2 =
607+ if Ops. is_homogeneous_prec_binop op then
608+ (* Homogeneous: both arguments use result precision *)
609+ let d1, e1 = pp_scalar prec v1 in
610+ let d2, e2 = pp_scalar prec v2 in
611+ (d1, e1, d2, e2)
612+ else
613+ (* Heterogeneous: arguments keep their natural precision *)
614+ (* Currently all binops are homogeneous, but this is here for future extension *)
615+ let d1, e1 = pp_scalar v1_prec v1 in
616+ let d2, e2 = pp_scalar v2_prec v2 in
617+ (d1, e1, d2, e2)
618+ in
579619 let defs = List. concat [ d1; d2 ] in
580620 let expr = group (B. binop_syntax prec op e1 e2) in
581621 (defs, expr)
582- | Unop (op , v ) ->
622+ | Unop (op , ( v , v_prec ) ) ->
583623 let arg_prec =
584- match op with
585- | Ops. Uint4x32_to_prec_uniform1 ->
586- (* The argument to Uint4x32_to_prec_uniform1 must be evaluated with uint4x32
587- precision, regardless of the target precision. This handles the case where the
588- operation is inlined as part of a scalar expression. *)
589- Ops. uint4x32
590- | _ -> prec
624+ if Ops. is_homogeneous_prec_unop op then prec
625+ (* Homogeneous: argument uses result precision *)
626+ else v_prec
591627 in
592628 let defs, expr_v = pp_scalar arg_prec v in
593629 let expr = group (B. unop_syntax prec op expr_v) in
@@ -651,19 +687,55 @@ module C_syntax (B : C_syntax_config) = struct
651687 | Embed_index idx ->
652688 let idx_doc = pp_axis_index idx in
653689 ((if PPrint. is_empty idx_doc then string " 0" else idx_doc), [] )
654- | Binop (Arg1, v1 , _v2 ) -> debug_float prec v1
655- | Binop (Arg2, _v1 , v2 ) -> debug_float prec v2
656- | Ternop (op , v1 , v2 , v3 ) ->
657- let v1_doc, idcs1 = debug_float prec v1 in
658- let v2_doc, idcs2 = debug_float prec v2 in
659- let v3_doc, idcs3 = debug_float prec v3 in
690+ | Binop (Arg1, (v1 , _ ), _v2 ) -> debug_float prec v1
691+ | Binop (Arg2, _v1 , (v2 , _ )) -> debug_float prec v2
692+ | Ternop (op , (v1 , v1_prec ), (v2 , v2_prec ), (v3 , v3_prec )) ->
693+ let v1_doc, idcs1, v2_doc, idcs2, v3_doc, idcs3 =
694+ if Ops. is_homogeneous_prec_ternop op then
695+ (* Homogeneous: all arguments use result precision *)
696+ let v1_doc, idcs1 = debug_float prec v1 in
697+ let v2_doc, idcs2 = debug_float prec v2 in
698+ let v3_doc, idcs3 = debug_float prec v3 in
699+ (v1_doc, idcs1, v2_doc, idcs2, v3_doc, idcs3)
700+ else
701+ (* Heterogeneous: handle based on operation *)
702+ match op with
703+ | Ops. Where ->
704+ let v1_doc, idcs1 = debug_float v1_prec v1 in
705+ (* condition: no conversion *)
706+ let v2_doc, idcs2 = debug_float prec v2 in
707+ (* then: result precision *)
708+ let v3_doc, idcs3 = debug_float prec v3 in
709+ (* else: result precision *)
710+ (v1_doc, idcs1, v2_doc, idcs2, v3_doc, idcs3)
711+ | _ ->
712+ let v1_doc, idcs1 = debug_float v1_prec v1 in
713+ let v2_doc, idcs2 = debug_float v2_prec v2 in
714+ let v3_doc, idcs3 = debug_float v3_prec v3 in
715+ (v1_doc, idcs1, v2_doc, idcs2, v3_doc, idcs3)
716+ in
660717 (B. ternop_syntax prec op v1_doc v2_doc v3_doc, idcs1 @ idcs2 @ idcs3)
661- | Binop (op , v1 , v2 ) ->
662- let v1_doc, idcs1 = debug_float prec v1 in
663- let v2_doc, idcs2 = debug_float prec v2 in
718+ | Binop (op , (v1 , v1_prec ), (v2 , v2_prec )) ->
719+ let v1_doc, idcs1, v2_doc, idcs2 =
720+ if Ops. is_homogeneous_prec_binop op then
721+ (* Homogeneous: both arguments use result precision *)
722+ let v1_doc, idcs1 = debug_float prec v1 in
723+ let v2_doc, idcs2 = debug_float prec v2 in
724+ (v1_doc, idcs1, v2_doc, idcs2)
725+ else
726+ (* Heterogeneous: arguments keep their natural precision *)
727+ let v1_doc, idcs1 = debug_float v1_prec v1 in
728+ let v2_doc, idcs2 = debug_float v2_prec v2 in
729+ (v1_doc, idcs1, v2_doc, idcs2)
730+ in
664731 (B. binop_syntax prec op v1_doc v2_doc, idcs1 @ idcs2)
665- | Unop (op , v ) ->
666- let v_doc, idcs = debug_float prec v in
732+ | Unop (op , (v , v_prec )) ->
733+ let arg_prec =
734+ if Ops. is_homogeneous_prec_unop op then prec
735+ (* Homogeneous: argument uses result precision *)
736+ else v_prec
737+ in
738+ let v_doc, idcs = debug_float arg_prec v in
667739 (B. unop_syntax prec op v_doc, idcs)
668740
669741 let compile_main llc : PPrint.document = pp_ll llc
0 commit comments