@@ -307,7 +307,7 @@ let dim_conjunction constr1 constr2 =
307307 | _ , Unconstrained_dim -> Some ([] , constr1)
308308 | At_least_dim d1 , At_least_dim d2 -> Some ([] , At_least_dim (Int. max d1 d2))
309309
310- let row_conjunction ?(id = phantom_row_id) constr1 constr2 =
310+ let rec row_conjunction ?(id = phantom_row_id) constr1 constr2 =
311311 let elems_mismatch n1 n2 =
312312 raise
313313 @@ Shape_error
@@ -376,40 +376,77 @@ let row_conjunction ?(id = phantom_row_id) constr1 constr2 =
376376 (* Neither coefficient divides the other - we can't make progress *)
377377 (* Keep both constraints - this will be resolved later *)
378378 None )
379+ | Total_elems { nominator = Strided_var { coeff = c1; var = v1 }; divided_by = vars1 }, constr2
380+ when Set. mem vars1 v1 ->
381+ row_conjunction ~id
382+ (Total_elems { nominator = Num_elems (Lazy. force c1); divided_by = Set. remove vars1 v1 })
383+ constr2
384+ | constr2, Total_elems { nominator = Strided_var { coeff = c1; var = v1 }; divided_by = vars1 }
385+ when Set. mem vars1 v1 ->
386+ row_conjunction ~id
387+ (Total_elems { nominator = Num_elems (Lazy. force c1); divided_by = Set. remove vars1 v1 })
388+ constr2
379389 | ( Total_elems { nominator = n1; divided_by = vars1 },
380390 Total_elems { nominator = n2; divided_by = vars2 } ) ->
381- let shared = Set. inter vars1 vars2 |> Set. to_list in
391+ let vars1_only = Set. diff vars1 vars2 |> Set. to_list in
392+ let vars2_only = Set. diff vars2 vars1 |> Set. to_list in
393+ let extras ~keep_constr1 ?(extra_var = [] ) ?(div_by = [] ) ~n1_val ~n2_val () =
394+ (* If we keep constr1, then it has fewer divided_by, i.e. vars1 ⊂ vars2. n1 / (product of
395+ vars1) = n2 / (product of vars2) Since vars1 ⊂ vars2, we have vars2 = vars1 ∪ vars2_only
396+ So: n1 / (product of vars1) = n2 / (product of vars1 × product of vars2_only) Thus: n1 =
397+ n2 / (product of vars2_only) Which means: product of vars2_only = n2 / n1 *)
398+ let diff_vars = extra_var @ if keep_constr1 then vars2_only else vars1_only in
399+ let quotient = if keep_constr1 then n2_val / n1_val else n1_val / n2_val in
400+ if quotient < = 0 then elems_mismatch n1 n2
401+ else if quotient = 1 then
402+ (* The difference variables must all be 1 *)
403+ List. map diff_vars ~f: (fun v -> Dim_eq { d1 = Var v; d2 = get_dim ~d: 1 () })
404+ else if List. is_empty diff_vars then
405+ (* No difference in variables but different nominators - this is a mismatch *)
406+ elems_mismatch n1 n2
407+ else
408+ (* The product of difference variables equals the quotient *)
409+ let r = { dims = List. map diff_vars ~f: (fun v -> Var v); bcast = Broadcastable ; id } in
410+ [
411+ Rows_constr
412+ {
413+ r = [ r ];
414+ constr =
415+ Total_elems
416+ {
417+ nominator = Num_elems quotient;
418+ divided_by = Set. of_list (module Dim_var ) div_by;
419+ };
420+ };
421+ ]
422+ in
382423 let extras ~keep_constr1 =
383- (* If we keep constr1, then it has fewer divided_by, i.e. n1 > n2. *)
384424 match (n1, n2) with
385- | Num_elems n1_val , Num_elems n2_val ->
386- let nominator_val = if keep_constr1 then n1_val / n2_val else n2_val / n1_val in
387- if nominator_val < = 0 then elems_mismatch n1 n2
388- else if nominator_val = 1 then
389- List. map shared ~f: (fun v -> Dim_eq { d1 = Var v; d2 = get_dim ~d: 1 () })
390- else if List. is_empty shared then []
391- else
392- let r = { dims = List. map shared ~f: (fun v -> Var v); bcast = Broadcastable ; id } in
393- [
394- Rows_constr
395- {
396- r = [ r ];
397- constr =
398- Total_elems
399- {
400- nominator = Num_elems nominator_val;
401- divided_by = Set. empty (module Dim_var );
402- };
403- };
404- ]
425+ | Num_elems n1_val , Num_elems n2_val -> extras ~keep_constr1 ~n1_val ~n2_val ()
426+ | Strided_var { coeff = c1; var = v1 }, Strided_var { coeff = c2; var = v2 }
427+ when equal_dim_var v1 v2 ->
428+ extras ~keep_constr1 ~n1_val: (Lazy. force c1) ~n2_val: (Lazy. force c2) ()
429+ | Strided_var { coeff = c1 ; var = v1 } , Num_elems n2_val when keep_constr1 ->
430+ (* v1 from the nominator joins v2_vars from the denominator. *)
431+ extras ~keep_constr1 ~extra_var: [ v1 ] ~n1_val: (Lazy. force c1) ~n2_val ()
432+ | Num_elems n1_val , Strided_var { coeff = c2 ; var = v2 } when not keep_constr1 ->
433+ extras ~keep_constr1 ~extra_var: [ v2 ] ~n1_val ~n2_val: (Lazy. force c2) ()
434+ | Strided_var { coeff = c1; var = v1 }, Strided_var { coeff = c2; var = v2 }
435+ when keep_constr1 ->
436+ (* v1 from the nominator joins v2_vars from the denominator. *)
437+ extras ~keep_constr1 ~extra_var: [ v1 ] ~div_by: [ v2 ] ~n1_val: (Lazy. force c1)
438+ ~n2_val: (Lazy. force c2) ()
439+ | Strided_var { coeff = c1; var = v1 }, Strided_var { coeff = c2; var = v2 }
440+ when not keep_constr1 ->
441+ (* v2 from the nominator joins v1_vars from the denominator. *)
442+ extras ~keep_constr1 ~extra_var: [ v2 ] ~div_by: [ v1 ] ~n1_val: (Lazy. force c1)
443+ ~n2_val: (Lazy. force c2) ()
405444 | _ ->
406- (* TODO: Handle Strided_var cases - for now, return empty list *)
445+ (* NOTE: being leaky here to not overcomplicate the code *)
407446 []
408447 in
409- let subsum = Set. symmetric_diff vars1 vars2 in
410- if Sequence. for_all ~f: Either. is_first subsum then Some (extras ~keep_constr1: false , constr2)
411- else if Sequence. for_all ~f: Either. is_second subsum then
412- Some (extras ~keep_constr1: true , constr1)
448+ if List. is_empty vars1_only then Some (extras ~keep_constr1: false , constr2)
449+ else if List. is_empty vars2_only then Some (extras ~keep_constr1: true , constr1)
413450 else None
414451 | Exact dims1 , Exact dims2 ->
415452 if List. length dims1 <> List. length dims2 then
0 commit comments