Skip to content

Commit 7e47040

Browse files
committed
Refactor proj_id into abstract type; in progress: solving affine projections
1 parent 12bc3d5 commit 7e47040

File tree

4 files changed

+156
-40
lines changed

4 files changed

+156
-40
lines changed

lib/row.ml

Lines changed: 151 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,27 @@ type dim_var = Dim_var.t [@@deriving equal, hash, compare, sexp]
2323
type dim_cmp = Dim_var.comparator_witness
2424
type dim_var_set = Set.M(Dim_var).t [@@deriving equal, sexp]
2525
type 'a dim_map = 'a Map.M(Dim_var).t [@@deriving equal, sexp]
26-
type proj_id = int [@@deriving equal, hash, compare, sexp]
26+
27+
module Proj_id = struct
28+
type t = Proj_id of int [@@deriving equal, hash, compare, sexp]
29+
30+
let to_string (Proj_id i) = Int.to_string i
31+
32+
let fresh =
33+
let uid = ref 0 in
34+
fun () ->
35+
Int.incr uid;
36+
Proj_id !uid
37+
38+
include Comparator.Make (struct
39+
type nonrec t = t
40+
41+
let compare = compare
42+
let sexp_of_t = sexp_of_t
43+
end)
44+
end
45+
46+
type proj_id = Proj_id.t [@@deriving equal, hash, compare, sexp]
2747

2848
let dim_var_set_empty = Set.empty (module Dim_var)
2949
let dim_map_empty = Map.empty (module Dim_var)
@@ -52,7 +72,7 @@ let dim_hashtbl () = Hashtbl.create (module Dim_var)
5272

