Skip to content

Commit 20809e9

Browse files
committed
Rename float_t -> scalar_t
1 parent 9a988fd commit 20809e9

File tree

4 files changed

+32
-31
lines changed

4 files changed

+32
-31
lines changed

arrayjit/lib/c_syntax.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,7 @@ module C_syntax (B : C_syntax_config) = struct
576576
in
577577
if PPrint.is_empty local_defs then assignment else local_defs ^^ hardline ^^ assignment
578578

579-
and pp_float (prec : Ops.prec) (vcomp : Low_level.float_t) : PPrint.document * PPrint.document =
579+
and pp_float (prec : Ops.prec) (vcomp : Low_level.scalar_t) : PPrint.document * PPrint.document =
580580
(* Returns (local definitions, value expression) *)
581581
let open PPrint in
582582
match vcomp with
@@ -658,7 +658,7 @@ module C_syntax (B : C_syntax_config) = struct
658658
let expr = group (B.unop_syntax prec op expr_v) in
659659
(defs, expr)
660660

661-
and debug_float (prec : Ops.prec) (value : Low_level.float_t) :
661+
and debug_float (prec : Ops.prec) (value : Low_level.scalar_t) :
662662
PPrint.document
663663
* [ `Accessor of Indexing.axis_index array * int array | `Value of PPrint.document ] list =
664664
(* Returns (value expression doc, list of arguments for printf) *)

arrayjit/lib/low_level.ml

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,18 @@ type t =
3535
| Seq of t * t
3636
| For_loop of { index : Indexing.symbol; from_ : int; to_ : int; body : t; trace_it : bool }
3737
| Zero_out of Tn.t
38-
| Set of { tn : Tn.t; idcs : Indexing.axis_index array; llv : float_t; mutable debug : string }
39-
| Set_local of scope_id * float_t
38+
| Set of { tn : Tn.t; idcs : Indexing.axis_index array; llv : scalar_t; mutable debug : string }
39+
| Set_local of scope_id * scalar_t
4040
[@@deriving sexp_of, equal]
4141

42-
and float_t =
42+
and scalar_t =
4343
| Local_scope of { id : scope_id; body : t; orig_indices : Indexing.axis_index array }
4444
| Get_local of scope_id
4545
| Get of Tn.t * Indexing.axis_index array
4646
| Get_merge_buffer of Tn.t * Indexing.axis_index array
47-
| Ternop of Ops.ternop * float_t * float_t * float_t
48-
| Binop of Ops.binop * float_t * float_t
49-
| Unop of Ops.unop * float_t
47+
| Ternop of Ops.ternop * scalar_t * scalar_t * scalar_t
48+
| Binop of Ops.binop * scalar_t * scalar_t
49+
| Unop of Ops.unop * scalar_t
5050
| Constant of float
5151
| Embed_index of Indexing.axis_index
5252
[@@deriving sexp_of, equal, compare]
@@ -530,7 +530,7 @@ let inline_computation ~id computations_table traced static_indices call_args =
530530
| Set_local (id, llv) -> Some (Set_local (id, loop_float env llv))
531531
| Comment _ -> Some llc
532532
| Staged_compilation _ -> Some llc
533-
and loop_float env llv : float_t =
533+
and loop_float env llv : scalar_t =
534534
match llv with
535535
| Constant _ -> llv
536536
| Get (tn, indices) when Tn.equal tn traced.tn ->
@@ -564,7 +564,7 @@ let inline_computation ~id computations_table traced static_indices call_args =
564564

565565
let optimize_integer_pow = ref true
566566

