Skip to content

Commit 6c8c8df

Browse files
lukstaficlaude
andcommitted
Implement shape equality constraints with set_dim and set_equal
Complete the implementation of Shape.set_equal to handle all cases of equality constraints between delayed variable references: - Both solved dimensions (validation) - One solved, one unsolved (propagation) - Dimension variable pairs (Dim_eq constraint) - Row variable pairs (Row_eq constraint) - Mixed dimension/row variables (Total_elems constraint) - Proper error handling for conflicting constraints Add comprehensive test coverage in test_einsum_capture.ml: - Low-level functionality tests (set_dim, set_equal variants) - Shape validation integration (constraint checking) - Pure shape inference (constraint-driven shape resolution) This enables powerful constraint-driven tensor shape specification and validation integrated with OCANNL's shape inference system. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent f5369c2 commit 6c8c8df

File tree

3 files changed

+467
-8
lines changed

3 files changed

+467
-8
lines changed

lib/shape.ml

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -723,18 +723,60 @@ let set_dim delayed_var_ref dim =
723723
:: !active_constraints
724724

725725
let set_equal delayed_ref1 delayed_ref2 =
726-
match delayed_ref1, delayed_ref2 with
727-
| { var_ref = { solved_dim = Some dim1; _ }; _ },
728-
{ var_ref = { solved_dim = Some dim2; _ }; _ } ->
726+
match (delayed_ref1, delayed_ref2) with
727+
| { var_ref = { solved_dim = Some dim1; _ }; _ }, { var_ref = { solved_dim = Some dim2; _ }; _ }
728+
->
729729
if dim1 = dim2 then ()
730730
else
731731
raise
732732
@@ Row.Shape_error
733733
( "Cannot set equal dimensions for variable references with different values",
734734
[ Row.Dim_mismatch [ Row.get_dim ~d:dim1 (); Row.get_dim ~d:dim2 () ] ] )
735-
| _ ->
736-
(* FIXME: NOT IMPLEMENTED YET *)
737-
()
735+
| { var_ref = { solved_dim = Some dim; _ }; _ }, delayed_ref2 ->
736+
(* First is solved, second is not - set the second to match the first *)
737+
set_dim delayed_ref2 dim
738+
| delayed_ref1, { var_ref = { solved_dim = Some dim; _ }; _ } ->
739+
(* Second is solved, first is not - set the first to match the second *)
740+
set_dim delayed_ref1 dim
741+
| ( { var_ref = { solved_dim = None; ref_label = ref_label1; _ }; var = _ },
742+
{ var_ref = { solved_dim = None; ref_label = ref_label2; _ }; var = `Not_set_yet } )
743+
| ( { var_ref = { solved_dim = None; ref_label = ref_label1; _ }; var = `Not_set_yet },
744+
{ var_ref = { solved_dim = None; ref_label = ref_label2; _ }; var = _ } ) ->
745+
raise
746+
@@ Row.Shape_error
747+
( "set_equal: insufficient information between labels " ^ ref_label1 ^ " and "
748+
^ ref_label2,
749+
[] )
750+
| { var = `Dim dim_var1; _ }, { var = `Dim dim_var2; _ } ->
751+
(* Both are dimension variables - create equality constraint *)
752+
active_constraints :=
753+
Row.Dim_eq { d1 = Row.Var dim_var1; d2 = Row.Var dim_var2 } :: !active_constraints
754+
| { var = `Row row_var1; _ }, { var = `Row row_var2; _ } ->
755+
(* Both are row variables - create row equality constraint *)
756+
active_constraints :=
757+
Row.Row_eq { r1 = Row.get_row_for_var row_var1; r2 = Row.get_row_for_var row_var2 }
758+
:: !active_constraints
759+
| { var = `Dim dim_var; _ }, { var = `Row row_var; _ }
760+
| { var = `Row row_var; _ }, { var = `Dim dim_var; _ } ->
761+
(* One is dim var, one is row var - equality via Total_elems constraint *)
762+
active_constraints :=
763+
Row.Rows_constr
764+
{
765+
r = [ Row.get_row_for_var row_var ];
766+
constr =
767+
Total_elems
768+
{
769+
numerator =
770+
Strided_var
771+
{
772+
coeff = Utils.safe_lazy "set_equal_dim_row" (fun () -> 1);
773+
var = dim_var;
774+
denom = 1;
775+
};
776+
divided_by = [];
777+
};
778+
}
779+
:: !active_constraints
738780

