Skip to content

Commit bdc8695

Browse files
lukstaficlaude
andcommitted
Refactor Conv_spec to Affine_spec and clean up Shape module interface
- Rename Conv_spec to Affine_spec in einsum_types.ml with structure matching Row.Affine: stride, over_label, conv (optional), stride_offset - Add conv_spec type for the optional convolution component - Update parser.mly with expanded affine_expr grammar supporting stride*over+offset+dilation*kernel patterns - Remove duplicate type exports (conv_spec, axis_spec, axis_map, parsed_axis_labels) from shape.mli - use Einsum_parser directly - Remove axis_labels_of_spec wrapper from shape.ml - Add parse_n5_layout helper to Shape for N5_layout parsing - Update tensor.ml to use Shape.parse_n5_layout - Update test_conv_syntax.ml to use Einsum_parser directly 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 20967eb commit bdc8695

File tree

8 files changed

+138
-136
lines changed

8 files changed

+138
-136
lines changed

tensor/einsum_types.ml

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,21 @@
44

55
open Base
66

7+
(** Convolution component for affine axis specifications. *)
8+
type conv_spec = { dilation : int; kernel_label : string } [@@deriving compare, sexp]
9+
710
(** Specification for individual axes in the einsum notation. *)
811
type axis_spec =
912
| Label of string (** A variable axis label. *)
1013
| Fixed_index of int (** A fixed index, used for projection. *)
11-
| Conv_spec of { stride : int; output_label : string; dilation : int; kernel_label : string }
12-
(** Convolution-style axis specification: stride*output + dilation*kernel. *)
14+
| Affine_spec of {
15+
stride : int; (** Coefficient for the over dimension. *)
16+
over_label : string; (** The output/iteration dimension label. *)
17+
conv : conv_spec option; (** Optional convolution: dilation*kernel. *)
18+
stride_offset : int; (** Constant offset added after stride*over. *)
19+
}
20+
(** Affine axis specification: stride*over + stride_offset [+ dilation*kernel].
21+
Corresponds to [Row.Affine] in shape inference. *)
1322
[@@deriving compare, sexp]
1423

