Skip to content

Commit ba1de07

Browse files
committed
Flesh out row_conjunction case Total_elems vs Total_elems
1 parent b520bd3 commit ba1de07

File tree

1 file changed

+65
-28
lines changed

1 file changed

+65
-28
lines changed

lib/row.ml

Lines changed: 65 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ let dim_conjunction constr1 constr2 =
307307
| _, Unconstrained_dim -> Some ([], constr1)
308308
| At_least_dim d1, At_least_dim d2 -> Some ([], At_least_dim (Int.max d1 d2))
309309

310-
let row_conjunction ?(id = phantom_row_id) constr1 constr2 =
310+
let rec row_conjunction ?(id = phantom_row_id) constr1 constr2 =
311311
let elems_mismatch n1 n2 =
312312
raise
313313
@@ Shape_error
@@ -376,40 +376,77 @@ let row_conjunction ?(id = phantom_row_id) constr1 constr2 =
376376
(* Neither coefficient divides the other - we can't make progress *)
377377
(* Keep both constraints - this will be resolved later *)
378378
None)
379+
| Total_elems { nominator = Strided_var { coeff = c1; var = v1 }; divided_by = vars1 }, constr2
380+
when Set.mem vars1 v1 ->
381+
row_conjunction ~id
382+
(Total_elems { nominator = Num_elems (Lazy.force c1); divided_by = Set.remove vars1 v1 })
383+
constr2
384+
| constr2, Total_elems { nominator = Strided_var { coeff = c1; var = v1 }; divided_by = vars1 }
385+
when Set.mem vars1 v1 ->
386+
row_conjunction ~id
387+
(Total_elems { nominator = Num_elems (Lazy.force c1); divided_by = Set.remove vars1 v1 })
388+
constr2
379389
| ( Total_elems { nominator = n1; divided_by = vars1 },
380390
Total_elems { nominator = n2; divided_by = vars2 } ) ->
381-
let shared = Set.inter vars1 vars2 |> Set.to_list in
391+
let vars1_only = Set.diff vars1 vars2 |> Set.to_list in
392+
let vars2_only = Set.diff vars2 vars1 |> Set.to_list in
393+
let extras ~keep_constr1 ?(extra_var = []) ?(div_by = []) ~n1_val ~n2_val () =
394+
(* If we keep constr1, then it has fewer divided_by, i.e. vars1 ⊂ vars2. n1 / (product of
395+
vars1) = n2 / (product of vars2) Since vars1 ⊂ vars2, we have vars2 = vars1 ∪ vars2_only
396+
So: n1 / (product of vars1) = n2 / (product of vars1 × product of vars2_only) Thus: n1 =
397+
n2 / (product of vars2_only) Which means: product of vars2_only = n2 / n1 *)
398+
let diff_vars = extra_var @ if keep_constr1 then vars2_only else vars1_only in
399+
let quotient = if keep_constr1 then n2_val / n1_val else n1_val / n2_val in
400+
if quotient <= 0 then elems_mismatch n1 n2
401+
else if quotient = 1 then
402+
(* The difference variables must all be 1 *)
403+
List.map diff_vars ~f:(fun v -> Dim_eq { d1 = Var v; d2 = get_dim ~d:1 () })
404+
else if List.is_empty diff_vars then
405+
(* No difference in variables but different nominators - this is a mismatch *)
406+
elems_mismatch n1 n2
407+
else
408+
(* The product of difference variables equals the quotient *)
409+
let r = { dims = List.map diff_vars ~f:(fun v -> Var v); bcast = Broadcastable; id } in
410+
[
411+
Rows_constr
412+
{
413+
r = [ r ];
414+
constr =
415+
Total_elems
416+
{
417+
nominator = Num_elems quotient;
418+
divided_by = Set.of_list (module Dim_var) div_by;
419+
};
420+
};
421+
]
422+
in
382423
let extras ~keep_constr1 =
383-
(* If we keep constr1, then it has fewer divided_by, i.e. n1 > n2. *)
384424
match (n1, n2) with
385-
| Num_elems n1_val, Num_elems n2_val ->
386-
let nominator_val = if keep_constr1 then n1_val / n2_val else n2_val / n1_val in
387-
if nominator_val <= 0 then elems_mismatch n1 n2
388-
else if nominator_val = 1 then
389-
List.map shared ~f:(fun v -> Dim_eq { d1 = Var v; d2 = get_dim ~d:1 () })
390-
else if List.is_empty shared then []
391-
else
392-
let r = { dims = List.map shared ~f:(fun v -> Var v); bcast = Broadcastable; id } in
393-
[
394-
Rows_constr
395-
{
396-
r = [ r ];
397-
constr =
398-
Total_elems
399-
{
400-
nominator = Num_elems nominator_val;
401-
divided_by = Set.empty (module Dim_var);
402-
};
403-
};
404-
]
425+
| Num_elems n1_val, Num_elems n2_val -> extras ~keep_constr1 ~n1_val ~n2_val ()
426+
| Strided_var { coeff = c1; var = v1 }, Strided_var { coeff = c2; var = v2 }
427+
when equal_dim_var v1 v2 ->
428+
extras ~keep_constr1 ~n1_val:(Lazy.force c1) ~n2_val:(Lazy.force c2) ()
429+
| Strided_var { coeff = c1; var = v1 }, Num_elems n2_val when keep_constr1 ->
430+
(* v1 from the nominator joins v2_vars from the denominator. *)
431+
extras ~keep_constr1 ~extra_var:[ v1 ] ~n1_val:(Lazy.force c1) ~n2_val ()
432+
| Num_elems n1_val, Strided_var { coeff = c2; var = v2 } when not keep_constr1 ->
433+
extras ~keep_constr1 ~extra_var:[ v2 ] ~n1_val ~n2_val:(Lazy.force c2) ()
434+
| Strided_var { coeff = c1; var = v1 }, Strided_var { coeff = c2; var = v2 }
435+
when keep_constr1 ->
436+
(* v1 from the nominator joins v2_vars from the denominator. *)
437+
extras ~keep_constr1 ~extra_var:[ v1 ] ~div_by:[ v2 ] ~n1_val:(Lazy.force c1)
438+
~n2_val:(Lazy.force c2) ()
439+
| Strided_var { coeff = c1; var = v1 }, Strided_var { coeff = c2; var = v2 }
440+
when not keep_constr1 ->
441+
(* v2 from the nominator joins v1_vars from the denominator. *)
442+
extras ~keep_constr1 ~extra_var:[ v2 ] ~div_by:[ v1 ] ~n1_val:(Lazy.force c1)
443+
~n2_val:(Lazy.force c2) ()
405444
| _ ->
406-
(* TODO: Handle Strided_var cases - for now, return empty list *)
445+
(* NOTE: being leaky here to not overcomplicate the code *)
407446
[]
408447
in
409-
let subsum = Set.symmetric_diff vars1 vars2 in
410-
if Sequence.for_all ~f:Either.is_first subsum then Some (extras ~keep_constr1:false, constr2)
411-
else if Sequence.for_all ~f:Either.is_second subsum then
412-
Some (extras ~keep_constr1:true, constr1)
448+
if List.is_empty vars1_only then Some (extras ~keep_constr1:false, constr2)
449+
else if List.is_empty vars2_only then Some (extras ~keep_constr1:true, constr1)
413450
else None
414451
| Exact dims1, Exact dims2 ->
415452
if List.length dims1 <> List.length dims2 then

0 commit comments

Comments
 (0)