@@ -57,7 +57,7 @@ let rec dim_to_string style = function
5757 | Dim { d; label = Some l ; _ } -> [% string " %{l}=%{d#Int}" ]
5858 | Var { id; label = Some l } -> [% string " $%{id#Int}:%{l}" ]
5959 | Var { id; label = None } -> " $" ^ Int. to_string id
60- | Prod ds -> String. concat ~sep: " * " (List. map ds ~f: (dim_to_string style))
60+ | Prod ds -> String. concat ~sep: " & " (List. map ds ~f: (dim_to_string style))
6161
6262module Row_var = struct
6363 type t = Row_var of int [@@ deriving equal , hash , compare , sexp ]
@@ -196,15 +196,37 @@ let rec dim_to_int_exn = function
196196
197197(* Helper functions for Prod *)
198198let rec dim_vars = function
199- | Var v -> [v ]
199+ | Var v -> [ v ]
200200 | Dim _ -> []
201201 | Prod dims -> List. concat_map dims ~f: dim_vars
202202
203+ let rec extract_dims_and_vars = function
204+ | Dim { d; _ } -> ([ d ], [] )
205+ | Var v -> ([] , [ v ])
206+ | Prod dims ->
207+ List. fold dims ~init: ([] , [] ) ~f: (fun (ds , vs ) dim ->
208+ let d', v' = extract_dims_and_vars dim in
209+ (ds @ d', vs @ v'))
210+
203211let rec is_solved_dim = function
204212 | Var _ -> false
205213 | Dim _ -> true
206214 | Prod dims -> List. for_all dims ~f: is_solved_dim
207- let s_dim_one v ~value ~in_ = match in_ with Var v2 when equal_dim_var v v2 -> value | _ -> in_
215+
216+ let s_dim_one v ~value ~in_ =
217+ match in_ with
218+ | Var v2 when equal_dim_var v v2 -> value
219+ | Prod dims -> (
220+ let rec flatten_prods = function
221+ | Var v2 when equal_dim_var v v2 -> flatten_prods value
222+ | Prod nested_dims -> List. concat_map nested_dims ~f: flatten_prods
223+ | d -> [ d ]
224+ in
225+ match List. concat_map dims ~f: flatten_prods with
226+ | [] -> get_dim ~d: 1 ()
227+ | [ d ] -> d
228+ | dims -> Prod dims)
229+ | Dim _ | Var _ -> in_
208230
209231(* For future flexibility *)
210232let dim_conjunction constr1 constr2 =
@@ -247,8 +269,8 @@ let row_conjunction ?(id = phantom_row_id) constr1 constr2 =
247269 Some (extras ~keep_constr1: true , constr1)
248270 else None
249271
250- let apply_dim_constraint ~(source : source ) ~(stage : stage ) (dim : dim ) ( constr : dim_constraint )
251- (env : environment ) : constraint_ list * dim_constraint =
272+ let rec apply_dim_constraint ~(source : source ) ~(stage : stage ) (dim : dim )
273+ (constr : dim_constraint ) ( env : environment ) : constraint_ list * dim_constraint =
252274 let extras, constr =
253275 match (dim, constr) with
254276 | Dim { d; _ } , At_least_dim d_min ->
@@ -258,21 +280,27 @@ let apply_dim_constraint ~(source : source) ~(stage : stage) (dim : dim) (constr
258280 ( " At_least_dim constraint failed, expected " ^ Int. to_string d_min,
259281 [ Dim_mismatch [ dim ] ] )
260282 else ([] , constr)
261- | Prod dims , At_least_dim d_min ->
283+ | Prod dims , At_least_dim d_min -> (
262284 (* For a product, we check if the product of all known dimensions meets the constraint *)
263285 let product = ref 1 in
264- let has_vars = ref false in
265- List. iter dims ~f: ( function
286+ let vars = ref [] in
287+ let rec f = function
266288 | Dim { d; _ } -> product := ! product * d
267- | Var _ -> has_vars := true
268- | Prod _ -> has_vars := true (* Nested products need recursive handling *)
269- );
270- if not ! has_vars && ! product < d_min then
271- raise
272- @@ Shape_error
273- ( " At_least_dim constraint failed for product, expected at least " ^ Int. to_string d_min,
274- [ Dim_mismatch [ dim ] ] )
275- else ([] , constr) (* TODO: Could propagate constraints to constituent dimensions *)
289+ | Var v -> vars := v :: ! vars
290+ | Prod dims -> List. iter dims ~f
291+ in
292+ List. iter dims ~f ;
293+ match ! vars with
294+ | [] ->
295+ if ! product < d_min then
296+ raise
297+ @@ Shape_error
298+ ( " At_least_dim constraint failed for product, expected at least "
299+ ^ Int. to_string d_min,
300+ [ Dim_mismatch [ dim ] ] )
301+ else ([] , constr)
302+ | [ v ] -> apply_dim_constraint ~source ~stage (Var v) (At_least_dim (d_min / ! product)) env
303+ | _ -> ([] , constr))
276304 | Var v , _ -> (
277305 match Map. find env.dim_env v with
278306 | None -> ([] , constr)
@@ -293,14 +321,6 @@ let reduce_row_constraint (constr : row_constraint) ~(beg_dims : dim list) ~(dim
293321 row_constraint =
294322 match constr with
295323 | Total_elems { nominator; divided_by } ->
296- let rec extract_dims_and_vars = function
297- | Dim { d; _ } -> ([ d ], [] )
298- | Var v -> ([] , [ v ])
299- | Prod dims ->
300- List. fold dims ~init: ([] , [] ) ~f: (fun (ds , vs ) dim ->
301- let d', v' = extract_dims_and_vars dim in
302- (ds @ d', vs @ v'))
303- in
304324 let ds, vars =
305325 List. fold (beg_dims @ dims) ~init: ([] , [] ) ~f: (fun (ds , vs ) dim ->
306326 let d', v' = extract_dims_and_vars dim in
@@ -325,14 +345,6 @@ let _lift_row_constraint (constr : row_constraint) ~(beg_dims : dim list) ~(dims
325345 row_constraint =
326346 match constr with
327347 | Total_elems { nominator; divided_by } ->
328- let rec extract_dims_and_vars = function
329- | Dim { d; _ } -> ([ d ], [] )
330- | Var v -> ([] , [ v ])
331- | Prod dims ->
332- List. fold dims ~init: ([] , [] ) ~f: (fun (ds , vs ) dim ->
333- let d', v' = extract_dims_and_vars dim in
334- (ds @ d', vs @ v'))
335- in
336348 let ds, vars =
337349 List. fold (beg_dims @ dims) ~init: ([] , [] ) ~f: (fun (ds , vs ) dim ->
338350 let d', v' = extract_dims_and_vars dim in
@@ -409,17 +421,9 @@ let apply_row_constraint ~stage:_ (r : row) (constr : row_constraint) env : cons
409421 | _ , Unconstrained -> assert false
410422 | { dims; bcast = Broadcastable ; _ }, Total_elems { nominator; divided_by }
411423 when Set. length divided_by < = 1 -> (
412- let rec extract_dims_and_vars_dim = function
413- | Dim { d; _ } -> ([ d ], [] )
414- | Var v -> ([] , [ v ])
415- | Prod dims ->
416- List. fold dims ~init: ([] , [] ) ~f: (fun (ds , vs ) dim ->
417- let d', v' = extract_dims_and_vars_dim dim in
418- (ds @ d', vs @ v'))
419- in
420424 let ds, vars =
421425 List. fold dims ~init: ([] , [] ) ~f: (fun (ds , vs ) dim ->
422- let d', v' = extract_dims_and_vars_dim dim in
426+ let d', v' = extract_dims_and_vars dim in
423427 (ds @ d', vs @ v'))
424428 in
425429 let d : int = List. fold ds ~init: 1 ~f: ( * ) in
@@ -492,19 +496,21 @@ let s_dim_one_in_row_constr v ~value constr =
492496 ( " s_dim_one_in_row_constr: Total_elems constraint failed: shape is too big" ,
493497 [ Dim_mismatch [ value ] ] )
494498 else Total_elems { nominator; divided_by }
495- | Prod dims ->
496- (* When substituting with a Prod, we need to calculate its total dimension *)
497- let d = dim_to_int_exn (Prod dims) in
499+ | Prod _ ->
500+ (* When substituting with a Prod, we need to calculate its total dimension and extract
501+ variables *)
502+ let ds, vars = extract_dims_and_vars value in
503+ let d = List. fold ds ~init: 1 ~f: ( * ) in
498504 let nominator = nominator / d in
499505 if nominator < = 0 then
500506 raise
501507 @@ Shape_error
502508 ( " s_dim_one_in_row_constr: Total_elems constraint failed: shape is too big" ,
503509 [ Dim_mismatch [ value ] ] )
504- else
505- (* Extract any variables from the Prod and add them to divided_by *)
506- let vars = dim_vars value in
507- Total_elems { nominator; divided_by = Set. union divided_by (Set. of_list (module Dim_var ) vars) })
510+ else
511+ (* Add any variables from the Prod to divided_by *)
512+ Total_elems
513+ { nominator; divided_by = Set. union divided_by (Set. of_list (module Dim_var ) vars) })
508514 | _ -> constr
509515
510516let s_dim_one_in_row_entry v ~value in_ =
@@ -521,7 +527,12 @@ let rec subst_dim env = function
521527 | Some (Solved_dim (Var v2 )) when equal_dim_var v v2 -> default
522528 | Some (Solved_dim d ) -> subst_dim env d
523529 | _ -> default)
524- | Prod dims -> Prod (List. map dims ~f: (subst_dim env))
530+ | Prod dims -> (
531+ let rec f dim = match subst_dim env dim with Prod dims -> dims | dim -> [ dim ] in
532+ match List. concat_map dims ~f with
533+ | [] -> get_dim ~d: 1 ()
534+ | [ dim ] -> dim
535+ | dims -> Prod dims)
525536
526537let s_row_one v ~value :{ dims = more_dims ; bcast; id = _ } ~in_ =
527538 match in_ with
@@ -822,7 +833,7 @@ let%debug5_sexp solve_dim_ineq ~(stage : stage) ~(cur : dim) ~(subr : dim) (env
822833 ||
823834 match Map. find env.dim_env cur_v with
824835 | None | Some (Solved_dim (Dim _ )) -> false
825- | Some (Solved_dim (Var v )) -> equal_dim_var subr_v v
836+ | Some (Solved_dim (Var v )) | Some (Solved_dim (Prod [ Var v ])) -> equal_dim_var subr_v v
826837 | Some (Solved_dim (Prod _ )) -> false (* Prod doesn't contain variables directly *)
827838 | Some (Bounds_dim { cur = curs ; _ } ) -> cyclic ~subr_v ~curs )
828839 in
@@ -845,22 +856,30 @@ let%debug5_sexp solve_dim_ineq ~(stage : stage) ~(cur : dim) ~(subr : dim) (env
845856 let expanded2 = get_dim ~d: d2 () in
846857 ([ Dim_ineq { cur = expanded1; subr = expanded2 } ], env)
847858 else
848- raise @@ Shape_error (" Cannot compare Prod dimensions with unresolved variables" , [ Dim_mismatch [ cur; subr ] ])
859+ raise
860+ @@ Shape_error
861+ ( " Cannot compare Prod dimensions with unresolved variables" ,
862+ [ Dim_mismatch [ cur; subr ] ] )
849863 | Prod ds , Dim _ | Dim _ , Prod ds ->
850- (* For now, we can't directly compare a Prod with a Dim in an inequality.
851- We could potentially expand this to handle cases where the product is known. *)
864+ (* For now, we can't directly compare a Prod with a Dim in an inequality. We could potentially
865+ expand this to handle cases where the product is known. *)
852866 if is_solved_dim (Prod ds) then
853867 let prod_val = dim_to_int_exn (Prod ds) in
854868 let expanded = get_dim ~d: prod_val () in
855- ( match (cur, subr) with
856- | Prod _ , _ -> ([ Dim_ineq { cur = expanded; subr } ], env)
857- | _ , Prod _ -> ([ Dim_ineq { cur; subr = expanded } ], env)
858- | _ -> assert false )
869+ match (cur, subr) with
870+ | Prod _ , _ -> ([ Dim_ineq { cur = expanded; subr } ], env)
871+ | _ , Prod _ -> ([ Dim_ineq { cur; subr = expanded } ], env)
872+ | _ -> assert false
859873 else
860- raise @@ Shape_error (" Cannot compare Prod with unresolved variables in inequality" , [ Dim_mismatch [ cur; subr ] ])
874+ raise
875+ @@ Shape_error
876+ ( " Cannot compare Prod with unresolved variables in inequality" ,
877+ [ Dim_mismatch [ cur; subr ] ] )
861878 | Prod _ , Var _ | Var _ , Prod _ ->
862879 (* Similar to above - we need all dimensions resolved to compare *)
863- raise @@ Shape_error (" Cannot compare Prod with variables in inequality" , [ Dim_mismatch [ cur; subr ] ])
880+ raise
881+ @@ Shape_error
882+ (" Cannot compare Prod with variables in inequality" , [ Dim_mismatch [ cur; subr ] ])
864883 | Var cur_v , Var subr_v -> (
865884 match (Map. find env.dim_env cur_v, Map. find env.dim_env subr_v) with
866885 | Some (Bounds_dim { cur = cur1 ; _ } ), _ when List. mem ~equal: equal_dim_var cur1 subr_v ->
@@ -997,9 +1016,13 @@ let%debug5_sexp solve_dim_ineq ~(stage : stage) ~(cur : dim) ~(subr : dim) (env
9971016 let lub = get_dim ~d: 1 () in
9981017 (lub, [ Dim_eq { d1 = subr; d2 = lub } ])
9991018 else
1000- raise @@ Shape_error (" Cannot compute LUB between Prod and unsolved dimensions" , [ Dim_mismatch [ cur; lub2 ] ])
1019+ raise
1020+ @@ Shape_error
1021+ ( " Cannot compute LUB between Prod and unsolved dimensions" ,
1022+ [ Dim_mismatch [ cur; lub2 ] ] )
10011023 | Prod _ , Prod _ ->
1002- (* For LUB between two Prods, they need to match structurally or we force to dim-1 *)
1024+ (* For LUB between two Prods, they need to match structurally or we force to
1025+ dim-1 *)
10031026 if equal_dim cur lub2 then (cur, [] )
10041027 else if is_solved_dim cur && is_solved_dim lub2 then
10051028 let d_cur = dim_to_int_exn cur in
@@ -1009,7 +1032,10 @@ let%debug5_sexp solve_dim_ineq ~(stage : stage) ~(cur : dim) ~(subr : dim) (env
10091032 let lub = get_dim ~d: 1 () in
10101033 (lub, [ Dim_eq { d1 = subr; d2 = lub } ])
10111034 else
1012- raise @@ Shape_error (" Cannot compute LUB between different Prod structures" , [ Dim_mismatch [ cur; lub2 ] ])
1035+ raise
1036+ @@ Shape_error
1037+ ( " Cannot compute LUB between different Prod structures" ,
1038+ [ Dim_mismatch [ cur; lub2 ] ] )
10131039 | Var _ , _ | _ , Var _ -> assert false
10141040 in
10151041 let from_constr, constr2 = apply_dim_constraint ~source: Cur ~stage cur constr2 env in
0 commit comments