@@ -36,7 +36,14 @@ type t =
3636 | For_loop of { index : Indexing .symbol ; from_ : int ; to_ : int ; body : t ; trace_it : bool }
3737 | Zero_out of Tn .t
3838 | Set of { tn : Tn .t ; idcs : Indexing .axis_index array ; llsc : scalar_t ; mutable debug : string }
39- | Set_from_vec of { tn : Tn .t ; idcs : Indexing .axis_index array ; length : int ; vec_unop : Ops .vec_unop ; arg : scalar_t ; mutable debug : string }
39+ | Set_from_vec of {
40+ tn : Tn .t ;
41+ idcs : Indexing .axis_index array ;
42+ length : int ;
43+ vec_unop : Ops .vec_unop ;
44+ arg : scalar_t ;
45+ mutable debug : string ;
46+ }
4047 | Set_local of scope_id * scalar_t
4148[@@ deriving sexp_of , equal ]
4249
@@ -255,14 +262,15 @@ let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
255262 let traced : traced_array = get_node traced_store tn in
256263 (* Vector operations cannot be scalar constexpr *)
257264 traced.is_scalar_constexpr < - false ;
258- if first_visit then
259- traced.is_complex < - traced.is_complex || is_complex_comp traced_store arg;
265+ if first_visit then traced.is_complex < - false ;
260266 (* Mark all positions that will be written to *)
261267 for i = 0 to length - 1 do
262268 let pos_idcs = Array. copy idcs in
263269 (match pos_idcs.(Array. length pos_idcs - 1 ) with
264270 | Fixed_idx idx -> pos_idcs.(Array. length pos_idcs - 1 ) < - Fixed_idx (idx + i)
265- | _ -> failwith " Set_from_vec: last index must be Fixed_idx" );
271+ | _ ->
272+ (* FIXME: NOT IMPLEMENTED YET *)
273+ failwith " FIXME: Set_from_vec: NOT IMPLEMENTED YET general index" );
266274 Hash_set. add traced.assignments (lookup env pos_idcs)
267275 done ;
268276 Array. iter idcs ~f: (function
@@ -566,10 +574,11 @@ let inline_computation ~id computations_table traced static_indices call_args =
566574 | Set { tn; idcs; llsc; debug = _ } when Tn. equal tn traced.tn ->
567575 assert ([% equal: Indexing. axis_index array option ] (Some idcs) def_args);
568576 Some (Set_local (id, loop_float env llsc))
569- | Set_from_vec { tn; idcs; length = _ ; vec_unop = _ ; arg = _ ; debug = _ } when Tn. equal tn traced.tn ->
577+ | Set_from_vec { tn; idcs; length = _; vec_unop = _; arg = _; debug = _ }
578+ when Tn. equal tn traced.tn ->
570579 assert ([% equal: Indexing. axis_index array option ] (Some idcs) def_args);
571580 (* For vector operations, we cannot inline them as scalar operations *)
572- raise @@ Non_virtual 14
581+ raise @@ Non_virtual 140
573582 | Zero_out _ -> None
574583 | Set _ -> None
575584 | Set_from_vec _ -> None
@@ -649,7 +658,9 @@ let virtual_llc computations_table traced_store reverse_node_map static_indices
649658 | Set_from_vec { tn; idcs; length; vec_unop; arg; debug } ->
650659 let traced : traced_array = get_node traced_store tn in
651660 let next = if Tn. known_non_virtual traced.tn then process_for else Set. add process_for tn in
652- let result = Set_from_vec { tn; idcs; length; vec_unop; arg = loop_float ~process_for: next arg; debug } in
661+ let result =
662+ Set_from_vec { tn; idcs; length; vec_unop; arg = loop_float ~process_for: next arg; debug }
663+ in
653664 if (not @@ Set. mem process_for tn) && (not @@ Tn. known_non_virtual traced.tn) then
654665 check_and_store_virtual computations_table traced static_indices result;
655666 result
@@ -735,7 +746,9 @@ let cleanup_virtual_llc reverse_node_map ~static_indices (llc : t) : t =
735746 else (
736747 assert (
737748 Array. for_all idcs ~f: (function Indexing. Iterator s -> Set. mem env_dom s | _ -> true ));
738- Some (Set_from_vec { tn; idcs; length; vec_unop; arg = loop_float ~balanced ~env_dom arg; debug }))
749+ Some
750+ (Set_from_vec
751+ { tn; idcs; length; vec_unop; arg = loop_float ~balanced ~env_dom arg; debug }))
739752 | Set_local (id , llsc ) ->
740753 assert (not @@ Tn. known_non_virtual id.tn);
741754 Tn. update_memory_mode id.tn Virtual 16 ;
@@ -867,12 +880,14 @@ let simplify_llc llc =
867880 | Binop (Arg1, llv1 , _ ) -> loop_float llv1
868881 | Binop (Arg2, _ , llv2 ) -> loop_float llv2
869882 | Binop (op , Constant c1 , Constant c2 ) -> Constant (Ops. interpret_binop op c1 c2)
870- | Binop (Add , llsc, Constant 0. ) | Binop (Sub , llsc, Constant 0. ) | Binop (Add , Constant 0. , llsc)
871- ->
883+ | Binop (Add , llsc, Constant 0. )
884+ | Binop (Sub , llsc, Constant 0. )
885+ | Binop (Add, Constant 0. , llsc ) ->
872886 loop_float llsc
873887 | Binop (Sub, Constant 0. , llsc ) -> loop_float @@ Binop (Mul , Constant (- 1. ), llsc)
874- | Binop (Mul , llsc, Constant 1. ) | Binop (Div , llsc, Constant 1. ) | Binop (Mul , Constant 1. , llsc)
875- ->
888+ | Binop (Mul , llsc, Constant 1. )
889+ | Binop (Div , llsc, Constant 1. )
890+ | Binop (Mul, Constant 1. , llsc ) ->
876891 loop_float llsc
877892 | Binop (Mul , _ , Constant 0. ) | Binop (Div , Constant 0. , _ ) | Binop (Mul, Constant 0. , _ ) ->
878893 Constant 0.
@@ -1130,14 +1145,17 @@ let to_doc_cstyle ?name ?static_indices () llc =
11301145 p.debug < - Buffer. contents b);
11311146 result
11321147 | Set_from_vec p ->
1148+ let prec = Lazy. force p.tn.prec in
1149+ let prefix, postfix = Ops. vec_unop_c_syntax prec p.vec_unop in
1150+ (* TODO: this assumes argument is generated from the high-level code, which means it is
1151+ either Get or Local_scope -- they don't need precision. *)
1152+ let vec_result = string prefix ^^ doc_of_float Ops. Void_prec p.arg ^^ string postfix in
1153+ let length_doc = string (" <" ^ Int. to_string p.length ^ " >" ) in
11331154 let result =
11341155 group
11351156 (doc_ident p.tn
11361157 ^^ brackets (pp_indices p.idcs)
1137- ^^ string " := "
1138- ^^ string (Ops. vec_unop_cd_syntax p.vec_unop)
1139- ^^ string " (" ^^ doc_of_float (Ops. uint4x32) p.arg ^^ string " , "
1140- ^^ int p.length ^^ string " );" )
1158+ ^^ length_doc ^^ string " := " ^^ vec_result ^^ string " ;" )
11411159 in
11421160 if not (String. is_empty p.debug) then (
11431161 let b = Buffer. create 100 in
@@ -1215,22 +1233,23 @@ let to_doc ?name ?static_indices () llc =
12151233 p.debug < - Buffer. contents b;
12161234 result
12171235 | Set_from_vec p ->
1236+ let length_doc = string (" <" ^ Int. to_string p.length ^ " >" ) in
12181237 let result =
12191238 group
12201239 (doc_ident p.tn
12211240 ^^ brackets (pp_indices p.idcs)
1222- ^^ string " := "
1241+ ^^ length_doc ^^ string " := "
12231242 ^^ string (Ops. vec_unop_cd_syntax p.vec_unop)
1224- ^^ string " (" ^^ doc_of_float p.arg ^^ string " , "
1225- ^^ int p.length ^^ string " );" )
1243+ ^^ string " (" ^^ doc_of_float p.arg ^^ string " , " ^^ length_doc ^^ string " );" )
12261244 in
12271245 let b = Buffer. create 100 in
12281246 PPrint.ToBuffer. pretty 0.7 100 b result;
12291247 p.debug < - Buffer. contents b;
12301248 result
12311249 | Comment message -> string (" /* " ^ message ^ " */" )
12321250 | Staged_compilation callback -> callback ()
1233- | Set_local (id , llsc ) -> group (doc_local id ^^ string " := " ^^ doc_of_float llsc ^^ string " ;" )
1251+ | Set_local (id , llsc ) ->
1252+ group (doc_local id ^^ string " := " ^^ doc_of_float llsc ^^ string " ;" )
12341253 and doc_of_float value =
12351254 match value with
12361255 | Local_scope { id; body; _ } ->
0 commit comments