@@ -205,8 +205,8 @@ let raw_unop ~initialize_neutral ~accum ~(t : t) ~(lhs_is_grad : bool) ~op ~(t1
205205type grad_spec = Require_grad | Prohibit_grad | If_needed [@@ deriving sexp , equal , variants ]
206206
207207let op ~(label : string list ) ?(ternary_op = Shape. Pointwise_tern )
208- ?(compose_op = Shape. Pointwise_bin ) ?(transpose_op = Shape. Pointwise_un ) ?terminal_op
209- ~op_asn ~ grad_asn ?(grad_spec = If_needed ) make_shape (orig_ts : t list ) : t =
208+ ?(compose_op = Shape. Pointwise_bin ) ?(transpose_op = Shape. Pointwise_un ) ?terminal_op ~ op_asn
209+ ~grad_asn ?(grad_spec = If_needed ) make_shape (orig_ts : t list ) : t =
210210 (* The code needs to be included in the order it was computed due to potential non-tree DAGs. *)
211211 let ordered_ts = List. dedup_and_sort orig_ts ~compare: (fun t1 t2 -> Int. ascending t1.id t2.id) in
212212 let id = session_state.next_id in
@@ -250,8 +250,7 @@ let op ~(label : string list) ?(ternary_op = Shape.Pointwise_tern)
250250 | Some (Shape. Data (Asgns. Padded { data; padding = padding_spec ; padded_value } )) ->
251251 let padding = Some (padding_spec, padded_value) in
252252 Tn. create_from_padded ~id ~label ~ndarray: data ~padding ()
253- | Some (Shape. Fetch _ ) | None ->
254- Tn. create ~default_prec ~id ~label ~dims ~padding ()
253+ | Some (Shape. Fetch _ ) | None -> Tn. create ~default_prec ~id ~label ~dims ~padding ()
255254 in
256255 let embedded_nodes = ref @@ Set. singleton (module Tn ) v in
257256 let children =
@@ -358,7 +357,7 @@ let unop ~label ?transpose_op ~op_asn ~grad_asn ?grad_spec t1 =
358357let term ~label ~grad_spec ?batch_dims ?input_dims ?output_dims ?batch_axes ?input_axes ?output_axes
359358 ?deduced ?init_data ?fetch_op () =
360359 let terminal_op =
361- match init_data, fetch_op with
360+ match ( init_data, fetch_op) with
362361 | Some _ , Some _ -> invalid_arg " Tensor.term: both init_data and fetch_op are provided"
363362 | Some init_data , None -> Some (Shape. Data init_data)
364363 | None , Some fetch_op -> Some (Shape. Fetch fetch_op)
@@ -369,16 +368,18 @@ let term ~label ~grad_spec ?batch_dims ?input_dims ?output_dims ?batch_axes ?inp
369368 let dims = lazy (Lazy. force projections).Idx. lhs_dims in
370369 match fetch_op with
371370 | None -> Asgns. empty_comp
372- | Some (( Constant _ | Slice _ | Embed_symbol _ | Range_over_offsets | Constant_fill _ ) as fetch_op ) ->
371+ | Some
372+ ((Constant _ | Slice _ | Embed_symbol _ | Range_over_offsets | Constant_fill _) as fetch_op)
373+ ->
373374 Asgns. to_comp @@ Fetch { array = v; fetch_op; dims }
374375 in
375376 let grad_asn ~t :_ ~g :_ ~projections :_ = Asgns. empty_comp in
376377 let make_shape =
377378 Shape. make ?batch_dims ?input_dims ?output_dims ?batch_axes ?input_axes ?output_axes ?deduced ()
378379 in
379380 (* Note: terminal_op is used for both tensor creation and shape inference. *)
380- op ~label ?compose_op:None ?transpose_op:None ?terminal_op ~op_asn ~grad_asn ~grad_spec
381- make_shape []
381+ op ~label ?compose_op:None ?transpose_op:None ?terminal_op ~op_asn ~grad_asn ~grad_spec make_shape
382+ []
382383
383384let float_to_label v = Float. to_string v
384385
@@ -467,6 +468,8 @@ let consume_forward_code t =
467468 @@ Session_error
468469 ( " Tensor.consume_forward_code: tensor is not a root for tnode: " ^ Tn. debug_name t.value,
469470 Some t );
471+ (* FIXME(#321): this is too aggressive, instead we should check if the code contains any
472+ non-embedded nodes that are embedded nodes of the other roots. *)
470473 let unsafe_roots =
471474 Map. data session_state.forward_roots
472475 |> List. filter ~f: (fun r -> not (List. is_empty r.children || r.id = t.id))
0 commit comments