739781
let unsafe_reinitialize () =
740782
update_uid := 0;

test/operations/test_einsum_capture.expected

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ Dimension c: 4
66
Dimension i: 5
77
Dimension j: 7
88
Row variable r (product of dims): 12
9-
HERE: test/operations/test_einsum_capture.ml:50:21
9+
HERE: test/operations/test_einsum_capture.ml:51:21
1010
┌───────────────────────────┐
1111
│[41]: +_dim_calc shape 0:1 │
1212
│┌┬─────────┐ │
@@ -15,3 +15,74 @@ HERE: test/operations/test_einsum_capture.ml:50:21
1515
│││ 2.10e+1 │ │
1616
│└┴─────────┘ │
1717
└───────────────────────────┘
18+
19+
=== Testing set_dim and set_equal functionality ===
20+
Test 1 - set_dim: var1 set to 42: 42
21+
Test 2 - set_equal (one solved): var2 should now be 42: 42
22+
Test 3 - set_equal (both solved, equal): Success - no exception
23+
Test 4 - set_equal error case: Correctly caught exception: Cannot set equal dimensions for variable references with different values
24+
Test 5 - einsum variable capture:
25+
Dimension p: 3
26+
Dimension q: 4
27+
Dimension r: 5
28+
Expected dimensions (p=3, q=4, r=5): p=3, q=4, r=5
29+
Test 6 - row-dimension equality:
30+
Row variable s (product): 6
31+
Dimension variable test_dim: 6
32+
s == test_dim constraint satisfied: true (both should be 6)
33+
=== All tests completed ===
34+
35+
=== Testing shape validation integration with equality constraints ===
36+
Test 1 - Constraint i=k in matrix multiply:
37+
Input a1 shape: 4,6
38+
Input b1 shape: 6,4
39+
Output c1 shape: 4,4
40+
Dimension i: 4, k: 4 (should be equal)
41+
42+
Test 2 - Multiple constraints (a=d, c=7):
43+
Input x2 shape: 3,5,7
44+
Input y2 shape: 5,7,3
45+
Output z2 shape: 3,3
46+
Dimensions: a=3, b=5, c=7, d=3
47+
48+
Test 3 - Row variable constraints (row1=row2 total elements):
49+
Input r1 shape: 2,3,4
50+
Input r2 shape: 3,5
51+
Output r3 shape: 2,3,5
52+
Row1 total: 3, Row2 total: 3 (should be equal)
53+
54+
Test 5 - Constraint propagation across operations:
55+
Chain1 shape: 4,5
56+
Chain2 shape: 5,4
57+
Chain3 shape: 5,6
58+
Chain4 shape: 4,6
59+
Variables: a=4, b=5, a_chain=4, b_chain=5, c_chain=6
60+
=== Shape inference integration tests completed ===
61+
62+
=== Testing pure shape inference with equality constraints ===
63+
Test 1 - Matrix multiply with constraint-driven shapes:
64+
m1 inferred shape: 3,4
65+
m2 inferred shape: 4,5
66+
result1 inferred shape: 3,5
67+
Captured dimensions: i=3, j=4, k=5
68+
69+
Test 2 - Chain operations with constraint propagation:
70+
base inferred shape: 1,1
71+
transposed inferred shape: 8,6
72+
multiplied inferred shape: 8,10
73+
final inferred shape: 6,10
74+
Constraint propagation: a=6, b=8, c=10
75+
76+
Test 3 - Simple 3-tensor einsum with pure inference:
77+
tensor1 inferred shape: 7,8
78+
tensor2 inferred shape: 8,9
79+
result3 inferred shape: 7,9
80+
Constraints: x=7, y=8, z=9
81+
82+
Test 4 - Complex interdependent constraints:
83+
complex1 inferred shape: 4,5,6
84+
complex2 inferred shape: 6,5,4
85+
complex_result inferred shape: 4,5,5,4
86+
Constraint resolution: p=4, q=5, r=6, s=5, t=4
87+
Expected: p=4, q=5, r=6, s=5, t=4 (with q=s and p=t constraints satisfied)
88+
=== Pure shape inference tests completed ===

0 commit comments

Comments
 (0)