Skip to content

Commit 77b3395

Browse files
committed
Memorize size_in_bytes inside Tnode.t
1 parent 186a2d3 commit 77b3395

File tree

5 files changed

+13
-14
lines changed

5 files changed

+13
-14
lines changed

arrayjit/lib/backends.ml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ module Add_device
271271

272272
let device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
273273
let s = dst.stream in
274-
let size_in_bytes = Tnode.size_in_bytes tn in
274+
let size_in_bytes = Lazy.force tn.Tnode.size_in_bytes in
275275
let work =
276276
(* TODO: log the operation if [Utils.settings.with_log_level > 1]. *)
277277
match (into_merge_buffer, dst_ptr) with
@@ -280,7 +280,6 @@ module Add_device
280280
| Streaming_for _, _ -> fun () -> s.merge_buffer := Some { ptr = src_ptr; size_in_bytes }
281281
| Copy, _ ->
282282
fun () ->
283-
let size_in_bytes = Tnode.size_in_bytes tn in
284283
let allocated_capacity =
285284
match s.allocated_buffer with None -> 0 | Some buf -> buf.size_in_bytes
286285
in

arrayjit/lib/c_syntax.ml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ struct
3434
let in_ctx tn = B.(Tn.is_in_context_force ~use_host_memory tn 46)
3535

3636
let pp_zero_out ppf tn =
37-
Stdlib.Format.fprintf ppf "@[<2>memset(%s, 0, %d);@]@ " (get_ident tn) @@ Tn.size_in_bytes tn
37+
Stdlib.Format.fprintf ppf "@[<2>memset(%s, 0, %d);@]@ " (get_ident tn)
38+
@@ Lazy.force tn.size_in_bytes
3839

3940
let pp_include ppf s = Stdlib.Format.fprintf ppf "#include %s" s
4041

arrayjit/lib/cuda_backend.cudajit.ml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ let device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
198198
dst.stream.stream_id src.stream.stream_id; *)
199199
let dev = dst.stream.device in
200200
let same_device = dev.ordinal = src.stream.device.ordinal in
201-
let size_in_bytes = Tn.size_in_bytes tn in
201+
let size_in_bytes = Lazy.force tn.Tn.size_in_bytes in
202202
let memcpy ~dst_ptr =
203203
if same_device && Cu.Deviceptr.equal dst_ptr src_ptr then ()
204204
else if same_device then
@@ -217,7 +217,6 @@ let device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
217217
dst.stream.merge_buffer := Some { ptr = src_ptr; size_in_bytes }
218218
| Copy, _ ->
219219
set_ctx @@ ctx_of dst;
220-
let size_in_bytes = Tn.size_in_bytes tn in
221220
opt_alloc_merge_buffer ~size_in_bytes dev.dev dst.stream;
222221
let buffer = Option.value_exn ~here:[%here] !(dst.stream.merge_buffer) in
223222
memcpy ~dst_ptr:buffer.ptr

arrayjit/lib/gcc_backend.gccjit.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ let zero_out ctx block node =
9393
[
9494
Lazy.force node.ptr;
9595
RValue.zero ctx c_int;
96-
RValue.int ctx c_index @@ Tn.size_in_bytes node.tn;
96+
RValue.int ctx c_index @@ Lazy.force node.tn.size_in_bytes;
9797
]
9898

9999
let get_c_ptr ctx num_typ ptr = Gccjit.(RValue.ptr ctx (Type.pointer num_typ) ptr)

arrayjit/lib/tnode.ml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,16 +62,14 @@ type memory_mode =
6262
optional [array] of {!t}). *)
6363
[@@deriving sexp, compare, equal]
6464

65-
type delayed_prec =
66-
| Not_specified
67-
| Default_spec of (Ops.prec Lazy.t[@sexp.opaque])
68-
| Specified of Ops.prec
65+
type delayed_prec = Not_specified | Default_spec of Ops.prec Lazy.t | Specified of Ops.prec
6966
[@@deriving sexp, equal]
7067

7168
type t = {
72-
array : (Nd.t option Lazy.t[@sexp.opaque]);
73-
prec : (Ops.prec Lazy.t[@sexp.opaque]);
74-
dims : (int array Lazy.t[@sexp.opaque]);
69+
array : Nd.t option Lazy.t;
70+
prec : Ops.prec Lazy.t;
71+
dims : int array Lazy.t;
72+
size_in_bytes : int Lazy.t;
7573
id : int;
7674
label : string list;
7775
(** Display information. It is better if the last element of the list is the most narrow or
@@ -90,7 +88,6 @@ let num_elems tn =
9088
let dims = Lazy.force tn.dims in
9189
if Array.is_empty dims then 0 else Array.reduce_exn dims ~f:( * )
9290

93-
let size_in_bytes tn = num_elems tn * Ops.prec_in_bytes (Lazy.force tn.prec)
9491
let id { id; _ } = "n" ^ Int.to_string id
9592
let label a = String.concat ~sep:"_" a.label
9693
let is_alphanum_ = String.for_all ~f:(fun c -> Char.equal c '_' || Char.is_alphanum c)
@@ -512,6 +509,7 @@ let create ?default_prec ~id ~label ~dims init_op =
512509
| Specified prec | Default_spec (lazy prec) -> prec
513510
| Not_specified ->
514511
raise @@ Utils.User_error "Tnode.update_prec: precision is not specified yet")
512+
and size_in_bytes = lazy (num_elems tn * Ops.prec_in_bytes (Lazy.force tn.prec))
515513
and tn =
516514
let delayed_prec_unsafe =
517515
match default_prec with None -> Not_specified | Some prec -> Default_spec prec
@@ -521,6 +519,7 @@ let create ?default_prec ~id ~label ~dims init_op =
521519
delayed_prec_unsafe;
522520
prec;
523521
dims;
522+
size_in_bytes;
524523
id;
525524
label;
526525
memory_mode = None;
@@ -541,6 +540,7 @@ let find =
541540
prec = lazy Ops.single;
542541
delayed_prec_unsafe = Specified Ops.single;
543542
dims = lazy [||];
543+
size_in_bytes = lazy 0;
544544
id = -1;
545545
label = [];
546546
memory_mode = None;

0 commit comments

Comments
 (0)