Skip to content

Commit ec7bf34

Browse files
committed
Fix compilation errors and add missing pattern match cases
- Fixed tensor.ml to match interface with init_data parameter ordering - Added handling of Exact row constraint with FIXME comments for implementation - Added missing pattern match cases to avoid warnings - Fixed use of init_data in terminal_logic and tensor node creation
1 parent 760f2d8 commit ec7bf34

File tree

2 files changed

+34
-18
lines changed

2 files changed

+34
-18
lines changed

lib/row.ml

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,9 @@ let row_conjunction ?(id = phantom_row_id) constr1 constr2 =
312312
else if Sequence.for_all ~f:Either.is_second subsum then
313313
Some (extras ~keep_constr1:true, constr1)
314314
else None
315+
| Exact _, _ | _, Exact _ ->
316+
(* FIXME: NOT IMPLEMENTED YET *)
317+
None
315318

316319
let rec apply_dim_constraint ~(source : source) ~(stage : stage) (dim : dim)
317320
(constr : dim_constraint) (env : environment) : constraint_ list * dim_constraint =
@@ -387,6 +390,9 @@ let reduce_row_constraint (constr : row_constraint) ~(beg_dims : dim list) ~(dim
387390
else if d = 1 && Set.is_empty vars then constr
388391
else Total_elems { nominator; divided_by = Utils.Set_O.(divided_by + vars) }
389392
with Given_up -> Unconstrained)
393+
| Exact _ ->
394+
(* FIXME: NOT IMPLEMENTED YET *)
395+
constr
390396

391397
(* Inverts what [reduce_row_constraint] would do. *)
392398
let _lift_row_constraint (constr : row_constraint) ~(beg_dims : dim list) ~(dims : dim list) :
@@ -406,6 +412,9 @@ let _lift_row_constraint (constr : row_constraint) ~(beg_dims : dim list) ~(dims
406412
if d = 1 && Set.is_empty vars then constr
407413
else Total_elems { nominator = nominator * d; divided_by = Utils.Set_O.(divided_by - vars) }
408414
| Unconstrained -> Unconstrained
415+
| Exact _ ->
416+
(* FIXME: NOT IMPLEMENTED YET *)
417+
constr
409418

410419
let apply_row_constraint ~stage:_ (r : row) (constr : row_constraint) env : constraint_ list * _ =
411420
if is_unconstrained constr then ([], env)
@@ -504,6 +513,10 @@ let apply_row_constraint ~stage:_ (r : row) (constr : row_constraint) env : cons
504513
| { bcast = Row_var _; _ }, _ | _, Total_elems { nominator = _; divided_by = _ } ->
505514
if stored then (extras, env)
506515
else (Rows_constr { r = [r]; constr } :: extras, env (* Wait for more shape inference. *))
516+
| _, Exact _ ->
517+
(* FIXME: NOT IMPLEMENTED YET *)
518+
if stored then (extras, env)
519+
else (Rows_constr { r = [r]; constr } :: extras, env (* Wait for more shape inference. *))
507520

508521
let s_dim_one_in_entry v ~value (in_ : dim_entry) : _ * dim_entry =
509522
match in_ with
@@ -585,7 +598,12 @@ let s_row_one v ~value:{ dims = more_dims; bcast; id = _ } ~in_ =
585598
})
586599
| _ -> in_
587600

588-
let s_row_one_in_row_constr _v ~value:_ ~in_ = match in_ with Unconstrained | Total_elems _ -> in_
601+
let s_row_one_in_row_constr _v ~value:_ ~in_ =
602+
match in_ with
603+
| Unconstrained | Total_elems _ -> in_
604+
| Exact _ ->
605+
(* FIXME: NOT IMPLEMENTED YET *)
606+
in_
589607
let row_of_var v id = { dims = []; bcast = Row_var { v; beg_dims = [] }; id }
590608

591609
let s_row_one_in_entry (v : row_var) ~(value : row) ~(in_ : row_entry) :
@@ -1390,6 +1408,9 @@ let%debug5_sexp rec eliminate_row_constraint ~lub (r : row) (constr : row_constr
13901408
| ineq -> [ ineq ])
13911409
| _, [ v ], _ -> no_further_axes :: [ Dim_eq { d1 = Var v; d2 = get_dim ~d () } ]
13921410
| _ -> [])
1411+
| Exact _ ->
1412+
(* FIXME: NOT IMPLEMENTED YET *)
1413+
[]
13931414
| _ -> [])
13941415

13951416
let%debug5_sexp close_row_terminal ~(stage : stage) (env : environment)

