@@ -760,29 +760,29 @@ let _lift_row_constraint (constr : row_constraint) ~(beg_dims : dim list) ~(dims
760760 | Unconstrained -> Unconstrained
761761 | Exact exact_dims -> Exact (beg_dims @ exact_dims @ dims)
762762
763- (* * Helper function to convert a list of rows to either a single row or information about multiple
764- row variables. Returns Either.First with a single row if there are zero or one row variables.
765- Returns Either.Second with (all_dims, row_vars) if there are multiple row variables, where
766- all_dims is a concatenation of all dims and beg_dims in proper order, and row_vars is a list
767- of (row_var * row_id) pairs. *)
763+ (* * Helper function to convert a list of rows to either a single row or information about multiple
764+ row variables. Returns Either.First with a single row if there are zero or one row variables.
765+ Returns Either.Second with (all_dims, row_vars) if there are multiple row variables, where
766+ all_dims is a concatenation of all dims and beg_dims in proper order, and row_vars is a list of
767+ (row_var * row_id) pairs. *)
768768let rows_to_row_or_vars (rows : row list ) : (row, dim list * (row_var * row_id) list) Either.t =
769769 let rec collect_info before_dims row_vars rows =
770770 match rows with
771771 | [] -> (List. rev before_dims, List. rev row_vars)
772772 | row :: remaining_rows -> (
773773 match row.bcast with
774- | Broadcastable ->
774+ | Broadcastable ->
775775 (* Regular row, add its dims and continue *)
776776 collect_info (List. rev_append row.dims before_dims) row_vars remaining_rows
777- | Row_var { v; beg_dims } ->
777+ | Row_var { v; beg_dims } ->
778778 (* Row variable - collect it and continue *)
779779 let new_before_dims = List. rev_append row.dims (List. rev_append beg_dims before_dims) in
780780 let new_row_vars = (v, row.id) :: row_vars in
781781 collect_info new_before_dims new_row_vars remaining_rows)
782782 in
783783 let all_dims, row_vars = collect_info [] [] rows in
784784 match row_vars with
785- | [] ->
785+ | [] ->
786786 (* No row variables found *)
787787 let first_id = match rows with [] -> phantom_row_id | first_row :: _ -> first_row.id in
788788 Either. First { dims = all_dims; bcast = Broadcastable ; id = first_id }
@@ -793,18 +793,17 @@ let rows_to_row_or_vars (rows : row list) : (row, dim list * (row_var * row_id)
793793 | [] -> failwith " rows_to_row_or_vars: single row variable not found during reconstruction"
794794 | row :: remaining_rows -> (
795795 match row.bcast with
796- | Broadcastable ->
796+ | Broadcastable ->
797797 reconstruct_single_var (List. rev_append row.dims before_dims) remaining_rows
798798 | Row_var { v = found_v ; beg_dims } when equal_row_var found_v v ->
799799 let new_beg_dims = List. rev_append before_dims beg_dims in
800800 let after_dims = List. concat_map remaining_rows ~f: (fun r -> r.dims) in
801801 let new_dims = row.dims @ after_dims in
802802 { dims = new_dims; bcast = Row_var { v; beg_dims = new_beg_dims }; id }
803- | Row_var _ ->
804- reconstruct_single_var before_dims remaining_rows)
803+ | Row_var _ -> reconstruct_single_var before_dims remaining_rows)
805804 in
806805 Either. First (reconstruct_single_var [] rows)
807- | _ ->
806+ | _ ->
808807 (* Multiple row variables *)
809808 Either. Second (all_dims, row_vars)
810809
@@ -824,9 +823,55 @@ let rec apply_rows_constraint ~stage (rows : row list) (constr : row_constraint)
824823 : constraint_ list * environment =
825824 match rows_to_row_or_vars rows with
826825 | Either. First single_row -> apply_row_constraint ~stage single_row constr env
827- | Either. Second (_all_dims , _row_vars ) -> (
826+ | Either. Second (all_dims , row_vars ) -> (
828827 match constr with
828+ | Exact dims when List. length dims < List. length all_dims ->
829+ (* Case 1: Exact dims has fewer axes than all_dims - raise mismatch *)
830+ raise
831+ @@ Shape_error
832+ (" apply_rows_constraint: Exact constraint has too few axes" , [ Row_mismatch rows ])
833+ | Exact dims when List. length dims = List. length all_dims ->
834+ (* Case 2: Exact dims has same length as all_dims - derive pairwise equations *)
835+ let dim_eqs = List. map2_exn dims all_dims ~f: (fun d1 d2 -> Dim_eq { d1; d2 }) in
836+ let row_eqs =
837+ List. map row_vars ~f: (fun (v , id ) ->
838+ Row_eq { r1 = row_of_var v id; r2 = { dims = [] ; bcast = Broadcastable ; id } })
839+ in
840+ (dim_eqs @ row_eqs, env)
841+ | Total_elems { numerator = Num_elems n ; divided_by } -> (
842+ (* Case 3: Total_elems with known numerator *)
843+ match collect_factors all_dims with
844+ | None -> ([ Rows_constr { r = rows; constr } ], env) (* Give up on complex cases *)
845+ | Some (known_product , product_vars ) ->
846+ (* Move divided_by variables to the other side by combining with product_vars *)
847+ let all_product_vars = product_vars @ divided_by in
848+ if n % known_product <> 0 then
849+ raise
850+ @@ Shape_error
851+ ( [% string
852+ " Total_elems constraint: %{n#Int} not divisible by known product \
853+ %{known_product#Int}" ],
854+ [] )
855+ else if n = known_product then
856+ (* Equate all product vars to d=1 and add Total_elems 1 for each row var *)
857+ let var_eqs =
858+ List. map all_product_vars ~f: (fun v ->
859+ Dim_eq { d1 = Var v; d2 = get_dim ~d: 1 () })
860+ in
861+ let row_constrs =
862+ List. map row_vars ~f: (fun (v , id ) ->
863+ Rows_constr
864+ {
865+ r = [ row_of_var v id ];
866+ constr = Total_elems { numerator = Num_elems 1 ; divided_by = [] };
867+ })
868+ in
869+ (var_eqs @ row_constrs, env)
870+ else
871+ (* Cannot deduce no_further_axes, return unchanged *)
872+ ([ Rows_constr { r = rows; constr } ], env))
829873 | Exact [ single_dim ] -> (
874+ (* Keep existing logic for single_dim case *)
830875 match List. rev rows with
831876 | { dims = [] ; bcast = Broadcastable ; id = _ } :: more_rows ->
832877 apply_rows_constraint ~stage (List. rev more_rows) constr env
@@ -838,30 +883,16 @@ let rec apply_rows_constraint ~stage (rows : row list) (constr : row_constraint)
838883 env )
839884 | { dims = [] ; bcast = Row_var { v; beg_dims = [] }; id = { kind = `Output ; _ } as id }
840885 :: more_rows ->
841- (* TODO: we could check if there is a non-empty output row later on. *)
842886 ( Row_eq
843887 {
844888 r1 = row_of_var v id;
845889 r2 = { dims = [ single_dim ]; bcast = Broadcastable ; id };
846890 }
847891 :: List. concat_map ~f: check_empty_row more_rows,
848892 env )
849- | {
850- dims = [ single_other ];
851- bcast = Row_var { v; beg_dims = [] };
852- id = { kind = `Output ; _ } as id;
853- }
854- :: more_rows
855- | {
856- dims = [] ;
857- bcast = Row_var { v; beg_dims = [ single_other ] };
858- id = { kind = `Output ; _ } as id;
859- }
860- :: more_rows ->
861- ( Dim_eq { d1 = single_other; d2 = single_dim }
862- :: Row_eq { r1 = row_of_var v id; r2 = { dims = [] ; bcast = Broadcastable ; id } }
863- :: List. concat_map ~f: check_empty_row more_rows,
864- env )
893+ | { dims = _; bcast = Row_var { v = _; beg_dims = _ }; id = { kind = `Output ; _ } } :: _
894+ ->
895+ assert false
865896 | _ -> raise @@ Shape_error (" apply_rows_constraint: shape too big" , [ Row_mismatch rows ])
866897 )
867898 | _ -> ([ Rows_constr { r = rows; constr } ], env))
@@ -932,10 +963,8 @@ and apply_row_constraint ~stage (r : row) (constr : row_constraint) env : constr
932963 | { dims; bcast = Broadcastable ; _ }, Total_elems { numerator; divided_by }
933964 when List. length divided_by < = 1 -> (
934965 try
935- let d, vars =
936- match collect_factors dims with
937- | Some (d , vars ) -> (d, vars)
938- | None -> raise Given_up
966+ let d, vars =
967+ match collect_factors dims with Some (d , vars ) -> (d, vars) | None -> raise Given_up
939968 in
940969 let numerator = total_elems_divide numerator d in
941970 if total_elems_known_zero numerator then
0 commit comments