Skip to content

Commit 4dabf5b

Browse files
committed
Factor out extract_dims_and_vars, flatten Prod on substitution
1 parent 8e818a1 commit 4dabf5b

File tree

2 files changed

+90
-64
lines changed

2 files changed

+90
-64
lines changed

lib/prod_dimension_implementation_status.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ The remaining functions follow similar patterns - they need to handle Prod by ei
116116

117117
## Next Steps
118118

119-
1. Implement the helper functions in row.ml
119+
1. Implement the helper functions in row.ml
120120
2. Systematically go through each linter error and add Prod handling
121121
3. Update shape.ml for any needed changes
122122
4. Design and implement einsum notation parsing for `&`

lib/row.ml

Lines changed: 89 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ let rec dim_to_string style = function
5757
| Dim { d; label = Some l; _ } -> [%string "%{l}=%{d#Int}"]
5858
| Var { id; label = Some l } -> [%string "$%{id#Int}:%{l}"]
5959
| Var { id; label = None } -> "$" ^ Int.to_string id
60-
| Prod ds -> String.concat ~sep:"*" (List.map ds ~f:(dim_to_string style))
60+
| Prod ds -> String.concat ~sep:"&" (List.map ds ~f:(dim_to_string style))
6161

6262
module Row_var = struct
6363
type t = Row_var of int [@@deriving equal, hash, compare, sexp]
@@ -196,15 +196,37 @@ let rec dim_to_int_exn = function
196196

197197
(* Helper functions for Prod *)
198198
let rec dim_vars = function
199-
| Var v -> [v]
199+
| Var v -> [ v ]
200200
| Dim _ -> []
201201
| Prod dims -> List.concat_map dims ~f:dim_vars
202202

203+
let rec extract_dims_and_vars = function
204+
| Dim { d; _ } -> ([ d ], [])
205+
| Var v -> ([], [ v ])
206+
| Prod dims ->
207+
List.fold dims ~init:([], []) ~f:(fun (ds, vs) dim ->
208+
let d', v' = extract_dims_and_vars dim in
209+
(ds @ d', vs @ v'))
210+
203211
let rec is_solved_dim = function
204212
| Var _ -> false
205213
| Dim _ -> true
206214
| Prod dims -> List.for_all dims ~f:is_solved_dim
207-
let s_dim_one v ~value ~in_ = match in_ with Var v2 when equal_dim_var v v2 -> value | _ -> in_
215+
216+
let s_dim_one v ~value ~in_ =
217+
match in_ with
218+
| Var v2 when equal_dim_var v v2 -> value
219+
| Prod dims -> (
220+
let rec flatten_prods = function
221+
| Var v2 when equal_dim_var v v2 -> flatten_prods value
222+
| Prod nested_dims -> List.concat_map nested_dims ~f:flatten_prods
223+
| d -> [ d ]
224+
in
225+
match List.concat_map dims ~f:flatten_prods with
226+
| [] -> get_dim ~d:1 ()
227+
| [ d ] -> d
228+
| dims -> Prod dims)
229+
| Dim _ | Var _ -> in_
208230

209231
(* For future flexibility *)
210232
let dim_conjunction constr1 constr2 =
@@ -247,8 +269,8 @@ let row_conjunction ?(id = phantom_row_id) constr1 constr2 =
247269
Some (extras ~keep_constr1:true, constr1)
248270
else None
249271

