Skip to content

Commit 153e04b

Browse files
committed
Fixed, correct surjectivity testing for initialization; problem spotted and fixed by Claude Opus with my guidance on the surjectivity algo/heuristic
Signed-off-by: Lukasz Stafiniak <lukstafi@gmail.com>
1 parent 2312d1d commit 153e04b

File tree

7 files changed

+363
-22
lines changed

7 files changed

+363
-22
lines changed

arrayjit/lib/assignments.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ let get_name_exn asgns =
9696
if String.is_empty result then invalid_arg "Assignments.get_name: no comments in code" else result
9797

9898
let is_total ~initialize_neutral ~projections =
99-
initialize_neutral && Indexing.is_bijective projections
99+
initialize_neutral && Indexing.is_surjective projections
100100

101101
(** Returns materialized nodes in the sense of {!Tnode.is_in_context_force}. NOTE: it must be called
102102
after compilation; otherwise, it will disrupt memory mode inference. *)

arrayjit/lib/indexing.ml

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,9 @@ let iterated dim = dim > 1
149149
let opt_symbol d = if iterated d then Some (get_symbol ()) else None
150150
let opt_iterator = function None -> Fixed_idx 0 | Some sym -> Iterator sym
151151

152-
let is_bijective proj =
153-
(* For bijection, we need the assignment to be both injective and surjective.
154-
We check surjectivity (all target positions are written) and that each source
155-
position maps to exactly one target position. *)
152+
let is_surjective proj =
153+
(* For surjectivity, we check if all target (LHS) positions will be written to.
154+
This is used to determine if we need to zero-initialize before assignment. *)
156155

