Skip to content

Commit 53c007b

Browse files
committed
Embedding of dimensions in tensor expressions: %op syntax extension and state updates, collab with Claude
I took over for the ppx_op.ml part.
1 parent a103908 commit 53c007b

File tree

9 files changed

+197
-19
lines changed

9 files changed

+197
-19
lines changed

lib/operation.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -661,8 +661,8 @@ struct
661661
let ( = ) ?label t1 t2 = eq ?label t1 t2 ()
662662
let ( <> ) ?label t1 t2 = ne ?label t1 t2 ()
663663
let embed_self_id = embed_self_id
664-
let einsum ?label spec t1 t2 = einsum ?label spec t1 t2 ()
665-
let einsum1 ?label spec t1 = einsum1 ?label spec t1 ()
664+
let einsum ?label ?capture_dims spec t1 t2 = einsum ?label ?capture_dims spec t1 t2 ()
665+
let einsum1 ?label ?capture_dims spec t1 = einsum1 ?label ?capture_dims spec t1 ()
666666
let ndarray = ndarray
667667
let uniform ?label () = uniform () ?label ()
668668
let uniform_at ?label counter = uniform_at ?label counter ()

lib/ppx_op.ml

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -216,26 +216,31 @@ let rec translate ~num_configs ~is_toplevel ~opt_label ?label expr =
216216
| [%expr
217217
[%e? expr1]
218218
*+ [%e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ }]
219-
[%e? { pexp_desc = Pexp_construct ({ txt = Lident "::"; _ }, _); _ }]
219+
([%e? { pexp_desc = Pexp_constant (Pconst_string _); _ } as head] :: [%e? rest])
220220
[%e? expr2]]
221221
when String.contains spec_str '>' ->
222-
(* FIXME: introduce inline definitions for new Indexing.variable_ref objects corresponding to
223-
the strings in the list, and pass them as ~capture_dims *)
222+
let capture_vbs, capture_dims_expr = collect_capture_labels ~loc head rest in
224223
let vbs1, e1 = loop expr1 in
225224
let vbs2, e2 = loop expr2 in
226225
let spec = substitute_identifiers_in_einsum_spec ~loc spec_str in
227-
( reduce_vbss [ vbs1; vbs2 ],
228-
[%expr einsum ?label:[%e opt_expr ~loc label] [%e spec] [%e e1] [%e e2]] )
226+
let combined_vbs = reduce_vbss [ vbs1; vbs2; capture_vbs ] in
227+
( combined_vbs,
228+
[%expr
229+
einsum ?label:[%e opt_expr ~loc label] ~capture_dims:[%e capture_dims_expr] [%e spec]
230+
[%e e1] [%e e2]] )
229231
| [%expr
230232
[%e? expr1]
231233
++ [%e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ }]
232-
[%e? { pexp_desc = Pexp_construct ({ txt = Lident "::"; _ }, _); _ }]]
234+
([%e? { pexp_desc = Pexp_constant (Pconst_string _); _ } as head] :: [%e? rest])]
233235
when String.contains spec_str '>' ->
234-
(* FIXME: introduce inline definitions for new Indexing.variable_ref objects corresponding to
235-
the strings in the list, and pass them as ~capture_dims *)
236+
let capture_vbs, capture_dims_expr = collect_capture_labels ~loc head rest in
236237
let vbs1, e1 = loop expr1 in
237238
let spec = substitute_identifiers_in_einsum_spec ~loc spec_str in
238-
(vbs1, [%expr einsum1 ?label:[%e opt_expr ~loc label] [%e spec] [%e e1]])
239+
let combined_vbs = reduce_vbss [ vbs1; capture_vbs ] in
240+
( combined_vbs,
241+
[%expr
242+
einsum1 ?label:[%e opt_expr ~loc label] ~capture_dims:[%e capture_dims_expr] [%e spec]
243+
[%e e1]] )
239244
| { pexp_desc = Pexp_record ([], _); _ } ->
240245
(* Empty record - not a tensor definition *)
241246
(no_vbs, expr)