1524
(** An index pointing to any of a shape's axes, including the kind of the axis ([Batch, Input,

tensor/parser.mly

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -96,31 +96,53 @@ axis_spec:
9696
{ Label id }
9797
| UNDERSCORE
9898
{ Label "_" }
99-
| conv_expr
99+
| affine_expr
100100
{ $1 }
101101

102-
/* Convolution expression: [stride*]output[+[dilation*]kernel] */
103-
conv_expr:
104-
| stride = INT; STAR; output = IDENT; PLUS; dilation = INT; STAR; kernel = IDENT
105-
{ Conv_spec { stride; output_label = output; dilation; kernel_label = kernel } }
106-
| stride = INT; STAR; output = IDENT; PLUS; dilation = INT; STAR; UNDERSCORE
107-
{ Conv_spec { stride; output_label = output; dilation; kernel_label = "_offset_only" } }
108-
| stride = INT; STAR; output = IDENT; PLUS; offset = INT
109-
{ Conv_spec { stride; output_label = output; dilation = offset; kernel_label = "_offset_only" } }
110-
| output = IDENT; PLUS; dilation = INT; STAR; kernel = IDENT
111-
{ Conv_spec { stride = 1; output_label = output; dilation; kernel_label = kernel } }
112-
| output = IDENT; PLUS; dilation = INT; STAR; UNDERSCORE
113-
{ Conv_spec { stride = 1; output_label = output; dilation; kernel_label = "_offset_only" } }
114-
| output = IDENT; PLUS; offset = INT
115-
{ Conv_spec { stride = 1; output_label = output; dilation = offset; kernel_label = "_offset_only" } }
116-
| output = IDENT; PLUS; kernel = IDENT
117-
{ Conv_spec { stride = 1; output_label = output; dilation = 1; kernel_label = kernel } }
118-
| stride = INT; STAR; output = IDENT; PLUS; kernel = IDENT
119-
{ Conv_spec { stride; output_label = output; dilation = 1; kernel_label = kernel } }
120-
| stride = INT; STAR; output = IDENT; PLUS; UNDERSCORE
121-
{ Conv_spec { stride; output_label = output; dilation = 1; kernel_label = "_offset_only" } }
122-
| stride = INT; STAR; output = IDENT
123-
{ Conv_spec { stride; output_label = output; dilation = 0; kernel_label = "_stride_only" } }
102+
/* Affine expression: [stride*]over[+offset][+[dilation*]kernel]
103+
Supports various combinations of stride, offset, and convolution components.
104+
The underscore (_) can be used as a placeholder in convolution syntax. */
105+
affine_expr:
106+
/* Full form: stride*over+offset+dilation*kernel */
107+
| stride = INT; STAR; over = IDENT; PLUS; offset = INT; PLUS; dilation = INT; STAR; kernel = IDENT
108+
{ Affine_spec { stride; over_label = over; stride_offset = offset;
109+
conv = Some { dilation; kernel_label = kernel } } }
110+
/* stride*over+offset+kernel (dilation=1) */
111+
| stride = INT; STAR; over = IDENT; PLUS; offset = INT; PLUS; kernel = IDENT
112+
{ Affine_spec { stride; over_label = over; stride_offset = offset;
113+
conv = Some { dilation = 1; kernel_label = kernel } } }
114+
/* over+offset+dilation*kernel (stride=1) */
115+
| over = IDENT; PLUS; offset = INT; PLUS; dilation = INT; STAR; kernel = IDENT
116+
{ Affine_spec { stride = 1; over_label = over; stride_offset = offset;
117+
conv = Some { dilation; kernel_label = kernel } } }
118+
/* over+offset+kernel (stride=1, dilation=1) */
119+
| over = IDENT; PLUS; offset = INT; PLUS; kernel = IDENT
120+
{ Affine_spec { stride = 1; over_label = over; stride_offset = offset;
121+
conv = Some { dilation = 1; kernel_label = kernel } } }
122+
/* stride*over+dilation*kernel (no offset) */
123+
| stride = INT; STAR; over = IDENT; PLUS; dilation = INT; STAR; kernel = IDENT
124+
{ Affine_spec { stride; over_label = over; stride_offset = 0;
125+
conv = Some { dilation; kernel_label = kernel } } }
126+
/* stride*over+kernel (no offset, dilation=1) */
127+
| stride = INT; STAR; over = IDENT; PLUS; kernel = IDENT
128+
{ Affine_spec { stride; over_label = over; stride_offset = 0;
129+
conv = Some { dilation = 1; kernel_label = kernel } } }
130+
/* over+dilation*kernel (stride=1, no offset) */
131+
| over = IDENT; PLUS; dilation = INT; STAR; kernel = IDENT
132+
{ Affine_spec { stride = 1; over_label = over; stride_offset = 0;
133+
conv = Some { dilation; kernel_label = kernel } } }
134+
/* over+kernel (stride=1, dilation=1, no offset) */
135+
| over = IDENT; PLUS; kernel = IDENT
136+
{ Affine_spec { stride = 1; over_label = over; stride_offset = 0;
137+
conv = Some { dilation = 1; kernel_label = kernel } } }
138+
/* stride*over+offset (no conv) */
139+
| stride = INT; STAR; over = IDENT; PLUS; offset = INT
140+
{ Affine_spec { stride; over_label = over; stride_offset = offset; conv = None } }
141+
/* over+offset (stride=1, no conv) - note: ambiguous with over+kernel, resolved by INT vs IDENT */
142+
/* This case is handled by the over+kernel rule when offset is an IDENT */
143+
/* stride*over (no offset, no conv) */
144+
| stride = INT; STAR; over = IDENT
145+
{ Affine_spec { stride; over_label = over; stride_offset = 0; conv = None } }
124146

125147
/* List of axis specifications - can be empty, allows trailing comma */
126148
axes_spec:

tensor/shape.ml

Lines changed: 31 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,6 @@ type terminal_type = Data of Ir.Assignments.init_data | Fetch of Ir.Assignments.
6767
type ternary_type = Pointwise_tern | Compose_accumulate | Defined_by_cd_logic
6868
[@@deriving sexp, equal]
6969

70-
let axis_labels_of_spec spec =
71-
try Einsum_parser.axis_labels_of_spec spec
72-
with Einsum_parser.Parse_error msg ->
73-
raise
74-
@@ Utils.User_error ("Shape.axis_labels_of_spec: while parsing: " ^ spec ^ " error: " ^ msg)
75-
7670
let einsum_of_spec spec =
7771
try Einsum_parser.einsum_of_spec spec
7872
with Einsum_parser.Parse_error msg ->
@@ -221,28 +215,22 @@ let einsum_slot_spec_to_dims_bio ~original_spec ~sh_id ~row_var_env ~dim_var_env
221215
}
222216
:: !extras;
223217
d
224-
| Conv_spec { stride; output_label; dilation; kernel_label } ->
218+
| Affine_spec { stride; over_label; conv; stride_offset } ->
225219
let over_dim =
226220
Row.Var
227-
(Hashtbl.find_or_add dim_var_env output_label ~default:(fun () ->
228-
Row.get_var ~name:output_label ()))
221+
(Hashtbl.find_or_add dim_var_env over_label ~default:(fun () ->
222+
Row.get_var ~name:over_label ()))
229223
in
230224
let conv =
231-
if String.equal kernel_label "_stride_only" then
232-
(* For strided iteration (dilation=0), no convolution *)
233-
None
234-
else if String.equal kernel_label "_offset_only" then
235-
(* For offset-only iteration (dilation=offset), no convolution *)
236-
None
237-
else
238-
let kernel =
239-
Row.Var
240-
(Hashtbl.find_or_add dim_var_env kernel_label ~default:(fun () ->
241-
Row.get_var ~name:kernel_label ()))
242-
in
243-
Some { Row.dilation; kernel; use_padding = !Row.use_padding }
225+
Option.map conv ~f:(fun { dilation; kernel_label } ->
226+
let kernel =
227+
Row.Var
228+
(Hashtbl.find_or_add dim_var_env kernel_label ~default:(fun () ->
229+
Row.get_var ~name:kernel_label ()))
230+
in
231+
{ Row.dilation; kernel; use_padding = !Row.use_padding })
244232
in
245-
Row.Affine { stride; over = over_dim; conv; stride_offset = 0 }
233+
Row.Affine { stride; over = over_dim; conv; stride_offset }
246234
in
247235
let result = axes_spec_to_dims_bio ~sh_id ~row_var_env ~dim_var_env ~f labels in
248236
(!extras, !proj_env_update, result)
@@ -1790,28 +1778,22 @@ let shape_spec_to_dims_bio labels =
17901778
| Label name ->
17911779
Var (Hashtbl.find_or_add dim_var_env name ~default:(fun () -> Row.get_var ~name ()))
17921780
| Fixed_index d -> Row.get_dim ~d ()
1793-
| Conv_spec { stride; output_label; dilation; kernel_label } ->
1781+
| Affine_spec { stride; over_label; conv; stride_offset } ->
17941782
let over_dim =
17951783
Row.Var
1796-
(Hashtbl.find_or_add dim_var_env output_label ~default:(fun () ->
1797-
Row.get_var ~name:output_label ()))
1784+
(Hashtbl.find_or_add dim_var_env over_label ~default:(fun () ->
1785+
Row.get_var ~name:over_label ()))
17981786
in
17991787
let conv =
1800-
if String.equal kernel_label "_stride_only" then
1801-
(* For strided iteration (dilation=0), no convolution *)
1802-
None
1803-
else if String.equal kernel_label "_offset_only" then
1804-
(* For offset-only iteration (dilation=offset), no convolution *)
1805-
None
1806-
else
1807-
let kernel =
1808-
Row.Var
1809-
(Hashtbl.find_or_add dim_var_env kernel_label ~default:(fun () ->
1810-
Row.get_var ~name:kernel_label ()))
1811-
in
1812-
Some { Row.dilation; kernel; use_padding = !Row.use_padding }
1788+
Option.map conv ~f:(fun { dilation; kernel_label } ->
1789+
let kernel =
1790+
Row.Var
1791+
(Hashtbl.find_or_add dim_var_env kernel_label ~default:(fun () ->
1792+
Row.get_var ~name:kernel_label ()))
1793+
in
1794+
{ Row.dilation; kernel; use_padding = !Row.use_padding })
18131795
in
1814-
Row.Affine { stride; over = over_dim; conv; stride_offset = 0 }
1796+
Row.Affine { stride; over = over_dim; conv; stride_offset }
18151797
in
18161798
let row_var_env = Hashtbl.create (module String) in
18171799
axes_spec_to_dims_bio ~row_var_env ~dim_var_env ~f labels
@@ -1940,3 +1922,12 @@ let%debug5_sexp default_display_indices (sh : t) : int array =
19401922
in
19411923
let axes = loop 1 axes in
19421924
axis_map_to_dims_index axes
1925+
1926+
let parse_n5_layout priorities =
1927+
let f : Einsum_parser.axis_spec -> int = function
1928+
| Fixed_index i -> i
1929+
| Label _ -> invalid_arg "parse_n5_layout requires integer-only labels"
1930+
| Affine_spec _ -> invalid_arg "parse_n5_layout does not support affine expressions"
1931+
in
1932+
let p_labels = Einsum_parser.(axis_labels @@ axis_labels_of_spec priorities) in
1933+
axis_map_to_dims_index p_labels |> Array.map ~f

