@@ -343,32 +343,29 @@ let term ~label ~grad_spec ?batch_dims ?input_dims ?output_dims ?batch_axes ?inp
343343 let dims = lazy (Lazy. force projections).Idx. lhs_dims in
344344 match fetch_op with
345345 | None -> Asgns. empty_comp
346- | Some fetch_op_fn ->
347- let fetch_op = fetch_op_fn ~v in
348- (match fetch_op with
349- | Constant _ | Slice _ | Embed_symbol _ | Range_over_offsets | Constant_fill _
350- | Access (Uint4x32_to_prec_uniform _ ) ->
351- (* For these operations it makes sense to have a local / virtual tensor if the result is
352- consumed in the same computation. *)
353- ()
354- | Access _ ->
355- (* Note: [Access] can be used for merging across devices. But, some use cases of
356- [Access] will require a hosted tensor node. *)
357- Tn. update_memory_mode v Materialized 22 );
346+ | Some
347+ (( Constant _ | Slice _ | Embed_symbol _ | Range_over_offsets | Constant_fill _
348+ | Access (Uint4x32_to_prec_uniform _ ) ) as fetch_op ) ->
349+ Asgns. to_comp @@ Fetch { array = v; fetch_op; dims }
350+ | Some (Access _ as fetch_op ) ->
351+ (* Note: [Access] can be used for merging across devices. But, some use cases of [Access]
352+ will require a hosted tensor node. *)
353+ Tn. update_memory_mode v Materialized 22 ;
358354 Asgns. to_comp @@ Fetch { array = v; fetch_op; dims }
359355 in
360356 let grad_asn ~t :_ ~g :_ ~projections :_ = Asgns. empty_comp in
361357 let make_shape =
362358 Shape. make ?batch_dims ?input_dims ?output_dims ?batch_axes ?input_axes ?output_axes ?deduced ()
363359 in
364- op ~label ?compose_op:None ?transpose_op:None ~op_asn ~grad_asn ~grad_spec make_shape []
360+ (* Note: fetch_op in op is used only for shape inference. *)
361+ op ~label ?compose_op:None ?transpose_op:None ?fetch_op ~op_asn ~grad_asn ~grad_spec make_shape []
365362
366363let float_to_label v = Float. to_string v
367364
368365let number ?(label = [] ) ?axis_label ?(grad_spec = Prohibit_grad ) c =
369366 (* Note: no axis label so that we do not conflict with user labels. *)
370367 let label = float_to_label c :: label in
371- let fetch_op ~ v : _ = Ir.Assignments. Constant c in
368+ let fetch_op = Ir.Assignments. Constant c in
372369 let t = term ~label ~grad_spec ~batch_dims: [] ~input_dims: [] ~fetch_op in
373370 let t =
374371 match axis_label with
@@ -416,7 +413,7 @@ let ndarray ?(label = []) ?(grad_spec = Prohibit_grad) ?batch_dims ?input_dims ?
416413 let t =
417414 term ~label ~grad_spec ?batch_dims ?input_dims ?output_dims ?batch_axes ?input_axes ?output_axes
418415 ~deduced: Not_constrained
419- ~fetch_op: (fun ~ v : _ -> Asgns. Constant_fill values)
416+ ~fetch_op: (Asgns. Constant_fill values)
420417 ()
421418 in
422419 Tn. update_memory_mode t.value Effectively_constant 24 ;
@@ -428,7 +425,7 @@ let ndarray ?(label = []) ?(grad_spec = Prohibit_grad) ?batch_dims ?input_dims ?
428425
429426let param ?(more_label = [] ) ?input_dims ?output_dims ?input_axes ?output_axes ?deduced ?value
430427 ?values label =
431- let fetch_op_fn ~ v : _ =
428+ let fetch_op =
432429 match (values, value) with
433430 | Some values , None -> Asgns. Constant_fill values
434431 | None , Some value -> Asgns. Constant value
@@ -437,7 +434,7 @@ let param ?(more_label = []) ?input_dims ?output_dims ?input_axes ?output_axes ?
437434 in
438435 let t =
439436 term ~label: (label :: more_label) ~grad_spec: Require_grad ~batch_dims: [] ?input_dims
440- ?output_dims ?input_axes ?output_axes ?deduced ~fetch_op: fetch_op_fn ()
437+ ?output_dims ?input_axes ?output_axes ?deduced ~fetch_op ()
441438 in
442439 let v = t.value in
443440 (* It is convenient to use the param syntax for volatiles (mutable embedded_nodes). *)
0 commit comments