157156
(* Check if there are any fixed indices (except Fixed_idx 0 when dim is 1) *)
158157
let has_non_trivial_fixed =
@@ -163,7 +162,7 @@ let is_bijective proj =
163162
in
164163
if has_non_trivial_fixed then false
165164
else
166-
(* Collect symbols used in LHS with their properties *)
165+
(* Collect symbols used in LHS *)
167166
let lhs_symbols, has_affine, has_sub_axis =
168167
Array.fold proj.project_lhs ~init:([], false, false)
169168
~f:(fun (syms, has_aff, has_sub) idx ->
@@ -181,24 +180,24 @@ let is_bijective proj =
181180
let lhs_symbol_set = Set.of_list (module Symbol) lhs_symbols in
182181
let product_symbol_set = Set.of_array (module Symbol) proj.product_iterators in
183182

184-
(* Basic check: All lhs symbols must be from product iterators (no bound symbols) *)
183+
(* All lhs symbols must be from product iterators (no bound symbols) *)
185184
if not (Set.is_subset lhs_symbol_set ~of_:product_symbol_set) then false
186185
else if has_sub_axis then
187-
(* Conservative: Sub_axis case is complex, so assume non-bijective.
186+
(* Conservative: Sub_axis case is complex, so assume non-surjective.
188187
This is pessimistic but safe - Sub_axis would require comparing
189188
lhs_dims and product_space dimensions carefully. *)
190189
false
191190
else if has_affine then
192-
(* For Affine indices: check that coefficient=1 symbols don't have
193-
dimensions smaller than any stride coefficients used *)
191+
(* For Affine indices with strides: check coefficient compatibility.
192+
A strided access pattern may skip elements. *)
194193
let symbol_dims =
195194
Array.filter_mapi proj.product_iterators ~f:(fun i sym ->
196195
if Set.mem lhs_symbol_set sym then Some (sym, proj.product_space.(i))
197196
else None)
198197
|> Array.to_list
199198
|> Map.of_alist_exn (module Symbol)
200199
in
201-
let check_affine_valid =
200+
let check_affine_surjective =
202201
Array.for_all proj.project_lhs ~f:(function
203202
| Affine { symbols; _ } ->
204203
(* Find max dimension of coeff=1 symbols *)
@@ -208,22 +207,37 @@ let is_bijective proj =
208207
|> List.max_elt ~compare:Int.compare
209208
|> Option.value ~default:Int.max_value
210209
in
211-
(* Check that it's not smaller than any stride coefficient *)
210+
(* Check that coeff=1 dimension is not smaller than any stride *)
212211
List.for_all symbols ~f:(fun (coeff, _) ->
213212
coeff = 1 || max_coeff1_dim >= coeff)
214213
| _ -> true)
215214
in
216-
if not check_affine_valid then false
215+
if not check_affine_surjective then false
217216
else
218-
(* Final check: number of unique symbols must equal number of LHS dims
219-
AND the symbols must equal product iterators *)
220-
Set.length lhs_symbol_set = Array.length proj.project_lhs
221-
&& Set.equal lhs_symbol_set product_symbol_set
217+
(* Check that we have enough unique symbols to cover all LHS dimensions *)
218+
Set.length lhs_symbol_set >= Array.length proj.project_lhs
222219
else
223220
(* Simple case: only Iterator and Fixed_idx *)
224-
(* Need all dimensions covered and symbols to match exactly *)
225-
Set.length lhs_symbol_set = Array.length proj.project_lhs
226-
&& Set.equal lhs_symbol_set product_symbol_set
221+
(* Need enough unique symbols to cover all dimensions *)
222+
Set.length lhs_symbol_set >= Array.length proj.project_lhs
223+
224+
(* For backwards compatibility, keep is_bijective as an alias that checks
225+
both surjectivity and injectivity (stricter than just surjectivity) *)
226+
let is_bijective proj =
227+
is_surjective proj &&
228+
let lhs_symbols =
229+
Array.concat_map proj.project_lhs ~f:(function
230+
| Iterator s -> [| s |]
231+
| Fixed_idx _ -> [||]
232+
| Affine { symbols; _ } ->
233+
List.filter_map symbols ~f:(fun (coeff, s) ->
234+
if coeff = 1 then Some s else None)
235+
|> Array.of_list
236+
| Sub_axis -> [||])
237+
|> Set.of_array (module Symbol)
238+
in
239+
(* For bijectivity, also need exact match of symbols *)
240+
Set.equal lhs_symbols (Set.of_array (module Symbol) proj.product_iterators)
227241

228242
(** Projections for a pointwise unary operator. Provide only one of [debug_info] or [derived_for].
229243
*)

bin/einsum_trivia.ml

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ open Ocannl
33
module IDX = Train.IDX
44
module CDSL = Train.CDSL
55
module TDSL = Operation.TDSL
6+
module NTDSL = Operation.NTDSL
67

78
module type Backend = Ir.Backend_intf.Backend
89

@@ -37,7 +38,7 @@ let _suspended () =
3738
ignore (Train.forward_once backend ho2);
3839
Train.printf ~here:[%here] ~with_code:false ~with_grad:false ho2
3940

40-
let () =
41+
let _suspended () =
4142
let module Backend = (val Backends.fresh_backend ()) in
4243
let backend =
4344
(module Backend : Backend
@@ -59,3 +60,30 @@ let () =
5960
"a|i->h; b|h->o => i->o" b in Utils.capture_stdout_logs (fun () -> ignore (Train.forward_once backend f)); *)
6061
(* Train.printf ~here:[%here] ~with_code:false ~with_grad:false a2; *)
6162
Train.printf ~here:[%here] ~with_code:false ~with_grad:false c
63+
64+
let () =
65+
Tensor.unsafe_reinitialize ();
66+
let module Backend = (val Backends.fresh_backend ()) in
67+
let backend =
68+
(module Backend : Backend
69+
with type buffer_ptr = Backend.buffer_ptr
70+
and type dev = Backend.dev
71+
and type runner = Backend.runner
72+
and type event = Backend.event
73+
and type optimize_ctx = Backend.optimize_ctx)
74+
in
75+
76+
let ri = TDSL.range 3 in
77+
let%op ti = ri ++ "i=>i0" in
78+
(* Write position 2 of ti, otherwise shape inference concludes it's dim-1 and broadcasted. *)
79+
let%cd _ = ti =: 0 ++ "i=>i2" in
80+
let rj = TDSL.range 4 in
81+
let%op tj = rj ++ "j=>j1" in
82+
let rk = TDSL.range 5 in
83+
let%op tk = rk ++ "k=>k2" in
84+
let positions = TDSL.outer_sum "ijl;kl=>ijkl" (TDSL.outer_sum "il;jl=>ijl" ti tj ()) tk () in
85+
Train.set_hosted tk.value;
86+
ignore (Train.forward_once backend positions);
87+
Train.printf ~here:[%here] ~with_code:false ~with_grad:false positions;
88+
Train.printf ~here:[%here] ~with_code:false ~with_grad:false ti;
89+
Train.printf ~here:[%here] ~with_code:false ~with_grad:false tk

test/einsum/dune

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,21 @@
3333
(preprocess
3434
(pps ppx_here ppx_ocannl)))
3535

36+
(test
37+
(name test_surjectivity)
38+
(deps ocannl_config)
39+
(modules test_surjectivity)
40+
(libraries ocannl)
41+
(preprocess
42+
(pps ppx_here ppx_ocannl)))
43+
3644
(library
3745
(name einsum_tutorials)
3846
(package neural_nets_lib)
3947
(inline_tests
4048
(deps ocannl_config))
4149
(libraries base dynlink ocannl)
42-
(modules einsum_trivia)
50+
(modules einsum_trivia surjectivity)
4351
(preprocess
4452
(pps ppx_here ppx_expect ppx_inline_test ppx_ocannl))
4553
(modes best))

test/einsum/surjectivity.ml

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
open Base
2+
open Ocannl
3+
module IDX = Train.IDX
4+
module CDSL = Train.CDSL
5+
module TDSL = Operation.TDSL
6+
module NTDSL = Operation.NTDSL
7+
8+
module type Backend = Ir.Backend_intf.Backend
9+
10+
let%expect_test "diagonal_tensor_initialization" =
11+
Tensor.unsafe_reinitialize ();
12+
let module Backend = (val Backends.fresh_backend ()) in
13+
let backend =
14+
(module Backend : Backend
15+
with type buffer_ptr = Backend.buffer_ptr
16+
and type dev = Backend.dev
17+
and type runner = Backend.runner
18+
and type event = Backend.event
19+
and type optimize_ctx = Backend.optimize_ctx)
20+
in
21+
22+
(* Create a diagonal tensor using einsum: i->ii *)
23+
let input = TDSL.range 5 in
24+
let%op diagonal = input ++ "i=>ii" in
25+
26+
(* Ensure the diagonal tensor is hosted *)
27+
Train.set_hosted diagonal.value;
28+
ignore (Train.forward_once backend diagonal);
29+
30+
(* Print the diagonal tensor *)
31+
Train.printf ~here:[%here] ~with_code:false ~with_grad:false diagonal;
32+
[%expect {|
33+
HERE: test/einsum/surjectivity.ml:31:21
34+
┌──────────────────────────────────────┐
35+
│[1]: =>_diagonal shape 0:6,1:6
36+
│┌──────┬─────────────────────────────┐│
37+
││ │axis 1 ││
38+
│├──────┼─────────────────────────────┤│
39+
││axis 00.00 0.00 ... 0.00 0.00 ││
40+
││ │ 0.00 1.00 ... 0.00 0.00 ││
41+
││ │ ... ... ... ... ... ││
42+
││ │ 0.00 0.00 ... 4.00 0.00 ││
43+
││ │ 0.00 0.00 ... 0.00 5.00 ││
44+
│└──────┴─────────────────────────────┘│
45+
└──────────────────────────────────────┘
46+
|}]
47+
48+
let%expect_test "sparse_assignment_with_fixed_indices" =
49+
Tensor.unsafe_reinitialize ();
50+
let module Backend = (val Backends.fresh_backend ()) in
51+
let backend =
52+
(module Backend : Backend
53+
with type buffer_ptr = Backend.buffer_ptr
54+
and type dev = Backend.dev
55+
and type runner = Backend.runner
56+
and type event = Backend.event
57+
and type optimize_ctx = Backend.optimize_ctx)
58+
in
59+
60+
(* Create a sparse tensor using fixed indices: i->i0j *)
61+
let input = TDSL.range 4 in
62+
let%op sparse = input ++ "i=>i0j" in
63+
64+
Train.set_hosted sparse.value;
65+
ignore (Train.forward_once backend sparse);
66+
67+
Train.printf ~here:[%here] ~with_code:false ~with_grad:false sparse;
68+
[%expect {|
69+
HERE: test/einsum/surjectivity.ml:64:21
70+
┌─────────────────────────────────┐
71+
│[1]: =>_sparse shape 0:5,1:1,2:1
72+
│┌──────┬──────┐ │
73+
││ │axis 2│ │
74+
│├──────┼──────┤ │
75+
││0 @ 00.00 │ │
76+
││axis 1│ │ │
77+
│├──────┼──────┤ │
78+
││1 @ 01.00 │ │
79+
││axis 1│ │ │
80+
│├──────┼──────┤ │
81+
││2 @ 02.00 │ │
82+
││axis 1│ │ │
83+
│├──────┼──────┤ │
84+
││3 @ 03.00 │ │
85+
││axis 1│ │ │
86+
│├──────┼──────┤ │
87+
││4 @ 04.00 │ │
88+
││axis 1│ │ │
89+
│└──────┴──────┘ │
90+
└─────────────────────────────────┘
91+
|}]
92+
93+
let%expect_test "multiple_sparse_axes" =
94+
Tensor.unsafe_reinitialize ();
95+
let module Backend = (val Backends.fresh_backend ()) in
96+
let backend =
97+
(module Backend : Backend
98+
with type buffer_ptr = Backend.buffer_ptr
99+
and type dev = Backend.dev
100+
and type runner = Backend.runner
101+
and type event = Backend.event
102+
and type optimize_ctx = Backend.optimize_ctx)
103+
in
104+
105+
(* Test with multiple fixed indices: ij->i1j2 *)
106+
let input = TDSL.range_of_shape ~output_dims:[3; 4] () in
107+
let%op sparse_multi = input ++ "ij=>i1j2" in
108+
109+
Train.set_hosted sparse_multi.value;
110+
ignore (Train.forward_once backend sparse_multi);
111+
112+
Train.printf ~here:[%here] ~with_code:false ~with_grad:false sparse_multi;
113+
[%expect {|
114+
HERE: test/einsum/surjectivity.ml:113:21
115+
┌───────────────────────────────────────────┐
116+
│[1]: =>_sparse_multi shape 0:3,1:2,2:4,3:3
117+
│┌──────┬──────────────────┐ │
118+
││0 @ 0 │axis 3 │ │
119+
│├──────┼──────────────────┤ │
120+
││0 @ 10.00 0.00 0.00 │ │
121+
││axis 20.00 0.00 0.00 │ │
122+
││ │ 0.00 0.00 0.00 │ │
123+
││ │ 0.00 0.00 0.00 │ │
124+
│├──────┼──────────────────┤ │
125+
││1 @ 10.00 0.00 0.00 │ │
126+
││axis 20.00 0.00 1.00 │ │
127+
││ │ 0.00 0.00 2.00 │ │
128+
││ │ 0.00 0.00 3.00 │ │
129+
│└──────┴──────────────────┘ │
130+
├───────────────────────────────────────────┤
131+
│┌──────┬──────────────────┐ │
132+
││1 @ 0 │axis 3 │ │
133+
│├──────┼──────────────────┤ │
134+
││0 @ 10.00 0.00 0.00 │ │
135+
││axis 20.00 0.00 0.00 │ │
136+
││ │ 0.00 0.00 0.00 │ │
137+
││ │ 0.00 0.00 0.00 │ │
138+
│├──────┼──────────────────┤ │
139+
││1 @ 10.00 0.00 4.00 │ │
140+
││axis 20.00 0.00 5.00 │ │
141+
││ │ 0.00 0.00 6.00 │ │
142+
││ │ 0.00 0.00 7.00 │ │
143+
│└──────┴──────────────────┘ │
144+
├───────────────────────────────────────────┤
145+
│┌──────┬─────────────────────┐ │
146+
││2 @ 0 │axis 3 │ │
147+
│├──────┼─────────────────────┤ │
148+
││0 @ 10.00 0.00 0.00 │ │
149+
││axis 20.00 0.00 0.00 │ │
150+
││ │ 0.00 0.00 0.00 │ │
151+
││ │ 0.00 0.00 0.00 │ │
152+
│├──────┼─────────────────────┤ │
153+
││1 @ 10.00 0.00 8.00 │ │
154+
││axis 20.00 0.00 9.00 │ │
155+
││ │ 0.00 0.00 1.00e+1 │ │
156+
││ │ 0.00 0.00 1.10e+1 │ │
157+
│└──────┴─────────────────────┘ │
158+
└───────────────────────────────────────────┘
159+
|}]
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
Retrieving commandline, environment, or config file variable ocannl_log_level
2+
Found 0, in the config file
3+
4+
Testing diagonal tensor initialization:
5+
HERE: test/einsum/test_surjectivity.ml:24:21
6+
┌──────────────────────────────────────┐
7+
│[1]: =>_diagonal shape 0:6,1:6 │
8+
│┌──────┬─────────────────────────────┐│
9+
││ │axis 1 ││
10+
│├──────┼─────────────────────────────┤│
11+
││axis 0│ 0.00 0.00 ... 0.00 0.00 ││
12+
││ │ 0.00 1.00 ... 0.00 0.00 ││
13+
││ │ ... ... ... ... ... ││
14+
││ │ 0.00 0.00 ... 4.00 0.00 ││
15+
││ │ 0.00 0.00 ... 0.00 5.00 ││
16+
│└──────┴─────────────────────────────┘│
17+
└──────────────────────────────────────┘
18+
19+
20+
Testing sparse assignment with fixed index:
21+
HERE: test/einsum/test_surjectivity.ml:39:21
22+
┌─────────────────────────────┐
23+
│[1]: =>_sparse shape 0:5,1:1 │
24+
│┌──────┬──────┐ │
25+
││ │axis 1│ │
26+
│├──────┼──────┤ │
27+
││axis 0│ 0.00 │ │
28+
││ │ 1.00 │ │
29+
││ │ 2.00 │ │
30+
││ │ 3.00 │ │
31+
││ │ 4.00 │ │
32+
│└──────┴──────┘ │
33+
└─────────────────────────────┘
34+
35+
36+
Testing multiple sparse axes:
37+
HERE: test/einsum/test_surjectivity.ml:54:21
38+
┌───────────────────────────────────────┐
39+
│[1]: =>_result shape 0:3,1:2,2:4 │
40+
│┌──────┬──────────────────────────────┐│
41+
││ │axis 2 ││
42+
│├──────┼──────────────────────────────┤│
43+
││0 @ 0 │ 0.00 0.00 0.00 0.00 ││
44+
││axis 1│ 0.00 1.00 2.00 3.00 ││
45+
│├──────┼──────────────────────────────┤│
46+
││1 @ 0 │ 0.00 0.00 0.00 0.00 ││
47+
││axis 1│ 4.00 5.00 6.00 7.00 ││
48+
│├──────┼──────────────────────────────┤│
49+
││2 @ 0 │ 0.00 0.00 0.00 0.00 ││
50+
││axis 1│ 8.00 9.00 1.00e+1 1.10e+1 ││
51+
│└──────┴──────────────────────────────┘│
52+
└───────────────────────────────────────┘
53+
54+
All surjectivity tests completed.

0 commit comments

Comments
 (0)