567-
let rec unroll_pow ~(base : float_t) ~(exp : int) : float_t =
567+
let rec unroll_pow ~(base : scalar_t) ~(exp : int) : scalar_t =
568568
if exp < 0 then unroll_pow ~base:(Binop (Div, Constant 1., base)) ~exp:(Int.neg exp)
569569
else if exp = 0 then Constant 1.
570570
else Fn.apply_n_times ~n:(exp - 1) (fun accu -> Binop (Mul, base, accu)) base
@@ -603,7 +603,7 @@ let virtual_llc computations_table traced_store reverse_node_map static_indices
603603
| Set_local (id, llv) -> Set_local (id, loop_float ~process_for llv)
604604
| Comment _ -> llc
605605
| Staged_compilation _ -> llc
606-
and loop_float ~process_for (llv : float_t) : float_t =
606+
and loop_float ~process_for (llv : scalar_t) : scalar_t =
607607
match llv with
608608
| Constant _ -> llv
609609
| Get (tn, _) when Set.mem process_for tn ->
@@ -680,7 +680,7 @@ let cleanup_virtual_llc reverse_node_map ~static_indices (llc : t) : t =
680680
Some (Set_local (id, loop_float ~balanced ~env_dom llv))
681681
| Comment _ -> Some llc
682682
| Staged_compilation _ -> Some llc
683-
and loop_float ~balanced ~env_dom (llv : float_t) : float_t =
683+
and loop_float ~balanced ~env_dom (llv : scalar_t) : scalar_t =
684684
let loop = loop_float ~balanced ~env_dom in
685685
match llv with
686686
| Constant _ -> llv
@@ -770,7 +770,7 @@ let simplify_llc llc =
770770
| Set_local (id, llv) -> Set_local (id, loop_float llv)
771771
| Comment _ -> llc
772772
| Staged_compilation _ -> llc
773-
and loop_float (llv : float_t) : float_t =
773+
and loop_float (llv : scalar_t) : scalar_t =
774774
let local_scope_body, llv' =
775775
match llv with
776776
| Local_scope opts ->
@@ -838,9 +838,9 @@ let simplify_llc llc =
838838
| Binop (Div, Binop (Div, llv1, llv2), llv3) ->
839839
loop_float @@ Binop (Div, llv1, Binop (Mul, llv2, llv3))
840840
| Binop (ToPowOf, llv1, llv2) -> (
841-
let v1 : float_t = loop_float llv1 in
842-
let v2 : float_t = loop_float llv2 in
843-
let result : float_t = Binop (ToPowOf, v1, v2) in
841+
let v1 : scalar_t = loop_float llv1 in
842+
let v2 : scalar_t = loop_float llv2 in
843+
let result : scalar_t = Binop (ToPowOf, v1, v2) in
844844
if not !optimize_integer_pow then result
845845
else
846846
match v2 with

arrayjit/lib/low_level.mli

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,31 +16,32 @@ type scope_id = Scope_id.t = { tn : Tnode.t; scope_id : int }
1616

1717
(** {2 Low-level representation} *)
1818

19-
(** Cases: [t] -- code, [float_t] -- single number at some precision. *)
19+
(** Cases: [t] -- code, [scalar_t] -- single number at some precision. *)
2020
type t =
2121
| Noop
2222
| Comment of string
2323
| Staged_compilation of (unit -> PPrint.document)
2424
| Seq of t * t
2525
| For_loop of { index : Indexing.symbol; from_ : int; to_ : int; body : t; trace_it : bool }
2626
| Zero_out of Tnode.t
27-
| Set of { tn : Tnode.t; idcs : Indexing.axis_index array; llv : float_t; mutable debug : string }
28-
| Set_local of scope_id * float_t
27+
| Set of { tn : Tnode.t; idcs : Indexing.axis_index array; llv : scalar_t; mutable debug : string }
28+
| Set_local of scope_id * scalar_t
2929
[@@deriving sexp_of, equal]
3030

31-
and float_t =
31+
and scalar_t =
3232
| Local_scope of { id : scope_id; body : t; orig_indices : Indexing.axis_index array }
3333
| Get_local of scope_id
3434
| Get of Tnode.t * Indexing.axis_index array
3535
| Get_merge_buffer of Tnode.t * Indexing.axis_index array
36-
| Ternop of Ops.ternop * float_t * float_t * float_t
37-
| Binop of Ops.binop * float_t * float_t
38-
| Unop of Ops.unop * float_t
36+
| Ternop of Ops.ternop * scalar_t * scalar_t * scalar_t
37+
| Binop of Ops.binop * scalar_t * scalar_t
38+
| Unop of Ops.unop * scalar_t
3939
| Constant of float
4040
| Embed_index of Indexing.axis_index
41+
| Vec of int * scalar_t
4142
[@@deriving sexp_of, equal, compare]
4243

43-
val apply_op : Ops.op -> float_t array -> float_t
44+
val apply_op : Ops.op -> scalar_t array -> scalar_t
4445
val flat_lines : t list -> t list
4546
val unflat_lines : t list -> t
4647
val loop_over_dims : int array -> body:(Indexing.axis_index array -> t) -> t

arrayjit/lib/lowering_and_inlining.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,22 @@ type t =
1414
| Seq of t * t
1515
| For_loop of { index : Indexing.symbol; from_ : int; to_ : int; body : t; trace_it : bool }
1616
| Zero_out of Tnode.t
17-
| Set of { array : Tnode.t; idcs : Indexing.axis_index array; llv : float_t; mutable debug : string }
18-
| Set_local of scope_id * float_t
17+
| Set of { array : Tnode.t; idcs : Indexing.axis_index array; llv : scalar_t; mutable debug : string }
18+
| Set_local of scope_id * scalar_t
1919
20-
and float_t =
20+
and scalar_t =
2121
| Local_scope of { id : scope_id; body : t; orig_indices : Indexing.axis_index array }
2222
| Get_local of scope_id
2323
| Access of dedicated_access * Indexing.axis_index array option
2424
| Get of Tnode.t * Indexing.axis_index array
25-
| Ternop of Ops.ternop * float_t * float_t * float_t
26-
| Binop of Ops.binop * float_t * float_t
27-
| Unop of Ops.unop * float_t
25+
| Ternop of Ops.ternop * scalar_t * scalar_t * scalar_t
26+
| Binop of Ops.binop * scalar_t * scalar_t
27+
| Unop of Ops.unop * scalar_t
2828
| Constant of float
2929
| Embed_index of Indexing.axis_index
3030
```
3131

32-
`t` represents code/statements while `float_t` represents scalar expressions. The `trace_it` flag in `For_loop` indicates whether the loop should be traced for optimization (its initial segment will be unrolled for analysis).
32+
`t` represents code/statements while `scalar_t` represents scalar expressions. The `trace_it` flag in `For_loop` indicates whether the loop should be traced for optimization (its initial segment will be unrolled for analysis).
3333

3434
## Translation from Assignments
3535

0 commit comments

Comments
 (0)