Skip to content

Commit f6776d1

Browse files
committed
Cleanup: remove unused ~v input to fetch_op
1 parent 3ff3883 commit f6776d1

File tree

8 files changed

+25
-33
lines changed

8 files changed

+25
-33
lines changed

bin/primitive_ops.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ let%debug_sexp graph_t () : unit =
2727
let xs = Array.init size ~f:Float.(fun i -> (of_int i / 10.) + 0.1) in
2828
let x_flat =
2929
Tensor.term ~grad_spec:Require_grad ~label:[ "x_flat" ]
30-
~fetch_op:(fun ~v:_ -> Constant_fill xs)
30+
~fetch_op:(Constant_fill xs)
3131
()
3232
in
3333
let step_sym, bindings = IDX.get_static_symbol ~static_range:size IDX.empty in

bin/zero2hero_1of7.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ let _suspended () =
5757
let x_flat =
5858
Tensor.term ~grad_spec:Tensor.Require_grad
5959
~label:[ "x_flat" ] (* ~input_dims:[] ~output_dims:[ 1 ] *)
60-
~fetch_op:(fun ~v:_ -> Constant_fill values)
60+
~fetch_op:(Constant_fill values)
6161
()
6262
in
6363
let step_sym, bindings = IDX.get_static_symbol ~static_range:size IDX.empty in
@@ -111,7 +111,7 @@ let _suspended () =
111111
(* Yay, the whole shape gets inferred! *)
112112
let x_flat =
113113
Tensor.term ~grad_spec:Require_grad ~label:[ "x_flat" ]
114-
~fetch_op:(fun ~v:_ -> Constant_fill xs)
114+
~fetch_op:(Constant_fill xs)
115115
()
116116
in
117117
let step_sym, bindings = IDX.get_static_symbol ~static_range:size IDX.empty in

lib/operation.ml

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,7 @@ let range ?(label = []) ?(grad_spec = Tensor.Prohibit_grad) ?axis_label upto =
319319
let result =
320320
Tensor.term
321321
~label:(("0" ^ "..." ^ Int.to_string upto) :: label)
322-
~grad_spec ~batch_dims:[] ~input_dims:[]
323-
~fetch_op:(fun ~v:_ -> Range_over_offsets)
322+
~grad_spec ~batch_dims:[] ~input_dims:[] ~fetch_op:Range_over_offsets
324323
in
325324
match axis_label with
326325
| None -> result ~output_dims:[ upto + 1 ] ()
@@ -344,8 +343,7 @@ let range_of_shape ?(label = []) ?(grad_spec = Tensor.Prohibit_grad) ?batch_dims
344343
Tensor.term
345344
~label:(("r" ^ Idx.dims_to_string dims) :: label)
346345
~grad_spec ?batch_dims ?input_dims ?output_dims ?batch_axes ?input_axes ?output_axes
347-
~fetch_op:(fun ~v:_ -> Range_over_offsets)
348-
()
346+
~fetch_op:Range_over_offsets ()
349347

350348
(** A [stop_gradient] is an identity in the forward pass and a no-op in the backprop pass. *)
351349
let stop_gradient ?(label = []) =
@@ -390,7 +388,7 @@ let random_seed =
390388
let seed = Option.value ~default:42 @@ Utils.settings.fixed_state_for_init in
391389
let res =
392390
Tensor.term ~label:[ "random_seed" ] ~grad_spec:Prohibit_grad
393-
~fetch_op:(fun ~v:_ -> Asgns.Constant_fill [| Int.to_float seed |])
391+
~fetch_op:(Asgns.Constant_fill [| Int.to_float seed |])
394392
()
395393
in
396394
Tn.update_memory_mode res.value Tn.Effectively_constant 24;
@@ -469,14 +467,12 @@ module TDSL = struct
469467
*)
470468
let init_const ~l ?b ?(i = []) ~o values =
471469
Tensor.term ~label:[ l ] ~grad_spec:Prohibit_grad ?batch_dims:b ~input_dims:i ~output_dims:o
472-
~fetch_op:(fun ~v:_ -> Asgns.Constant_fill values)
473-
()
470+
~fetch_op:(Asgns.Constant_fill values) ()
474471

