@@ -206,7 +206,7 @@ type grad_spec = Require_grad | Prohibit_grad | If_needed [@@deriving sexp, equa
206206
207207let op ~(label : string list ) ?(ternary_op = Shape. Pointwise_tern )
208208 ?(compose_op = Shape. Pointwise_bin ) ?(transpose_op = Shape. Pointwise_un )
209- ?fetch_op ? init_data ?init_data_spec ~op_asn ~grad_asn
209+ ?init_data ?fetch_op ~op_asn ~grad_asn
210210 ?(grad_spec = If_needed ) make_shape (orig_ts : t list ) : t =
211211 (* The code needs to be included in the order it was computed due to potential non-tree DAGs. *)
212212 let ordered_ts = List. dedup_and_sort orig_ts ~compare: (fun t1 t2 -> Int. ascending t1.id t2.id) in
@@ -222,13 +222,10 @@ let op ~(label : string list) ?(ternary_op = Shape.Pointwise_tern)
222222 |> Option. value ~default )
223223 in
224224 let terminal_logic () =
225- match fetch_op, init_data, init_data_spec with
226- | None , None , _ -> Shape. Terminal (`Fetch (Asgns. Constant 0.0 ))
227- | Some fetch_op , _ , _ -> Shape. Terminal (`Fetch fetch_op)
228- | None , Some data , Some `Reshape -> Shape. Terminal (`Data (Asgns. Reshape data))
229- | None , Some data , None -> Shape. Terminal (`Data (Asgns. Reshape data)) (* default *)
230- | None , Some data , Some `Keep_shape_no_padding -> Shape. Terminal (`Data (Asgns. Keep_shape_no_padding data))
231- | None , Some data , Some (`Padded (padding , padded_value )) -> Shape. Terminal (`Data (Asgns. Padded { data; padding; padded_value }))
225+ match fetch_op, init_data with
226+ | None , None -> Shape. Terminal (`Fetch (Asgns. Constant 0.0 ))
227+ | Some fetch_op , _ -> Shape. Terminal (`Fetch fetch_op)
228+ | None , Some init_data -> Shape. Terminal (`Data init_data)
232229 in
233230 let rec shape_logics = function
234231 | [] -> [ terminal_logic () ]
@@ -245,15 +242,13 @@ let op ~(label : string list) ?(ternary_op = Shape.Pointwise_tern)
245242 let projections = lazy (Shape. derive_projections @@ List. hd_exn local_shape_updates) in
246243 let padding = lazy (Shape. to_padding shape) in
247244 let v =
248- match ( init_data, init_data_spec) with
249- | None , _ -> Tn. create ~default_prec ~id ~label ~dims ~padding ()
250- | Some data , Some ` Reshape ->
245+ match init_data with
246+ | None -> Tn. create ~default_prec ~id ~label ~dims ~padding ()
247+ | Some ( Asgns. Reshape data ) ->
251248 Tn. create_with_reshape ~id ~label ~dims ~padding ~from_padded: false ~base_ndarray: data ()
252- | Some data , None -> (* default to Reshape *)
253- Tn. create_with_reshape ~id ~label ~dims ~padding ~from_padded: false ~base_ndarray: data ()
254- | Some data , Some `Keep_shape_no_padding ->
249+ | Some (Asgns. Keep_shape_no_padding data ) ->
255250 Tn. create_from_padded ~id ~label ~ndarray: data ~padding: None ()
256- | Some data , Some (`Padded ( padding_spec , padded_value ) ) ->
251+ | Some ( Asgns. Padded { data; padding = padding_spec ; padded_value } ) ->
257252 let padding = Some (padding_spec, padded_value) in
258253 Tn. create_from_padded ~id ~label ~ndarray: data ~padding ()
259254 in
@@ -359,7 +354,7 @@ let unop ~label ?transpose_op ~op_asn ~grad_asn ?grad_spec t1 =
359354 op ~label ?compose_op:None ?transpose_op ~op_asn ~grad_asn ?grad_spec (Shape.make ( )) [ t1 ]
360355
361356let term ~label ~grad_spec ?batch_dims ?input_dims ?output_dims ?batch_axes ?input_axes ?output_axes
362- ?deduced ?init_data ?init_data_spec ? fetch_op () =
357+ ?deduced ?init_data ?fetch_op () =
363358 let op_asn ~v ~projections =
364359 let open Asgns in
365360 let dims = lazy (Lazy. force projections).Idx. lhs_dims in
@@ -380,7 +375,7 @@ let term ~label ~grad_spec ?batch_dims ?input_dims ?output_dims ?batch_axes ?inp
380375 Shape. make ?batch_dims ?input_dims ?output_dims ?batch_axes ?input_axes ?output_axes ?deduced ()
381376 in
382377 (* Note: fetch_op in op is used only for shape inference. *)
383- op ~label ?compose_op:None ?transpose_op:None ?fetch_op ~op_asn ~grad_asn ~grad_spec make_shape []
378+ op ~label ?compose_op:None ?transpose_op:None ?init_data ? fetch_op ~op_asn ~grad_asn ~grad_spec make_shape []
384379
385380let float_to_label v = Float. to_string v
386381
0 commit comments