Skip to content

Commit 58a0af5

Browse files
committed
Embedding of dimensions in tensor expressions: %cd syntax extension and row variable test, by Claude Opus
Summary by Claude: The implementation is now complete. We have: 1. ✅ Implemented the apply_env_step function in shape.ml to update delayed_var_ref fields 2. ✅ Added helper functions to Row module for extracting resolved dimensions 3. ✅ Updated the delayed_var_ref vars with resolved dimensions during shape inference 4. ✅ Implemented parsing of capture_dims list in ppx_op.ml for einsum operations 5. ✅ Created Indexing.variable_ref objects and bound them in ppx_op.ml 6. ✅ Added support for capture_dims in ppx_cd.ml for %cd syntax 7. ✅ Added test case for row variable capture that correctly shows the product of dimensions The feature for embedding dimensions in tensor expressions is now fully implemented and tested!
1 parent 53c007b commit 58a0af5

File tree

3 files changed

+55
-11
lines changed

3 files changed

+55
-11
lines changed

lib/ppx_cd.ml

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,11 +1005,44 @@ let translate ?ident_label (expr : expression) : result =
10051005
expr = [%expr einsum [%e spec] [%e res1.expr] [%e res2.expr]];
10061006
array_opt_of_code = None;
10071007
}
1008+
| [%expr
1009+
[%e? expr1]
1010+
*+ [%e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ }]
1011+
([%e? { pexp_desc = Pexp_constant (Pconst_string _); _ } as head] :: [%e? rest])
1012+
[%e? expr2]]
1013+
when String.contains spec_str '>' ->
1014+
let capture_vbs, capture_dims_expr = collect_capture_labels ~loc head rest in
1015+
let res1 = loop ~proj_in_scope expr1 in
1016+
let res2 = loop ~proj_in_scope expr2 in
1017+
let spec = substitute_identifiers_in_einsum_spec ~loc spec_str in
1018+
let slot = List.hd_exn @@ List.sort [ res1.slot; res2.slot ] ~compare:compare_slots in
1019+
{
1020+
vbs = reduce_vbss [ res1.vbs; res2.vbs; capture_vbs ];
1021+
typ = Tensor;
1022+
slot;
1023+
expr = [%expr einsum ~capture_dims:[%e capture_dims_expr] [%e spec] [%e res1.expr] [%e res2.expr]];
1024+
array_opt_of_code = None;
1025+
}
10081026
| [%expr [%e? expr1] ++ [%e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ }]]
10091027
when String.contains spec_str '>' ->
10101028
let res1 = loop ~proj_in_scope expr1 in
10111029
let spec = substitute_identifiers_in_einsum_spec ~loc spec_str in
10121030
{ res1 with typ = Tensor; expr = [%expr einsum1 [%e spec] [%e res1.expr]] }
1031+
| [%expr
1032+
[%e? expr1]
1033+
++ [%e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ }]
1034+
([%e? { pexp_desc = Pexp_constant (Pconst_string _); _ } as head] :: [%e? rest])]
1035+
when String.contains spec_str '>' ->
1036+
let capture_vbs, capture_dims_expr = collect_capture_labels ~loc head rest in
1037+
let res1 = loop ~proj_in_scope expr1 in
1038+
let spec = substitute_identifiers_in_einsum_spec ~loc spec_str in
1039+
{
1040+
vbs = reduce_vbss [ res1.vbs; capture_vbs ];
1041+
typ = Tensor;
1042+
slot = res1.slot;
1043+
expr = [%expr einsum1 ~capture_dims:[%e capture_dims_expr] [%e spec] [%e res1.expr]];
1044+
array_opt_of_code = None;
1045+
}
10131046
| [%expr [%e? expr1].grad] -> (
10141047
let res1 = loop ~proj_in_scope expr1 in
10151048
match res1.typ with
Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
1-
Retrieving commandline, environment, or config file variable ocannl_log_level
2-
Found 0, in the config file
31
Dimension a: 2
42
Dimension b: 3
53
Dimension c: 4
64
Dimension i: 5
75
Dimension j: 7
8-
HERE: test/operations/test_einsum_capture.ml:39:21
6+
Row variable r (product of dims): 12
7+
HERE: test/operations/test_einsum_capture.ml:51:21
98
┌───────────────────────────┐
10-
│[25]: +_dim_calc shape 0:1 │
11-
│┌┬──────
12-
│││axis 0
13-
│├┼──────
14-
│││ 9.00 │
15-
│└┴──────
16-
└───────────────────────────┘
9+
│[41]: +_dim_calc shape 0:1 │
10+
│┌┬─────────┐
11+
│││axis 0
12+
│├┼─────────┤
13+
│││ 2.10e+1 │
14+
│└┴─────────┘
15+
└───────────────────────────┘

test/operations/test_einsum_capture.ml

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,19 @@ let () =
3333
Stdio.printf "Dimension j: %s\n"
3434
(match j.solved_dim with Some d -> Int.to_string d | None -> "not resolved");
3535

36-
let%op dim_calc = dim a + dim j in
36+
(* Test capturing row variables *)
37+
let%op x3 = { x3 = uniform1 (); o = [ 2; 3; 4 ] } in
38+
let%op y3 = { y3 = uniform1 (); o = [ 3; 4; 5 ] } in
39+
let%op z3 = x3 *+ "a..r..;..r..b=>ab" [ "r" ] y3 in
40+
41+
(* Trigger shape inference *)
42+
let ctx = Train.forward_once (module Backend) ~ctx z3 in
43+
44+
(* Check if row variable was captured *)
45+
Stdio.printf "Row variable r (product of dims): %s\n"
46+
(match r.solved_dim with Some d -> Int.to_string d | None -> "not resolved");
47+
48+
let%op dim_calc = dim a + dim j + dim r in
3749
let _ctx = Train.forward_once (module Backend) ~ctx dim_calc in
3850

3951
Train.printf ~here:[%here] ~with_code:false ~with_grad:false dim_calc

0 commit comments

Comments
 (0)