tensor/shape.mli

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,6 @@ open Base
5858

5959
type padding = Row.axis_padding array option [@@deriving sexp, equal]
6060

61-
(** Specification for individual axes in the einsum notation. *)
62-
type axis_spec =
63-
| Label of string (** A variable axis label. *)
64-
| Fixed_index of int (** A fixed index, used for projection. *)
65-
| Conv_spec of { stride : int; output_label : string; dilation : int; kernel_label : string }
66-
(** Convolution-style axis specification: stride*output + dilation*kernel. *)
67-
[@@deriving compare, sexp]
68-
6961
type t = {
7062
mutable batch : Row.t;
7163
mutable input : Row.t;
@@ -246,9 +238,6 @@ val default_display_indices : t -> int array
246238
val to_labels : t -> string array
247239
(** Uses the matrix convention of putting the input axes last. *)
248240

249-
type 'a axis_map
250-
type parsed_axis_labels [@@deriving sexp]
251-
252-
val axis_labels : parsed_axis_labels -> axis_spec axis_map
253-
val axis_labels_of_spec : string -> parsed_axis_labels
254-
val axis_map_to_dims_index : ?default:'a -> 'a axis_map -> 'a array
241+
val parse_n5_layout : string -> int array
242+
(** Parse a N5_layout priority string (e.g., "0,1,2") into display indices.
243+
Only supports integer labels (Fixed_index). *)

tensor/tensor.ml

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -880,14 +880,7 @@ let%debug5_sexp to_doc ~force ~with_grad ~with_code ?(with_low_level = false)
880880
let indices =
881881
match style with
882882
| `Default -> Shape.default_display_indices sh
883-
| `N5_layout priorities ->
884-
let f : Shape.axis_spec -> int = function
885-
| Shape.Fixed_index i -> i
886-
| Shape.Label _ -> invalid_arg "`N5_layout requires integer-only labels"
887-
| Shape.Conv_spec _ -> invalid_arg "`N5_layout does not support conv expressions"
888-
in
889-
let p_labels = Shape.(axis_labels @@ axis_labels_of_spec priorities) in
890-
(Shape.axis_map_to_dims_index p_labels : Shape.axis_spec array) |> Array.map ~f
883+
| `N5_layout priorities -> Shape.parse_n5_layout priorities
891884
| `Label_layout label_idcs ->
892885
let inv_labels =
893886
Array.mapi labels ~f:(fun i l -> (l, i)) |> Array.to_list |> Map.of_alist (module String)

test/einsum/dune

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,7 @@
3030
ocannl_config
3131
(env_var OCANNL_BACKEND))
3232
(modules test_conv_syntax)
33-
(libraries ocannl)
34-
(preprocess
35-
(pps ppx_here ppx_ocannl)))
33+
(libraries base einsum_parser stdio))
3634

