Skip to content

Commit ad2fce3

Browse files
committed
Incremental progress on handling Exact row constraints: mostly row_conjunction
1 parent ec7bf34 commit ad2fce3

File tree

2 files changed

+74
-17
lines changed

2 files changed

+74
-17
lines changed

lib/row.ml

Lines changed: 72 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -304,17 +304,72 @@ let row_conjunction ?(id = phantom_row_id) constr1 constr2 =
304304
let r = { dims = List.map shared ~f:(fun v -> Var v); bcast = Broadcastable; id } in
305305
[
306306
Rows_constr
307-
{ r = [r]; constr = Total_elems { nominator; divided_by = Set.empty (module Dim_var) } };
307+
{
308+
r = [ r ];
309+
constr = Total_elems { nominator; divided_by = Set.empty (module Dim_var) };
310+
};
308311
]
309312
in
310313
let subsum = Set.symmetric_diff vars1 vars2 in
311314
if Sequence.for_all ~f:Either.is_first subsum then Some (extras ~keep_constr1:false, constr2)
312315
else if Sequence.for_all ~f:Either.is_second subsum then
313316
Some (extras ~keep_constr1:true, constr1)
314317
else None
315-
| Exact _, _ | _, Exact _ ->
316-
(* FIXME: NOT IMPLEMENTED YET *)
317-
None
318+
| Exact dims1, Exact dims2 ->
319+
if List.length dims1 <> List.length dims2 then
320+
raise @@ Shape_error ("Exact constraint length mismatch", [])
321+
else
322+
let eqs = List.map2_exn dims1 dims2 ~f:(fun d1 d2 -> Dim_eq { d1; d2 }) in
323+
Some (eqs, constr1)
324+
| Total_elems { nominator; divided_by }, Exact dims
325+
| Exact dims, Total_elems { nominator; divided_by } -> (
326+
(* Simple collect_factors logic - handle only basic Dim and Var cases *)
327+
let rec collect_dim_factors (ds, vars) = function
328+
| Dim { d; _ } -> Some (d :: ds, vars)
329+
| Var v -> Some (ds, v :: vars)
330+
| Conv_input _ -> None (* Too complex, give up *)
331+
in
332+
match
333+
List.fold dims
334+
~init:(Some ([], []))
335+
~f:(fun acc dim ->
336+
match acc with None -> None | Some (ds, vars) -> collect_dim_factors (ds, vars) dim)
337+
with
338+
| None -> None (* Give up on complex cases *)
339+
| Some (ds, vars) ->
340+
let known_product = List.fold ds ~init:1 ~f:( * ) in
341+
if nominator <= 0 then
342+
raise @@ Shape_error ([%string "Invalid Total_elems nominator: %{nominator#Int}"], [])
343+
else if known_product = 0 then
344+
raise @@ Shape_error ("Exact constraint has zero dimension", [])
345+
else if nominator % known_product <> 0 then
346+
raise
347+
@@ Shape_error
348+
( [%string
349+
"Total_elems nominator %{nominator#Int} not divisible by Exact dimensions \
350+
product %{known_product#Int}"],
351+
[] )
352+
else
353+
let reminder = nominator / known_product in
354+
if reminder = 1 then
355+
(* reminder is 1: equate all variables on both sides to 1 *)
356+
let divided_by_eqs =
357+
Set.to_list divided_by
358+
|> List.map ~f:(fun v -> Dim_eq { d1 = Var v; d2 = get_dim ~d:1 () })
359+
in
360+
let exact_vars_eqs =
361+
List.map vars ~f:(fun v -> Dim_eq { d1 = Var v; d2 = get_dim ~d:1 () })
362+
in
363+
Some (divided_by_eqs @ exact_vars_eqs, Exact dims)
364+
else if Set.is_empty divided_by && List.length vars = 1 && reminder > 0 then
365+
(* divided_by is empty and there is only one dim variable in Exact dims *)
366+
let v = List.hd_exn vars in
367+
Some ([ Dim_eq { d1 = Var v; d2 = get_dim ~d:reminder () } ], Exact dims)
368+
else if List.is_empty vars && Set.length divided_by = 1 && reminder > 0 then
369+
(* Exact dims contain only known dimensions and divided_by has exactly one variable *)
370+
let v = Set.choose_exn divided_by in
371+
Some ([ Dim_eq { d1 = Var v; d2 = get_dim ~d:reminder () } ], Exact dims)
372+
else None (* Cannot handle this case *))
318373

319374
let rec apply_dim_constraint ~(source : source) ~(stage : stage) (dim : dim)
320375
(constr : dim_constraint) (env : environment) : constraint_ list * dim_constraint =
@@ -506,17 +561,19 @@ let apply_row_constraint ~stage:_ (r : row) (constr : row_constraint) env : cons
506561
(* | v :: _, [] | [], v :: _ when (is_stage4_up stage) -> (Dim_eq { d1 = Var v; d2 = get_dim
507562
~d:nominator () } :: extras, env) *)
508563
| _ :: _, _ when stored -> (extras, env)
509-
| _, _ -> (Rows_constr { r = [r]; constr } :: extras, env (* Wait for more shape inference. *))
564+
| _, _ ->
565+
(Rows_constr { r = [ r ]; constr } :: extras, env (* Wait for more shape inference. *))
510566
with Given_up ->
511567
if stored then (extras, env)
512-
else (Rows_constr { r = [r]; constr } :: extras, env (* Wait for more shape inference. *)))
568+
else
569+
(Rows_constr { r = [ r ]; constr } :: extras, env (* Wait for more shape inference. *)))
513570
| { bcast = Row_var _; _ }, _ | _, Total_elems { nominator = _; divided_by = _ } ->
514571
if stored then (extras, env)
515-
else (Rows_constr { r = [r]; constr } :: extras, env (* Wait for more shape inference. *))
572+
else (Rows_constr { r = [ r ]; constr } :: extras, env (* Wait for more shape inference. *))
516573
| _, Exact _ ->
517574
(* FIXME: NOT IMPLEMENTED YET *)
518575
if stored then (extras, env)
519-
else (Rows_constr { r = [r]; constr } :: extras, env (* Wait for more shape inference. *))
576+
else (Rows_constr { r = [ r ]; constr } :: extras, env (* Wait for more shape inference. *))
520577

