Skip to content

Commit 4a29157

Browse files
committed
Convolutions clarification and a bit of formatting
1 parent f5e5d08 commit 4a29157

File tree

2 files changed

+103
-86
lines changed

2 files changed

+103
-86
lines changed

lib/row.ml

Lines changed: 97 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1588,7 +1588,12 @@ let%debug5_sexp eliminate_variables (env : environment) ({ dims; bcast; id } as
15881588
| Some (Bounds_row { constr = Total_elems _; _ }) -> assert false
15891589
| _ -> elim_var :: elim_dims)
15901590

1591-
let empty_env = { dim_env = Map.empty (module Dim_var); row_env = Map.empty (module Row_var) }
1591+
let empty_env =
1592+
{
1593+
dim_env = Map.empty (module Dim_var);
1594+
row_env = Map.empty (module Row_var);
1595+
(* padding_env = Map.empty (module Dim_var); *)
1596+
}
15921597

15931598
let%debug4_sexp solve_inequalities ~(stage : stage) (ineqs : constraint_ list) (env : environment) :
15941599
constraint_ list * environment =
@@ -1827,19 +1832,19 @@ let get_proj_index proj_env =
18271832
| Some i -> i
18281833
| None -> unknown_projection proj_id d)
18291834
| Solved idx -> idx
1830-
| Conv_input { stride; output; solved_kernel; unsolved_kernel } ->
1835+
| Conv_input { stride; output; solved_kernel; unsolved_kernel } -> (
18311836
let output_idx = loop output in
18321837
let symbols = ref [] in
18331838
let offset = ref 0 in
1834-
1839+
18351840
(* Expand output index - multiply by stride *)
18361841
(match output_idx with
18371842
| Idx.Fixed_idx i -> offset := !offset + (stride * i)
18381843
| Idx.Iterator s -> symbols := (stride, s) :: !symbols
18391844
| Idx.Affine { symbols = output_syms; offset = output_offset } ->
18401845
symbols := List.map output_syms ~f:(fun (c, s) -> (stride * c, s)) @ !symbols;
18411846
offset := !offset + (stride * output_offset));
1842-
1847+
18431848
(* Process solved kernel terms *)
18441849
List.iter solved_kernel ~f:(fun (dilation, idx) ->
18451850
match idx with
@@ -1848,7 +1853,7 @@ let get_proj_index proj_env =
18481853
| Idx.Affine { symbols = kernel_syms; offset = kernel_offset } ->
18491854
symbols := List.map kernel_syms ~f:(fun (c, s) -> (dilation * c, s)) @ !symbols;
18501855
offset := !offset + (dilation * kernel_offset));
1851-
1856+
18521857
(* Process unsolved kernel terms *)
18531858
List.iter unsolved_kernel ~f:(fun (dilation, v) ->
18541859
let idx = loop (Var v) in
@@ -1858,41 +1863,41 @@ let get_proj_index proj_env =
18581863
| Idx.Affine { symbols = kernel_syms; offset = kernel_offset } ->
18591864
symbols := List.map kernel_syms ~f:(fun (c, s) -> (dilation * c, s)) @ !symbols;
18601865
offset := !offset + (dilation * kernel_offset));
1861-
1866+
18621867
(* Add padding if use_padding is true *)
1863-
let offset =
1868+
let offset =
18641869
if !use_padding then
18651870
match output with
18661871
| Proj (proj_id, _) ->
1867-
let repr, _ =
1872+
let repr, _ =
18681873
Utils.union_find ~equal:Proj_id.equal proj_env.proj_classes ~key:proj_id ~rank:0
18691874
in
1870-
let total_padding =
1875+
let total_padding =
18711876
Option.value (Map.find proj_env.proj_to_padding repr) ~default:0
18721877
in
1873-
(* Left padding with right tilt when padding is even *)
1874-
let left_padding = (total_padding + 1) / 2 in
1875-
!offset - left_padding
1878+
(* Left padding smaller than right when split needed *)
1879+
let right_padding = (total_padding + 1) / 2 in
1880+
!offset + right_padding - total_padding
18761881
| _ -> !offset
18771882
else !offset
18781883
in
1879-
1884+
18801885
(* Combine and simplify symbols *)
1881-
let symbols =
1886+
let symbols =
18821887
!symbols
18831888
|> List.filter ~f:(fun (c, _) -> c <> 0)
18841889
|> List.sort ~compare:(fun (_, s1) (_, s2) -> Idx.compare_symbol s1 s2)
18851890
|> List.group ~break:(fun (_, s1) (_, s2) -> not (Idx.equal_symbol s1 s2))
18861891
|> List.map ~f:(fun group ->
1887-
let s = snd (List.hd_exn group) in
1888-
let coeff = List.sum (module Int) group ~f:fst in
1889-
(coeff, s))
1892+
let s = snd (List.hd_exn group) in
1893+
let coeff = List.sum (module Int) group ~f:fst in
1894+
(coeff, s))
18901895
|> List.filter ~f:(fun (c, _) -> c <> 0)
18911896
in
1892-
1893-
(match symbols with
1897+
1898+
match symbols with
18941899
| [] -> Idx.Fixed_idx offset
1895-
| [(1, s)] when offset = 0 -> Idx.Iterator s
1900+
| [ (1, s) ] when offset = 0 -> Idx.Iterator s
18961901
| _ -> Idx.Affine { symbols; offset })
18971902
| Var v when Hashtbl.mem proj_env.v_env v -> loop (Hashtbl.find_exn proj_env.v_env v)
18981903
| Var v ->
@@ -1931,34 +1936,37 @@ let get_dim_index proj_env =
19311936
| Dim s -> Proj (Proj_id.fresh (), s)
19321937
| Conv_input { stride; output; solved_kernel; unsolved_kernel } ->
19331938
let output_proj = dim_to_proj output in
1934-
let solved_kernel_proj =
1939+
let solved_kernel_proj =
19351940
match solved_kernel with
1936-
| Some sd ->
1941+
| Some sd ->
19371942
(* Convert solved_dim to axis_index *)
1938-
let idx =
1943+
let idx =
19391944
if Idx.iterated sd.d then
19401945
match sd.proj_id with
1941-
| Some proj_id ->
1942-
let repr, _ =
1943-
Utils.union_find ~equal:Proj_id.equal proj_env.proj_classes ~key:proj_id ~rank:0
1946+
| Some proj_id -> (
1947+
let repr, _ =
1948+
Utils.union_find ~equal:Proj_id.equal proj_env.proj_classes
1949+
~key:proj_id ~rank:0
19441950
in
1945-
(match Map.find proj_env.proj_to_index repr with
1951+
match Map.find proj_env.proj_to_index repr with
19461952
| Some idx -> idx
19471953
| None -> Idx.Fixed_idx 0)
19481954
| None -> Idx.Fixed_idx 0
19491955
else Idx.Fixed_idx 0
19501956
in
1951-
[(1, idx)]
1957+
[ (1, idx) ]
19521958
| None -> []
19531959
in
1954-
Conv_input {
1955-
stride;
1956-
output = output_proj;
1957-
solved_kernel = solved_kernel_proj;
1958-
unsolved_kernel
1959-
}
1960+
Conv_input
1961+
{
1962+
stride;
1963+
output = output_proj;
1964+
solved_kernel = solved_kernel_proj;
1965+
unsolved_kernel;
1966+
}
19601967
in
1961-
get_proj_index proj_env (dim_to_proj (Conv_input { stride; output; solved_kernel; unsolved_kernel }))
1968+
get_proj_index proj_env
1969+
(dim_to_proj (Conv_input { stride; output; solved_kernel; unsolved_kernel }))
19621970
in
19631971
loop
19641972

@@ -2051,99 +2059,102 @@ let%debug4_sexp solve_proj_equations (eqs : proj_equation list) : proj_env =
20512059

20522060
(* Process p_conv_input to populate projs and compute padding *)
20532061
let proj_to_padding = ref @@ Map.empty (module Proj_id) in
2054-
2062+
20552063
(* Helper to compute padding from Conv_input projection *)
20562064
let compute_padding_from_proj = function
20572065
| Conv_input { unsolved_kernel; _ } ->
20582066
(* Padding is the maximum dilation factor among kernel terms *)
20592067
List.fold unsolved_kernel ~init:0 ~f:(fun acc (dilation, _) -> Int.max acc dilation)
20602068
| _ -> 0
20612069
in
2062-
2070+
20632071
(* Process postponed Conv_input equations *)
20642072
List.iter !p_conv_input ~f:(fun (p, conv_input) ->
20652073
let repr, _ = Utils.union_find ~equal:Proj_id.equal !proj_classes ~key:p ~rank:0 in
2066-
2074+
20672075
(* Substitute variables in conv_input to get resolved projection *)
20682076
let rec substitute_vars_in_proj = function
20692077
| Conv_input { stride; output; solved_kernel; unsolved_kernel } ->
20702078
let output' = substitute_vars_in_proj output in
20712079
Conv_input { stride; output = output'; solved_kernel; unsolved_kernel }
20722080
| Var v as proj -> (
2073-
match Hashtbl.find v_env v with
2074-
| Some p -> substitute_vars_in_proj p
2075-
| None -> proj)
2081+
match Hashtbl.find v_env v with Some p -> substitute_vars_in_proj p | None -> proj)
20762082
| proj -> proj
20772083
in
2078-
2084+
20792085
let resolved_conv_input = substitute_vars_in_proj conv_input in
2080-
2086+
20812087
(* Compute padding if use_padding is true *)
2082-
if !use_padding then (
2083-
let padding = compute_padding_from_proj resolved_conv_input in
2084-
if padding > 0 then
2085-
Utils.mref_add proj_to_padding ~key:repr ~data:padding ~or_:(fun existing ->
2086-
ignore (Int.max existing padding)));
2087-
2088+
(if !use_padding then
2089+
let padding = compute_padding_from_proj resolved_conv_input in
2090+
if padding > 0 then
2091+
Utils.mref_add proj_to_padding ~key:repr ~data:padding ~or_:(fun existing ->
2092+
ignore (Int.max existing padding)));
2093+
20882094
(* Try to get index for Conv_input - this creates a temporary proj_env *)
2089-
let temp_proj_env = {
2090-
v_env;
2091-
proj_classes = !proj_classes;
2092-
proj_to_index = !projs;
2093-
proj_to_padding = !proj_to_padding;
2094-
product_dim = !product_dim;
2095-
non_product = !non_product;
2096-
} in
2097-
2095+
let temp_proj_env =
2096+
{
2097+
v_env;
2098+
proj_classes = !proj_classes;
2099+
proj_to_index = !projs;
2100+
proj_to_padding = !proj_to_padding;
2101+
product_dim = !product_dim;
2102+
non_product = !non_product;
2103+
}
2104+
in
2105+
20982106
try
20992107
let idx = get_proj_index temp_proj_env resolved_conv_input in
21002108
Utils.mref_add projs ~key:repr ~data:idx ~or_:(fun idx2 ->
21012109
if not @@ Idx.equal_axis_index idx idx2 then
21022110
raise
21032111
@@ Shape_error
2104-
("Multiple constraints on the same Conv_input projection", [ Index_mismatch [ idx; idx2 ] ]))
2105-
with _ -> () (* Ignore errors for now, will be caught later if needed *)
2106-
);
2107-
2112+
( "Multiple constraints on the same Conv_input projection",
2113+
[ Index_mismatch [ idx; idx2 ] ] ))
2114+
with _ -> () (* Ignore errors for now, will be caught later if needed *));
2115+
21082116
(* Verify postponed equations *)
21092117
List.iter !verify_when_solved1 ~f:(fun (idx, conv_input) ->
2110-
let temp_proj_env = {
2111-
v_env;
2112-
proj_classes = !proj_classes;
2113-
proj_to_index = !projs;
2114-
proj_to_padding = !proj_to_padding;
2115-
product_dim = !product_dim;
2116-
non_product = !non_product;
2117-
} in
2118-
2118+
let temp_proj_env =
2119+
{
2120+
v_env;
2121+
proj_classes = !proj_classes;
2122+
proj_to_index = !projs;
2123+
proj_to_padding = !proj_to_padding;
2124+
product_dim = !product_dim;
2125+
non_product = !non_product;
2126+
}
2127+
in
2128+
21192129
try
21202130
let conv_idx = get_proj_index temp_proj_env conv_input in
21212131
if not @@ Idx.equal_axis_index idx conv_idx then
21222132
raise
21232133
@@ Shape_error
2124-
("Cannot unify index with Conv_input projection", [ Index_mismatch [ idx; conv_idx ] ])
2125-
with _ -> () (* Ignore errors for now *)
2126-
);
2127-
2134+
( "Cannot unify index with Conv_input projection",
2135+
[ Index_mismatch [ idx; conv_idx ] ] )
2136+
with _ -> () (* Ignore errors for now *));
2137+
21282138
List.iter !verify_when_solved2 ~f:(fun (conv_input1, conv_input2) ->
2129-
let temp_proj_env = {
2130-
v_env;
2131-
proj_classes = !proj_classes;
2132-
proj_to_index = !projs;
2133-
proj_to_padding = !proj_to_padding;
2134-
product_dim = !product_dim;
2135-
non_product = !non_product;
2136-
} in
2137-
2139+
let temp_proj_env =
2140+
{
2141+
v_env;
2142+
proj_classes = !proj_classes;
2143+
proj_to_index = !projs;
2144+
proj_to_padding = !proj_to_padding;
2145+
product_dim = !product_dim;
2146+
non_product = !non_product;
2147+
}
2148+
in
2149+
21382150
try
21392151
let idx1 = get_proj_index temp_proj_env conv_input1 in
21402152
let idx2 = get_proj_index temp_proj_env conv_input2 in
21412153
if not @@ Idx.equal_axis_index idx1 idx2 then
21422154
raise
21432155
@@ Shape_error
21442156
("Cannot unify two Conv_input projections", [ Index_mismatch [ idx1; idx2 ] ])
2145-
with _ -> () (* Ignore errors for now *)
2146-
);
2157+
with _ -> () (* Ignore errors for now *));
21472158