3735
(test
3836
(name test_print_style)

test/einsum/test_conv_syntax.expected

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,30 @@
1-
Retrieving commandline, environment, or config file variable ocannl_log_level
2-
Found 0, in the config file
31
Testing conv syntax parsing...
42
Test 1: Parsed '2*o+3*k' successfully
53
Structure: ((bcast_batch ()) (bcast_input ()) (bcast_output ()) (given_batch 0)
64
(given_input 0) (given_output 1) (given_beg_batch 0) (given_beg_input 0)
75
(given_beg_output 0)
86
(labels
97
((((in_axes Output) (pos 1) (from_end true))
10-
(Conv_spec (stride 2) (output_label o) (dilation 3) (kernel_label k))))))
8+
(Affine_spec (stride 2) (over_label o)
9+
(conv (((dilation 3) (kernel_label k)))) (stride_offset 0))))))
1110

1211
Test 2: Parsed 'o+k' successfully
1312
Structure: ((bcast_batch ()) (bcast_input ()) (bcast_output ()) (given_batch 0)
1413
(given_input 0) (given_output 1) (given_beg_batch 0) (given_beg_input 0)
1514
(given_beg_output 0)
1615
(labels
1716
((((in_axes Output) (pos 1) (from_end true))
18-
(Conv_spec (stride 1) (output_label o) (dilation 1) (kernel_label k))))))
17+
(Affine_spec (stride 1) (over_label o)
18+
(conv (((dilation 1) (kernel_label k)))) (stride_offset 0))))))
1919

