Skip to content

Commit 2556abb

Browse files
committed
dim.Affine.solved is now solved_dim option; fix padding handling via add_dims; fix proj equations for Affine
1 parent f18e6e4 commit 2556abb

File tree

4 files changed

+108
-55
lines changed

4 files changed

+108
-55
lines changed

lib/row.ml

Lines changed: 97 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -44,23 +44,23 @@ module Proj_id = struct
4444
end
4545

4646
type proj_id = Proj_id.t [@@deriving equal, hash, compare, sexp]
47+
type proj_cmp = Proj_id.comparator_witness
48+
type proj_var_set = Set.M(Proj_id).t [@@deriving equal, sexp]
49+
type 'a proj_map = 'a Map.M(Proj_id).t [@@deriving equal, sexp]
4750

51+
let proj_var_set_empty = Set.empty (module Proj_id)
52+
let proj_map_empty = Map.empty (module Proj_id)
4853
let dim_var_set_empty = Set.empty (module Dim_var)
4954
let dim_map_empty = Map.empty (module Dim_var)
5055
let use_padding = ref false
5156

52-
type solved_dim = {
53-
d : int;
54-
mutable padding : int option; [@hash.ignore]
55-
label : string option;
56-
proj_id : proj_id option;
57-
}
57+
type solved_dim = { d : int; padding : int option; label : string option; proj_id : proj_id option }
5858
[@@deriving equal, hash, compare, sexp]
5959

6060
type dim =
6161
| Var of dim_var
6262
| Dim of solved_dim
63-
| Affine of { solved : (int * solved_dim) list; unsolved : (int * dim_var) list }
63+
| Affine of { solved : solved_dim option; unsolved : (int * dim_var) list }
6464
[@@deriving equal, hash, compare, sexp, variants]
6565