250-
let apply_dim_constraint ~(source : source) ~(stage : stage) (dim : dim) (constr : dim_constraint)
251-
(env : environment) : constraint_ list * dim_constraint =
272+
let rec apply_dim_constraint ~(source : source) ~(stage : stage) (dim : dim)
273+
(constr : dim_constraint) (env : environment) : constraint_ list * dim_constraint =
252274
let extras, constr =
253275
match (dim, constr) with
254276
| Dim { d; _ }, At_least_dim d_min ->
@@ -258,21 +280,27 @@ let apply_dim_constraint ~(source : source) ~(stage : stage) (dim : dim) (constr
258280
( "At_least_dim constraint failed, expected " ^ Int.to_string d_min,
259281
[ Dim_mismatch [ dim ] ] )
260282
else ([], constr)
261-
| Prod dims, At_least_dim d_min ->
283+
| Prod dims, At_least_dim d_min -> (
262284
(* For a product, we check if the product of all known dimensions meets the constraint *)
263285
let product = ref 1 in
264-
let has_vars = ref false in
265-
List.iter dims ~f:(function
286+
let vars = ref [] in
287+
let rec f = function
266288
| Dim { d; _ } -> product := !product * d
267-
| Var _ -> has_vars := true
268-
| Prod _ -> has_vars := true (* Nested products need recursive handling *)
269-
);
270-
if not !has_vars && !product < d_min then
271-
raise
272-
@@ Shape_error
273-
( "At_least_dim constraint failed for product, expected at least " ^ Int.to_string d_min,
274-
[ Dim_mismatch [ dim ] ] )
275-
else ([], constr) (* TODO: Could propagate constraints to constituent dimensions *)
289+
| Var v -> vars := v :: !vars
290+
| Prod dims -> List.iter dims ~f
291+
in
292+
List.iter dims ~f;
293+
match !vars with
294+
| [] ->
295+
if !product < d_min then
296+
raise
297+
@@ Shape_error
298+
( "At_least_dim constraint failed for product, expected at least "
299+
^ Int.to_string d_min,
300+
[ Dim_mismatch [ dim ] ] )
301+
else ([], constr)
302+
| [ v ] -> apply_dim_constraint ~source ~stage (Var v) (At_least_dim (d_min / !product)) env
303+
| _ -> ([], constr))
276304
| Var v, _ -> (
277305
match Map.find env.dim_env v with
278306
| None -> ([], constr)
@@ -293,14 +321,6 @@ let reduce_row_constraint (constr : row_constraint) ~(beg_dims : dim list) ~(dim
293321
row_constraint =
294322
match constr with
295323
| Total_elems { nominator; divided_by } ->
296-
let rec extract_dims_and_vars = function
297-
| Dim { d; _ } -> ([ d ], [])
298-
| Var v -> ([], [ v ])
299-
| Prod dims ->
300-
List.fold dims ~init:([], []) ~f:(fun (ds, vs) dim ->
301-
let d', v' = extract_dims_and_vars dim in
302-
(ds @ d', vs @ v'))
303-
in
304324
let ds, vars =
305325
List.fold (beg_dims @ dims) ~init:([], []) ~f:(fun (ds, vs) dim ->
306326
let d', v' = extract_dims_and_vars dim in
@@ -325,14 +345,6 @@ let _lift_row_constraint (constr : row_constraint) ~(beg_dims : dim list) ~(dims
325345
row_constraint =
326346
match constr with
327347
| Total_elems { nominator; divided_by } ->
328-
let rec extract_dims_and_vars = function
329-
| Dim { d; _ } -> ([ d ], [])
330-
| Var v -> ([], [ v ])
331-
| Prod dims ->
332-
List.fold dims ~init:([], []) ~f:(fun (ds, vs) dim ->
333-
let d', v' = extract_dims_and_vars dim in
334-
(ds @ d', vs @ v'))
335-
in
336348
let ds, vars =
337349
List.fold (beg_dims @ dims) ~init:([], []) ~f:(fun (ds, vs) dim ->
338350
let d', v' = extract_dims_and_vars dim in
@@ -409,17 +421,9 @@ let apply_row_constraint ~stage:_ (r : row) (constr : row_constraint) env : cons
409421
| _, Unconstrained -> assert false
410422
| { dims; bcast = Broadcastable; _ }, Total_elems { nominator; divided_by }
411423
when Set.length divided_by <= 1 -> (
412-
let rec extract_dims_and_vars_dim = function
413-
| Dim { d; _ } -> ([ d ], [])
414-
| Var v -> ([], [ v ])
415-
| Prod dims ->
416-
List.fold dims ~init:([], []) ~f:(fun (ds, vs) dim ->
417-
let d', v' = extract_dims_and_vars_dim dim in
418-
(ds @ d', vs @ v'))
419-
in
420424
let ds, vars =
421425
List.fold dims ~init:([], []) ~f:(fun (ds, vs) dim ->
422-
let d', v' = extract_dims_and_vars_dim dim in
426+
let d', v' = extract_dims_and_vars dim in
423427
(ds @ d', vs @ v'))
424428
in
425429
let d : int = List.fold ds ~init:1 ~f:( * ) in
@@ -492,19 +496,21 @@ let s_dim_one_in_row_constr v ~value constr =
492496
( "s_dim_one_in_row_constr: Total_elems constraint failed: shape is too big",
493497
[ Dim_mismatch [ value ] ] )
494498
else Total_elems { nominator; divided_by }
495-
| Prod dims ->
496-
(* When substituting with a Prod, we need to calculate its total dimension *)
497-
let d = dim_to_int_exn (Prod dims) in
499+
| Prod _ ->
500+
(* When substituting with a Prod, we need to calculate its total dimension and extract
501+
variables *)
502+
let ds, vars = extract_dims_and_vars value in
503+
let d = List.fold ds ~init:1 ~f:( * ) in
498504
let nominator = nominator / d in
499505
if nominator <= 0 then
500506
raise
501507
@@ Shape_error
502508
( "s_dim_one_in_row_constr: Total_elems constraint failed: shape is too big",
503509
[ Dim_mismatch [ value ] ] )
504-
else
505-
(* Extract any variables from the Prod and add them to divided_by *)
506-
let vars = dim_vars value in
507-
Total_elems { nominator; divided_by = Set.union divided_by (Set.of_list (module Dim_var) vars) })
510+
else
511+
(* Add any variables from the Prod to divided_by *)
512+
Total_elems
513+
{ nominator; divided_by = Set.union divided_by (Set.of_list (module Dim_var) vars) })
508514
| _ -> constr
509515

510516
let s_dim_one_in_row_entry v ~value in_ =
@@ -521,7 +527,12 @@ let rec subst_dim env = function
521527
| Some (Solved_dim (Var v2)) when equal_dim_var v v2 -> default
522528
| Some (Solved_dim d) -> subst_dim env d
523529
| _ -> default)
524-
| Prod dims -> Prod (List.map dims ~f:(subst_dim env))
530+
| Prod dims -> (
531+
let rec f dim = match subst_dim env dim with Prod dims -> dims | dim -> [ dim ] in
532+
match List.concat_map dims ~f with
533+
| [] -> get_dim ~d:1 ()
534+
| [ dim ] -> dim
535+
| dims -> Prod dims)
525536

526537
let s_row_one v ~value:{ dims = more_dims; bcast; id = _ } ~in_ =
527538
match in_ with
@@ -822,7 +833,7 @@ let%debug5_sexp solve_dim_ineq ~(stage : stage) ~(cur : dim) ~(subr : dim) (env
822833
||
823834
match Map.find env.dim_env cur_v with
824835
| None | Some (Solved_dim (Dim _)) -> false
825-
| Some (Solved_dim (Var v)) -> equal_dim_var subr_v v
836+
| Some (Solved_dim (Var v)) | Some (Solved_dim (Prod [ Var v ])) -> equal_dim_var subr_v v
826837
| Some (Solved_dim (Prod _)) -> false (* Prod doesn't contain variables directly *)
827838
| Some (Bounds_dim { cur = curs; _ }) -> cyclic ~subr_v ~curs)
828839
in
@@ -845,22 +856,30 @@ let%debug5_sexp solve_dim_ineq ~(stage : stage) ~(cur : dim) ~(subr : dim) (env
845856
let expanded2 = get_dim ~d:d2 () in
846857
([ Dim_ineq { cur = expanded1; subr = expanded2 } ], env)
847858
else
848-
raise @@ Shape_error ("Cannot compare Prod dimensions with unresolved variables", [ Dim_mismatch [ cur; subr ] ])
859+
raise
860+
@@ Shape_error
861+
( "Cannot compare Prod dimensions with unresolved variables",
862+
[ Dim_mismatch [ cur; subr ] ] )
849863
| Prod ds, Dim _ | Dim _, Prod ds ->
850-
(* For now, we can't directly compare a Prod with a Dim in an inequality.
851-
We could potentially expand this to handle cases where the product is known. *)
864+
(* For now, we can't directly compare a Prod with a Dim in an inequality. We could potentially
865+
expand this to handle cases where the product is known. *)
852866
if is_solved_dim (Prod ds) then
853867
let prod_val = dim_to_int_exn (Prod ds) in
854868
let expanded = get_dim ~d:prod_val () in
855-
(match (cur, subr) with
856-
| Prod _, _ -> ([ Dim_ineq { cur = expanded; subr } ], env)
857-
| _, Prod _ -> ([ Dim_ineq { cur; subr = expanded } ], env)
858-
| _ -> assert false)
869+
match (cur, subr) with
870+
| Prod _, _ -> ([ Dim_ineq { cur = expanded; subr } ], env)
871+
| _, Prod _ -> ([ Dim_ineq { cur; subr = expanded } ], env)
872+
| _ -> assert false
859873
else
860-
raise @@ Shape_error ("Cannot compare Prod with unresolved variables in inequality", [ Dim_mismatch [ cur; subr ] ])
874+
raise
875+
@@ Shape_error
876+
( "Cannot compare Prod with unresolved variables in inequality",
877+
[ Dim_mismatch [ cur; subr ] ] )
861878
| Prod _, Var _ | Var _, Prod _ ->
862879
(* Similar to above - we need all dimensions resolved to compare *)
863-
raise @@ Shape_error ("Cannot compare Prod with variables in inequality", [ Dim_mismatch [ cur; subr ] ])
880+
raise
881+
@@ Shape_error
882+
("Cannot compare Prod with variables in inequality", [ Dim_mismatch [ cur; subr ] ])
864883
| Var cur_v, Var subr_v -> (
865884
match (Map.find env.dim_env cur_v, Map.find env.dim_env subr_v) with
866885
| Some (Bounds_dim { cur = cur1; _ }), _ when List.mem ~equal:equal_dim_var cur1 subr_v ->
@@ -997,9 +1016,13 @@ let%debug5_sexp solve_dim_ineq ~(stage : stage) ~(cur : dim) ~(subr : dim) (env
9971016
let lub = get_dim ~d:1 () in
9981017
(lub, [ Dim_eq { d1 = subr; d2 = lub } ])
9991018
else
1000-
raise @@ Shape_error ("Cannot compute LUB between Prod and unsolved dimensions", [ Dim_mismatch [ cur; lub2 ] ])
1019+
raise
1020+
@@ Shape_error
1021+
( "Cannot compute LUB between Prod and unsolved dimensions",
1022+
[ Dim_mismatch [ cur; lub2 ] ] )
10011023
| Prod _, Prod _ ->
1002-
(* For LUB between two Prods, they need to match structurally or we force to dim-1 *)
1024+
(* For LUB between two Prods, they need to match structurally or we force to
1025+
dim-1 *)
10031026
if equal_dim cur lub2 then (cur, [])
10041027
else if is_solved_dim cur && is_solved_dim lub2 then
10051028
let d_cur = dim_to_int_exn cur in
@@ -1009,7 +1032,10 @@ let%debug5_sexp solve_dim_ineq ~(stage : stage) ~(cur : dim) ~(subr : dim) (env
10091032
let lub = get_dim ~d:1 () in
10101033
(lub, [ Dim_eq { d1 = subr; d2 = lub } ])
10111034
else
1012-
raise @@ Shape_error ("Cannot compute LUB between different Prod structures", [ Dim_mismatch [ cur; lub2 ] ])
1035+
raise
1036+
@@ Shape_error
1037+
( "Cannot compute LUB between different Prod structures",
1038+
[ Dim_mismatch [ cur; lub2 ] ] )
10131039
| Var _, _ | _, Var _ -> assert false
10141040
in
10151041
let from_constr, constr2 = apply_dim_constraint ~source:Cur ~stage cur constr2 env in

0 commit comments

Comments
 (0)