2020
Test 3: Parsed 'a, 2*b+c' successfully
2121
Structure: ((bcast_batch ()) (bcast_input ()) (bcast_output ()) (given_batch 0)
2222
(given_input 0) (given_output 2) (given_beg_batch 0) (given_beg_input 0)
2323
(given_beg_output 0)
2424
(labels
2525
((((in_axes Output) (pos 1) (from_end true))
26-
(Conv_spec (stride 2) (output_label b) (dilation 1) (kernel_label c)))
26+
(Affine_spec (stride 2) (over_label b)
27+
(conv (((dilation 1) (kernel_label c)))) (stride_offset 0)))
2728
(((in_axes Output) (pos 2) (from_end true)) (Label a)))))
2829

2930
Test 4: Parsed 'i, o+k, j' successfully (multichar mode)
@@ -33,7 +34,8 @@ Test 4: Parsed 'i, o+k, j' successfully (multichar mode)
3334
(labels
3435
((((in_axes Output) (pos 1) (from_end true)) (Label j))
3536
(((in_axes Output) (pos 2) (from_end true))
36-
(Conv_spec (stride 1) (output_label o) (dilation 1) (kernel_label k)))
37+
(Affine_spec (stride 1) (over_label o)
38+
(conv (((dilation 1) (kernel_label k)))) (stride_offset 0)))
3739
(((in_axes Output) (pos 3) (from_end true)) (Label i)))))
3840

3941
Test 5: Parsed 'a+bc' successfully (multichar mode)
@@ -42,15 +44,17 @@ Test 5: Parsed 'a+bc' successfully (multichar mode)
4244
(given_beg_output 0)
4345
(labels
4446
((((in_axes Output) (pos 1) (from_end true))
45-
(Conv_spec (stride 1) (output_label a) (dilation 1) (kernel_label bc))))))
47+
(Affine_spec (stride 1) (over_label a)
48+
(conv (((dilation 1) (kernel_label bc)))) (stride_offset 0))))))
4649

4750
Test 6: Parsed 'i, j -> 2*i+j' successfully
4851
Structure: ((bcast_batch ()) (bcast_input ()) (bcast_output ()) (given_batch 0)
4952
(given_input 2) (given_output 1) (given_beg_batch 0) (given_beg_input 0)
5053
(given_beg_output 0)
5154
(labels
5255
((((in_axes Output) (pos 1) (from_end true))
53-
(Conv_spec (stride 2) (output_label i) (dilation 1) (kernel_label j)))
56+
(Affine_spec (stride 2) (over_label i)
57+
(conv (((dilation 1) (kernel_label j)))) (stride_offset 0)))
5458
(((in_axes Input) (pos 1) (from_end true)) (Label j))
5559
(((in_axes Input) (pos 2) (from_end true)) (Label i)))))
5660

