Skip to content

Commit 54e0262

Browse files
committed
Fixed conv spec parsing in einsum and cleaned up einsum parsing overall
1 parent b02a785 commit 54e0262

File tree

3 files changed

+215
-90
lines changed

3 files changed

+215
-90
lines changed

lib/shape.ml

Lines changed: 62 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -98,93 +98,62 @@ type transpose_type =
9898

9999
type ternary_type = Pointwise_tern | Compose_accumulate [@@deriving sexp, equal]
100100

101-
let identifier_multichar = Angstrom.take_while1 Char.is_alphanum
101+
let identifier ~multichar =
102+
let open Angstrom in
103+
if multichar then lift2 ( ^ ) (take_while1 Char.is_alpha) (take_while Char.is_alphanum)
104+
else Angstrom.satisfy Char.is_alpha >>| Char.to_string
105+
106+
let integer = Angstrom.(take_while1 Char.is_digit >>| Int.of_string)
107+
108+
let scaled_identifier ~multichar =
109+
let open Angstrom in
110+
integer <* char '*'
111+
>>= (fun coeff -> identifier ~multichar >>| fun id -> (coeff, id))
112+
<|> (identifier ~multichar >>| fun id -> (1, id))
113+
114+
let conv_term ~multichar =
115+
let open Angstrom in
116+
let* stride, output_label = scaled_identifier ~multichar in
117+
char '+' *> scaled_identifier ~multichar >>| fun (dilation, kernel_label) ->
118+
Conv_spec { stride; output_label; dilation; kernel_label }
102119

103-
let opt_separators : _ Angstrom.t =
104-
Angstrom.take_while (fun c -> Char.is_whitespace c || Char.equal c ',')
120+
let opt_separators = Angstrom.take_while (fun c -> Char.is_whitespace c || Char.equal c ',')
105121

106122
let separators_with_comma =
107123
let open Angstrom in
108124
let* sep = opt_separators in
109125
if String.contains sep ',' then return () else fail "comma expected"
110126

111127
(** Parse a single axis specification that can be a label, fixed index, or conv expression. *)
112-
let parse_single_axis_spec ~multichar : string Angstrom.t =
128+
let parse_single_axis_spec ~multichar =
113129
let open Angstrom in
114-
if multichar then
115-
(* In multichar mode, try to parse conv expressions first, then fall back to regular identifiers *)
116-
let integer = take_while1 Char.is_digit in
117-
let identifier = identifier_multichar in
118-
let conv_term =
119-
(integer <* char '*' >>= fun coeff -> identifier >>| fun id -> coeff ^ "*" ^ id)
120-
<|> identifier
121-
in
122-
let conv_expr =
123-
conv_term >>= fun first_term ->
124-
char '+' *> conv_term >>| fun second_term ->
125-
first_term ^ "+" ^ second_term
126-
in
127-
choice [ conv_expr; integer; identifier ]
128-
else
129-
(* In single-char mode, conv expressions are not supported since they would be ambiguous *)
130-
take 1
131-
132-
(** Convert a string to axis_spec, handling conv expressions. *)
133-
let string_to_axis_spec (s : string) : axis_spec =
134-
if String.contains s '*' || String.contains s '+' then
135-
(* Parse as conv expression *)
136-
let parse_conv_expr spec =
137-
let open Angstrom in
138-
let integer = take_while1 Char.is_digit >>| Int.of_string in
139-
let identifier = take_while1 Char.is_alphanum in
140-
let term =
141-
(integer <* char '*' >>= fun coeff -> identifier >>| fun id -> (coeff, id))
142-
<|> (identifier >>| fun id -> (1, id))
143-
in
144-
let conv_expr =
145-
term >>= fun (stride, output_label) ->
146-
char '+' *> term >>| fun (dilation, kernel_label) ->
147-
Conv_spec { stride; output_label; dilation; kernel_label }
148-
in
149-
match parse_string ~consume:Consume.All conv_expr spec with
150-
| Ok result -> result
151-
| Error _ ->
152-
(* If conv parsing fails, treat as regular label *)
153-
Label s
154-
in
155-
parse_conv_expr s
156-
else
157-
(* Try to parse as integer first, then as label *)
158-
try Fixed_index (Int.of_string s)
159-
with _ -> Label s
130+
choice
131+
[
132+
conv_term ~multichar <?> "conv_term";
133+
integer >>| (fun i -> Fixed_index i) <?> "fixed_index";
134+
identifier ~multichar >>| (fun s -> Label s) <?> "label";
135+
]
160136