lib/ppx_shared.ml

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,3 +319,36 @@ let ndarray_op ?axis_labels ?label expr =
319319
[%expr
320320
[%e op] ~batch_dims:[%e edims batch_dims] ~input_dims:[%e edims input_dims]
321321
~output_dims:[%e edims output_dims] ()]
322+
323+
let collect_capture_labels ~loc head rest =
324+
let capture_labels = head :: collect_list [] rest in
325+
let capture_labels, errors =
326+
List.partition_map capture_labels ~f:(function
327+
| { pexp_desc = Pexp_constant (Pconst_string (label, _, _)); pexp_loc; _ } ->
328+
Either.First (pexp_loc, label)
329+
| expr ->
330+
Either.Second
331+
(Ast_builder.Default.pexp_extension ~loc:expr.pexp_loc
332+
@@ Location.error_extensionf ~loc:expr.pexp_loc
333+
"ppx_ocannl %%op: expected a string literal"))
334+
in
335+
let capture_refs, capture_bindings =
336+
List.map capture_labels ~f:(fun (loc, label) ->
337+
let ref_expr =
338+
[%expr
339+
{
340+
Ir.Indexing.ref_label = [%e Ast_builder.Default.estring ~loc label];
341+
solved_dim = None;
342+
}]
343+
in
344+
let binding =
345+
Ast_builder.Default.value_binding ~loc
346+
~pat:(Ast_builder.Default.pvar ~loc label)
347+
~expr:ref_expr
348+
in
349+
(Ast_builder.Default.evar ~loc label, (label, binding)))
350+
|> List.unzip
351+
in
352+
let capture_dims_expr = Ast_builder.Default.elist ~loc (errors @ capture_refs) in
353+
let capture_vbs = Map.of_alist_exn (module String) capture_bindings in
354+
(capture_vbs, capture_dims_expr)

lib/row.ml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,11 @@ type environment = { dim_env : dim_env; row_env : row_env } [@@deriving sexp_of]
220220
let get_dim_val env var =
221221
match Map.find env.dim_env var with Some (Solved_dim (Dim { d; _ })) -> Some d | _ -> None
222222

223+
let get_dim_from_env env var = get_dim_val env var
224+
225+
let get_row_from_env env var =
226+
match Map.find env.row_env var with Some (Solved_row row) -> Some row | _ -> None
227+
223228
type constraint_ =
224229
| Dim_eq of { d1 : dim; d2 : dim }
225230
| Row_eq of { r1 : t; r2 : t }

lib/row.mli

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,9 @@ val subst_row : environment -> t -> t
136136
val unify_row : stage:stage -> t * t -> environment -> constraint_ list * environment
137137
val empty_env : environment
138138

139+
val get_dim_from_env : environment -> dim_var -> int option
140+
val get_row_from_env : environment -> row_var -> t option
141+
139142
val solve_inequalities :
140143
stage:stage -> constraint_ list -> environment -> constraint_ list * environment
141144

lib/shape.ml