521578
let s_dim_one_in_entry v ~value (in_ : dim_entry) : _ * dim_entry =
522579
match in_ with
@@ -598,12 +655,13 @@ let s_row_one v ~value:{ dims = more_dims; bcast; id = _ } ~in_ =
598655
})
599656
| _ -> in_
600657

601-
let s_row_one_in_row_constr _v ~value:_ ~in_ =
602-
match in_ with
658+
let s_row_one_in_row_constr _v ~value:_ ~in_ =
659+
match in_ with
603660
| Unconstrained | Total_elems _ -> in_
604-
| Exact _ ->
661+
| Exact _ ->
605662
(* FIXME: NOT IMPLEMENTED YET *)
606663
in_
664+
607665
let row_of_var v id = { dims = []; bcast = Row_var { v; beg_dims = [] }; id }
608666

609667
let s_row_one_in_entry (v : row_var) ~(value : row) ~(in_ : row_entry) :
@@ -1408,9 +1466,8 @@ let%debug5_sexp rec eliminate_row_constraint ~lub (r : row) (constr : row_constr
14081466
| ineq -> [ ineq ])
14091467
| _, [ v ], _ -> no_further_axes :: [ Dim_eq { d1 = Var v; d2 = get_dim ~d () } ]
14101468
| _ -> [])
1411-
| Exact _ ->
1412-
(* FIXME: NOT IMPLEMENTED YET *)
1413-
[]
1469+
| Exact dims ->
1470+
[ Row_eq { r1; r2 = { dims; bcast = Broadcastable; id } } ]
14141471
| _ -> [])
14151472

14161473
let%debug5_sexp close_row_terminal ~(stage : stage) (env : environment)
@@ -1523,8 +1580,7 @@ let%debug4_sexp solve_inequalities ~(stage : stage) (ineqs : constraint_ list) (
15231580
let r = subst_row env (List.hd_exn rows) in
15241581
let more_ineqs, env = apply_row_constraint ~stage r constr env in
15251582
(more_ineqs @ ineqs, env)
1526-
else
1527-
(ineqs, env)
1583+
else (ineqs, env)
15281584
| Terminal_dim d ->
15291585
let more_ineqs = close_dim_terminal ~stage env @@ subst_dim env d in
15301586
(more_ineqs @ ineqs, env)

lib/row.mli

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ type constraint_ =
114114
| Row_ineq of { cur : t; subr : t }
115115
| Dim_constr of { d : dim; constr : dim_constraint }
116116
| Rows_constr of { r : t list; constr : row_constraint }
117-
(** The constraint applies to the concatenation of the rows. *)
117+
(** The constraint applies to the concatenation of the rows. Note: broadcasting does not
118+
affect the constraint (i.e. there is no "subtyping", it resembles Row_eq). *)
118119
| Terminal_dim of dim
119120
| Terminal_row of t
120121
[@@deriving compare, equal, sexp, variants]

0 commit comments

Comments
 (0)