475472
(** It's like `Tensor.param` but without shape inference. *)
476473
let init_param ~l ?(b = []) ?(i = []) ?(o = []) values =
477474
Tensor.term ~label:[ l ] ~grad_spec:Require_grad ~batch_dims:b ~input_dims:i ~output_dims:o
478-
~fetch_op:(fun ~v:_ -> Asgns.Constant_fill values)
479-
()
475+
~fetch_op:(Asgns.Constant_fill values) ()
480476
end
481477

482478
module NTDSL = struct

lib/tensor.ml

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -343,32 +343,29 @@ let term ~label ~grad_spec ?batch_dims ?input_dims ?output_dims ?batch_axes ?inp
343343
let dims = lazy (Lazy.force projections).Idx.lhs_dims in
344344
match fetch_op with
345345
| None -> Asgns.empty_comp
346-
| Some fetch_op_fn ->
347-
let fetch_op = fetch_op_fn ~v in
348-
(match fetch_op with
349-
| Constant _ | Slice _ | Embed_symbol _ | Range_over_offsets | Constant_fill _
350-
| Access (Uint4x32_to_prec_uniform _) ->
351-
(* For these operations it makes sense to have a local / virtual tensor if the result is
352-
consumed in the same computation. *)
353-
()
354-
| Access _ ->
355-
(* Note: [Access] can be used for merging across devices. But, some use cases of
356-
[Access] will require a hosted tensor node. *)
357-
Tn.update_memory_mode v Materialized 22);
346+
| Some
347+
(( Constant _ | Slice _ | Embed_symbol _ | Range_over_offsets | Constant_fill _
348+
| Access (Uint4x32_to_prec_uniform _) ) as fetch_op) ->
349+
Asgns.to_comp @@ Fetch { array = v; fetch_op; dims }
350+
| Some (Access _ as fetch_op) ->
351+
(* Note: [Access] can be used for merging across devices. But, some use cases of [Access]
352+
will require a hosted tensor node. *)
353+
Tn.update_memory_mode v Materialized 22;
358354
Asgns.to_comp @@ Fetch { array = v; fetch_op; dims }
359355
in
360356
let grad_asn ~t:_ ~g:_ ~projections:_ = Asgns.empty_comp in
361357
let make_shape =
362358
Shape.make ?batch_dims ?input_dims ?output_dims ?batch_axes ?input_axes ?output_axes ?deduced ()
363359
in
364-
op ~label ?compose_op:None ?transpose_op:None ~op_asn ~grad_asn ~grad_spec make_shape []
360+
(* Note: fetch_op in op is used only for shape inference. *)
361+
op ~label ?compose_op:None ?transpose_op:None ?fetch_op ~op_asn ~grad_asn ~grad_spec make_shape []
365362

366363
let float_to_label v = Float.to_string v
367364

368365
let number ?(label = []) ?axis_label ?(grad_spec = Prohibit_grad) c =
369366
(* Note: no axis label so that we do not conflict with user labels. *)
370367
let label = float_to_label c :: label in
371-
let fetch_op ~v:_ = Ir.Assignments.Constant c in
368+
let fetch_op = Ir.Assignments.Constant c in
372369
let t = term ~label ~grad_spec ~batch_dims:[] ~input_dims:[] ~fetch_op in
373370
let t =
374371
match axis_label with
@@ -416,7 +413,7 @@ let ndarray ?(label = []) ?(grad_spec = Prohibit_grad) ?batch_dims ?input_dims ?
416413
let t =
417414
term ~label ~grad_spec ?batch_dims ?input_dims ?output_dims ?batch_axes ?input_axes ?output_axes
418415
~deduced:Not_constrained
419-
~fetch_op:(fun ~v:_ -> Asgns.Constant_fill values)
416+
~fetch_op:(Asgns.Constant_fill values)
420417
()
421418
in
422419
Tn.update_memory_mode t.value Effectively_constant 24;
@@ -428,7 +425,7 @@ let ndarray ?(label = []) ?(grad_spec = Prohibit_grad) ?batch_dims ?input_dims ?
428425

429426
let param ?(more_label = []) ?input_dims ?output_dims ?input_axes ?output_axes ?deduced ?value
430427
?values label =
431-
let fetch_op_fn ~v:_ =
428+
let fetch_op =
432429
match (values, value) with
433430
| Some values, None -> Asgns.Constant_fill values
434431
| None, Some value -> Asgns.Constant value
@@ -437,7 +434,7 @@ let param ?(more_label = []) ?input_dims ?output_dims ?input_axes ?output_axes ?
437434
in
438435
let t =
439436
term ~label:(label :: more_label) ~grad_spec:Require_grad ~batch_dims:[] ?input_dims
440-
?output_dims ?input_axes ?output_axes ?deduced ~fetch_op:fetch_op_fn ()
437+
?output_dims ?input_axes ?output_axes ?deduced ~fetch_op ()
441438
in
442439
let v = t.value in
443440
(* It is convenient to use the param syntax for volatiles (mutable embedded_nodes). *)

lib/tensor.mli

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ val term :
181181
?input_axes:(string * int) list ->
182182
?output_axes:(string * int) list ->
183183
?deduced:Shape.deduce_within_shape ->
184-
?fetch_op:(v:tn -> fetch_op) ->
184+
?fetch_op:fetch_op ->
185185
unit ->
186186
t
187187
(** A terminal: a constant, a parameter, an input of the model. The semantics of shape specification

test/einsum/moons_demo_variant.ml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ let () =
2525
let steps = epochs * 2 * len / batch_size in
2626
let noise () = Rand.float_range (-0.1) 0.1 in
2727
let moons_flat =
28-
Bigarray.Genarray.init in
2928
Array.concat_map (Array.create ~len ())
3029
~f:
3130
Float.(

test/primitive_ops.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ let plot_unop ~f ?(x_min = -5.) ?(x_max = 5.) () =
2323
in
2424
let x_flat =
2525
Tensor.term ~grad_spec:Require_grad ~label:[ "x_flat" ]
26-
~fetch_op:(fun ~v:_ -> Constant_fill xs)
26+
~fetch_op:(Constant_fill xs)
2727
()
2828
in
2929
let step_sym, bindings = IDX.get_static_symbol ~static_range:size IDX.empty in

test/zero2hero_1of7.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ let%expect_test "Graph drawing fetch" =
174174
(* Yay, the whole shape gets inferred! *)
175175
let x_flat =
176176
Tensor.term ~grad_spec:Require_grad ~label:[ "x_flat" ]
177-
~fetch_op:(fun ~v:_ -> Constant_fill xs)
177+
~fetch_op:(Constant_fill xs)
178178
()
179179
in
180180
let step_sym, bindings = IDX.get_static_symbol ~static_range:size IDX.empty in

0 commit comments

Comments
 (0)