Lines changed: 74 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -570,8 +570,7 @@ let%debug4_sexp get_inequalities ({ shape = cur_sh; logic; id = _ } as _upd : up
570570
Row_eq { r1 = cur_sh.input; r2 = sh.input };
571571
Row_eq { r1 = cur_sh.output; r2 = sh.output };
572572
] )
573-
| Transpose (Permute (spec, _dim_refs), sh) ->
574-
(* FIXME: support dim_refs *)
573+
| Transpose (Permute (spec, dim_refs), sh) ->
575574
let ls_rhs, ls_lhs =
576575
match einsum_of_spec spec with
577576
| ls_rhs, None, ls_lhs -> (ls_rhs, ls_lhs)
@@ -590,6 +589,18 @@ let%debug4_sexp get_inequalities ({ shape = cur_sh; logic; id = _ } as _upd : up
590589
let extras_lhs, proj_env_lhs, (b_lhs, i_lhs, o_lhs) =
591590
einsum_slot_spec_to_dims_bio ~generative ~sh_id:cur_sh.id ~row_var_env ~dim_var_env ls_lhs
592591
in
592+
(* Bind delayed_var_refs to the variables after they are created *)
593+
List.iter dim_refs ~f:(fun delayed_ref ->
594+
let label = delayed_ref.var_ref.ref_label in
595+
(* Check if it's in one of the environments *)
596+
match Hashtbl.find dim_var_env label with
597+
| Some var -> delayed_ref.var <- `Dim var
598+
| None -> (
599+
match Hashtbl.find row_var_env label with
600+
| Some var -> delayed_ref.var <- `Row var
601+
| None -> ()
602+
)
603+
);
593604
let proj_env =
594605
let combine ~key:_ _ _ = assert false in
595606
Map.merge_skewed ~combine proj_env_rhs proj_env_lhs
@@ -621,8 +632,7 @@ let%debug4_sexp get_inequalities ({ shape = cur_sh; logic; id = _ } as _upd : up
621632
{ numerator = Row.Strided_var { coeff; var; denom = 1 }; divided_by = [] };
622633
};
623634
] )
624-
| Broadcast (Einsum (spec, _dim_refs), sh1, sh2) ->
625-
(* FIXME: support dim_refs *)
635+
| Broadcast (Einsum (spec, dim_refs), sh1, sh2) ->
626636
let ls_rhs1, ls_rhs2, ls_lhs =
627637
match einsum_of_spec spec with
628638
| ls_rhs1, Some ls_rhs2, ls_lhs -> (ls_rhs1, ls_rhs2, ls_lhs)
@@ -643,6 +653,18 @@ let%debug4_sexp get_inequalities ({ shape = cur_sh; logic; id = _ } as _upd : up
643653
let extras_lhs, proj_env_lhs, (b_lhs, i_lhs, o_lhs) =
644654
einsum_slot_spec_to_dims_bio ~generative ~sh_id:cur_sh.id ~row_var_env ~dim_var_env ls_lhs
645655
in
656+
(* Bind delayed_var_refs to the variables after they are created *)
657+
List.iter dim_refs ~f:(fun delayed_ref ->
658+
let label = delayed_ref.var_ref.ref_label in
659+
(* Check if it's in one of the environments *)
660+
match Hashtbl.find dim_var_env label with
661+
| Some var -> delayed_ref.var <- `Dim var
662+
| None -> (
663+
match Hashtbl.find row_var_env label with
664+
| Some var -> delayed_ref.var <- `Row var
665+
| None -> ()
666+
)
667+
);
646668
let proj_env =
647669
let combine ~key:_ _ _ = assert false in
648670
Map.merge_skewed ~combine proj_env_rhs1
@@ -701,6 +723,52 @@ let apply_env_t env sh =
701723
sh.input <- Row.subst_row env sh.input;
702724
sh.output <- Row.subst_row env sh.output
703725

726+
let rec compute_row_product env (row : Row.t) : int =
727+
match row.dims with
728+
| [] -> 1
729+
| dim :: rest ->
730+
let dim_val =
731+
match dim with
732+
| Row.Dim { d; _ } -> d
733+
| Row.Var v -> (
734+
match Row.get_dim_from_env env v with
735+
| Some d -> d
736+
| None -> 1 (* Variable not yet resolved *)
737+
)
738+
| Row.Conv_input _ -> 1 (* TODO: handle convolution input dimensions *)
739+
in
740+
dim_val * compute_row_product env { row with dims = rest }
741+
742+
let update_delayed_var_refs env update_step =
743+
let update_var_ref_list var_refs =
744+
List.iter var_refs ~f:(fun delayed_ref ->
745+
match delayed_ref.var with
746+
| `Not_set_yet -> () (* Variable not bound yet, will be set later *)
747+
| `Dim dim_var -> (
748+
match Row.get_dim_from_env env dim_var with
749+
| Some d -> delayed_ref.var_ref.solved_dim <- Some d
750+
| None -> () (* Not yet resolved *)
751+
)
752+
| `Row row_var -> (
753+
match Row.get_row_from_env env row_var with
754+
| Some row ->
755+
let product = compute_row_product env row in
756+
delayed_ref.var_ref.solved_dim <- Some product
757+
| None -> () (* Not yet resolved *)
758+
)
759+
)
760+
in
761+
match update_step.logic with
762+
| Transpose (Permute (_, var_refs), _) ->
763+
update_var_ref_list var_refs
764+
| Broadcast (Einsum (_, var_refs), _, _) ->
765+
update_var_ref_list var_refs
766+
| _ -> ()
767+
768+
let apply_env_step env update_step =
769+
iter_shapes update_step ~f:(apply_env_t env);
770+
update_delayed_var_refs env update_step
771+
704772
let%debug4_sexp propagate_shapes (update_step : update_step) : unit =
705773
(* Allow the derivation of constraints to depend on the shapes (currently, only Batch_slice
706774
does). *)
@@ -711,8 +779,7 @@ let%debug4_sexp propagate_shapes (update_step : update_step) : unit =
711779
active_constraints := ineqs @ !active_constraints;
712780
let ineqs', env = Row.solve_inequalities ~stage:Row.Stage1 ineqs !state in
713781
let _debug_remaining_constraints : Row.constraint_ list = ineqs' in
714-
(* FIXME: call apply_env_step instead *)
715-
iter_shapes update_step ~f:(apply_env_t env);
782+
apply_env_step env update_step;
716783
state := env
717784

718785
let%debug4_sexp finish_inference (() : unit) : unit =
@@ -732,8 +799,7 @@ let%debug4_sexp finish_inference (() : unit) : unit =
732799
let unsolved, env = Row.solve_inequalities ~stage:Stage7 unsolved env in
733800
assert (List.is_empty unsolved);
734801
let _active_update_steps : update_step list = !active_update_steps in
735-
(* FIXME: call apply_env_step instead *)
736-
List.iter ~f:(iter_shapes ~f:(apply_env_t env)) !active_update_steps;
802+
List.iter ~f:(apply_env_step env) !active_update_steps;
737803
let _applied_update_steps : update_step list = !active_update_steps in
738804
active_constraints := [];
739805
active_update_steps := [];

test/operations/dune

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,17 @@
303303
(preprocess
304304
(pps ppx_here ppx_ocannl)))
305305

306+
(test
307+
(name test_einsum_capture)
308+
(package neural_nets_lib)
309+
(deps
310+
ocannl_config
311+
(env_var OCANNL_BACKEND))
312+
(modules test_einsum_capture)
313+
(libraries base ocannl stdio)
314+
(preprocess
315+
(pps ppx_here ppx_ocannl)))
316+
306317
(library
307318
(name operations_tutorials)
308319
(package neural_nets_lib)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
Retrieving commandline, environment, or config file variable ocannl_log_level
2+
Found 0, in the config file
3+
Dimension a: 2
4+
Dimension b: 3
5+
Dimension c: 4
6+
Dimension i: 5
7+
Dimension j: 7
8+
HERE: test/operations/test_einsum_capture.ml:39:21
9+
┌───────────────────────────┐
10+
│[25]: +_dim_calc shape 0:1 │
11+
│┌┬──────┐ │
12+
│││axis 0│ │
13+
│├┼──────┤ │
14+
│││ 9.00 │ │
15+
│└┴──────┘ │
16+
└───────────────────────────┘
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
open Base
2+
open Ocannl
3+
4+
let () =
5+
let module TDSL = Operation.TDSL in
6+
let module PDSL = Operation.PDSL in
7+
let module Backend = (val Backends.fresh_backend ()) in
8+
let%op x = { x = uniform1 (); o = [ 2; 3 ] } in
9+
let%op y = { y = uniform1 (); o = [ 3; 4 ] } in
10+
let%op z = x *+ "ab;bc=>ac" [ "a"; "b"; "c" ] y in
11+
12+
(* Trigger shape inference by accessing the tensor node *)
13+
let ctx = Train.forward_once (module Backend) z in
14+
15+
(* Check if dimensions were captured *)
16+
Stdio.printf "Dimension a: %s\n"
17+
(match a.solved_dim with Some d -> Int.to_string d | None -> "not resolved");
18+
Stdio.printf "Dimension b: %s\n"
19+
(match b.solved_dim with Some d -> Int.to_string d | None -> "not resolved");
20+
Stdio.printf "Dimension c: %s\n"
21+
(match c.solved_dim with Some d -> Int.to_string d | None -> "not resolved");
22+
23+
let%op x2 = { x2 = uniform1 (); o = [ 5; 7 ] } in
24+
(* Manually call einsum1 with capture_dims for now *)
25+
let%op y2 = x2 ++ "ij=>ji" [ "i"; "j" ] in
26+
27+
(* Trigger shape inference by accessing the tensor node *)
28+
let ctx = Train.forward_once (module Backend) ~ctx y2 in
29+
30+
(* Check if dimensions were captured *)
31+
Stdio.printf "Dimension i: %s\n"
32+
(match i.solved_dim with Some d -> Int.to_string d | None -> "not resolved");
33+
Stdio.printf "Dimension j: %s\n"
34+
(match j.solved_dim with Some d -> Int.to_string d | None -> "not resolved");
35+
36+
let%op dim_calc = dim a + dim j in
37+
let _ctx = Train.forward_once (module Backend) ~ctx dim_calc in
38+
39+
Train.printf ~here:[%here] ~with_code:false ~with_grad:false dim_calc

0 commit comments

Comments
 (0)