lib/tensor.ml

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ type grad_spec = Require_grad | Prohibit_grad | If_needed [@@deriving sexp, equa
206206

207207
let op ~(label : string list) ?(ternary_op = Shape.Pointwise_tern)
208208
?(compose_op = Shape.Pointwise_bin) ?(transpose_op = Shape.Pointwise_un)
209-
?fetch_op ?init_data ?init_data_spec ~op_asn ~grad_asn
209+
?init_data ?fetch_op ~op_asn ~grad_asn
210210
?(grad_spec = If_needed) make_shape (orig_ts : t list) : t =
211211
(* The code needs to be included in the order it was computed due to potential non-tree DAGs. *)
212212
let ordered_ts = List.dedup_and_sort orig_ts ~compare:(fun t1 t2 -> Int.ascending t1.id t2.id) in
@@ -222,13 +222,10 @@ let op ~(label : string list) ?(ternary_op = Shape.Pointwise_tern)
222222
|> Option.value ~default)
223223
in
224224
let terminal_logic () =
225-
match fetch_op, init_data, init_data_spec with
226-
| None, None, _ -> Shape.Terminal (`Fetch (Asgns.Constant 0.0))
227-
| Some fetch_op, _, _ -> Shape.Terminal (`Fetch fetch_op)
228-
| None, Some data, Some `Reshape -> Shape.Terminal (`Data (Asgns.Reshape data))
229-
| None, Some data, None -> Shape.Terminal (`Data (Asgns.Reshape data)) (* default *)
230-
| None, Some data, Some `Keep_shape_no_padding -> Shape.Terminal (`Data (Asgns.Keep_shape_no_padding data))
231-
| None, Some data, Some (`Padded (padding, padded_value)) -> Shape.Terminal (`Data (Asgns.Padded { data; padding; padded_value }))
225+
match fetch_op, init_data with
226+
| None, None -> Shape.Terminal (`Fetch (Asgns.Constant 0.0))
227+
| Some fetch_op, _ -> Shape.Terminal (`Fetch fetch_op)
228+
| None, Some init_data -> Shape.Terminal (`Data init_data)
232229
in
233230
let rec shape_logics = function
234231
| [] -> [ terminal_logic () ]
@@ -245,15 +242,13 @@ let op ~(label : string list) ?(ternary_op = Shape.Pointwise_tern)
245242
let projections = lazy (Shape.derive_projections @@ List.hd_exn local_shape_updates) in
246243
let padding = lazy (Shape.to_padding shape) in
247244
let v =
248-
match (init_data, init_data_spec) with
249-
| None, _ -> Tn.create ~default_prec ~id ~label ~dims ~padding ()
250-
| Some data, Some `Reshape ->
245+
match init_data with
246+
| None -> Tn.create ~default_prec ~id ~label ~dims ~padding ()
247+
| Some (Asgns.Reshape data) ->
251248
Tn.create_with_reshape ~id ~label ~dims ~padding ~from_padded:false ~base_ndarray:data ()
252-
| Some data, None -> (* default to Reshape *)
253-
Tn.create_with_reshape ~id ~label ~dims ~padding ~from_padded:false ~base_ndarray:data ()
254-
| Some data, Some `Keep_shape_no_padding ->
249+
| Some (Asgns.Keep_shape_no_padding data) ->
255250
Tn.create_from_padded ~id ~label ~ndarray:data ~padding:None ()
256-
| Some data, Some (`Padded (padding_spec, padded_value)) ->
251+
| Some (Asgns.Padded { data; padding = padding_spec; padded_value }) ->
257252
let padding = Some (padding_spec, padded_value) in
258253
Tn.create_from_padded ~id ~label ~ndarray:data ~padding ()
259254
in
@@ -359,7 +354,7 @@ let unop ~label ?transpose_op ~op_asn ~grad_asn ?grad_spec t1 =
359354
op ~label ?compose_op:None ?transpose_op ~op_asn ~grad_asn ?grad_spec (Shape.make ()) [ t1 ]
360355

361356
let term ~label ~grad_spec ?batch_dims ?input_dims ?output_dims ?batch_axes ?input_axes ?output_axes
362-
?deduced ?init_data ?init_data_spec ?fetch_op () =
357+
?deduced ?init_data ?fetch_op () =
363358
let op_asn ~v ~projections =
364359
let open Asgns in
365360
let dims = lazy (Lazy.force projections).Idx.lhs_dims in
@@ -380,7 +375,7 @@ let term ~label ~grad_spec ?batch_dims ?input_dims ?output_dims ?batch_axes ?inp
380375
Shape.make ?batch_dims ?input_dims ?output_dims ?batch_axes ?input_axes ?output_axes ?deduced ()
381376
in
382377
(* Note: fetch_op in op is used only for shape inference. *)
383-
op ~label ?compose_op:None ?transpose_op:None ?fetch_op ~op_asn ~grad_asn ~grad_spec make_shape []
378+
op ~label ?compose_op:None ?transpose_op:None ?init_data ?fetch_op ~op_asn ~grad_asn ~grad_spec make_shape []
384379

385380
let float_to_label v = Float.to_string v
386381

0 commit comments

Comments
 (0)