21482159
{
21492160
v_env;

lib/shape_inference.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ When `use_padding` is true, the offset is chosen to preserve dimensionality (i.e
6565

6666
The shape and projection inference handles `Conv_input` terms differently depending on `use_padding`. If `use_padding` is false, the impact of convolution kernels is incorporated additively during shape inference, and there's nothing more to do during projection inference. If `use_padding` is true, convolution kernels don't contribute during shape inference, and padding is computed during projection inference, keyed by `proj_id`. Padding is maximal size (width) of dilated kernels as encountered in `Dim` - `Conv_input` constraints and is propagated in either direction, although in practice for CNNs `Conv_input` should only appear as `subr` of `Dim_ineq`.
6767

68+
Shape inference does not maintain padding for axes of individual tensor nodes, these padding values are computed and updated during projections inference.
69+
6870
### Inference strategy
6971

7072
The actual shape inference combines row polymorphism with (nominal) subtyping, as known in the type inference literature. The subtyping stems merely from the fact that a dimension-1 axis can be used in the context of any dimension due to per-axis broadcasting. Row polymorphism stems from broadcasting to more axes: for example, when unifying an unknown (shape) row with a known one, we cannot assume that the unknown row will have just the axes of the known one, because maybe the known row is meant to be broadcasted here to more axes. The combination of row polymorphism with nominal subtyping means that the constraints we are solving are inequalities, both inequalities between rows (the `Row.t` type, i.e. the `row` type above), and between axes/dimensions (the `Row.dim` type). We maintain the inequality ordering between variables in the environment to compute the transitive closure during simplification. We also maintain a least upper bound on the solution.
@@ -210,6 +212,10 @@ The projection inference functions.
210212
* `solve_proj_equations` unifies the projection equations, using union-find to maintain a representative for equal projections. Projections that already have an `axis_index` are `non_product` (not to be iterated over). The remaining projections have a `product_dim`, and get a fresh iterator.
211213
* `get_dim_index` gets an `axis_index` for a `dim` based on the representative of its `proj_id`; and `Fixed_idx 0` for dim=1.
212214

215+
### Convolutions
216+
217+
There is an important and intentional disconnect between `dims` in the `arrayjit` part of the project: tensor nodes, `Ndarray` buffers, code generation: they include padding in the dimension sizes -- and on the other hand shape types, shape inference and tensors exclude padding from the dimension sizes. There is a tension: once the delayed computations of padding, projections and dims (dimension sizes) are forced for a particular node, the padding can no longer be updated (the underlying `Ndarray` buffer might already be created). Since during inference we update the padding incrementally without variables standing in for insufficient information, this unfortunately causes observability of the during-inference and post-inference distinction for the padding of a tensor node.
218+
213219
## Deriving the constraints
214220

215221
Other important functions in the `Shape` module.

0 commit comments

Comments
 (0)