@@ -98,93 +98,62 @@ type transpose_type =
9898
9999type ternary_type = Pointwise_tern | Compose_accumulate [@@ deriving sexp , equal ]
100100
101- let identifier_multichar = Angstrom. take_while1 Char. is_alphanum
101+ let identifier ~multichar =
102+ let open Angstrom in
103+ if multichar then lift2 ( ^ ) (take_while1 Char. is_alpha) (take_while Char. is_alphanum)
104+ else Angstrom. satisfy Char. is_alpha >> | Char. to_string
105+
106+ let integer = Angstrom. (take_while1 Char. is_digit >> | Int. of_string)
107+
108+ let scaled_identifier ~multichar =
109+ let open Angstrom in
110+ integer < * char '*'
111+ >> = (fun coeff -> identifier ~multichar >> | fun id -> (coeff, id))
112+ < |> (identifier ~multichar >> | fun id -> (1 , id))
113+
114+ let conv_term ~multichar =
115+ let open Angstrom in
116+ let * stride, output_label = scaled_identifier ~multichar in
117+ char '+' *> scaled_identifier ~multichar >> | fun (dilation , kernel_label ) ->
118+ Conv_spec { stride; output_label; dilation; kernel_label }
102119
103- let opt_separators : _ Angstrom.t =
104- Angstrom. take_while (fun c -> Char. is_whitespace c || Char. equal c ',' )
120+ let opt_separators = Angstrom. take_while (fun c -> Char. is_whitespace c || Char. equal c ',' )
105121
106122let separators_with_comma =
107123 let open Angstrom in
108124 let * sep = opt_separators in
109125 if String. contains sep ',' then return () else fail " comma expected"
110126
111127(* * Parse a single axis specification that can be a label, fixed index, or conv expression. *)
112- let parse_single_axis_spec ~multichar : string Angstrom. t =
128+ let parse_single_axis_spec ~multichar =
113129 let open Angstrom in
114- if multichar then
115- (* In multichar mode, try to parse conv expressions first, then fall back to regular identifiers *)
116- let integer = take_while1 Char. is_digit in
117- let identifier = identifier_multichar in
118- let conv_term =
119- (integer < * char '*' >> = fun coeff -> identifier >> | fun id -> coeff ^ " *" ^ id)
120- < |> identifier
121- in
122- let conv_expr =
123- conv_term >> = fun first_term ->
124- char '+' *> conv_term >> | fun second_term ->
125- first_term ^ " +" ^ second_term
126- in
127- choice [ conv_expr; integer; identifier ]
128- else
129- (* In single-char mode, conv expressions are not supported since they would be ambiguous *)
130- take 1
131-
132- (* * Convert a string to axis_spec, handling conv expressions. *)
133- let string_to_axis_spec (s : string ) : axis_spec =
134- if String. contains s '*' || String. contains s '+' then
135- (* Parse as conv expression *)
136- let parse_conv_expr spec =
137- let open Angstrom in
138- let integer = take_while1 Char. is_digit >> | Int. of_string in
139- let identifier = take_while1 Char. is_alphanum in
140- let term =
141- (integer < * char '*' >> = fun coeff -> identifier >> | fun id -> (coeff, id))
142- < |> (identifier >> | fun id -> (1 , id))
143- in
144- let conv_expr =
145- term >> = fun (stride , output_label ) ->
146- char '+' *> term >> | fun (dilation , kernel_label ) ->
147- Conv_spec { stride; output_label; dilation; kernel_label }
148- in
149- match parse_string ~consume: Consume. All conv_expr spec with
150- | Ok result -> result
151- | Error _ ->
152- (* If conv parsing fails, treat as regular label *)
153- Label s
154- in
155- parse_conv_expr s
156- else
157- (* Try to parse as integer first, then as label *)
158- try Fixed_index (Int. of_string s)
159- with _ -> Label s
130+ choice
131+ [
132+ conv_term ~multichar < ?> " conv_term" ;
133+ integer >> | (fun i -> Fixed_index i) < ?> " fixed_index" ;
134+ identifier ~multichar >> | (fun s -> Label s) < ?> " label" ;
135+ ]
160136
161137let axes_spec ~from_end ~multichar : _ Angstrom. t =
162138 let open Angstrom in
163- let single_char (pos , acc ) c =
164- Option. some_if (Char. is_alphanum c) (pos + 1 , (pos, Char. to_string c) :: acc)
165- in
166139 let result =
167140 let p n i = if from_end then n - i else i + 1 in
168- if multichar then
169- lift (fun l ->
170- let n = List. length l in
171- List. mapi l ~f: (fun i v -> (p n i, v)))
172- @@ sep_by1 separators_with_comma (parse_single_axis_spec ~multichar )
173- else
174- lift (fun (_ , acc ) ->
175- let n = List. length acc in
176- List. rev_map acc ~f: (fun (i , v ) -> (p n i, v)))
177- @@ scan_state (0 , [] ) single_char
141+ lift (fun l ->
142+ let n = List. length l in
143+ List. mapi l ~f: (fun i v -> (p n i, v)))
144+ @@ sep_by1
145+ (if multichar then separators_with_comma else opt_separators >> | ignore)
146+ (parse_single_axis_spec ~multichar )
178147 in
179- opt_separators *> result < * opt_separators
148+ opt_separators *> result < * opt_separators < ?> " axes_spec "
180149
181150let axis_labels_of_spec_parser ~multichar : parsed_axis_labels Angstrom. t =
182151 let open Angstrom in
183152 let combine ~key :_ _v1 _v2 = assert false in
184153 let axes_spec ~from_end =
185154 axes_spec ~from_end ~multichar < ?> if from_end then " axes_spec" else " axes_spec_beg"
186155 in
187- let ellipsis_spec = string " ..." < |> (string " .." *> identifier_multichar < * string " .." ) in
156+ let ellipsis_spec = string " ..." < |> (string " .." *> identifier ~multichar < * string " .." ) in
188157 let ellipsis_spec = ellipsis_spec < ?> " ellipsis_spec" in
189158 let for_row ~kind in_axes beg_axes row_var_spec end_axes =
190159 let f from_end (pos , label ) = (AxisKey. { in_axes; pos; from_end }, label) in
@@ -198,12 +167,12 @@ let axis_labels_of_spec_parser ~multichar : parsed_axis_labels Angstrom.t =
198167 let parse_row ~kind in_axes =
199168 let row = lift3 (for_row ~kind in_axes) in
200169 opt_separators
201- *> (row (return [] ) (lift Option. some ellipsis_spec) (axes_spec ~from_end: true )
202- < |> row (axes_spec ~from_end: false ) (lift Option. some ellipsis_spec)
203- (axes_spec ~from_end: true )
170+ *> (row (axes_spec ~from_end: false ) (lift Option. some ellipsis_spec) (axes_spec ~from_end: true )
171+ < |> row (return [] ) (lift Option. some ellipsis_spec) (axes_spec ~from_end: true )
172+ < |> row (axes_spec ~from_end: false ) (lift Option. some ellipsis_spec) (return [] )
204173 < |> row (return [] ) (return None ) (axes_spec ~from_end: true )
205174 < |> row (return [] ) (lift Option. some ellipsis_spec) (return [] ))
206- < * opt_separators
175+ < * opt_separators < ?> " row_spec "
207176 in
208177 let default = Option. value ~default: (None , 0 , 0 , Map. empty (module AxisKey )) in
209178 let shape = lift3 (fun batch input output -> (default batch, default input, output)) in
@@ -217,10 +186,10 @@ let axis_labels_of_spec_parser ~multichar : parsed_axis_labels Angstrom.t =
217186 < |> shape (p_b < * char '|' ) (p_i < * string " ->" ) p_o
218187 < |> shape (p_b < * char '|' ) (return None ) p_o
219188 < |> shape (return None ) (return None ) p_o
189+ < ?> " shape_spec"
220190 in
221191 let labels =
222192 Map. merge_skewed ~combine input_labels @@ Map. merge_skewed ~combine output_labels batch_labels
223- |> Map. map ~f: string_to_axis_spec
224193 in
225194 {
226195 bcast_batch;
@@ -236,7 +205,7 @@ let axis_labels_of_spec_parser ~multichar : parsed_axis_labels Angstrom.t =
236205 }
237206
238207let axis_labels_of_spec spec =
239- let multichar = String. contains spec ',' || String. contains spec '*' || String. contains spec '+' in
208+ let multichar = String. contains spec ',' in
240209 match
241210 Angstrom. (
242211 parse_string ~consume: Consume. All (axis_labels_of_spec_parser ~multichar < * end_of_input) spec)
@@ -255,9 +224,10 @@ let einsum_of_spec_parser ~multichar : _ Angstrom.t =
255224 (p < ?> " RHS2" )
256225 (string " =>" *> (p < ?> " LHS" ))
257226 < |> lift2 (fun a c -> (a, None , c)) (p < ?> " RHS" ) (string " =>" *> (p < ?> " LHS" ))
227+ < ?> " einsum_spec"
258228
259229let einsum_of_spec spec =
260- let multichar = String. contains spec ',' || String. contains spec '*' || String. contains spec '+' in
230+ let multichar = String. contains spec ',' in
261231 match
262232 Angstrom. (
263233 parse_string ~consume: Consume. All (einsum_of_spec_parser ~multichar < * end_of_input) spec)
@@ -391,8 +361,16 @@ let einsum_slot_spec_to_dims_bio ~generative ~sh_id ~row_var_env ~dim_var_env la
391361 extras := Row. Dim_constr { d; constr = At_least_dim (i + 1 ) } :: ! extras;
392362 d
393363 | Conv_spec { stride; output_label; dilation; kernel_label } ->
394- let output_dim = Row. Var (Hashtbl. find_or_add dim_var_env output_label ~default: (fun () -> Row. get_var ~label: output_label () )) in
395- let kernel_dim = Row. Var (Hashtbl. find_or_add dim_var_env kernel_label ~default: (fun () -> Row. get_var ~label: kernel_label () )) in
364+ let output_dim =
365+ Row. Var
366+ (Hashtbl. find_or_add dim_var_env output_label ~default: (fun () ->
367+ Row. get_var ~label: output_label () ))
368+ in
369+ let kernel_dim =
370+ Row. Var
371+ (Hashtbl. find_or_add dim_var_env kernel_label ~default: (fun () ->
372+ Row. get_var ~label: kernel_label () ))
373+ in
396374 Row. Conv_input { stride; output = output_dim; dilation; kernel = kernel_dim }
397375 in
398376 let result = axes_spec_to_dims_bio ~sh_id ~row_var_env ~dim_var_env ~f labels in
@@ -934,8 +912,16 @@ let shape_spec_to_dims_bio labels =
934912 Var (Hashtbl. find_or_add dim_var_env label ~default: (fun () -> Row. get_var ~label () ))
935913 | Fixed_index d -> Row. get_dim ~d ()
936914 | Conv_spec { stride; output_label; dilation; kernel_label } ->
937- let output_dim = Row. Var (Hashtbl. find_or_add dim_var_env output_label ~default: (fun () -> Row. get_var ~label: output_label () )) in
938- let kernel_dim = Row. Var (Hashtbl. find_or_add dim_var_env kernel_label ~default: (fun () -> Row. get_var ~label: kernel_label () )) in
915+ let output_dim =
916+ Row. Var
917+ (Hashtbl. find_or_add dim_var_env output_label ~default: (fun () ->
918+ Row. get_var ~label: output_label () ))
919+ in
920+ let kernel_dim =
921+ Row. Var
922+ (Hashtbl. find_or_add dim_var_env kernel_label ~default: (fun () ->
923+ Row. get_var ~label: kernel_label () ))
924+ in
939925 Row. Conv_input { stride; output = output_dim; dilation; kernel = kernel_dim }
940926 in
941927 let row_var_env = Hashtbl. create (module String ) in
0 commit comments