Skip to content

Commit f529aee

Browse files
committed
row.ml: fix missing divisibility guards, improve At_least_dim for Conv_input
1 parent 5420de6 commit f529aee

File tree

1 file changed

+28
-18
lines changed

1 file changed

+28
-18
lines changed

lib/row.ml

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -456,17 +456,24 @@ let rec row_conjunction ?(id = phantom_row_id) stage constr1 constr2 =
456456
let extra_var = Option.to_list @@ if keep_constr1 then v1 else v2 in
457457
let num_var = if keep_constr1 then v2 else v1 in
458458
let diff_vars = extra_var @ if keep_constr1 then vars2_only else vars1_only in
459-
if List.is_empty diff_vars then (
460-
(* No difference in variables but different numerators - this is a mismatch *)
461-
match num_var with
462-
| None -> elems_mismatch n1 n2
463-
| Some v ->
464-
let d = if keep_constr1 then n1_val / n2_val else n2_val / n1_val in
465-
if d <= 0 then elems_mismatch n1 n2;
466-
[ Dim_eq { d1 = Var v; d2 = get_dim ~d () } ])
459+
let n_big = if keep_constr1 then n2_val else n1_val in
460+
let n_small = if keep_constr1 then n1_val else n2_val in
461+
if n_small = 0 then
462+
raise @@ Shape_error ([%string "Division by zero in constraint solving"], [])
463+
else if n_big % n_small <> 0 then
464+
raise
465+
@@ Shape_error
466+
([%string "Total_elems constraint: %{n_big#Int} not divisible by %{n_small#Int}"], [])
467467
else
468-
let quotient = if keep_constr1 then n2_val / n1_val else n1_val / n2_val in
469-
if quotient <= 0 && Option.is_none num_var then elems_mismatch n1 n2
468+
let quotient = n_big / n_small in
469+
if List.is_empty diff_vars then (
470+
(* No difference in variables but different numerators - this is a mismatch *)
471+
match num_var with
472+
| None -> elems_mismatch n1 n2
473+
| Some v ->
474+
if quotient <= 0 then elems_mismatch n1 n2;
475+
[ Dim_eq { d1 = Var v; d2 = get_dim ~d:quotient () } ])
476+
else if quotient <= 0 && Option.is_none num_var then elems_mismatch n1 n2
470477
else if quotient = 1 && Option.is_none num_var then
471478
(* The difference variables must all be 1 *)
472479
List.map diff_vars ~f:(fun v -> Dim_eq { d1 = Var v; d2 = get_dim ~d:1 () })
@@ -477,10 +484,8 @@ let rec row_conjunction ?(id = phantom_row_id) stage constr1 constr2 =
477484
match num_var with
478485
| None -> Num_elems quotient
479486
| Some var ->
480-
let value = if keep_constr1 then n2_val else n1_val in
481-
let coeff = Utils.{ value = `Value value; unique_id = Int.to_string value } in
482-
let denom = if keep_constr1 then n1_val else n2_val in
483-
Strided_var { coeff; var; denom }
487+
let coeff = Utils.{ value = `Value n_big; unique_id = Int.to_string n_big } in
488+
Strided_var { coeff; var; denom = n_small }
484489
in
485490
[ Rows_constr { r = [ r ]; constr = Total_elems { numerator; divided_by = [] } } ]
486491
in
@@ -671,12 +676,17 @@ let rec apply_dim_constraint ~(source : source) ~(stage : stage) (dim : dim)
671676
[ Dim_mismatch [ dim ] ] )
672677
else ([], constr)
673678
| Conv_input { stride; output; dilation; kernel }, At_least_dim d_min -> (
679+
let quotient = if d_min % stride = 0 then d_min / stride else (d_min / stride) + 1 in
674680
match kernel with
675681
| Dim { d = d_k; _ } when not !use_padding ->
676-
apply_dim_constraint ~source ~stage output
677-
(At_least_dim ((d_min / stride) + (dilation * d_k)))
678-
env
679-
| _ -> apply_dim_constraint ~source ~stage output (At_least_dim (d_min / stride)) env)
682+
let d_min = d_min - (dilation * d_k) in
683+
if d_min <= 0 then ([], Unconstrained_dim)
684+
else
685+
let quotient = if d_min % stride = 0 then d_min / stride else (d_min / stride) + 1 in
686+
apply_dim_constraint ~source ~stage output
687+
(At_least_dim (quotient + (dilation * d_k)))
688+
env
689+
| _ -> apply_dim_constraint ~source ~stage output (At_least_dim quotient) env)
680690
| Var v, _ -> (
681691
match Map.find env.dim_env v with
682692
| None -> ([], constr)

0 commit comments

Comments
 (0)