Skip to content

Commit 19637e0

Browse files
committed
Fill-in missing cases in apply_rows_constraint
1 parent e4a69d2 commit 19637e0

File tree

1 file changed

+63
-34
lines changed

1 file changed

+63
-34
lines changed

lib/row.ml

Lines changed: 63 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -760,29 +760,29 @@ let _lift_row_constraint (constr : row_constraint) ~(beg_dims : dim list) ~(dims
760760
| Unconstrained -> Unconstrained
761761
| Exact exact_dims -> Exact (beg_dims @ exact_dims @ dims)
762762

763-
(** Helper function to convert a list of rows to either a single row or information about multiple
764-
row variables. Returns Either.First with a single row if there are zero or one row variables.
765-
Returns Either.Second with (all_dims, row_vars) if there are multiple row variables, where
766-
all_dims is a concatenation of all dims and beg_dims in proper order, and row_vars is a list
767-
of (row_var * row_id) pairs. *)
763+
(** Helper function to convert a list of rows to either a single row or information about multiple
764+
row variables. Returns Either.First with a single row if there are zero or one row variables.
765+
Returns Either.Second with (all_dims, row_vars) if there are multiple row variables, where
766+
all_dims is a concatenation of all dims and beg_dims in proper order, and row_vars is a list of
767+
(row_var * row_id) pairs. *)
768768
let rows_to_row_or_vars (rows : row list) : (row, dim list * (row_var * row_id) list) Either.t =
769769
let rec collect_info before_dims row_vars rows =
770770
match rows with
771771
| [] -> (List.rev before_dims, List.rev row_vars)
772772
| row :: remaining_rows -> (
773773
match row.bcast with
774-
| Broadcastable ->
774+
| Broadcastable ->
775775
(* Regular row, add its dims and continue *)
776776
collect_info (List.rev_append row.dims before_dims) row_vars remaining_rows
777-
| Row_var { v; beg_dims } ->
777+
| Row_var { v; beg_dims } ->
778778
(* Row variable - collect it and continue *)
779779
let new_before_dims = List.rev_append row.dims (List.rev_append beg_dims before_dims) in
780780
let new_row_vars = (v, row.id) :: row_vars in
781781
collect_info new_before_dims new_row_vars remaining_rows)
782782
in
783783
let all_dims, row_vars = collect_info [] [] rows in
784784
match row_vars with
785-
| [] ->
785+
| [] ->
786786
(* No row variables found *)
787787
let first_id = match rows with [] -> phantom_row_id | first_row :: _ -> first_row.id in
788788
Either.First { dims = all_dims; bcast = Broadcastable; id = first_id }
@@ -793,18 +793,17 @@ let rows_to_row_or_vars (rows : row list) : (row, dim list * (row_var * row_id)
793793
| [] -> failwith "rows_to_row_or_vars: single row variable not found during reconstruction"
794794
| row :: remaining_rows -> (
795795
match row.bcast with
796-
| Broadcastable ->
796+
| Broadcastable ->
797797
reconstruct_single_var (List.rev_append row.dims before_dims) remaining_rows
798798
| Row_var { v = found_v; beg_dims } when equal_row_var found_v v ->
799799
let new_beg_dims = List.rev_append before_dims beg_dims in
800800
let after_dims = List.concat_map remaining_rows ~f:(fun r -> r.dims) in
801801
let new_dims = row.dims @ after_dims in
802802
{ dims = new_dims; bcast = Row_var { v; beg_dims = new_beg_dims }; id }
803-
| Row_var _ ->
804-
reconstruct_single_var before_dims remaining_rows)
803+
| Row_var _ -> reconstruct_single_var before_dims remaining_rows)
805804
in
806805
Either.First (reconstruct_single_var [] rows)
807-
| _ ->
806+
| _ ->
808807
(* Multiple row variables *)
809808
Either.Second (all_dims, row_vars)
810809

@@ -824,9 +823,55 @@ let rec apply_rows_constraint ~stage (rows : row list) (constr : row_constraint)
824823
: constraint_ list * environment =
825824
match rows_to_row_or_vars rows with
826825
| Either.First single_row -> apply_row_constraint ~stage single_row constr env
827-
| Either.Second (_all_dims, _row_vars) -> (
826+
| Either.Second (all_dims, row_vars) -> (
828827
match constr with
828+
| Exact dims when List.length dims < List.length all_dims ->
829+
(* Case 1: Exact dims has fewer axes than all_dims - raise mismatch *)
830+
raise
831+
@@ Shape_error
832+
("apply_rows_constraint: Exact constraint has too few axes", [ Row_mismatch rows ])
833+
| Exact dims when List.length dims = List.length all_dims ->
834+
(* Case 2: Exact dims has same length as all_dims - derive pairwise equations *)
835+
let dim_eqs = List.map2_exn dims all_dims ~f:(fun d1 d2 -> Dim_eq { d1; d2 }) in
836+
let row_eqs =
837+
List.map row_vars ~f:(fun (v, id) ->
838+
Row_eq { r1 = row_of_var v id; r2 = { dims = []; bcast = Broadcastable; id } })
839+
in
840+
(dim_eqs @ row_eqs, env)
841+
| Total_elems { numerator = Num_elems n; divided_by } -> (
842+
(* Case 3: Total_elems with known numerator *)
843+
match collect_factors all_dims with
844+
| None -> ([ Rows_constr { r = rows; constr } ], env) (* Give up on complex cases *)
845+
| Some (known_product, product_vars) ->
846+
(* Move divided_by variables to the other side by combining with product_vars *)
847+
let all_product_vars = product_vars @ divided_by in
848+
if n % known_product <> 0 then
849+
raise
850+
@@ Shape_error
851+
( [%string
852+
"Total_elems constraint: %{n#Int} not divisible by known product \
853+
%{known_product#Int}"],
854+
[] )
855+
else if n = known_product then
856+
(* Equate all product vars to d=1 and add Total_elems 1 for each row var *)
857+
let var_eqs =
858+
List.map all_product_vars ~f:(fun v ->
859+
Dim_eq { d1 = Var v; d2 = get_dim ~d:1 () })
860+
in
861+
let row_constrs =
862+
List.map row_vars ~f:(fun (v, id) ->
863+
Rows_constr
864+
{
865+
r = [ row_of_var v id ];
866+
constr = Total_elems { numerator = Num_elems 1; divided_by = [] };
867+
})
868+
in
869+
(var_eqs @ row_constrs, env)
870+
else
871+
(* Cannot deduce no_further_axes, return unchanged *)
872+
([ Rows_constr { r = rows; constr } ], env))
829873
| Exact [ single_dim ] -> (
874+
(* Keep existing logic for single_dim case *)
830875
match List.rev rows with
831876
| { dims = []; bcast = Broadcastable; id = _ } :: more_rows ->
832877
apply_rows_constraint ~stage (List.rev more_rows) constr env
@@ -838,30 +883,16 @@ let rec apply_rows_constraint ~stage (rows : row list) (constr : row_constraint)
838883
env )
839884
| { dims = []; bcast = Row_var { v; beg_dims = [] }; id = { kind = `Output; _ } as id }
840885
:: more_rows ->
841-
(* TODO: we could check if there is a non-empty output row later on. *)
842886
( Row_eq
843887
{
844888
r1 = row_of_var v id;
845889
r2 = { dims = [ single_dim ]; bcast = Broadcastable; id };
846890
}
847891
:: List.concat_map ~f:check_empty_row more_rows,
848892
env )
849-
| {
850-
dims = [ single_other ];
851-
bcast = Row_var { v; beg_dims = [] };
852-
id = { kind = `Output; _ } as id;
853-
}
854-
:: more_rows
855-
| {
856-
dims = [];
857-
bcast = Row_var { v; beg_dims = [ single_other ] };
858-
id = { kind = `Output; _ } as id;
859-
}
860-
:: more_rows ->
861-
( Dim_eq { d1 = single_other; d2 = single_dim }
862-
:: Row_eq { r1 = row_of_var v id; r2 = { dims = []; bcast = Broadcastable; id } }
863-
:: List.concat_map ~f:check_empty_row more_rows,
864-
env )
893+
| { dims = _; bcast = Row_var { v = _; beg_dims = _ }; id = { kind = `Output; _ } } :: _
894+
->
895+
assert false
865896
| _ -> raise @@ Shape_error ("apply_rows_constraint: shape too big", [ Row_mismatch rows ])
866897
)
867898
| _ -> ([ Rows_constr { r = rows; constr } ], env))
@@ -932,10 +963,8 @@ and apply_row_constraint ~stage (r : row) (constr : row_constraint) env : constr
932963
| { dims; bcast = Broadcastable; _ }, Total_elems { numerator; divided_by }
933964
when List.length divided_by <= 1 -> (
934965
try
935-
let d, vars =
936-
match collect_factors dims with
937-
| Some (d, vars) -> (d, vars)
938-
| None -> raise Given_up
966+
let d, vars =
967+
match collect_factors dims with Some (d, vars) -> (d, vars) | None -> raise Given_up
939968
in
940969
let numerator = total_elems_divide numerator d in
941970
if total_elems_known_zero numerator then

0 commit comments

Comments
 (0)