@@ -61,8 +65,8 @@ Test 7: Parsed 'batch|input->3*output+1*kernel,' successfully
6165
(labels
6266
((((in_axes Batch) (pos 1) (from_end true)) (Label batch))
6367
(((in_axes Output) (pos 1) (from_end true))
64-
(Conv_spec (stride 3) (output_label output) (dilation 1)
65-
(kernel_label kernel)))
68+
(Affine_spec (stride 3) (over_label output)
69+
(conv (((dilation 1) (kernel_label kernel)))) (stride_offset 0)))
6670
(((in_axes Input) (pos 1) (from_end true)) (Label input)))))
6771

6872
All conv syntax parsing tests passed!
@@ -74,26 +78,23 @@ Test 1: Parsed strided iteration '2*output' successfully
7478
(given_beg_output 0)
7579
(labels
7680
((((in_axes Output) (pos 1) (from_end true))
77-
(Conv_spec (stride 2) (output_label output) (dilation 0)
78-
(kernel_label _stride_only))))))
81+
(Affine_spec (stride 2) (over_label output) (conv ()) (stride_offset 0))))))
7982

8083
Test 2: Parsed strided iteration '3*i' successfully
8184
Structure: ((bcast_batch ()) (bcast_input ()) (bcast_output ()) (given_batch 0)
8285
(given_input 0) (given_output 1) (given_beg_batch 0) (given_beg_input 0)
8386
(given_beg_output 0)
8487
(labels
8588
((((in_axes Output) (pos 1) (from_end true))
86-
(Conv_spec (stride 3) (output_label i) (dilation 0)
87-
(kernel_label _stride_only))))))
89+
(Affine_spec (stride 3) (over_label i) (conv ()) (stride_offset 0))))))
8890

8991
Test 3: Parsed einsum with strided iteration 'input -> 2*output' successfully
9092
Structure: ((bcast_batch ()) (bcast_input ()) (bcast_output ()) (given_batch 0)
9193
(given_input 1) (given_output 1) (given_beg_batch 0) (given_beg_input 0)
9294
(given_beg_output 0)
9395
(labels
9496
((((in_axes Output) (pos 1) (from_end true))
95-
(Conv_spec (stride 2) (output_label output) (dilation 0)
96-
(kernel_label _stride_only)))
97+
(Affine_spec (stride 2) (over_label output) (conv ()) (stride_offset 0)))
9798
(((in_axes Input) (pos 1) (from_end true)) (Label input)))))
9899

99100
Test 4: Parsed mixed labels with strided iteration 'regular, 3*strided' successfully
@@ -102,8 +103,7 @@ Test 4: Parsed mixed labels with strided iteration 'regular, 3*strided' successf
102103
(given_beg_output 0)
103104
(labels
104105
((((in_axes Output) (pos 1) (from_end true))
105-
(Conv_spec (stride 3) (output_label strided) (dilation 0)
106-
(kernel_label _stride_only)))
106+
(Affine_spec (stride 3) (over_label strided) (conv ()) (stride_offset 0)))
107107
(((in_axes Output) (pos 2) (from_end true)) (Label regular)))))
108108

109109

@@ -128,7 +128,8 @@ Conv spec 'a+b': ((bcast_batch ()) (bcast_input ()) (bcast_output ()) (given_bat
128128
(given_beg_output 0)
129129
(labels
130130
((((in_axes Output) (pos 1) (from_end true))
131-
(Conv_spec (stride 1) (output_label a) (dilation 1) (kernel_label b))))))
131+
(Affine_spec (stride 1) (over_label a)
132+
(conv (((dilation 1) (kernel_label b)))) (stride_offset 0))))))
132133
Note: Conv expressions with + or * now always use multichar mode
133134

134135
All conv syntax tests completed!

0 commit comments

Comments
 (0)