5373
let solved_dim_to_string style { d; label; proj_id } =
5474
match proj_id with
55-
| Some proj_id -> [%string "p%{proj_id#Int}"]
75+
| Some proj_id -> [%string "p%{Proj_id.to_string proj_id}"]
5676
| None -> (
5777
match style with
5878
| `Only_labels -> ( match label with None -> "_" | Some l -> l)
@@ -234,8 +254,7 @@ let s_dim_one v ~value ~in_ =
234254
else
235255
match value with
236256
| Dim s -> ((coeff_of_v, s) :: solved, new_unsolved)
237-
| Var v' ->
238-
(solved, (coeff_of_v, v') :: new_unsolved)
257+
| Var v' -> (solved, (coeff_of_v, v') :: new_unsolved)
239258
| Affine { solved = value_solved; unsolved = value_unsolved } ->
240259
(* Inlining affine expression: coeff_of_v * value *)
241260
let scaled_solved = List.map value_solved ~f:(fun (c, sd) -> (coeff_of_v * c, sd)) in
@@ -1490,15 +1509,9 @@ let rec row_to_labels env =
14901509

14911510
(** *** Projection inference *** *)
14921511

1493-
let fresh_proj =
1494-
let uid = ref 0 in
1495-
fun () ->
1496-
Int.incr uid;
1497-
!uid
1498-
14991512
let fresh_row_proj r =
15001513
let fresh_dim = function
1501-
| Dim { d; label; proj_id = _ } -> Dim { d; label; proj_id = Some (fresh_proj ()) }
1514+
| Dim { d; label; proj_id = _ } -> Dim { d; label; proj_id = Some (Proj_id.fresh ()) }
15021515
| Var _ as d -> d
15031516
| Affine _ as d -> d (* FIXME: Affine dimensions don't get fresh projections *)
15041517
in
@@ -1507,7 +1520,15 @@ let fresh_row_proj r =
15071520
(* let update_proj_classes pid1 pid2 proj_classes = Utils.union_add ~equal:Int.equal proj_classes
15081521
pid1 pid2 *)
15091522

1510-
type proj = Var of dim_var | Proj of { proj_id : proj_id; d : int } | Solved of Idx.axis_index
1523+
type proj =
1524+
| Var of dim_var
1525+
| Proj of { proj_id : proj_id; d : int }
1526+
| Solved of Idx.axis_index
1527+
| Affine of {
1528+
solved : (int * Idx.axis_index) list;
1529+
solving : (int * proj_id) list;
1530+
unsolved : (int * dim_var) list;
1531+
}
15111532
[@@deriving compare, equal, sexp]
15121533

15131534
type error_trace += Projection_mismatch of proj list
@@ -1517,14 +1538,14 @@ let sexp_of_error_trace = function
15171538
Sexp.List (Sexp.Atom "Projection_mismatch" :: List.map ps ~f:sexp_of_proj)
15181539
| error_trace -> sexp_of_error_trace error_trace
15191540

1520-
type proj_to_index = Idx.axis_index Map.M(Int).t [@@deriving sexp]
1521-
type proj_classes = int Map.M(Int).t [@@deriving sexp]
1541+
type proj_to_index = Idx.axis_index Map.M(Proj_id).t [@@deriving sexp]
1542+
type proj_classes = Proj_id.t Map.M(Proj_id).t [@@deriving sexp]
15221543

15231544
type proj_env = {
15241545
proj_to_index : proj_to_index;
15251546
proj_classes : proj_classes;
1526-
product_dim : int Map.M(Int).t;
1527-
non_product : Set.M(Int).t;
1547+
product_dim : int Map.M(Proj_id).t;
1548+
non_product : Set.M(Proj_id).t;
15281549
}
15291550
[@@deriving sexp]
15301551

@@ -1544,7 +1565,7 @@ let%debug4_sexp get_proj_equations (inequalities : constraint_ list) proj_axis_e
15441565
| d -> (
15451566
match subst_dim env d with
15461567
| Dim { proj_id = Some proj_id; d; label = _ } -> Proj { proj_id; d }
1547-
| Dim { proj_id = None; d; _ } -> Proj { proj_id = fresh_proj (); d }
1568+
| Dim { proj_id = None; d; _ } -> Proj { proj_id = Proj_id.fresh (); d }
15481569
| Var v when Map.mem proj_axis_env v -> Solved (Map.find_exn proj_axis_env v)
15491570
| Var v -> Var v
15501571
| Affine _ -> failwith "get_proj_equations: affine dimensions not supported in projections")
@@ -1591,9 +1612,43 @@ let%debug4_sexp solve_proj_equations (eqs : proj_equation list) : proj_env =
15911612
let v_env = dim_hashtbl () in
15921613
let p_solved = ref [] in
15931614
let p_dims = ref [] in
1594-
let proj_classes = ref @@ Map.empty (module Int) in
1615+
let proj_classes = ref @@ Map.empty (module Proj_id) in
1616+
let _s_proj_one proj_id ~value ~in_ =
1617+
match in_ with
1618+
| Var _ -> in_
1619+
| Proj { proj_id = pid; d = _ } -> if Proj_id.equal pid proj_id then value else in_
1620+
| Solved _ -> in_
1621+
| Affine { solved; solving; unsolved } ->
1622+
(* Substitute proj_id in affine expression *)
1623+
let coeff_of_p =
1624+
List.fold solving ~init:0 ~f:(fun acc (c, pid) ->
1625+
if Proj_id.equal pid proj_id then acc + c else acc)
1626+
in
1627+
let new_solving =
1628+
List.filter_map solving ~f:(fun (coeff, pid) ->
1629+
if Proj_id.equal pid proj_id then None else Some (coeff, pid))
1630+
in
1631+
let new_solved, extra_solving =
1632+
if coeff_of_p = 0 then (solved, new_solving)
1633+
else
1634+
match value with
1635+
| Solved idx -> ((coeff_of_p, idx) :: solved, new_solving)
1636+
| Proj { proj_id = pid; d = _ } -> (solved, (coeff_of_p, pid) :: new_solving)
1637+
| Var _ -> (solved, new_solving) (* Can't substitute variable into affine *)
1638+
| Affine { solved = value_solved; solving = value_solving; unsolved = _ } ->
1639+
(* Inlining affine expression: coeff_of_p * value *)
1640+
let scaled_solved =
1641+
List.map value_solved ~f:(fun (c, idx) -> (coeff_of_p * c, idx))
1642+
in
1643+
let scaled_solving =
1644+
List.map value_solving ~f:(fun (c, pid) -> (coeff_of_p * c, pid))
1645+
in
1646+
(scaled_solved @ solved, scaled_solving @ new_solving)
1647+
in
1648+
Affine { solved = new_solved; solving = extra_solving; unsolved }
1649+
in
15951650
let rec loop = function
1596-
| Proj_eq (Proj { proj_id = p1; d }, Proj { proj_id = p2; _ }) when p1 = p2 ->
1651+
| Proj_eq (Proj { proj_id = p1; d }, Proj { proj_id = p2; _ }) when Proj_id.equal p1 p2 ->
15971652
p_dims := (p1, d) :: !p_dims
15981653
| Proj_eq (Var v1, Var v2) when equal_dim_var v1 v2 -> ()
15991654
| Proj_eq ((Proj { proj_id = p1; d = d1 } as proj1), (Proj { proj_id = p2; d = d2 } as proj2))
@@ -1604,9 +1659,32 @@ let%debug4_sexp solve_proj_equations (eqs : proj_equation list) : proj_env =
16041659
( "Conflicting dimensions for the same projection",
16051660
[ Projection_mismatch [ proj1; proj2 ] ] );
16061661
p_dims := (p1, d1) :: !p_dims;
1607-
proj_classes := Utils.union_add ~equal:Int.equal !proj_classes p1 p2
1662+
proj_classes := Utils.union_add ~equal:Proj_id.equal !proj_classes p1 p2
16081663
| Proj_eq (Proj p, Solved idx) | Proj_eq (Solved idx, Proj p) ->
16091664
p_solved := (p.proj_id, idx) :: !p_solved
1665+
| Proj_eq ((Proj { proj_id=_; _ } ), (Affine _ as _affine))
1666+
| Proj_eq ((Affine _ as _affine), (Proj { proj_id=_; _ })) ->
1667+
()
1668+
| Proj_eq (Solved idx, Affine affine) | Proj_eq (Affine affine, Solved idx) ->
1669+
(* For affine expressions with solved indices, we can't directly substitute *)
1670+
(* This case might require more sophisticated handling *)
1671+
raise
1672+
@@ Shape_error
1673+
( "Cannot unify solved index with affine projection",
1674+
[ Projection_mismatch [ Solved idx; Affine affine ] ] )
1675+
| Proj_eq (Var v, Affine affine) | Proj_eq (Affine affine, Var v) -> (
1676+
(* Handle variable to affine binding *)
1677+
match Hashtbl.find v_env v with
1678+
| None -> Hashtbl.add_exn v_env ~key:v ~data:(Affine affine)
1679+
| Some p2 -> loop (Proj_eq (Affine affine, p2)))
1680+
| Proj_eq (Affine affine1, Affine affine2) ->
1681+
(* For now, we can only unify identical affine expressions *)
1682+
if equal_proj (Affine affine1) (Affine affine2) then ()
1683+
else
1684+
raise
1685+
@@ Shape_error
1686+
( "Cannot unify different affine projections",
1687+
[ Projection_mismatch [ Affine affine1; Affine affine2 ] ] )
16101688
| Proj_eq (Solved idx1, Solved idx2) when Idx.equal_axis_index idx1 idx2 -> ()
16111689
| Proj_eq (Solved idx1, Solved idx2) ->
16121690
raise
@@ -1618,28 +1696,34 @@ let%debug4_sexp solve_proj_equations (eqs : proj_equation list) : proj_env =
16181696
| Some p2 -> loop (Proj_eq (p, p2)))
16191697
| Iterated (Solved _) -> ()
16201698
| Iterated (Proj { proj_id; d }) -> p_dims := (proj_id, d) :: !p_dims
1699+
| Iterated (Affine { solving; _ }) ->
1700+
(* For affine expressions, mark all constituent projections as iterated *)
1701+
List.iter solving ~f:(fun (_, proj_id) -> p_dims := (proj_id, 1) :: !p_dims)
16211702
| Iterated (Var v) -> (
16221703
match Hashtbl.find v_env v with
16231704
| None ->
16241705
let idx = Idx.(Iterator (get_symbol ())) in
16251706
Hashtbl.add_exn v_env ~key:v ~data:(Solved idx)
16261707
| Some (Var v2) -> loop (Iterated (Var v2))
16271708
| Some (Solved _) -> ()
1628-
| Some (Proj { proj_id; d }) -> p_dims := (proj_id, d) :: !p_dims)
1709+
| Some (Proj { proj_id; d }) -> p_dims := (proj_id, d) :: !p_dims
1710+
| Some (Affine { solving; _ }) ->
1711+
(* For affine expressions, mark all constituent projections as iterated *)
1712+
List.iter solving ~f:(fun (_, proj_id) -> p_dims := (proj_id, 1) :: !p_dims))
16291713
in
16301714
List.iter eqs ~f:loop;
1631-
let projs = ref @@ Map.empty (module Int) and non_product = ref @@ Set.empty (module Int) in
1715+
let projs = ref @@ Map.empty (module Proj_id) and non_product = ref @@ Set.empty (module Proj_id) in
16321716
List.iter !p_solved ~f:(fun (p, idx) ->
1633-
let repr, _ = Utils.union_find ~equal:Int.equal !proj_classes ~key:p ~rank:0 in
1717+
let repr, _ = Utils.union_find ~equal:Proj_id.equal !proj_classes ~key:p ~rank:0 in
16341718
non_product := Set.add !non_product repr;
16351719
Utils.mref_add projs ~key:repr ~data:idx ~or_:(fun idx2 ->
16361720
if not @@ Idx.equal_axis_index idx idx2 then
16371721
raise
16381722
@@ Shape_error
16391723
("Multiple constraints on the same projection", [ Index_mismatch [ idx; idx2 ] ])));
1640-
let product_dim = ref @@ Map.empty (module Int) in
1724+
let product_dim = ref @@ Map.empty (module Proj_id) in
16411725
List.iter !p_dims ~f:(fun (p, d) ->
1642-
let repr, _ = Utils.union_find ~equal:Int.equal !proj_classes ~key:p ~rank:0 in
1726+
let repr, _ = Utils.union_find ~equal:Proj_id.equal !proj_classes ~key:p ~rank:0 in
16431727
if Idx.iterated d && (not @@ Map.mem !projs repr) then
16441728
Utils.mref_add product_dim ~key:repr ~data:d ~or_:(fun d2 ->
16451729
if d <> d2 then
@@ -1650,7 +1734,7 @@ let%debug4_sexp solve_proj_equations (eqs : proj_equation list) : proj_env =
16501734
Projection_mismatch [ Proj { proj_id = p; d }; Proj { proj_id = p; d = d2 } ];
16511735
] )));
16521736
Map.iteri !product_dim ~f:(fun ~key:p ~data:_ ->
1653-
let repr, _ = Utils.union_find ~equal:Int.equal !proj_classes ~key:p ~rank:0 in
1737+
let repr, _ = Utils.union_find ~equal:Proj_id.equal !proj_classes ~key:p ~rank:0 in
16541738
Utils.mref_add_missing projs repr ~f:(fun () -> Idx.(Iterator (get_symbol ()))));
16551739
{
16561740
proj_classes = !proj_classes;
@@ -1659,7 +1743,14 @@ let%debug4_sexp solve_proj_equations (eqs : proj_equation list) : proj_env =
16591743
non_product = !non_product;
16601744
}
16611745

1662-
let get_proj_index proj_env = function
1746+
let get_proj_index proj_env =
1747+
let unknown_projection proj_id d =
1748+
raise
1749+
@@ Shape_error
1750+
( "projection_of_solved_dims: unknown projection",
1751+
[ Projection_mismatch [ Proj { proj_id; d } ] ] )
1752+
in
1753+
function
16631754
| Dim { d; _ } when not @@ Idx.iterated d -> Idx.Fixed_idx 0
16641755
| Dim { proj_id = None; _ } -> assert false
16651756
| Var v as dim ->
@@ -1669,18 +1760,43 @@ let get_proj_index proj_env = function
16691760
^ Sexp.to_string_hum ([%sexp_of: dim_var] v),
16701761
[ Dim_mismatch [ dim ] ] )
16711762
| Dim { proj_id = Some proj_id; d; _ } -> (
1672-
let repr, _ = Utils.union_find ~equal:Int.equal proj_env.proj_classes ~key:proj_id ~rank:0 in
1763+
let repr, _ = Utils.union_find ~equal:Proj_id.equal proj_env.proj_classes ~key:proj_id ~rank:0 in
16731764
match Map.find proj_env.proj_to_index repr with
16741765
| Some i -> i
1675-
| None ->
1676-
raise
1677-
@@ Shape_error
1678-
( "projection_of_solved_dims: unknown projection",
1679-
[ Projection_mismatch [ Proj { proj_id; d } ] ] ))
1680-
| Affine _ -> failwith "get_proj_index: affine dimensions not supported in projections"
1766+
| None -> unknown_projection proj_id d)
1767+
| Affine { solved; unsolved } ->
1768+
(* Handle unsolved variables - these should be resolved by now *)
1769+
if not (List.is_empty unsolved) then
1770+
raise
1771+
@@ Shape_error
1772+
( "Affine dimension has unresolved variables",
1773+
[ Dim_mismatch [ Affine { solved; unsolved } ] ] );
1774+
1775+
(* Process solved terms *)
1776+
let symbols = ref [] in
1777+
let offset = ref 0 in
1778+
1779+
List.iter solved ~f:(fun (coeff, solved_dim) ->
1780+
match solved_dim with
1781+
| { d; proj_id = Some proj_id; _ } ->
1782+
if Idx.iterated d then
1783+
let repr, _ =
1784+
Utils.union_find ~equal:Proj_id.equal proj_env.proj_classes ~key:proj_id ~rank:0
1785+
in
1786+
match Map.find proj_env.proj_to_index repr with
1787+
| Some (Iterator symbol) -> symbols := (coeff, symbol) :: !symbols
1788+
| Some (Fixed_idx i) -> offset := !offset + (coeff * i)
1789+
| Some (Affine _) ->
1790+
(* Nested affine - would need recursive handling *)
1791+
raise @@ Shape_error ("Nested affine projections not supported", [])
1792+
| None -> unknown_projection proj_id d
1793+
else ()
1794+
| { proj_id = None; _ } -> assert false);
1795+
1796+
Idx.Affine { symbols = List.rev !symbols; offset = !offset }
16811797

16821798
let proj_repr proj_env p =
1683-
fst @@ Utils.union_find ~equal:Int.equal proj_env.proj_classes ~key:p ~rank:0
1799+
fst @@ Utils.union_find ~equal:Proj_id.equal proj_env.proj_classes ~key:p ~rank:0
16841800

16851801
let get_product_proj proj_env dim =
16861802
match dim with

lib/row.mli

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,5 +139,5 @@ val get_proj_equations :
139139

140140
val solve_proj_equations : proj_equation list -> proj_env
141141
val get_proj_index : proj_env -> dim -> Ir.Indexing.axis_index
142-
val get_product_proj : proj_env -> dim -> (int * int) option
143-
val proj_to_iterator : proj_env -> int -> Ir.Indexing.symbol
142+
val get_product_proj : proj_env -> dim -> (proj_id * int) option
143+
val proj_to_iterator : proj_env -> proj_id -> Ir.Indexing.symbol

lib/shape.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -729,8 +729,8 @@ let derive_projections (update_step : update_step) : Idx.projections =
729729
let rhs_dims = Array.of_list_map ~f:to_dims rhs in
730730
let all_dims : Row.dim list = List.concat_map ~f:dims_of @@ (lhs :: rhs) in
731731
(* Note: the ordering will affect performance of naive backends. *)
732-
let all_product_projs : (int * int) list =
733-
Utils.unique_keep_first ~equal:(fun (p, _) (q, _) -> p = q)
732+
let all_product_projs : (Row.proj_id * int) list =
733+
Utils.unique_keep_first ~equal:(fun (p, _) (q, _) -> Row.equal_proj_id p q)
734734
@@ List.filter_map all_dims ~f:(Row.get_product_proj proj_env)
735735
in
736736
let product_space : int array = Array.of_list_map all_product_projs ~f:snd in

lib/shape_inference.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ The projection inference functions.
195195

196196
* `get_proj_equations inequalities proj_axis_env env` converts both equations and inequalitites to projection equations. For inequalities, it takes broadcasting into account, and equates a potentially-broadcasted dim-1 projection to `Fixed_idx 0`. `proj_axis_env` originates from the `Shape` module, holds projections from the slice operator and the einsum syntax.
197197
* `solve_proj_equations` unifies the projection equations, using union-find to maintain a representative for equal projections. Projections that already have an `axis_index` are `non_product` (not to be iterated over). The remaining projections have a `product_dim`, and get a fresh iterator.
198-
* `get_proj_index` gets an `axis_index` for a `dim` based on the representative of its `proj_id`; and `Fixed_idx 0` for dim-1.
198+
* `get_proj_index` gets an `axis_index` for a `dim` based on the representative of its `proj_id`; and `Fixed_idx 0` for dim=1.
199199

200200
## Deriving the constraints
201201

0 commit comments

Comments
 (0)