@@ -44,23 +44,23 @@ module Proj_id = struct
4444end
4545
4646type proj_id = Proj_id .t [@@ deriving equal , hash , compare , sexp ]
47+ type proj_cmp = Proj_id .comparator_witness
48+ type proj_var_set = Set .M (Proj_id ).t [@@ deriving equal , sexp ]
49+ type 'a proj_map = 'a Map .M (Proj_id ).t [@@ deriving equal , sexp ]
4750
51+ let proj_var_set_empty = Set. empty (module Proj_id )
52+ let proj_map_empty = Map. empty (module Proj_id )
4853let dim_var_set_empty = Set. empty (module Dim_var )
4954let dim_map_empty = Map. empty (module Dim_var )
5055let use_padding = ref false
5156
52- type solved_dim = {
53- d : int ;
54- mutable padding : int option ; [@ hash.ignore]
55- label : string option ;
56- proj_id : proj_id option ;
57- }
57+ type solved_dim = { d : int ; padding : int option ; label : string option ; proj_id : proj_id option }
5858[@@ deriving equal , hash , compare , sexp ]
5959
6060type dim =
6161 | Var of dim_var
6262 | Dim of solved_dim
63- | Affine of { solved : ( int * solved_dim ) list ; unsolved : (int * dim_var ) list }
63+ | Affine of { solved : solved_dim option ; unsolved : (int * dim_var ) list }
6464[@@ deriving equal , hash , compare , sexp , variants ]
6565
6666let uid = ref 0
@@ -81,9 +81,9 @@ type print_style = Only_labels | Axis_size | Axis_number_and_size | Projection_a
8181let solved_dim_to_string style { d; label; proj_id; padding } =
8282 match style with
8383 | Only_labels -> ( match label with None -> " _" | Some l -> l)
84- | Axis_size | Axis_number_and_size ->
84+ | Axis_size | Axis_number_and_size -> (
8585 let label_prefix = match label with None -> " " | Some l -> l ^ " =" in
86- ( match (proj_id, padding) with
86+ match (proj_id, padding) with
8787 | None , None -> label_prefix ^ Int. to_string d
8888 | None , Some p -> label_prefix ^ [% string " %{d#Int}+%{p#Int}" ]
8989 | Some _ , None -> label_prefix ^ Int. to_string d
@@ -93,7 +93,7 @@ let solved_dim_to_string style { d; label; proj_id; padding } =
9393 let size_part = Int. to_string d in
9494 let padding_part = match padding with None -> " " | Some p -> " +" ^ Int. to_string p in
9595 let proj_part = match proj_id with None -> " " | Some pid -> " p" ^ Proj_id. to_string pid in
96- let extra_parts =
96+ let extra_parts =
9797 match (proj_id, padding) with
9898 | None , None -> " "
9999 | None , Some _ -> padding_part
@@ -105,25 +105,33 @@ let solved_dim_to_string style { d; label; proj_id; padding } =
105105let dim_to_string style = function
106106 | Dim { label = None ; _ } when equal_print_style style Only_labels -> " _"
107107 | Dim { label = Some l ; _ } when equal_print_style style Only_labels -> l
108- | Dim { d; label = None ; padding = None ; proj_id = None } when equal_print_style style Axis_size -> Int. to_string d
109- | Dim { d; label = Some l ; padding = None ; proj_id = None } when equal_print_style style Axis_size -> [% string " %{l}=%{d#Int}" ]
110- | Dim { d; label = None ; padding = Some p ; proj_id = None } when equal_print_style style Axis_size -> [% string " %{d#Int}+%{p#Int}" ]
111- | Dim { d; label = Some l ; padding = Some p ; proj_id = None } when equal_print_style style Axis_size -> [% string " %{l}=%{d#Int}+%{p#Int}" ]
108+ | Dim { d; label = None ; padding = None ; proj_id = None } when equal_print_style style Axis_size
109+ ->
110+ Int. to_string d
111+ | Dim { d; label = Some l; padding = None ; proj_id = None } when equal_print_style style Axis_size
112+ ->
113+ [% string " %{l}=%{d#Int}" ]
114+ | Dim { d; label = None ; padding = Some p; proj_id = None } when equal_print_style style Axis_size
115+ ->
116+ [% string " %{d#Int}+%{p#Int}" ]
117+ | Dim { d; label = Some l; padding = Some p; proj_id = None }
118+ when equal_print_style style Axis_size ->
119+ [% string " %{l}=%{d#Int}+%{p#Int}" ]
112120 | Dim solved_dim -> solved_dim_to_string style solved_dim
113121 | Var { id; label = Some l } -> [% string " $%{id#Int}:%{l}" ]
114122 | Var { id; label = None } -> " $" ^ Int. to_string id
115123 | Affine { solved; unsolved } -> (
116- let solved_terms =
117- List. map solved ~f: (fun (coeff , solved_dim ) ->
118- let base = solved_dim_to_string style solved_dim in
119- if coeff = 1 then base else [% string " %{coeff#Int}*%{base}" ])
124+ let solved_term =
125+ Option. to_list solved
126+ |> List. map ~f: (fun solved_dim -> solved_dim_to_string style solved_dim)
120127 in
121128 let unsolved_terms =
122129 List. map unsolved ~f: (fun (coeff , v ) ->
123130 let label_part = match v.label with None -> " " | Some l -> " :" ^ l in
124- if coeff = 1 then [% string " $%{v.id#Int}%{label_part}" ] else [% string " %{coeff#Int}*$%{v.id#Int}%{label_part}" ])
131+ if coeff = 1 then [% string " $%{v.id#Int}%{label_part}" ]
132+ else [% string " %{coeff#Int}*$%{v.id#Int}%{label_part}" ])
125133 in
126- let all_terms = solved_terms @ unsolved_terms in
134+ let all_terms = solved_term @ unsolved_terms in
127135 match all_terms with
128136 | [] -> " 0"
129137 | [ t ] -> t
@@ -264,6 +272,24 @@ let dim_to_int_exn = function
264272 | Var _ -> invalid_arg " dim_to_int: dim still unknown"
265273 | Affine _ -> invalid_arg " dim_to_int: affine dimension cannot be converted to single int"
266274
275+ let add_dims ~keep_proj_id :{ d = d1 ; padding = p1 ; label = l1 ; proj_id } ~coef
276+ { d = d2 ; padding = p2 ; label = l2 ; proj_id = _ } =
277+ match (p1, p2) with
278+ | Some p1 , Some p2 ->
279+ {
280+ d = d1 + (coef * d2);
281+ padding = Some (max p1 (coef * p2));
282+ label = Option. first_some l1 l2;
283+ proj_id;
284+ }
285+ | _ ->
286+ {
287+ d = d1 + (coef * d2);
288+ padding = Option. first_some p1 (Option. map ~f: (( * ) coef) p2);
289+ label = Option. first_some l1 l2;
290+ proj_id;
291+ }
292+
267293let s_dim_one v ~value ~in_ =
268294 match in_ with
269295 | Var v2 when equal_dim_var v v2 -> value
@@ -281,15 +307,24 @@ let s_dim_one v ~value ~in_ =
281307 if coeff_of_v = 0 then (solved, new_unsolved)
282308 else
283309 match value with
284- | Dim s -> ((coeff_of_v, s) :: solved, new_unsolved)
310+ | Dim s ->
311+ let existing = Option. value solved ~default: { s with d = 0 ; padding = None } in
312+ let new_solved = add_dims ~keep_proj_id: existing ~coef: coeff_of_v s in
313+ (Some new_solved, new_unsolved)
285314 | Var v' -> (solved, (coeff_of_v, v') :: new_unsolved)
286315 | Affine { solved = value_solved ; unsolved = value_unsolved } ->
287316 (* Inlining affine expression: coeff_of_v * value *)
288- let scaled_solved = List. map value_solved ~f: (fun (c , sd ) -> (coeff_of_v * c, sd)) in
317+ let new_solved =
318+ match value_solved with
319+ | None -> solved
320+ | Some vs ->
321+ let existing = Option. value solved ~default: { vs with d = 0 ; padding = None } in
322+ Some (add_dims ~keep_proj_id: existing ~coef: coeff_of_v vs)
323+ in
289324 let scaled_unsolved =
290325 List. map value_unsolved ~f: (fun (c , v ) -> (coeff_of_v * c, v))
291326 in
292- (scaled_solved @ solved , scaled_unsolved @ new_unsolved)
327+ (new_solved , scaled_unsolved @ new_unsolved)
293328 in
294329 Affine { solved = new_solved; unsolved = extra_unsolved }
295330 | _ -> in_
@@ -674,7 +709,7 @@ let%debug5_sexp rec unify_dim ~stage (eq : dim * dim) (env : environment) :
674709 | Dim { d = d1 ; _ } , Dim { d = d2 ; _ } when d1 = d2 -> ([] , env)
675710 | Var v1 , Var v2 when equal_dim_var v1 v2 -> ([] , env)
676711 | Affine _ , Affine _ ->
677- (* FIXME: For now, we can only unify identical affine expressions *)
712+ (* FIXME: NOT IMPLEMENTED YET *)
678713 if equal_dim dim1 dim2 then ([] , env)
679714 else
680715 raise
@@ -908,8 +943,7 @@ let%debug5_sexp solve_dim_ineq ~(stage : stage) ~(cur : dim) ~(subr : dim) (env
908943 | _ , Dim { d = 1 ; _ } -> ([] , env)
909944 | (Dim { d = 1 ; _ } as cur ), _ -> ([ Dim_eq { d1 = subr; d2 = cur } ], env)
910945 | Affine _ , _ | _ , Affine _ ->
911- (* FIXME: For affine dimensions in inequalities, we can't directly compare them *)
912- (* This would need more sophisticated constraint solving *)
946+ (* FIXME: NOT IMPLEMENTED YET *)
913947 ([] , env)
914948 | Var cur_v , Var subr_v -> (
915949 match (Map. find env.dim_env cur_v, Map. find env.dim_env subr_v) with
@@ -1039,7 +1073,7 @@ let%debug5_sexp solve_dim_ineq ~(stage : stage) ~(cur : dim) ~(subr : dim) (env
10391073 Dim_mismatch [ lub2; cur; subr ] ] ) *)
10401074 | Var _ , _ | _ , Var _ -> assert false
10411075 | Affine _ , _ | _ , Affine _ ->
1042- (* FIXME: Can't compute LUB with affine dimensions *)
1076+ (* FIXME: NOT IMPLEMENTED YET *)
10431077 (lub2, [] )
10441078 in
10451079 let from_constr, constr2 = apply_dim_constraint ~source: Cur ~stage cur constr2 env in
@@ -1319,7 +1353,7 @@ let%debug5_sexp close_dim_terminal ~(stage : stage) (env : environment) (dim : d
13191353 | _ when not (is_stage4_up stage) -> [ Terminal_dim dim ]
13201354 | _ -> [] )
13211355 | Affine _ ->
1322- (* FIXME: For affine dimensions, we can't generate simple terminal constraints *)
1356+ (* FIXME: NOT IMPLEMENTED YET *)
13231357 []
13241358
13251359let last_dim_is dims d2 = match List. last dims with Some (Dim { d; _ } ) -> d = d2 | _ -> false
@@ -1465,7 +1499,7 @@ let%debug4_sexp solve_inequalities ~(stage : stage) (ineqs : constraint_ list) (
14651499 | Some (Bounds_dim bounds ) -> Bounds_dim { bounds with constr }
14661500 | None -> Bounds_dim { constr; lub = None ; cur = [] ; subr = [] });
14671501 }
1468- | _ , Affine _ -> env (* FIXME: Can't store constraints on affine dimensions *)
1502+ | _ , Affine _ -> env (* FIXME: NOT IMPLEMENTED YET *)
14691503 in
14701504 (extras @ ineqs, env)
14711505 | Row_constr { r; constr } ->
@@ -1542,7 +1576,12 @@ let fresh_row_proj r =
15421576 | Dim { d; padding; label; proj_id = _ } ->
15431577 Dim { d; padding; label; proj_id = Some (Proj_id. fresh () ) }
15441578 | Var _ as d -> d
1545- | Affine _ as d -> d (* FIXME: Affine dimensions don't get fresh projections *)
1579+ | Affine { solved; unsolved } as d ->
1580+ if List. is_empty unsolved then
1581+ Dim { (Option. value_exn ~here: [% here] solved) with proj_id = Some (Proj_id. fresh () ) }
1582+ else (
1583+ assert (Option. is_none solved);
1584+ d)
15461585 in
15471586 { r with dims = List. map r.dims ~f: fresh_dim }
15481587
@@ -1591,13 +1630,17 @@ let%debug4_sexp get_proj_equations (inequalities : constraint_ list) proj_axis_e
15911630 let to_proj : dim -> proj = function
15921631 | Var v when Map. mem proj_axis_env v -> Solved (Map. find_exn proj_axis_env v)
15931632 | Dim ({ proj_id = Some proj_id ; _ } as solved_dim ) -> Proj (proj_id, solved_dim)
1633+ | Affine { solved = None ; unsolved } as d -> Affine { solved = [] ; solving = [] ; unsolved }
15941634 | d -> (
15951635 match subst_dim env d with
15961636 | Dim ({ proj_id = Some proj_id ; _ } as solved_dim ) -> Proj (proj_id, solved_dim)
15971637 | Dim s -> Proj (Proj_id. fresh () , s)
15981638 | Var v when Map. mem proj_axis_env v -> Solved (Map. find_exn proj_axis_env v)
15991639 | Var v -> Var v
1600- | Affine _ -> failwith " get_proj_equations: affine dimensions not supported in projections" )
1640+ | Affine { solved = Some d ; unsolved } ->
1641+ assert (List. is_empty unsolved);
1642+ Proj (Option. value ~default: (Proj_id. fresh () ) d.proj_id, d)
1643+ | Affine { solved = None ; unsolved = _ } -> assert false (* handled above *) )
16011644 in
16021645 let rec expand_dims = function
16031646 | { dims; bcast = Row_var { v; beg_dims } ; _ } when Map. mem env.row_env v -> (
@@ -1691,8 +1734,12 @@ let%debug4_sexp solve_proj_equations (eqs : proj_equation list) : proj_env =
16911734 | Proj_eq (Proj (p , _ ), Solved idx ) | Proj_eq (Solved idx , Proj (p , _ )) ->
16921735 p_solved := (p, idx) :: ! p_solved
16931736 | Proj_eq (Proj (_ , _ ), (Affine _ as _affine )) | Proj_eq ((Affine _ as _affine ), Proj (_ , _ )) ->
1694- (* FIXME: can't store constraints on affine dimensions *)
1737+ (* FIXME: NOT IMPLEMENTED YET *)
16951738 ()
1739+ | Proj_eq (Solved (Idx. Affine { symbols; offset }), Affine { solved; solving; unsolved })
1740+ | Proj_eq (Affine { solved; solving; unsolved } , Solved (Idx. Affine { symbols; offset } )) ->
1741+ (* FIXME: NOT IMPLEMENTED YET *)
1742+ ignore (symbols, offset, solved, solving, unsolved)
16961743 | Proj_eq (Solved idx , Affine affine ) | Proj_eq (Affine affine , Solved idx ) ->
16971744 (* For affine expressions with solved indices, we can't directly substitute *)
16981745 (* This case might require more sophisticated handling *)
@@ -1755,6 +1802,7 @@ let%debug4_sexp solve_proj_equations (eqs : proj_equation list) : proj_env =
17551802 let repr, _ = Utils. union_find ~equal: Proj_id. equal ! proj_classes ~key: p ~rank: 0 in
17561803 if Idx. iterated d && (not @@ Map. mem ! projs repr) then
17571804 Utils. mref_add product_dim ~key: repr ~data: d ~or_: (fun d2 ->
1805+ (* TODO: consider updating padding *)
17581806 if d <> d2 then
17591807 raise
17601808 @@ Shape_error
@@ -1776,8 +1824,7 @@ let get_proj_index proj_env =
17761824 let unknown_projection proj_id d =
17771825 raise
17781826 @@ Shape_error
1779- ( [% string " projection_of_solved_dims: unknown projection: %{proj_id#Proj_id} %{d#Int}" ],
1780- [] )
1827+ ([% string " projection_of_solved_dims: unknown projection: %{proj_id#Proj_id} %{d#Int}" ], [] )
17811828 in
17821829 function
17831830 | Dim { d; _ } when not @@ Idx. iterated d -> Idx. Fixed_idx 0
@@ -1807,22 +1854,21 @@ let get_proj_index proj_env =
18071854 let symbols = ref [] in
18081855 let offset = ref 0 in
18091856
1810- List. iter solved ~f: (fun (coeff , solved_dim ) ->
1811- match solved_dim with
1812- | { d; proj_id = Some proj_id ; _ } ->
1813- if Idx. iterated d then
1814- let repr, _ =
1815- Utils. union_find ~equal: Proj_id. equal proj_env.proj_classes ~key: proj_id ~rank: 0
1816- in
1817- match Map. find proj_env.proj_to_index repr with
1818- | Some (Iterator symbol ) -> symbols := (coeff, symbol) :: ! symbols
1819- | Some (Fixed_idx i ) -> offset := ! offset + (coeff * i)
1820- | Some (Affine _ ) ->
1821- (* Nested affine - would need recursive handling *)
1822- raise @@ Shape_error (" Nested affine projections not supported" , [] )
1823- | None -> unknown_projection proj_id d
1824- else ()
1825- | { proj_id = None ; _ } -> assert false );
1857+ Option. iter solved ~f: (function
1858+ | { d; proj_id = Some proj_id ; _ } ->
1859+ if Idx. iterated d then
1860+ let repr, _ =
1861+ Utils. union_find ~equal: Proj_id. equal proj_env.proj_classes ~key: proj_id ~rank: 0
1862+ in
1863+ match Map. find proj_env.proj_to_index repr with
1864+ | Some (Iterator symbol ) -> symbols := (1 , symbol) :: ! symbols
1865+ | Some (Fixed_idx i ) -> offset := ! offset + i
1866+ | Some (Affine _ ) ->
1867+ (* Nested affine - would need recursive handling *)
1868+ raise @@ Shape_error (" Nested affine projections not supported" , [] )
1869+ | None -> unknown_projection proj_id d
1870+ else ()
1871+ | { proj_id = None ; _ } -> assert false );
18261872
18271873 Idx. Affine { symbols = List. rev ! symbols; offset = ! offset }
18281874
@@ -1844,7 +1890,7 @@ let get_product_proj proj_env dim =
18441890 ( " projection_of_solved_dims: still not fully inferred for variable "
18451891 ^ Sexp. to_string_hum ([% sexp_of: dim_var] v),
18461892 [ Dim_mismatch [ dim ] ] )
1847- | Affine _ -> None (* FIXME: Affine dimensions don't participate in product projections *)
1893+ | Affine _ -> None
18481894
18491895let proj_to_iterator proj_env p =
18501896 match Map. find_exn proj_env.proj_to_index (proj_repr proj_env p) with
0 commit comments