6666
let uid = ref 0
@@ -81,9 +81,9 @@ type print_style = Only_labels | Axis_size | Axis_number_and_size | Projection_a
8181
let solved_dim_to_string style { d; label; proj_id; padding } =
8282
match style with
8383
| Only_labels -> ( match label with None -> "_" | Some l -> l)
84-
| Axis_size | Axis_number_and_size ->
84+
| Axis_size | Axis_number_and_size -> (
8585
let label_prefix = match label with None -> "" | Some l -> l ^ "=" in
86-
(match (proj_id, padding) with
86+
match (proj_id, padding) with
8787
| None, None -> label_prefix ^ Int.to_string d
8888
| None, Some p -> label_prefix ^ [%string "%{d#Int}+%{p#Int}"]
8989
| Some _, None -> label_prefix ^ Int.to_string d
@@ -93,7 +93,7 @@ let solved_dim_to_string style { d; label; proj_id; padding } =
9393
let size_part = Int.to_string d in
9494
let padding_part = match padding with None -> "" | Some p -> "+" ^ Int.to_string p in
9595
let proj_part = match proj_id with None -> "" | Some pid -> "p" ^ Proj_id.to_string pid in
96-
let extra_parts =
96+
let extra_parts =
9797
match (proj_id, padding) with
9898
| None, None -> ""
9999
| None, Some _ -> padding_part
@@ -105,25 +105,33 @@ let solved_dim_to_string style { d; label; proj_id; padding } =
105105
let dim_to_string style = function
106106
| Dim { label = None; _ } when equal_print_style style Only_labels -> "_"
107107
| Dim { label = Some l; _ } when equal_print_style style Only_labels -> l
108-
| Dim { d; label = None; padding = None; proj_id = None } when equal_print_style style Axis_size -> Int.to_string d
109-
| Dim { d; label = Some l; padding = None; proj_id = None } when equal_print_style style Axis_size -> [%string "%{l}=%{d#Int}"]
110-
| Dim { d; label = None; padding = Some p; proj_id = None } when equal_print_style style Axis_size -> [%string "%{d#Int}+%{p#Int}"]
111-
| Dim { d; label = Some l; padding = Some p; proj_id = None } when equal_print_style style Axis_size -> [%string "%{l}=%{d#Int}+%{p#Int}"]
108+
| Dim { d; label = None; padding = None; proj_id = None } when equal_print_style style Axis_size
109+
->
110+
Int.to_string d
111+
| Dim { d; label = Some l; padding = None; proj_id = None } when equal_print_style style Axis_size
112+
->
113+
[%string "%{l}=%{d#Int}"]
114+
| Dim { d; label = None; padding = Some p; proj_id = None } when equal_print_style style Axis_size
115+
->
116+
[%string "%{d#Int}+%{p#Int}"]
117+
| Dim { d; label = Some l; padding = Some p; proj_id = None }
118+
when equal_print_style style Axis_size ->
119+
[%string "%{l}=%{d#Int}+%{p#Int}"]
112120
| Dim solved_dim -> solved_dim_to_string style solved_dim
113121
| Var { id; label = Some l } -> [%string "$%{id#Int}:%{l}"]
114122
| Var { id; label = None } -> "$" ^ Int.to_string id
115123
| Affine { solved; unsolved } -> (
116-
let solved_terms =
117-
List.map solved ~f:(fun (coeff, solved_dim) ->
118-
let base = solved_dim_to_string style solved_dim in
119-
if coeff = 1 then base else [%string "%{coeff#Int}*%{base}"])
124+
let solved_term =
125+
Option.to_list solved
126+
|> List.map ~f:(fun solved_dim -> solved_dim_to_string style solved_dim)
120127
in
121128
let unsolved_terms =
122129
List.map unsolved ~f:(fun (coeff, v) ->
123130
let label_part = match v.label with None -> "" | Some l -> ":" ^ l in
124-
if coeff = 1 then [%string "$%{v.id#Int}%{label_part}"] else [%string "%{coeff#Int}*$%{v.id#Int}%{label_part}"])
131+
if coeff = 1 then [%string "$%{v.id#Int}%{label_part}"]
132+
else [%string "%{coeff#Int}*$%{v.id#Int}%{label_part}"])
125133
in
126-
let all_terms = solved_terms @ unsolved_terms in
134+
let all_terms = solved_term @ unsolved_terms in
127135
match all_terms with
128136
| [] -> "0"
129137
| [ t ] -> t
@@ -264,6 +272,24 @@ let dim_to_int_exn = function
264272
| Var _ -> invalid_arg "dim_to_int: dim still unknown"
265273
| Affine _ -> invalid_arg "dim_to_int: affine dimension cannot be converted to single int"
266274

275+
let add_dims ~keep_proj_id:{ d = d1; padding = p1; label = l1; proj_id } ~coef
276+
{ d = d2; padding = p2; label = l2; proj_id = _ } =
277+
match (p1, p2) with
278+
| Some p1, Some p2 ->
279+
{
280+
d = d1 + (coef * d2);
281+
padding = Some (max p1 (coef * p2));
282+
label = Option.first_some l1 l2;
283+
proj_id;
284+
}
285+
| _ ->
286+
{
287+
d = d1 + (coef * d2);
288+
padding = Option.first_some p1 (Option.map ~f:(( * ) coef) p2);
289+
label = Option.first_some l1 l2;
290+
proj_id;
291+
}
292+
267293
let s_dim_one v ~value ~in_ =
268294
match in_ with
269295
| Var v2 when equal_dim_var v v2 -> value
@@ -281,15 +307,24 @@ let s_dim_one v ~value ~in_ =
281307
if coeff_of_v = 0 then (solved, new_unsolved)
282308
else
283309
match value with
284-
| Dim s -> ((coeff_of_v, s) :: solved, new_unsolved)
310+
| Dim s ->
311+
let existing = Option.value solved ~default:{ s with d = 0; padding = None } in
312+
let new_solved = add_dims ~keep_proj_id:existing ~coef:coeff_of_v s in
313+
(Some new_solved, new_unsolved)
285314
| Var v' -> (solved, (coeff_of_v, v') :: new_unsolved)
286315
| Affine { solved = value_solved; unsolved = value_unsolved } ->
287316
(* Inlining affine expression: coeff_of_v * value *)
288-
let scaled_solved = List.map value_solved ~f:(fun (c, sd) -> (coeff_of_v * c, sd)) in
317+
let new_solved =
318+
match value_solved with
319+
| None -> solved
320+
| Some vs ->
321+
let existing = Option.value solved ~default:{ vs with d = 0; padding = None } in
322+
Some (add_dims ~keep_proj_id:existing ~coef:coeff_of_v vs)
323+
in
289324
let scaled_unsolved =
290325
List.map value_unsolved ~f:(fun (c, v) -> (coeff_of_v * c, v))
291326
in
292-
(scaled_solved @ solved, scaled_unsolved @ new_unsolved)
327+
(new_solved, scaled_unsolved @ new_unsolved)
293328
in
294329
Affine { solved = new_solved; unsolved = extra_unsolved }
295330
| _ -> in_
@@ -674,7 +709,7 @@ let%debug5_sexp rec unify_dim ~stage (eq : dim * dim) (env : environment) :
674709
| Dim { d = d1; _ }, Dim { d = d2; _ } when d1 = d2 -> ([], env)
675710
| Var v1, Var v2 when equal_dim_var v1 v2 -> ([], env)
676711
| Affine _, Affine _ ->
677-
(* FIXME: For now, we can only unify identical affine expressions *)
712+
(* FIXME: NOT IMPLEMENTED YET *)
678713
if equal_dim dim1 dim2 then ([], env)
679714
else
680715
raise
@@ -908,8 +943,7 @@ let%debug5_sexp solve_dim_ineq ~(stage : stage) ~(cur : dim) ~(subr : dim) (env
908943
| _, Dim { d = 1; _ } -> ([], env)
909944
| (Dim { d = 1; _ } as cur), _ -> ([ Dim_eq { d1 = subr; d2 = cur } ], env)
910945
| Affine _, _ | _, Affine _ ->
911-
(* FIXME: For affine dimensions in inequalities, we can't directly compare them *)
912-
(* This would need more sophisticated constraint solving *)
946+
(* FIXME: NOT IMPLEMENTED YET *)
913947
([], env)
914948
| Var cur_v, Var subr_v -> (
915949
match (Map.find env.dim_env cur_v, Map.find env.dim_env subr_v) with
@@ -1039,7 +1073,7 @@ let%debug5_sexp solve_dim_ineq ~(stage : stage) ~(cur : dim) ~(subr : dim) (env
10391073
Dim_mismatch [ lub2; cur; subr ] ] ) *)
10401074
| Var _, _ | _, Var _ -> assert false
10411075
| Affine _, _ | _, Affine _ ->
1042-
(* FIXME: Can't compute LUB with affine dimensions *)
1076+
(* FIXME: NOT IMPLEMENTED YET *)
10431077
(lub2, [])
10441078
in
10451079
let from_constr, constr2 = apply_dim_constraint ~source:Cur ~stage cur constr2 env in
@@ -1319,7 +1353,7 @@ let%debug5_sexp close_dim_terminal ~(stage : stage) (env : environment) (dim : d
13191353
| _ when not (is_stage4_up stage) -> [ Terminal_dim dim ]
13201354
| _ -> [])
13211355
| Affine _ ->
1322-
(* FIXME: For affine dimensions, we can't generate simple terminal constraints *)
1356+
(* FIXME: NOT IMPLEMENTED YET *)
13231357
[]
13241358

13251359
let last_dim_is dims d2 = match List.last dims with Some (Dim { d; _ }) -> d = d2 | _ -> false
@@ -1465,7 +1499,7 @@ let%debug4_sexp solve_inequalities ~(stage : stage) (ineqs : constraint_ list) (
14651499
| Some (Bounds_dim bounds) -> Bounds_dim { bounds with constr }
14661500
| None -> Bounds_dim { constr; lub = None; cur = []; subr = [] });
14671501
}
1468-
| _, Affine _ -> env (* FIXME: Can't store constraints on affine dimensions *)
1502+
| _, Affine _ -> env (* FIXME: NOT IMPLEMENTED YET *)
14691503
in
14701504
(extras @ ineqs, env)
14711505
| Row_constr { r; constr } ->
@@ -1542,7 +1576,12 @@ let fresh_row_proj r =
15421576
| Dim { d; padding; label; proj_id = _ } ->
15431577
Dim { d; padding; label; proj_id = Some (Proj_id.fresh ()) }
15441578
| Var _ as d -> d
1545-
| Affine _ as d -> d (* FIXME: Affine dimensions don't get fresh projections *)
1579+
| Affine { solved; unsolved } as d ->
1580+
if List.is_empty unsolved then
1581+
Dim { (Option.value_exn ~here:[%here] solved) with proj_id = Some (Proj_id.fresh ()) }
1582+
else (
1583+
assert (Option.is_none solved);
1584+
d)
15461585
in
15471586
{ r with dims = List.map r.dims ~f:fresh_dim }
15481587

@@ -1591,13 +1630,17 @@ let%debug4_sexp get_proj_equations (inequalities : constraint_ list) proj_axis_e
15911630
let to_proj : dim -> proj = function
15921631
| Var v when Map.mem proj_axis_env v -> Solved (Map.find_exn proj_axis_env v)
15931632
| Dim ({ proj_id = Some proj_id; _ } as solved_dim) -> Proj (proj_id, solved_dim)
1633+
| Affine { solved = None; unsolved } as d -> Affine { solved = []; solving = []; unsolved }
15941634
| d -> (
15951635
match subst_dim env d with
15961636
| Dim ({ proj_id = Some proj_id; _ } as solved_dim) -> Proj (proj_id, solved_dim)
15971637
| Dim s -> Proj (Proj_id.fresh (), s)
15981638
| Var v when Map.mem proj_axis_env v -> Solved (Map.find_exn proj_axis_env v)
15991639
| Var v -> Var v
1600-
| Affine _ -> failwith "get_proj_equations: affine dimensions not supported in projections")
1640+
| Affine { solved = Some d; unsolved } ->
1641+
assert (List.is_empty unsolved);
1642+
Proj (Option.value ~default:(Proj_id.fresh ()) d.proj_id, d)
1643+
| Affine { solved = None; unsolved = _ } -> assert false (* handled above *))
16011644
in
16021645
let rec expand_dims = function
16031646
| { dims; bcast = Row_var { v; beg_dims }; _ } when Map.mem env.row_env v -> (
@@ -1691,8 +1734,12 @@ let%debug4_sexp solve_proj_equations (eqs : proj_equation list) : proj_env =
16911734
| Proj_eq (Proj (p, _), Solved idx) | Proj_eq (Solved idx, Proj (p, _)) ->
16921735
p_solved := (p, idx) :: !p_solved
16931736
| Proj_eq (Proj (_, _), (Affine _ as _affine)) | Proj_eq ((Affine _ as _affine), Proj (_, _)) ->
1694-
(* FIXME: can't store constraints on affine dimensions *)
1737+
(* FIXME: NOT IMPLEMENTED YET *)
16951738
()
1739+
| Proj_eq (Solved (Idx.Affine { symbols; offset }), Affine { solved; solving; unsolved })
1740+
| Proj_eq (Affine { solved; solving; unsolved }, Solved (Idx.Affine { symbols; offset })) ->
1741+
(* FIXME: NOT IMPLEMENTED YET *)
1742+
ignore (symbols, offset, solved, solving, unsolved)
16961743
| Proj_eq (Solved idx, Affine affine) | Proj_eq (Affine affine, Solved idx) ->
16971744
(* For affine expressions with solved indices, we can't directly substitute *)
16981745
(* This case might require more sophisticated handling *)
@@ -1755,6 +1802,7 @@ let%debug4_sexp solve_proj_equations (eqs : proj_equation list) : proj_env =
17551802
let repr, _ = Utils.union_find ~equal:Proj_id.equal !proj_classes ~key:p ~rank:0 in
17561803
if Idx.iterated d && (not @@ Map.mem !projs repr) then
17571804
Utils.mref_add product_dim ~key:repr ~data:d ~or_:(fun d2 ->
1805+
(* TODO: consider updating padding *)
17581806
if d <> d2 then
17591807
raise
17601808
@@ Shape_error
@@ -1776,8 +1824,7 @@ let get_proj_index proj_env =
17761824
let unknown_projection proj_id d =
17771825
raise
17781826
@@ Shape_error
1779-
( [%string "projection_of_solved_dims: unknown projection: %{proj_id#Proj_id} %{d#Int}"],
1780-
[] )
1827+
([%string "projection_of_solved_dims: unknown projection: %{proj_id#Proj_id} %{d#Int}"], [])
17811828
in
17821829
function
17831830
| Dim { d; _ } when not @@ Idx.iterated d -> Idx.Fixed_idx 0
@@ -1807,22 +1854,21 @@ let get_proj_index proj_env =
18071854
let symbols = ref [] in
18081855
let offset = ref 0 in
18091856

1810-
List.iter solved ~f:(fun (coeff, solved_dim) ->
1811-
match solved_dim with
1812-
| { d; proj_id = Some proj_id; _ } ->
1813-
if Idx.iterated d then
1814-
let repr, _ =
1815-
Utils.union_find ~equal:Proj_id.equal proj_env.proj_classes ~key:proj_id ~rank:0
1816-
in
1817-
match Map.find proj_env.proj_to_index repr with
1818-
| Some (Iterator symbol) -> symbols := (coeff, symbol) :: !symbols
1819-
| Some (Fixed_idx i) -> offset := !offset + (coeff * i)
1820-
| Some (Affine _) ->
1821-
(* Nested affine - would need recursive handling *)
1822-
raise @@ Shape_error ("Nested affine projections not supported", [])
1823-
| None -> unknown_projection proj_id d
1824-
else ()
1825-
| { proj_id = None; _ } -> assert false);
1857+
Option.iter solved ~f:(function
1858+
| { d; proj_id = Some proj_id; _ } ->
1859+
if Idx.iterated d then
1860+
let repr, _ =
1861+
Utils.union_find ~equal:Proj_id.equal proj_env.proj_classes ~key:proj_id ~rank:0
1862+
in
1863+
match Map.find proj_env.proj_to_index repr with
1864+
| Some (Iterator symbol) -> symbols := (1, symbol) :: !symbols
1865+
| Some (Fixed_idx i) -> offset := !offset + i
1866+
| Some (Affine _) ->
1867+
(* Nested affine - would need recursive handling *)
1868+
raise @@ Shape_error ("Nested affine projections not supported", [])
1869+
| None -> unknown_projection proj_id d
1870+
else ()
1871+
| { proj_id = None; _ } -> assert false);
18261872

18271873
Idx.Affine { symbols = List.rev !symbols; offset = !offset }
18281874

@@ -1844,7 +1890,7 @@ let get_product_proj proj_env dim =
18441890
( "projection_of_solved_dims: still not fully inferred for variable "
18451891
^ Sexp.to_string_hum ([%sexp_of: dim_var] v),
18461892
[ Dim_mismatch [ dim ] ] )
1847-
| Affine _ -> None (* FIXME: Affine dimensions don't participate in product projections *)
1893+
| Affine _ -> None
18481894

18491895
let proj_to_iterator proj_env p =
18501896
match Map.find_exn proj_env.proj_to_index (proj_repr proj_env p) with

lib/row.mli

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,20 @@ type proj_id [@@deriving equal, hash, compare, sexp]
88
type dim_cmp
99
type dim_var_set = (dim_var, dim_cmp) Base.Set.t [@@deriving equal, sexp]
1010
type 'a dim_map = (dim_var, 'a, dim_cmp) Base.Map.t [@@deriving equal, sexp]
11+
type proj_cmp
12+
type proj_var_set = (proj_id, proj_cmp) Base.Set.t [@@deriving equal, sexp]
13+
type 'a proj_map = (proj_id, 'a, proj_cmp) Base.Map.t [@@deriving equal, sexp]
1114

1215
val get_var : ?label:string -> unit -> dim_var
1316
val dim_var_set_empty : dim_var_set
1417
val dim_map_empty : 'a dim_map
18+
val proj_var_set_empty : proj_var_set
19+
val proj_map_empty : 'a proj_map
1520
val use_padding : bool ref
1621

1722
type solved_dim = {
1823
d : int;
19-
mutable padding : int option; [@hash.ignore]
24+
padding : int option;
2025
(** The maximal required total (left + right) padding for this axis. *)
2126
label : string option;
2227
proj_id : proj_id option;
@@ -27,7 +32,7 @@ type solved_dim = {
2732
type dim =
2833
| Var of dim_var
2934
| Dim of solved_dim
30-
| Affine of { solved : (int * solved_dim) list; unsolved : (int * dim_var) list }
35+
| Affine of { solved : solved_dim option; unsolved : (int * dim_var) list }
3136
(** The offset is implicit, automatically derived. Most frequent use case: convolutions. If
3237
[!use_padding] is [true], the offset is the dimensionality-preserving left padding,
3338
otherwise it is 0. NOTE: negative strides are not supported (negative coefficients are

lib/shape.ml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -714,6 +714,8 @@ let derive_projections (update_step : update_step) : Idx.projections =
714714
let unsolved, local_env = Row.solve_inequalities ~stage:Stage6 unsolved local_env in
715715
let unsolved, local_env = Row.solve_inequalities ~stage:Stage7 unsolved local_env in
716716
assert (List.is_empty unsolved);
717+
(* Important: ineqs must not be substituted / solved before getting proj_equations, because
718+
get_inequalities provides indexing information that is lost after substitution. *)
717719
let proj_eqs : Row.proj_equation list = Row.get_proj_equations ineqs proj_axis_env local_env in
718720
let proj_env : Row.proj_env = Row.solve_proj_equations proj_eqs in
719721
let dims_of (sh : t) = sh.batch.dims @ sh.output.dims @ sh.input.dims in

lib/shape.mli

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
- if there is a comma [','] anywhere in the initial text, the multicharacter version is used,
99
- otherwise the single character version is used.
1010
- Currently, the only non-whitespace, non-alphanumeric characters that make sense / are allowed
11-
in a spec are: ['>', '|', '-', ',', '=', ';'].
11+
in a spec are: ['>', '|', '-', ',', '=', ';', '+', '*', '_'].
1212
- identifier: single alphanum character or '_' in single-char mode, a sequence of alphanum
1313
characters or '_' otherwise (whitespace not allowed).
1414
- separators: a sequence of commas and whitespaces.
@@ -51,7 +51,7 @@ type t = {
5151
mutable batch : Row.t;
5252
mutable input : Row.t;
5353
mutable output : Row.t;
54-
id : int; (** A node that has the same shape as this shape. *)
54+
id : int; (** A node that has the same shape as this shape, or [-1]. *)
5555
debug_name : string;
5656
}
5757
[@@deriving equal, sexp]

0 commit comments

Comments
 (0)