Skip to content

Commit f2e0756

Browse files
committed
Yay, a better design for convolution shape and projection inference
TODO: don't pass all pre-existing paddings as resolved_padding, some can still be updated (based on whether a tensor node's fields are forced).
1 parent 50d15cf commit f2e0756

File tree

6 files changed

+268
-415
lines changed

6 files changed

+268
-415
lines changed

lib/row.ml

Lines changed: 204 additions & 367 deletions
Large diffs are not rendered by default.

lib/row.mli

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
open Base
44

5+
type axis_padding = Ir.Ndarray.axis_padding [@@deriving equal, sexp]
56
type kind = [ `Batch | `Input | `Output ] [@@deriving equal, compare, sexp, hash, variants]
67
type dim_var [@@deriving equal, hash, compare, sexp]
78
type proj_id [@@deriving equal, hash, compare, sexp]
@@ -26,15 +27,9 @@ type solved_dim = { d : int; label : string option; proj_id : proj_id option }
2627
type dim =
2728
| Var of dim_var
2829
| Dim of solved_dim
29-
| Conv_input of {
30-
stride : int;
31-
output : dim;
32-
solved_kernel : solved_dim option;
33-
unsolved_kernel : (int * dim_var) list;
34-
}
35-
(** The offset is implicit, automatically derived. Most frequent use case: convolutions. If
36-
[!use_padding] is [true], the offset is the dimensionality-preserving left padding,
37-
otherwise it is 0. *)
30+
| Conv_input of { stride : int; output : dim; dilation : int; kernel : dim }
31+
(** The offset is implicit, automatically derived. If [!use_padding] is [true], the offset is
32+
the left part of the dimensionality-preserving symmetric padding, otherwise it is 0. *)
3833
[@@deriving equal, hash, compare, sexp, variants]
3934

4035
val get_dim : d:int -> ?label:string -> unit -> dim
@@ -151,7 +146,7 @@ type proj_equation =
151146
val get_proj_equations :
152147
constraint_ list -> Ir.Indexing.axis_index dim_map -> environment -> proj_equation list
153148

154-
val solve_proj_equations : proj_equation list -> proj_env
149+
val solve_proj_equations : proj_equation list -> resolved_padding:(proj_id, axis_padding) List.Assoc.t -> proj_env
155150
val get_proj_index : proj_env -> proj -> Ir.Indexing.axis_index
156151
val get_dim_index : proj_env -> dim -> Ir.Indexing.axis_index
157152
val get_product_proj : proj_env -> dim -> (proj_id * int) option

lib/shape.ml

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,15 @@ type parsed_axis_labels = {
6060

6161
let axis_labels parsed = parsed.labels
6262

63+
type padding = Row.axis_padding array option [@@deriving sexp, equal]
64+
6365
type t = {
6466
mutable batch : Row.t;
6567
mutable input : Row.t;
6668
mutable output : Row.t;
69+
mutable batch_padding : padding;
70+
mutable input_padding : padding;
71+
mutable output_padding : padding;
6772
id : int; (** A node that has the same shape as this shape. *)
6873
debug_name : string;
6974
}
@@ -676,13 +681,24 @@ let () =
676681
(** *** Projection inference *** *)
677682

678683
let fresh_proj_ids update =
684+
let resolved_padding = ref [] in
685+
let fetch_padding row row_padding =
686+
Option.iter row_padding ~f:(fun padding ->
687+
Array.iter2_exn (Array.of_list row.Row.dims) padding ~f:(fun d p ->
688+
match d with
689+
| Row.Dim { proj_id = Some proj_id; _ } -> resolved_padding := (proj_id, p) :: !resolved_padding
690+
| _ -> ()))
691+
in
679692
let fresh_shape (sh : t) =
680693
sh.batch <- Row.fresh_row_proj sh.batch;
681694
sh.input <- Row.fresh_row_proj sh.input;
682-
sh.output <- Row.fresh_row_proj sh.output
695+
sh.output <- Row.fresh_row_proj sh.output;
696+
fetch_padding sh.batch sh.batch_padding;
697+
fetch_padding sh.input sh.input_padding;
698+
fetch_padding sh.output sh.output_padding
683699
in
684700
fresh_shape update.shape;
685-
match update.logic with
701+
(match update.logic with
686702
| Terminal _ -> ()
687703
| Transpose (_, sh) -> fresh_shape sh
688704
| Broadcast (_, sh1, sh2) ->
@@ -691,13 +707,14 @@ let fresh_proj_ids update =
691707
| Broadcast_tern (_, sh1, sh2, sh3) ->
692708
fresh_shape sh1;
693709
fresh_shape sh2;
694-
fresh_shape sh3
710+
fresh_shape sh3);
711+
!resolved_padding
695712

696713
(** Computes the indexing into subtensors given the shape information of a tensor.
697714
[derive_projections] should only be invoked when the shapes are fully inferred already! *)
698715
let derive_projections (update_step : update_step) : Idx.projections =
699716
finish_inference ();
700-
fresh_proj_ids update_step;
717+
let resolved_padding = fresh_proj_ids update_step in
701718
let _debug_update_step : update_step = update_step in
702719
let (proj_axis_env, ineqs) : proj_axis_env * Row.constraint_ list =
703720
get_inequalities update_step
@@ -717,7 +734,7 @@ let derive_projections (update_step : update_step) : Idx.projections =
717734
(* Important: ineqs must not be substituted / solved before getting proj_equations, because
718735
get_inequalities provides indexing information that is lost after substitution. *)
719736
let proj_eqs : Row.proj_equation list = Row.get_proj_equations ineqs proj_axis_env local_env in
720-
let proj_env : Row.proj_env = Row.solve_proj_equations proj_eqs in
737+
let proj_env : Row.proj_env = Row.solve_proj_equations ~resolved_padding proj_eqs in
721738
let dims_of (sh : t) = sh.batch.dims @ sh.output.dims @ sh.input.dims in
722739
let lhs = update_step.shape in
723740
let rhs =
@@ -809,7 +826,18 @@ let make ?batch_dims ?input_dims ?output_dims ?batch_axes ?input_axes ?output_ax
809826
| None, None -> make_unknown `Output
810827
| Some _, Some _ -> invalid_arg "Shape.make: do not provide both output_dims, output_axes"
811828
in
812-
let result = { input; output; batch; id; debug_name } in
829+
let result =
830+
{
831+
input;
832+
output;
833+
batch;
834+
id;
835+
debug_name;
836+
batch_padding = None;
837+
input_padding = None;
838+
output_padding = None;
839+
}
840+
in
813841
(match deduced with
814842
| Not_constrained -> ()
815843
| Input_equals_output -> (
@@ -841,7 +869,18 @@ let shape_spec_to_dims_bio labels =
841869

842870
let of_spec ?(deduced = Not_constrained) ~debug_name ~id spec =
843871
let batch, input, output = shape_spec_to_dims_bio ~sh_id:id @@ axis_labels_of_spec spec in
844-
let result = { input; output; batch; id; debug_name } in
872+
let result =
873+
{
874+
input;
875+
output;
876+
batch;
877+
id;
878+
debug_name;
879+
batch_padding = None;
880+
input_padding = None;
881+
output_padding = None;
882+
}
883+
in
845884
(match deduced with
846885
| Not_constrained -> ()
847886
| Input_equals_output -> (

lib/shape.mli

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747

4848
open Base
4949

50-
type padding = Ir.Ndarray.axis_padding array option
50+
type padding = Row.axis_padding array option [@@deriving sexp, equal]
5151

5252
type t = {
5353
mutable batch : Row.t;
@@ -113,10 +113,7 @@ val make :
113113
that these are dimensions labels and not axis labels: they need not be unique for a row, are
114114
inferred when provided, and must match whenever the axis sizes must match. *)
115115

116-
val to_string_hum :
117-
?style:Row.print_style ->
118-
t ->
119-
string
116+
val to_string_hum : ?style:Row.print_style -> t -> string
120117

121118
val unsafe_reinitialize : unit -> unit
122119
(** Bring global state to its initialization values. This invalidates any unfinished inference. *)

lib/shape_inference.md

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,26 +13,15 @@ A tensor shape in OCANNL is composed of three rows of axes: batch, input and out
1313
A row is a sequence of axes of a single kind: batch, input, or output. The shape type incorporates information relevant to inference, in particular shape variables: both for individual axes (`dim` variables), and for extending a row with more axes (`row` variables). Currently, all rows are (independently) broadcastable: can be broadcasted to a larger number of axes. However, in OCANNL the broadcasting can happen "in the middle", with not only the given trailing axes fixed, but also with the given leading axes fixed.
1414

1515
```ocaml
16-
type solved_dim = {
17-
d : int;
18-
label : string option;
19-
proj_id : proj_id option;
20-
}
16+
type solved_dim = { d : int; label : string option; proj_id : proj_id option }
2117
(** A single axis in a shape. *)
2218
2319
type dim =
2420
| Var of dim_var
2521
| Dim of solved_dim
26-
| Conv_input of {
27-
stride : int;
28-
output : dim;
29-
solved_kernel : solved_dim option;
30-
unsolved_kernel : (int * dim_var) list;
31-
}
32-
(** Represents convolution-style input dimensions where the output dimension
33-
relates to the input dimension through: input = stride * output + kernel_terms.
34-
This is a generalization of convolutions that supports affine indexing patterns.
35-
The offset is implicit and depends on the global setting use_padding. *)
22+
| Conv_input of { stride : int; output : dim; dilation : int; kernel : dim }
23+
(** The offset is implicit, automatically derived. If [!use_padding] is [true], the offset is
24+
the left part of the dimensionality-preserving symmetric padding, otherwise it is 0. *)
3625
3726
type bcast =
3827
| Row_var of row_var (** The row can be inferred to have more axes. *)
@@ -214,7 +203,7 @@ The projection inference functions.
214203

215204
### Convolutions
216205

217-
There is an important and intentional disconnect between `dims` in the `arrayjit` part of the project: tensor nodes, `Ndarray` buffers, code generation: they include padding in the dimension sizes -- and on the other hand shape types, shape inference and tensors exclude padding from the dimension sizes. There is a tension: once the delayed computations of padding, projections and dims (dimension sizes) are forced for a particular node, the padding can no longer be updated (the underlying `Ndarray` buffer might already be created). Since during inference we update the padding incrementally without variables standing in for insufficient information, this unfortunately causes observability of the during-inference and post-inference distinction for the padding of a tensor node.
206+
There is an important and intentional difference between `dims` in the `arrayjit` part of the project: tensor nodes, `Ndarray` buffers, code generation -- they include padding in the dimension sizes; and on the other hand shape types, shape inference and tensors exclude padding from the dimension sizes. There is a tension: once the delayed computations of padding, projections and dims (dimension sizes) are forced for a particular node, the padding can no longer be updated (the underlying `Ndarray` buffer might already be created). Since during inference we update the padding incrementally without variables standing in for insufficient information, this unfortunately causes observability of the during-inference and post-inference distinction for the padding of a tensor node.
218207

219208
## Deriving the constraints
220209

test/test_print_style.expected

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,28 +7,24 @@ Testing print_style functionality:
77
=== Testing solved_dim_to_string ===
88
Full attributes (d=28, padding=2, label=height, proj_id):
99
Only_labels: height
10-
Axis_size: height=28+2
11-
Axis_number_and_size: height=28+2
12-
Projection_and_size: height=28+2
10+
Axis_size: height=28
11+
Axis_number_and_size: height=28
12+
Projection_and_size: height=28
1313

1414
Minimal attributes (d=64, no padding, no label, no proj_id):
1515
Only_labels: _
1616
Axis_size: 64
1717
Projection_and_size: 64
1818

19-
With padding only (d=32, padding=3, label=width, no proj_id):
20-
Axis_size: width=32+3
21-
Projection_and_size: width=32+3
22-
2319
With projection (d=32, label=width, proj_id):
2420
Axis_size: width=32
25-
Projection_and_size: width=32[p1]
21+
Projection_and_size: width=32p1
2622

2723
=== Testing dim_to_string ===
2824
Solved dimensions:
2925
Only_labels (full): height
30-
Axis_size (full): height=28+2
31-
Projection_and_size (full): height=28+2
26+
Axis_size (full): height=28
27+
Projection_and_size (full): height=28
3228
Only_labels (minimal): _
3329
Axis_size (minimal): 64
3430

0 commit comments

Comments
 (0)