Skip to content

Commit 075bc6a

Browse files
committed
The previous change is_bijective -> is_surjective was a mistake, but we can do better altogether, collab with Claude
Claude: The problem is that when the surjectivity check was fixed, it now correctly identifies that this mapping is non-surjective (because we're writing to the same ho position multiple times as i12 varies from 0 to 3). But instead of switching to accumulation mode, it seems to be zero-initializing and then just doing assignments, which means only the last value (when i12 = 3) survives. Me: > The problem is with lowering. The high-level representation shows accumulating assignment, with the syntax I invented: `ho =:+ r2x4x3` , but the unoptimized lowered representation doesn't expand it as expected, it should use two binary operators, the `accum` operator and the `op` operator. Claude: Perfect! Now I found the bug. is_total at line 99 returns true when both: 1. initialize_neutral is true 2. is_surjective returns true The bug is that the code assumes surjective mappings don't need accumulation, but that's wrong. We need to check for bijectivity (one-to-one mapping) to use plain assignment. Non-bijective mappings need accumulation even if they're surjective. The fix should be to use is_bijective instead of is_surjective in the is_total function, or better yet, check specifically whether multiple source positions map to the same target position (which would require checking injectivity). Signed-off-by: Lukasz Stafiniak <lukstafi@gmail.com>
1 parent c76d1c5 commit 075bc6a

File tree

7 files changed

+299
-47
lines changed

7 files changed

+299
-47
lines changed

arrayjit/lib/assignments.ml

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ let get_name_exn asgns =
9898
let is_total ~initialize_neutral ~projections =
9999
initialize_neutral && Indexing.is_surjective projections
100100

101+
let can_skip_accumulation ~projections =
102+
(* We can skip accumulation (use = instead of +=) only if the projection is injective *)
103+
Indexing.is_injective projections
104+
101105
(** Returns materialized nodes in the sense of {!Tnode.is_in_context_force}. NOTE: it must be called
102106
after compilation; otherwise, it will disrupt memory mode inference. *)
103107
let%debug3_sexp context_nodes ~(use_host_memory : 'a option) (asgns : t) : Tn.t_set =
@@ -226,7 +230,7 @@ let%track4_sexp to_low_level code =
226230
let lhs_ll = get (Node lhs) lhs_idcs in
227231
let rhses_ll = Array.mapi rhses_idcs ~f:(fun i rhs_idcs -> get rhses.(i) rhs_idcs) in
228232
let rhs2 = apply_op op rhses_ll in
229-
if is_total ~initialize_neutral ~projections then set lhs lhs_idcs rhs2
233+
if can_skip_accumulation ~projections then set lhs lhs_idcs rhs2
230234
else set lhs lhs_idcs @@ apply_op (Ops.Binop accum) [| lhs_ll; rhs2 |]
231235
in
232236
let rec for_loop rev_iters = function
@@ -243,7 +247,16 @@ let%track4_sexp to_low_level code =
243247
}
244248
in
245249
let for_loops = for_loop [] (Array.to_list projections.product_space) in
246-
if initialize_neutral && not (is_total ~initialize_neutral ~projections) then
250+
(* Need initialization if:
251+
- initialize_neutral is true AND
252+
- (not surjective OR not injective)
253+
Not surjective: some positions never written (need init to avoid garbage)
254+
Not injective: accumulation needed (need init for first += operation) *)
255+
let needs_init =
256+
initialize_neutral &&
257+
not (Indexing.is_surjective projections && Indexing.is_injective projections)
258+
in
259+
if needs_init then
247260
let dims = lazy projections.lhs_dims in
248261
let fetch_op = Constant (Ops.neutral_elem accum) in
249262
Low_level.Seq (loop (Fetch { array = lhs; fetch_op; dims }), for_loops)

arrayjit/lib/indexing.ml

Lines changed: 50 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -150,66 +150,61 @@ let opt_symbol d = if iterated d then Some (get_symbol ()) else None
150150
let opt_iterator = function None -> Fixed_idx 0 | Some sym -> Iterator sym
151151

152152
let is_surjective proj =
153-
(* For surjectivity, we check if all target (LHS) positions will be written to.
154-
This is used to determine if we need to zero-initialize before assignment. *)
155-
153+
(* For surjectivity, we check if all target (LHS) positions will be written to. This is used to
154+
determine if we need to zero-initialize before assignment. *)
155+
156156
(* Check if there are any fixed indices (except Fixed_idx 0 when dim is 1) *)
157157
let has_non_trivial_fixed =
158158
Array.exists2_exn proj.project_lhs proj.lhs_dims ~f:(fun idx dim ->
159159
match idx with
160-
| Fixed_idx i -> not (i = 0 && dim <= 1) (* Fixed_idx 0 is OK only when dim is 0 or 1 *)
160+
| Fixed_idx i -> not (i = 0 && dim <= 1) (* Fixed_idx 0 is OK only when dim is 0 or 1 *)
161161
| _ -> false)
162162
in
163163
if has_non_trivial_fixed then false
164164
else
165165
(* Collect symbols used in LHS *)
166-
let lhs_symbols, has_affine, has_sub_axis =
167-
Array.fold proj.project_lhs ~init:([], false, false)
168-
~f:(fun (syms, has_aff, has_sub) idx ->
166+
let lhs_symbols, has_affine, has_sub_axis =
167+
Array.fold proj.project_lhs ~init:([], false, false) ~f:(fun (syms, has_aff, has_sub) idx ->
169168
match idx with
170169
| Iterator s -> (s :: syms, has_aff, has_sub)
171170
| Fixed_idx _ -> (syms, has_aff, has_sub)
172171
| Affine { symbols; _ } ->
173-
let coeff1_syms =
174-
List.filter_map symbols ~f:(fun (coeff, s) ->
175-
if coeff = 1 then Some s else None)
172+
let coeff1_syms =
173+
List.filter_map symbols ~f:(fun (coeff, s) -> if coeff = 1 then Some s else None)
176174
in
177175
(coeff1_syms @ syms, true, has_sub)
178176
| Sub_axis -> (syms, has_aff, true))
179177
in
180178
let lhs_symbol_set = Set.of_list (module Symbol) lhs_symbols in
181179
let product_symbol_set = Set.of_array (module Symbol) proj.product_iterators in
182-
180+
183181
(* All lhs symbols must be from product iterators (no bound symbols) *)
184182
if not (Set.is_subset lhs_symbol_set ~of_:product_symbol_set) then false
185-
else if has_sub_axis then
186-
(* Conservative: Sub_axis case is complex, so assume non-surjective.
187-
This is pessimistic but safe - Sub_axis would require comparing
188-
lhs_dims and product_space dimensions carefully. *)
183+
else if has_sub_axis then
184+
(* Conservative: Sub_axis case is complex, so assume non-surjective. This is pessimistic but
185+
safe - Sub_axis would require comparing lhs_dims and product_space dimensions carefully. *)
189186
false
190187
else if has_affine then
191-
(* For Affine indices with strides: check coefficient compatibility.
192-
A strided access pattern may skip elements. *)
193-
let symbol_dims =
188+
(* For Affine indices with strides: check coefficient compatibility. A strided access pattern
189+
may skip elements. *)
190+
let symbol_dims =
194191
Array.filter_mapi proj.product_iterators ~f:(fun i sym ->
195-
if Set.mem lhs_symbol_set sym then Some (sym, proj.product_space.(i))
196-
else None)
192+
if Set.mem lhs_symbol_set sym then Some (sym, proj.product_space.(i)) else None)
197193
|> Array.to_list
198194
|> Map.of_alist_exn (module Symbol)
199195
in
200196
let check_affine_surjective =
201197
Array.for_all proj.project_lhs ~f:(function
202198
| Affine { symbols; _ } ->
203199
(* Find max dimension of coeff=1 symbols *)
204-
let max_coeff1_dim =
200+
let max_coeff1_dim =
205201
List.filter_map symbols ~f:(fun (coeff, s) ->
206-
if coeff = 1 then Map.find symbol_dims s else None)
202+
if coeff = 1 then Map.find symbol_dims s else None)
207203
|> List.max_elt ~compare:Int.compare
208204
|> Option.value ~default:Int.max_value
209205
in
210206
(* Check that coeff=1 dimension is not smaller than any stride *)
211-
List.for_all symbols ~f:(fun (coeff, _) ->
212-
coeff = 1 || max_coeff1_dim >= coeff)
207+
List.for_all symbols ~f:(fun (coeff, _) -> coeff = 1 || max_coeff1_dim >= coeff)
213208
| _ -> true)
214209
in
215210
if not check_affine_surjective then false
@@ -220,24 +215,38 @@ let is_surjective proj =
220215
(* Simple case: only Iterator and Fixed_idx *)
221216
(* Need enough unique symbols to cover all dimensions *)
222217
Set.length lhs_symbol_set >= Array.length proj.project_lhs
223-
224-
(* For backwards compatibility, keep is_bijective as an alias that checks
225-
both surjectivity and injectivity (stricter than just surjectivity) *)
226-
let is_bijective proj =
227-
is_surjective proj &&
228-
let lhs_symbols =
229-
Array.concat_map proj.project_lhs ~f:(function
230-
| Iterator s -> [| s |]
231-
| Fixed_idx _ -> [||]
232-
| Affine { symbols; _ } ->
233-
List.filter_map symbols ~f:(fun (coeff, s) ->
234-
if coeff = 1 then Some s else None)
235-
|> Array.of_list
236-
| Sub_axis -> [||])
237-
|> Set.of_array (module Symbol)
218+
219+
let is_injective proj =
220+
let product_iterator_set = Set.of_array (module Symbol) proj.product_iterators in
221+
222+
(* Check each LHS index for injectivity *)
223+
let lhs_symbols, is_injective_mapping =
224+
Array.fold proj.project_lhs ~init:([], true) ~f:(fun (syms, still_injective) idx ->
225+
if not still_injective then (syms, false)
226+
else
227+
match idx with
228+
| Iterator s -> (s :: syms, true)
229+
| Fixed_idx _ -> (syms, true)
230+
| Affine { symbols; _ } ->
231+
(* Filter for symbols that are product iterators *)
232+
let product_symbols =
233+
List.filter symbols ~f:(fun (_coeff, s) ->
234+
Set.mem product_iterator_set s)
235+
in
236+
(* If more than one product iterator in this Affine index, not injective *)
237+
if List.length product_symbols > 1 then
238+
(syms, false)
239+
else
240+
(* (coefficients don't matter for injectivity) *)
241+
(List.map product_symbols ~f:snd @ syms, true)
242+
| Sub_axis -> (syms, true))
238243
in
239-
(* For bijectivity, also need exact match of symbols *)
240-
Set.equal lhs_symbols (Set.of_array (module Symbol) proj.product_iterators)
244+
245+
if not is_injective_mapping then false
246+
else
247+
let lhs_symbol_set = Set.of_list (module Symbol) lhs_symbols in
248+
(* For injectivity, each product iterator must map to at most one position *)
249+
Set.is_subset (Set.of_array (module Symbol) proj.product_iterators) ~of_:lhs_symbol_set
241250

242251
(** Projections for a pointwise unary operator. Provide only one of [debug_info] or [derived_for].
243252
*)

test/einsum/dune

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,14 @@
4141
(preprocess
4242
(pps ppx_here ppx_ocannl)))
4343

44+
(test
45+
(name test_accumulation_semantics)
46+
(deps ocannl_config)
47+
(modules test_accumulation_semantics)
48+
(libraries ocannl)
49+
(preprocess
50+
(pps ppx_here ppx_ocannl)))
51+
4452
(library
4553
(name einsum_tutorials)
4654
(package neural_nets_lib)

test/einsum/moons_demo_variant.expected

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ n19 w2 as w2: Host&stream/412410; single prec 1x16; mem in bytes: <not-in-yet>
2424
n20 grad_w2 as w2.grad: Local/26046; single prec 1x16; mem in bytes: <not-in-yet>
2525
n21 @|_moons_input as moons_input: Virt/15; single prec 10x2; mem in bytes: <not-in-yet>
2626
n24 @|_moons_class as moons_class: Virt/15; single prec 10x1; mem in bytes: <not-in-yet>
27-
n27 * as n27: Virt/151; single prec 10x16; mem in bytes: <not-in-yet>
27+
n27 * as n27: Virt/15; single prec 10x16; mem in bytes: <not-in-yet>
2828
n28 grad_* as n27.grad: Local/1046; single prec 10x16; mem in bytes: <not-in-yet>
2929
n29 + as n29: Virt/15; single prec 10x16; mem in bytes: <not-in-yet>
3030
n30 grad_+ as n29.grad: Virt/151; single prec 10x16; mem in bytes: <not-in-yet>
@@ -65,7 +65,7 @@ n66 sgd_momentum_w2 as sgd_momentum_w2: unknown; single prec <not-in-yet>; mem i
6565
n67 0.0001 as n67: Virt/40; single prec 1; mem in bytes: <not-in-yet>
6666
n68 *. as n68: Virt/15; single prec 1x16; mem in bytes: <not-in-yet>
6767
n69 point_mlp_result as point_mlp_result: Host&shared/37039; single prec 2; mem in bytes: <not-in-yet>
68-
n70 * as n70: Virt/151; single prec 16; mem in bytes: <not-in-yet>
68+
n70 * as n70: Virt/15; single prec 16; mem in bytes: <not-in-yet>
6969
n71 grad_* as n70.grad: unknown; single prec 16; mem in bytes: <not-in-yet>
7070
n72 + as n72: Virt/15; single prec 16; mem in bytes: <not-in-yet>
7171
n73 grad_+ as n72.grad: unknown; single prec 16; mem in bytes: <not-in-yet>

test/einsum/surjectivity.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ let%expect_test "sparse_assignment_with_fixed_indices" =
6666

6767
Train.printf ~here:[%here] ~with_code:false ~with_grad:false sparse;
6868
[%expect {|
69-
HERE: test/einsum/surjectivity.ml:64:21
69+
HERE: test/einsum/surjectivity.ml:67:21
7070
┌─────────────────────────────────┐
7171
│[1]: =>_sparse shape 0:5,1:1,2:1
7272
│┌──────┬──────┐ │
@@ -111,7 +111,7 @@ let%expect_test "multiple_sparse_axes" =
111111

112112
Train.printf ~here:[%here] ~with_code:false ~with_grad:false sparse_multi;
113113
[%expect {|
114-
HERE: test/einsum/surjectivity.ml:113:21
114+
HERE: test/einsum/surjectivity.ml:112:21
115115
┌───────────────────────────────────────────┐
116116
│[1]: =>_sparse_multi shape 0:3,1:2,2:4,3:3
117117
│┌──────┬──────────────────┐ │
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
Retrieving commandline, environment, or config file variable ocannl_log_level
2+
Found 0, in the config file
3+
4+
=== Testing einsum reduction (surjective but not injective) ===
5+
Input tensor (shape: batch=2, input=3, output=4):
6+
HERE: test/einsum/test_accumulation_semantics.ml:32:21
7+
┌─────────────────────────────────────────────────────────────┐
8+
│[0]: r2x4x3 shape 0:2|2:3->1:4 │
9+
│┌──────┬────────────────────────┬───────────────────────────┐│
10+
││ │0 @ 0 │1 @ 0 ││
11+
││ │axis 2 │axis 2 ││
12+
│├──────┼────────────────────────┼───────────────────────────┤│
13+
││axis 1│ 0.00 1.00 2.00 │ 1.20e+1 1.30e+1 1.40e+1 ││
14+
││ │ 3.00 4.00 5.00 │ 1.50e+1 1.60e+1 1.70e+1 ││
15+
││ │ 6.00 7.00 8.00 │ 1.80e+1 1.90e+1 2.00e+1 ││
16+
││ │ 9.00 1.00e+1 1.10e+1 │ 2.10e+1 2.20e+1 2.30e+1 ││
17+
│└──────┴────────────────────────┴───────────────────────────┘│
18+
└─────────────────────────────────────────────────────────────┘
19+
20+
Result after reduction 'b|i->o => b|i' (should sum over output dimension):
21+
HERE: test/einsum/test_accumulation_semantics.ml:35:21
22+
┌────────────────────────────────────┐
23+
│[1]: =>_result shape 0:2|1:3 │
24+
│┌──────┬───────────────────────────┐│
25+
││ │axis 1 ││
26+
│├──────┼───────────────────────────┤│
27+
││axis 0│ 1.80e+1 2.20e+1 2.60e+1 ││
28+
││ │ 6.60e+1 7.00e+1 7.40e+1 ││
29+
│└──────┴───────────────────────────┘│
30+
└────────────────────────────────────┘
31+
32+
Expected values (summing over output dimension):
33+
Batch 0: [0+3+6+9=18, 1+4+7+10=22, 2+5+8+11=26]
34+
Batch 1: [12+15+18+21=66, 13+16+19+22=70, 14+17+20+23=74]
35+
36+
=== Testing diagonal tensor (not surjective) ===
37+
Input (1D tensor of size 5):
38+
HERE: test/einsum/test_accumulation_semantics.ml:59:21
39+
[0]: 0...5 shape 0:6 <not-hosted>
40+
41+
Diagonal tensor 'i=>ii' (5x5 with zeros off-diagonal):
42+
HERE: test/einsum/test_accumulation_semantics.ml:62:21
43+
┌──────────────────────────────────────┐
44+
│[1]: =>_diagonal shape 0:6,1:6 │
45+
│┌──────┬─────────────────────────────┐│
46+
││ │axis 1 ││
47+
│├──────┼─────────────────────────────┤│
48+
││axis 0│ 0.00 0.00 ... 0.00 0.00 ││
49+
││ │ 0.00 1.00 ... 0.00 0.00 ││
50+
││ │ ... ... ... ... ... ││
51+
││ │ 0.00 0.00 ... 4.00 0.00 ││
52+
││ │ 0.00 0.00 ... 0.00 5.00 ││
53+
│└──────┴─────────────────────────────┘│
54+
└──────────────────────────────────────┘
55+
56+
Note: Off-diagonal elements should be zero (initialized by Zero_out)
57+
58+
=== Testing fixed index projection (not surjective) ===
59+
Input (1D tensor of size 4):
60+
HERE: test/einsum/test_accumulation_semantics.ml:84:21
61+
[0]: 0...4 shape 0:5 <not-hosted>
62+
63+
Sparse tensor 'i=>i0' (only first column populated):
64+
HERE: test/einsum/test_accumulation_semantics.ml:87:21
65+
┌─────────────────────────────┐
66+
│[1]: =>_sparse shape 0:5,1:3 │
67+
│┌──────┬──────────────────┐ │
68+
││ │axis 1 │ │
69+
│├──────┼──────────────────┤ │
70+
││axis 0│ 0.00 0.00 0.00 │ │
71+
││ │ 1.00 0.00 0.00 │ │
72+
││ │ 2.00 0.00 0.00 │ │
73+
││ │ 3.00 0.00 0.00 │ │
74+
││ │ 4.00 0.00 0.00 │ │
75+
│└──────┴──────────────────┘ │
76+
└─────────────────────────────┘
77+
78+
Note: Only column 0 should have values, others should be zero
79+
80+
=== Testing bijective transpose (optimization case) ===
81+
Input (3x4 matrix):
82+
HERE: test/einsum/test_accumulation_semantics.ml:109:21
83+
[0]: r3x4 shape 0:3,1:4 <not-hosted>
84+
85+
Transposed 'ij=>ji' (4x3 matrix):
86+
HERE: test/einsum/test_accumulation_semantics.ml:112:21
87+
┌─────────────────────────────────┐
88+
│[1]: =>_transposed shape 0:4,1:3 │
89+
│┌──────┬─────────────────────┐ │
90+
││ │axis 1 │ │
91+
│├──────┼─────────────────────┤ │
92+
││axis 0│ 0.00 4.00 8.00 │ │
93+
││ │ 1.00 5.00 9.00 │ │
94+
││ │ 2.00 6.00 1.00e+1 │ │
95+
││ │ 3.00 7.00 1.10e+1 │ │
96+
│└──────┴─────────────────────┘ │
97+
└─────────────────────────────────┘
98+
99+
Note: Simple bijective mapping - no initialization or accumulation needed
100+
101+
=== All accumulation semantics tests completed ===

0 commit comments

Comments
 (0)