@@ -456,17 +456,24 @@ let rec row_conjunction ?(id = phantom_row_id) stage constr1 constr2 =
456456 let extra_var = Option. to_list @@ if keep_constr1 then v1 else v2 in
457457 let num_var = if keep_constr1 then v2 else v1 in
458458 let diff_vars = extra_var @ if keep_constr1 then vars2_only else vars1_only in
459- if List. is_empty diff_vars then (
460- (* No difference in variables but different numerators - this is a mismatch *)
461- match num_var with
462- | None -> elems_mismatch n1 n2
463- | Some v ->
464- let d = if keep_constr1 then n1_val / n2_val else n2_val / n1_val in
465- if d < = 0 then elems_mismatch n1 n2;
466- [ Dim_eq { d1 = Var v; d2 = get_dim ~d () } ])
459+ let n_big = if keep_constr1 then n2_val else n1_val in
460+ let n_small = if keep_constr1 then n1_val else n2_val in
461+ if n_small = 0 then
462+ raise @@ Shape_error ([ % string " Division by zero in constraint solving " ], [] )
463+ else if n_big % n_small <> 0 then
464+ raise
465+ @@ Shape_error
466+ ([ % string " Total_elems constraint: %{n_big#Int} not divisible by %{n_small#Int} " ], [ ] )
467467 else
468- let quotient = if keep_constr1 then n2_val / n1_val else n1_val / n2_val in
469- if quotient < = 0 && Option. is_none num_var then elems_mismatch n1 n2
468+ let quotient = n_big / n_small in
469+ if List. is_empty diff_vars then (
470+ (* No difference in variables but different numerators - this is a mismatch *)
471+ match num_var with
472+ | None -> elems_mismatch n1 n2
473+ | Some v ->
474+ if quotient < = 0 then elems_mismatch n1 n2;
475+ [ Dim_eq { d1 = Var v; d2 = get_dim ~d: quotient () } ])
476+ else if quotient < = 0 && Option. is_none num_var then elems_mismatch n1 n2
470477 else if quotient = 1 && Option. is_none num_var then
471478 (* The difference variables must all be 1 *)
472479 List. map diff_vars ~f: (fun v -> Dim_eq { d1 = Var v; d2 = get_dim ~d: 1 () })
@@ -477,10 +484,8 @@ let rec row_conjunction ?(id = phantom_row_id) stage constr1 constr2 =
477484 match num_var with
478485 | None -> Num_elems quotient
479486 | Some var ->
480- let value = if keep_constr1 then n2_val else n1_val in
481- let coeff = Utils. { value = `Value value; unique_id = Int. to_string value } in
482- let denom = if keep_constr1 then n1_val else n2_val in
483- Strided_var { coeff; var; denom }
487+ let coeff = Utils. { value = `Value n_big; unique_id = Int. to_string n_big } in
488+ Strided_var { coeff; var; denom = n_small }
484489 in
485490 [ Rows_constr { r = [ r ]; constr = Total_elems { numerator; divided_by = [] } } ]
486491 in
@@ -671,12 +676,17 @@ let rec apply_dim_constraint ~(source : source) ~(stage : stage) (dim : dim)
671676 [ Dim_mismatch [ dim ] ] )
672677 else ([] , constr)
673678 | Conv_input { stride; output; dilation; kernel } , At_least_dim d_min -> (
679+ let quotient = if d_min % stride = 0 then d_min / stride else (d_min / stride) + 1 in
674680 match kernel with
675681 | Dim { d = d_k ; _ } when not ! use_padding ->
676- apply_dim_constraint ~source ~stage output
677- (At_least_dim ((d_min / stride) + (dilation * d_k)))
678- env
679- | _ -> apply_dim_constraint ~source ~stage output (At_least_dim (d_min / stride)) env)
682+ let d_min = d_min - (dilation * d_k) in
683+ if d_min < = 0 then ([] , Unconstrained_dim )
684+ else
685+ let quotient = if d_min % stride = 0 then d_min / stride else (d_min / stride) + 1 in
686+ apply_dim_constraint ~source ~stage output
687+ (At_least_dim (quotient + (dilation * d_k)))
688+ env
689+ | _ -> apply_dim_constraint ~source ~stage output (At_least_dim quotient) env)
680690 | Var v , _ -> (
681691 match Map. find env.dim_env v with
682692 | None -> ([] , constr)
0 commit comments