@@ -23,7 +23,27 @@ type dim_var = Dim_var.t [@@deriving equal, hash, compare, sexp]
2323type dim_cmp = Dim_var .comparator_witness
2424type dim_var_set = Set .M (Dim_var ).t [@@ deriving equal , sexp ]
2525type 'a dim_map = 'a Map .M (Dim_var ).t [@@ deriving equal , sexp ]
26- type proj_id = int [@@ deriving equal , hash , compare , sexp ]
26+
27+ module Proj_id = struct
28+ type t = Proj_id of int [@@ deriving equal , hash , compare , sexp ]
29+
30+ let to_string (Proj_id i ) = Int. to_string i
31+
32+ let fresh =
33+ let uid = ref 0 in
34+ fun () ->
35+ Int. incr uid;
36+ Proj_id ! uid
37+
38+ include Comparator. Make (struct
39+ type nonrec t = t
40+
41+ let compare = compare
42+ let sexp_of_t = sexp_of_t
43+ end )
44+ end
45+
46+ type proj_id = Proj_id .t [@@ deriving equal , hash , compare , sexp ]
2747
2848let dim_var_set_empty = Set. empty (module Dim_var )
2949let dim_map_empty = Map. empty (module Dim_var )
@@ -52,7 +72,7 @@ let dim_hashtbl () = Hashtbl.create (module Dim_var)
5272
5373let solved_dim_to_string style { d; label; proj_id } =
5474 match proj_id with
55- | Some proj_id -> [% string " p%{proj_id#Int }" ]
75+ | Some proj_id -> [% string " p%{Proj_id.to_string proj_id}" ]
5676 | None -> (
5777 match style with
5878 | `Only_labels -> ( match label with None -> " _" | Some l -> l)
@@ -234,8 +254,7 @@ let s_dim_one v ~value ~in_ =
234254 else
235255 match value with
236256 | Dim s -> ((coeff_of_v, s) :: solved, new_unsolved)
237- | Var v' ->
238- (solved, (coeff_of_v, v') :: new_unsolved)
257+ | Var v' -> (solved, (coeff_of_v, v') :: new_unsolved)
239258 | Affine { solved = value_solved ; unsolved = value_unsolved } ->
240259 (* Inlining affine expression: coeff_of_v * value *)
241260 let scaled_solved = List. map value_solved ~f: (fun (c , sd ) -> (coeff_of_v * c, sd)) in
@@ -1490,15 +1509,9 @@ let rec row_to_labels env =
14901509
14911510(* * *** Projection inference *** *)
14921511
1493- let fresh_proj =
1494- let uid = ref 0 in
1495- fun () ->
1496- Int. incr uid;
1497- ! uid
1498-
14991512let fresh_row_proj r =
15001513 let fresh_dim = function
1501- | Dim { d; label; proj_id = _ } -> Dim { d; label; proj_id = Some (fresh_proj () ) }
1514+ | Dim { d; label; proj_id = _ } -> Dim { d; label; proj_id = Some (Proj_id. fresh () ) }
15021515 | Var _ as d -> d
15031516 | Affine _ as d -> d (* FIXME: Affine dimensions don't get fresh projections *)
15041517 in
@@ -1507,7 +1520,15 @@ let fresh_row_proj r =
15071520(* let update_proj_classes pid1 pid2 proj_classes = Utils.union_add ~equal:Int.equal proj_classes
15081521 pid1 pid2 *)
15091522
1510- type proj = Var of dim_var | Proj of { proj_id : proj_id ; d : int } | Solved of Idx .axis_index
1523+ type proj =
1524+ | Var of dim_var
1525+ | Proj of { proj_id : proj_id ; d : int }
1526+ | Solved of Idx .axis_index
1527+ | Affine of {
1528+ solved : (int * Idx .axis_index ) list ;
1529+ solving : (int * proj_id ) list ;
1530+ unsolved : (int * dim_var ) list ;
1531+ }
15111532[@@ deriving compare , equal , sexp ]
15121533
15131534type error_trace + = Projection_mismatch of proj list
@@ -1517,14 +1538,14 @@ let sexp_of_error_trace = function
15171538 Sexp. List (Sexp. Atom " Projection_mismatch" :: List. map ps ~f: sexp_of_proj)
15181539 | error_trace -> sexp_of_error_trace error_trace
15191540
1520- type proj_to_index = Idx .axis_index Map .M (Int ).t [@@ deriving sexp ]
1521- type proj_classes = int Map .M (Int ).t [@@ deriving sexp ]
1541+ type proj_to_index = Idx .axis_index Map .M (Proj_id ).t [@@ deriving sexp ]
1542+ type proj_classes = Proj_id .t Map .M (Proj_id ).t [@@ deriving sexp ]
15221543
15231544type proj_env = {
15241545 proj_to_index : proj_to_index ;
15251546 proj_classes : proj_classes ;
1526- product_dim : int Map .M (Int ).t;
1527- non_product : Set .M (Int ).t;
1547+ product_dim : int Map .M (Proj_id ).t;
1548+ non_product : Set .M (Proj_id ).t;
15281549}
15291550[@@ deriving sexp ]
15301551
@@ -1544,7 +1565,7 @@ let%debug4_sexp get_proj_equations (inequalities : constraint_ list) proj_axis_e
15441565 | d -> (
15451566 match subst_dim env d with
15461567 | Dim { proj_id = Some proj_id ; d; label = _ } -> Proj { proj_id; d }
1547- | Dim { proj_id = None ; d; _ } -> Proj { proj_id = fresh_proj () ; d }
1568+ | Dim { proj_id = None ; d; _ } -> Proj { proj_id = Proj_id. fresh () ; d }
15481569 | Var v when Map. mem proj_axis_env v -> Solved (Map. find_exn proj_axis_env v)
15491570 | Var v -> Var v
15501571 | Affine _ -> failwith " get_proj_equations: affine dimensions not supported in projections" )
@@ -1591,9 +1612,43 @@ let%debug4_sexp solve_proj_equations (eqs : proj_equation list) : proj_env =
15911612 let v_env = dim_hashtbl () in
15921613 let p_solved = ref [] in
15931614 let p_dims = ref [] in
1594- let proj_classes = ref @@ Map. empty (module Int ) in
1615+ let proj_classes = ref @@ Map. empty (module Proj_id ) in
1616+ let _s_proj_one proj_id ~value ~in_ =
1617+ match in_ with
1618+ | Var _ -> in_
1619+ | Proj { proj_id = pid ; d = _ } -> if Proj_id. equal pid proj_id then value else in_
1620+ | Solved _ -> in_
1621+ | Affine { solved; solving; unsolved } ->
1622+ (* Substitute proj_id in affine expression *)
1623+ let coeff_of_p =
1624+ List. fold solving ~init: 0 ~f: (fun acc (c , pid ) ->
1625+ if Proj_id. equal pid proj_id then acc + c else acc)
1626+ in
1627+ let new_solving =
1628+ List. filter_map solving ~f: (fun (coeff , pid ) ->
1629+ if Proj_id. equal pid proj_id then None else Some (coeff, pid))
1630+ in
1631+ let new_solved, extra_solving =
1632+ if coeff_of_p = 0 then (solved, new_solving)
1633+ else
1634+ match value with
1635+ | Solved idx -> ((coeff_of_p, idx) :: solved, new_solving)
1636+ | Proj { proj_id = pid ; d = _ } -> (solved, (coeff_of_p, pid) :: new_solving)
1637+ | Var _ -> (solved, new_solving) (* Can't substitute variable into affine *)
1638+ | Affine { solved = value_solved ; solving = value_solving ; unsolved = _ } ->
1639+ (* Inlining affine expression: coeff_of_p * value *)
1640+ let scaled_solved =
1641+ List. map value_solved ~f: (fun (c , idx ) -> (coeff_of_p * c, idx))
1642+ in
1643+ let scaled_solving =
1644+ List. map value_solving ~f: (fun (c , pid ) -> (coeff_of_p * c, pid))
1645+ in
1646+ (scaled_solved @ solved, scaled_solving @ new_solving)
1647+ in
1648+ Affine { solved = new_solved; solving = extra_solving; unsolved }
1649+ in
15951650 let rec loop = function
1596- | Proj_eq (Proj { proj_id = p1 ; d } , Proj { proj_id = p2 ; _ } ) when p1 = p2 ->
1651+ | Proj_eq (Proj { proj_id = p1 ; d } , Proj { proj_id = p2 ; _ } ) when Proj_id. equal p1 p2 ->
15971652 p_dims := (p1, d) :: ! p_dims
15981653 | Proj_eq (Var v1 , Var v2 ) when equal_dim_var v1 v2 -> ()
15991654 | Proj_eq ((Proj { proj_id = p1; d = d1 } as proj1), (Proj { proj_id = p2; d = d2 } as proj2))
@@ -1604,9 +1659,32 @@ let%debug4_sexp solve_proj_equations (eqs : proj_equation list) : proj_env =
16041659 ( " Conflicting dimensions for the same projection" ,
16051660 [ Projection_mismatch [ proj1; proj2 ] ] );
16061661 p_dims := (p1, d1) :: ! p_dims;
1607- proj_classes := Utils. union_add ~equal: Int . equal ! proj_classes p1 p2
1662+ proj_classes := Utils. union_add ~equal: Proj_id . equal ! proj_classes p1 p2
16081663 | Proj_eq (Proj p , Solved idx ) | Proj_eq (Solved idx , Proj p ) ->
16091664 p_solved := (p.proj_id, idx) :: ! p_solved
1665+ | Proj_eq ((Proj { proj_id= _; _ } ), (Affine _ as _affine))
1666+ | Proj_eq ((Affine _ as _affine ), (Proj { proj_id =_ ; _ } )) ->
1667+ ()
1668+ | Proj_eq (Solved idx , Affine affine ) | Proj_eq (Affine affine , Solved idx ) ->
1669+ (* For affine expressions with solved indices, we can't directly substitute *)
1670+ (* This case might require more sophisticated handling *)
1671+ raise
1672+ @@ Shape_error
1673+ ( " Cannot unify solved index with affine projection" ,
1674+ [ Projection_mismatch [ Solved idx; Affine affine ] ] )
1675+ | Proj_eq (Var v , Affine affine ) | Proj_eq (Affine affine , Var v ) -> (
1676+ (* Handle variable to affine binding *)
1677+ match Hashtbl. find v_env v with
1678+ | None -> Hashtbl. add_exn v_env ~key: v ~data: (Affine affine)
1679+ | Some p2 -> loop (Proj_eq (Affine affine, p2)))
1680+ | Proj_eq (Affine affine1 , Affine affine2 ) ->
1681+ (* For now, we can only unify identical affine expressions *)
1682+ if equal_proj (Affine affine1) (Affine affine2) then ()
1683+ else
1684+ raise
1685+ @@ Shape_error
1686+ ( " Cannot unify different affine projections" ,
1687+ [ Projection_mismatch [ Affine affine1; Affine affine2 ] ] )
16101688 | Proj_eq (Solved idx1 , Solved idx2 ) when Idx. equal_axis_index idx1 idx2 -> ()
16111689 | Proj_eq (Solved idx1 , Solved idx2 ) ->
16121690 raise
@@ -1618,28 +1696,34 @@ let%debug4_sexp solve_proj_equations (eqs : proj_equation list) : proj_env =
16181696 | Some p2 -> loop (Proj_eq (p, p2)))
16191697 | Iterated (Solved _ ) -> ()
16201698 | Iterated (Proj { proj_id; d } ) -> p_dims := (proj_id, d) :: ! p_dims
1699+ | Iterated (Affine { solving; _ } ) ->
1700+ (* For affine expressions, mark all constituent projections as iterated *)
1701+ List. iter solving ~f: (fun (_ , proj_id ) -> p_dims := (proj_id, 1 ) :: ! p_dims)
16211702 | Iterated (Var v ) -> (
16221703 match Hashtbl. find v_env v with
16231704 | None ->
16241705 let idx = Idx. (Iterator (get_symbol () )) in
16251706 Hashtbl. add_exn v_env ~key: v ~data: (Solved idx)
16261707 | Some (Var v2 ) -> loop (Iterated (Var v2))
16271708 | Some (Solved _ ) -> ()
1628- | Some (Proj { proj_id; d } ) -> p_dims := (proj_id, d) :: ! p_dims)
1709+ | Some (Proj { proj_id; d } ) -> p_dims := (proj_id, d) :: ! p_dims
1710+ | Some (Affine { solving; _ } ) ->
1711+ (* For affine expressions, mark all constituent projections as iterated *)
1712+ List. iter solving ~f: (fun (_ , proj_id ) -> p_dims := (proj_id, 1 ) :: ! p_dims))
16291713 in
16301714 List. iter eqs ~f: loop;
1631- let projs = ref @@ Map. empty (module Int ) and non_product = ref @@ Set. empty (module Int ) in
1715+ let projs = ref @@ Map. empty (module Proj_id ) and non_product = ref @@ Set. empty (module Proj_id ) in
16321716 List. iter ! p_solved ~f: (fun (p , idx ) ->
1633- let repr, _ = Utils. union_find ~equal: Int . equal ! proj_classes ~key: p ~rank: 0 in
1717+ let repr, _ = Utils. union_find ~equal: Proj_id . equal ! proj_classes ~key: p ~rank: 0 in
16341718 non_product := Set. add ! non_product repr;
16351719 Utils. mref_add projs ~key: repr ~data: idx ~or_: (fun idx2 ->
16361720 if not @@ Idx. equal_axis_index idx idx2 then
16371721 raise
16381722 @@ Shape_error
16391723 (" Multiple constraints on the same projection" , [ Index_mismatch [ idx; idx2 ] ])));
1640- let product_dim = ref @@ Map. empty (module Int ) in
1724+ let product_dim = ref @@ Map. empty (module Proj_id ) in
16411725 List. iter ! p_dims ~f: (fun (p , d ) ->
1642- let repr, _ = Utils. union_find ~equal: Int . equal ! proj_classes ~key: p ~rank: 0 in
1726+ let repr, _ = Utils. union_find ~equal: Proj_id . equal ! proj_classes ~key: p ~rank: 0 in
16431727 if Idx. iterated d && (not @@ Map. mem ! projs repr) then
16441728 Utils. mref_add product_dim ~key: repr ~data: d ~or_: (fun d2 ->
16451729 if d <> d2 then
@@ -1650,7 +1734,7 @@ let%debug4_sexp solve_proj_equations (eqs : proj_equation list) : proj_env =
16501734 Projection_mismatch [ Proj { proj_id = p; d }; Proj { proj_id = p; d = d2 } ];
16511735 ] )));
16521736 Map. iteri ! product_dim ~f: (fun ~key :p ~data :_ ->
1653- let repr, _ = Utils. union_find ~equal: Int . equal ! proj_classes ~key: p ~rank: 0 in
1737+ let repr, _ = Utils. union_find ~equal: Proj_id . equal ! proj_classes ~key: p ~rank: 0 in
16541738 Utils. mref_add_missing projs repr ~f: (fun () -> Idx. (Iterator (get_symbol () ))));
16551739 {
16561740 proj_classes = ! proj_classes;
@@ -1659,7 +1743,14 @@ let%debug4_sexp solve_proj_equations (eqs : proj_equation list) : proj_env =
16591743 non_product = ! non_product;
16601744 }
16611745
1662- let get_proj_index proj_env = function
1746+ let get_proj_index proj_env =
1747+ let unknown_projection proj_id d =
1748+ raise
1749+ @@ Shape_error
1750+ ( " projection_of_solved_dims: unknown projection" ,
1751+ [ Projection_mismatch [ Proj { proj_id; d } ] ] )
1752+ in
1753+ function
16631754 | Dim { d; _ } when not @@ Idx. iterated d -> Idx. Fixed_idx 0
16641755 | Dim { proj_id = None ; _ } -> assert false
16651756 | Var v as dim ->
@@ -1669,18 +1760,43 @@ let get_proj_index proj_env = function
16691760 ^ Sexp. to_string_hum ([% sexp_of: dim_var] v),
16701761 [ Dim_mismatch [ dim ] ] )
16711762 | Dim { proj_id = Some proj_id ; d; _ } -> (
1672- let repr, _ = Utils. union_find ~equal: Int . equal proj_env.proj_classes ~key: proj_id ~rank: 0 in
1763+ let repr, _ = Utils. union_find ~equal: Proj_id . equal proj_env.proj_classes ~key: proj_id ~rank: 0 in
16731764 match Map. find proj_env.proj_to_index repr with
16741765 | Some i -> i
1675- | None ->
1676- raise
1677- @@ Shape_error
1678- ( " projection_of_solved_dims: unknown projection" ,
1679- [ Projection_mismatch [ Proj { proj_id; d } ] ] ))
1680- | Affine _ -> failwith " get_proj_index: affine dimensions not supported in projections"
1766+ | None -> unknown_projection proj_id d)
1767+ | Affine { solved; unsolved } ->
1768+ (* Handle unsolved variables - these should be resolved by now *)
1769+ if not (List. is_empty unsolved) then
1770+ raise
1771+ @@ Shape_error
1772+ ( " Affine dimension has unresolved variables" ,
1773+ [ Dim_mismatch [ Affine { solved; unsolved } ] ] );
1774+
1775+ (* Process solved terms *)
1776+ let symbols = ref [] in
1777+ let offset = ref 0 in
1778+
1779+ List. iter solved ~f: (fun (coeff , solved_dim ) ->
1780+ match solved_dim with
1781+ | { d; proj_id = Some proj_id ; _ } ->
1782+ if Idx. iterated d then
1783+ let repr, _ =
1784+ Utils. union_find ~equal: Proj_id. equal proj_env.proj_classes ~key: proj_id ~rank: 0
1785+ in
1786+ match Map. find proj_env.proj_to_index repr with
1787+ | Some (Iterator symbol ) -> symbols := (coeff, symbol) :: ! symbols
1788+ | Some (Fixed_idx i ) -> offset := ! offset + (coeff * i)
1789+ | Some (Affine _ ) ->
1790+ (* Nested affine - would need recursive handling *)
1791+ raise @@ Shape_error (" Nested affine projections not supported" , [] )
1792+ | None -> unknown_projection proj_id d
1793+ else ()
1794+ | { proj_id = None ; _ } -> assert false );
1795+
1796+ Idx. Affine { symbols = List. rev ! symbols; offset = ! offset }
16811797
16821798let proj_repr proj_env p =
1683- fst @@ Utils. union_find ~equal: Int . equal proj_env.proj_classes ~key: p ~rank: 0
1799+ fst @@ Utils. union_find ~equal: Proj_id . equal proj_env.proj_classes ~key: p ~rank: 0
16841800
16851801let get_product_proj proj_env dim =
16861802 match dim with
0 commit comments