161137
let axes_spec ~from_end ~multichar : _ Angstrom.t =
162138
let open Angstrom in
163-
let single_char (pos, acc) c =
164-
Option.some_if (Char.is_alphanum c) (pos + 1, (pos, Char.to_string c) :: acc)
165-
in
166139
let result =
167140
let p n i = if from_end then n - i else i + 1 in
168-
if multichar then
169-
lift (fun l ->
170-
let n = List.length l in
171-
List.mapi l ~f:(fun i v -> (p n i, v)))
172-
@@ sep_by1 separators_with_comma (parse_single_axis_spec ~multichar)
173-
else
174-
lift (fun (_, acc) ->
175-
let n = List.length acc in
176-
List.rev_map acc ~f:(fun (i, v) -> (p n i, v)))
177-
@@ scan_state (0, []) single_char
141+
lift (fun l ->
142+
let n = List.length l in
143+
List.mapi l ~f:(fun i v -> (p n i, v)))
144+
@@ sep_by1
145+
(if multichar then separators_with_comma else opt_separators >>| ignore)
146+
(parse_single_axis_spec ~multichar)
178147
in
179-
opt_separators *> result <* opt_separators
148+
opt_separators *> result <* opt_separators <?> "axes_spec"
180149

181150
let axis_labels_of_spec_parser ~multichar : parsed_axis_labels Angstrom.t =
182151
let open Angstrom in
183152
let combine ~key:_ _v1 _v2 = assert false in
184153
let axes_spec ~from_end =
185154
axes_spec ~from_end ~multichar <?> if from_end then "axes_spec" else "axes_spec_beg"
186155
in
187-
let ellipsis_spec = string "..." <|> (string ".." *> identifier_multichar <* string "..") in
156+
let ellipsis_spec = string "..." <|> (string ".." *> identifier ~multichar <* string "..") in
188157
let ellipsis_spec = ellipsis_spec <?> "ellipsis_spec" in
189158
let for_row ~kind in_axes beg_axes row_var_spec end_axes =
190159
let f from_end (pos, label) = (AxisKey.{ in_axes; pos; from_end }, label) in
@@ -198,12 +167,12 @@ let axis_labels_of_spec_parser ~multichar : parsed_axis_labels Angstrom.t =
198167
let parse_row ~kind in_axes =
199168
let row = lift3 (for_row ~kind in_axes) in
200169
opt_separators
201-
*> (row (return []) (lift Option.some ellipsis_spec) (axes_spec ~from_end:true)
202-
<|> row (axes_spec ~from_end:false) (lift Option.some ellipsis_spec)
203-
(axes_spec ~from_end:true)
170+
*> (row (axes_spec ~from_end:false) (lift Option.some ellipsis_spec) (axes_spec ~from_end:true)
171+
<|> row (return []) (lift Option.some ellipsis_spec) (axes_spec ~from_end:true)
172+
<|> row (axes_spec ~from_end:false) (lift Option.some ellipsis_spec) (return [])
204173
<|> row (return []) (return None) (axes_spec ~from_end:true)
205174
<|> row (return []) (lift Option.some ellipsis_spec) (return []))
206-
<* opt_separators
175+
<* opt_separators <?> "row_spec"
207176
in
208177
let default = Option.value ~default:(None, 0, 0, Map.empty (module AxisKey)) in
209178
let shape = lift3 (fun batch input output -> (default batch, default input, output)) in
@@ -217,10 +186,10 @@ let axis_labels_of_spec_parser ~multichar : parsed_axis_labels Angstrom.t =
217186
<|> shape (p_b <* char '|') (p_i <* string "->") p_o
218187
<|> shape (p_b <* char '|') (return None) p_o
219188
<|> shape (return None) (return None) p_o
189+
<?> "shape_spec"
220190
in
221191
let labels =
222192
Map.merge_skewed ~combine input_labels @@ Map.merge_skewed ~combine output_labels batch_labels
223-
|> Map.map ~f:string_to_axis_spec
224193
in
225194
{
226195
bcast_batch;
@@ -236,7 +205,7 @@ let axis_labels_of_spec_parser ~multichar : parsed_axis_labels Angstrom.t =
236205
}
237206

238207
let axis_labels_of_spec spec =
239-
let multichar = String.contains spec ',' || String.contains spec '*' || String.contains spec '+' in
208+
let multichar = String.contains spec ',' in
240209
match
241210
Angstrom.(
242211
parse_string ~consume:Consume.All (axis_labels_of_spec_parser ~multichar <* end_of_input) spec)
@@ -255,9 +224,10 @@ let einsum_of_spec_parser ~multichar : _ Angstrom.t =
255224
(p <?> "RHS2")
256225
(string "=>" *> (p <?> "LHS"))
257226
<|> lift2 (fun a c -> (a, None, c)) (p <?> "RHS") (string "=>" *> (p <?> "LHS"))
227+
<?> "einsum_spec"
258228

259229
let einsum_of_spec spec =
260-
let multichar = String.contains spec ',' || String.contains spec '*' || String.contains spec '+' in
230+
let multichar = String.contains spec ',' in
261231
match
262232
Angstrom.(
263233
parse_string ~consume:Consume.All (einsum_of_spec_parser ~multichar <* end_of_input) spec)
@@ -391,8 +361,16 @@ let einsum_slot_spec_to_dims_bio ~generative ~sh_id ~row_var_env ~dim_var_env la
391361
extras := Row.Dim_constr { d; constr = At_least_dim (i + 1) } :: !extras;
392362
d
393363
| Conv_spec { stride; output_label; dilation; kernel_label } ->
394-
let output_dim = Row.Var (Hashtbl.find_or_add dim_var_env output_label ~default:(fun () -> Row.get_var ~label:output_label ())) in
395-
let kernel_dim = Row.Var (Hashtbl.find_or_add dim_var_env kernel_label ~default:(fun () -> Row.get_var ~label:kernel_label ())) in
364+
let output_dim =
365+
Row.Var
366+
(Hashtbl.find_or_add dim_var_env output_label ~default:(fun () ->
367+
Row.get_var ~label:output_label ()))
368+
in
369+
let kernel_dim =
370+
Row.Var
371+
(Hashtbl.find_or_add dim_var_env kernel_label ~default:(fun () ->
372+
Row.get_var ~label:kernel_label ()))
373+
in
396374
Row.Conv_input { stride; output = output_dim; dilation; kernel = kernel_dim }
397375
in
398376
let result = axes_spec_to_dims_bio ~sh_id ~row_var_env ~dim_var_env ~f labels in
@@ -934,8 +912,16 @@ let shape_spec_to_dims_bio labels =
934912
Var (Hashtbl.find_or_add dim_var_env label ~default:(fun () -> Row.get_var ~label ()))
935913
| Fixed_index d -> Row.get_dim ~d ()
936914
| Conv_spec { stride; output_label; dilation; kernel_label } ->
937-
let output_dim = Row.Var (Hashtbl.find_or_add dim_var_env output_label ~default:(fun () -> Row.get_var ~label:output_label ())) in
938-
let kernel_dim = Row.Var (Hashtbl.find_or_add dim_var_env kernel_label ~default:(fun () -> Row.get_var ~label:kernel_label ())) in
915+
let output_dim =
916+
Row.Var
917+
(Hashtbl.find_or_add dim_var_env output_label ~default:(fun () ->
918+
Row.get_var ~label:output_label ()))
919+
in
920+
let kernel_dim =
921+
Row.Var
922+
(Hashtbl.find_or_add dim_var_env kernel_label ~default:(fun () ->
923+
Row.get_var ~label:kernel_label ()))
924+
in
939925
Row.Conv_input { stride; output = output_dim; dilation; kernel = kernel_dim }
940926
in
941927
let row_var_env = Hashtbl.create (module String) in
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
2+
Welcome to OCANNL! Reading configuration defaults from /Users/lukstafi/ocannl/_build/default/test/einsum/ocannl_config.
3+
Retrieving commandline, environment, or config file variable ocannl_log_level
4+
Found 0, in the config file
5+
Testing conv syntax parsing...
6+
Test 1: Parsed '2*o+3*k' successfully
7+
Structure: ((bcast_batch ()) (bcast_input ()) (bcast_output ()) (given_batch 0)
8+
(given_input 0) (given_output 1) (given_beg_batch 0) (given_beg_input 0)
9+
(given_beg_output 0)
10+
(labels
11+
((((in_axes Output) (pos 1) (from_end true))
12+
(Conv_spec (stride 2) (output_label o) (dilation 3) (kernel_label k))))))
13+
14+
Test 2: Parsed 'o+k' successfully
15+
Structure: ((bcast_batch ()) (bcast_input ()) (bcast_output ()) (given_batch 0)
16+
(given_input 0) (given_output 1) (given_beg_batch 0) (given_beg_input 0)
17+
(given_beg_output 0)
18+
(labels
19+
((((in_axes Output) (pos 1) (from_end true))
20+
(Conv_spec (stride 1) (output_label o) (dilation 1) (kernel_label k))))))
21+
22+
Test 3: Parsed 'a,2*b+c' successfully
23+
Structure: ((bcast_batch ()) (bcast_input ()) (bcast_output ()) (given_batch 0)
24+
(given_input 0) (given_output 2) (given_beg_batch 0) (given_beg_input 0)
25+
(given_beg_output 0)
26+
(labels
27+
((((in_axes Output) (pos 1) (from_end true))
28+
(Conv_spec (stride 2) (output_label b) (dilation 1) (kernel_label c)))
29+
(((in_axes Output) (pos 2) (from_end true)) (Label a)))))
30+
31+
Test 4: Parsed 'io+kj' successfully (single-char mode)
32+
Structure: ((bcast_batch ()) (bcast_input ()) (bcast_output ()) (given_batch 0)
33+
(given_input 0) (given_output 3) (given_beg_batch 0) (given_beg_input 0)
34+
(given_beg_output 0)
35+
(labels
36+
((((in_axes Output) (pos 1) (from_end true)) (Label j))
37+
(((in_axes Output) (pos 2) (from_end true))
38+
(Conv_spec (stride 1) (output_label o) (dilation 1) (kernel_label k)))
39+
(((in_axes Output) (pos 3) (from_end true)) (Label i)))))
40+
41+
Test 5: Parsed 'a+bc' successfully (single-char mode)
42+
Structure: ((bcast_batch ()) (bcast_input ()) (bcast_output ()) (given_batch 0)
43+
(given_input 0) (given_output 2) (given_beg_batch 0) (given_beg_input 0)
44+
(given_beg_output 0)
45+
(labels
46+
((((in_axes Output) (pos 1) (from_end true)) (Label c))
47+
(((in_axes Output) (pos 2) (from_end true))
48+
(Conv_spec (stride 1) (output_label a) (dilation 1) (kernel_label b))))))
49+
50+
Test 6: Parsed 'i,j->2*i+j' successfully
51+
Structure: ((bcast_batch ()) (bcast_input ()) (bcast_output ()) (given_batch 0)
52+
(given_input 2) (given_output 1) (given_beg_batch 0) (given_beg_input 0)
53+
(given_beg_output 0)
54+
(labels
55+
((((in_axes Output) (pos 1) (from_end true))
56+
(Conv_spec (stride 2) (output_label i) (dilation 1) (kernel_label j)))
57+
(((in_axes Input) (pos 1) (from_end true)) (Label j))
58+
(((in_axes Input) (pos 2) (from_end true)) (Label i)))))
59+
60+
Test 7: Parsed 'batch|input->3*output+1*kernel,' successfully
61+
Structure: ((bcast_batch ()) (bcast_input ()) (bcast_output ()) (given_batch 1)
62+
(given_input 1) (given_output 1) (given_beg_batch 0) (given_beg_input 0)
63+
(given_beg_output 0)
64+
(labels
65+
((((in_axes Batch) (pos 1) (from_end true)) (Label batch))
66+
(((in_axes Output) (pos 1) (from_end true))
67+
(Conv_spec (stride 3) (output_label output) (dilation 1)
68+
(kernel_label kernel)))
69+
(((in_axes Input) (pos 1) (from_end true)) (Label input)))))
70+
71+
All conv syntax parsing tests passed!
72+
73+
Testing multichar mode detection...
74+
✓ Multichar spec 'a,b' parsed correctly
75+
✓ Multichar spec '2*o+k' parsed correctly
76+
✓ Multichar spec 'o+k' parsed correctly
77+
✓ Multichar spec 'a,2*b+c' parsed correctly
78+
✓ Single-char spec 'abc' parsed correctly
79+
✓ Single-char spec 'ijk' parsed correctly
80+
✓ Single-char spec 'i->j' parsed correctly
81+
✓ Single-char spec 'io+kj' parsed correctly
82+
✓ Single-char spec 'a+bc' parsed correctly
83+
✓ Single-char spec '...|ij' parsed correctly
84+
✓ Single-char spec 'j...' parsed correctly
85+
✓ Single-char spec '...|j...->i' parsed correctly
86+
✓ Single-char spec '...|i->1' parsed correctly
87+
88+
Testing single-char conv equivalence...
89+
Single-char 'a+b': ((bcast_batch ()) (bcast_input ()) (bcast_output ()) (given_batch 0)
90+
(given_input 0) (given_output 1) (given_beg_batch 0) (given_beg_input 0)
91+
(given_beg_output 0)
92+
(labels
93+
((((in_axes Output) (pos 1) (from_end true))
94+
(Conv_spec (stride 1) (output_label a) (dilation 1) (kernel_label b))))))
95+
Multi-char 'a+b': ((bcast_batch ()) (bcast_input ()) (bcast_output ()) (given_batch 0)
96+
(given_input 0) (given_output 1) (given_beg_batch 0) (given_beg_input 0)
97+
(given_beg_output 0)
98+
(labels
99+
((((in_axes Output) (pos 1) (from_end true))
100+
(Conv_spec (stride 1) (output_label a) (dilation 1) (kernel_label b))))))
101+
Note: Both should produce the same Conv_spec structure
102+
103+
All conv syntax tests completed!

0 commit comments

Comments
 (0)