1+ open Base
2+ open Ocannl
3+ module IDX = Train. IDX
4+ module CDSL = Train. CDSL
5+ module TDSL = Operation. TDSL
6+ module NTDSL = Operation. NTDSL
7+
8+ module type Backend = Ir.Backend_intf. Backend
9+
10+ let % expect_test " diagonal_tensor_initialization" =
11+ Tensor. unsafe_reinitialize () ;
12+ let module Backend = (val Backends. fresh_backend () ) in
13+ let backend =
14+ (module Backend : Backend
15+ with type buffer_ptr = Backend. buffer_ptr
16+ and type dev = Backend. dev
17+ and type runner = Backend. runner
18+ and type event = Backend. event
19+ and type optimize_ctx = Backend. optimize_ctx)
20+ in
21+
22+ (* Create a diagonal tensor using einsum: i->ii *)
23+ let input = TDSL. range 5 in
24+ let % op diagonal = input ++ " i=>ii" in
25+
26+ (* Ensure the diagonal tensor is hosted *)
27+ Train. set_hosted diagonal.value;
28+ ignore (Train. forward_once backend diagonal);
29+
30+ (* Print the diagonal tensor *)
31+ Train. printf ~here: [% here] ~with_code: false ~with_grad: false diagonal;
32+ [% expect {|
33+ HERE : test/ einsum/ surjectivity.ml :31 :21
34+ ┌──────────────────────────────────────┐
35+ │[1 ]: => _diagonal shape 0 :6 ,1 :6 │
36+ │┌──────┬─────────────────────────────┐│
37+ ││ │axis 1 ││
38+ │├──────┼─────────────────────────────┤│
39+ ││axis 0 │ 0.00 0.00 ... 0.00 0.00 ││
40+ ││ │ 0.00 1.00 ... 0.00 0.00 ││
41+ ││ │ ... ... ... ... ... ││
42+ ││ │ 0.00 0.00 ... 4.00 0.00 ││
43+ ││ │ 0.00 0.00 ... 0.00 5.00 ││
44+ │└──────┴─────────────────────────────┘│
45+ └──────────────────────────────────────┘
46+ |}]
47+
48+ let % expect_test " sparse_assignment_with_fixed_indices" =
49+ Tensor. unsafe_reinitialize () ;
50+ let module Backend = (val Backends. fresh_backend () ) in
51+ let backend =
52+ (module Backend : Backend
53+ with type buffer_ptr = Backend. buffer_ptr
54+ and type dev = Backend. dev
55+ and type runner = Backend. runner
56+ and type event = Backend. event
57+ and type optimize_ctx = Backend. optimize_ctx)
58+ in
59+
60+ (* Create a sparse tensor using fixed indices: i->i0j *)
61+ let input = TDSL. range 4 in
62+ let % op sparse = input ++ " i=>i0j" in
63+
64+ Train. set_hosted sparse.value;
65+ ignore (Train. forward_once backend sparse);
66+
67+ Train. printf ~here: [% here] ~with_code: false ~with_grad: false sparse;
68+ [% expect {|
69+ HERE : test/ einsum/ surjectivity.ml :64 :21
70+ ┌─────────────────────────────────┐
71+ │[1 ]: => _sparse shape 0 :5 ,1 :1 ,2 :1 │
72+ │┌──────┬──────┐ │
73+ ││ │axis 2 │ │
74+ │├──────┼──────┤ │
75+ ││0 @ 0 │ 0.00 │ │
76+ ││axis 1 │ │ │
77+ │├──────┼──────┤ │
78+ ││1 @ 0 │ 1.00 │ │
79+ ││axis 1 │ │ │
80+ │├──────┼──────┤ │
81+ ││2 @ 0 │ 2.00 │ │
82+ ││axis 1 │ │ │
83+ │├──────┼──────┤ │
84+ ││3 @ 0 │ 3.00 │ │
85+ ││axis 1 │ │ │
86+ │├──────┼──────┤ │
87+ ││4 @ 0 │ 4.00 │ │
88+ ││axis 1 │ │ │
89+ │└──────┴──────┘ │
90+ └─────────────────────────────────┘
91+ |}]
92+
93+ let % expect_test " multiple_sparse_axes" =
94+ Tensor. unsafe_reinitialize () ;
95+ let module Backend = (val Backends. fresh_backend () ) in
96+ let backend =
97+ (module Backend : Backend
98+ with type buffer_ptr = Backend. buffer_ptr
99+ and type dev = Backend. dev
100+ and type runner = Backend. runner
101+ and type event = Backend. event
102+ and type optimize_ctx = Backend. optimize_ctx)
103+ in
104+
105+ (* Test with multiple fixed indices: ij->i1j2 *)
106+ let input = TDSL. range_of_shape ~output_dims: [3 ; 4 ] () in
107+ let % op sparse_multi = input ++ " ij=>i1j2" in
108+
109+ Train. set_hosted sparse_multi.value;
110+ ignore (Train. forward_once backend sparse_multi);
111+
112+ Train. printf ~here: [% here] ~with_code: false ~with_grad: false sparse_multi;
113+ [% expect {|
114+ HERE : test/ einsum/ surjectivity.ml :113 :21
115+ ┌───────────────────────────────────────────┐
116+ │[1 ]: => _sparse_multi shape 0 :3 ,1 :2 ,2 :4 ,3 :3 │
117+ │┌──────┬──────────────────┐ │
118+ ││0 @ 0 │axis 3 │ │
119+ │├──────┼──────────────────┤ │
120+ ││0 @ 1 │ 0.00 0.00 0.00 │ │
121+ ││axis 2 │ 0.00 0.00 0.00 │ │
122+ ││ │ 0.00 0.00 0.00 │ │
123+ ││ │ 0.00 0.00 0.00 │ │
124+ │├──────┼──────────────────┤ │
125+ ││1 @ 1 │ 0.00 0.00 0.00 │ │
126+ ││axis 2 │ 0.00 0.00 1.00 │ │
127+ ││ │ 0.00 0.00 2.00 │ │
128+ ││ │ 0.00 0.00 3.00 │ │
129+ │└──────┴──────────────────┘ │
130+ ├───────────────────────────────────────────┤
131+ │┌──────┬──────────────────┐ │
132+ ││1 @ 0 │axis 3 │ │
133+ │├──────┼──────────────────┤ │
134+ ││0 @ 1 │ 0.00 0.00 0.00 │ │
135+ ││axis 2 │ 0.00 0.00 0.00 │ │
136+ ││ │ 0.00 0.00 0.00 │ │
137+ ││ │ 0.00 0.00 0.00 │ │
138+ │├──────┼──────────────────┤ │
139+ ││1 @ 1 │ 0.00 0.00 4.00 │ │
140+ ││axis 2 │ 0.00 0.00 5.00 │ │
141+ ││ │ 0.00 0.00 6.00 │ │
142+ ││ │ 0.00 0.00 7.00 │ │
143+ │└──────┴──────────────────┘ │
144+ ├───────────────────────────────────────────┤
145+ │┌──────┬─────────────────────┐ │
146+ ││2 @ 0 │axis 3 │ │
147+ │├──────┼─────────────────────┤ │
148+ ││0 @ 1 │ 0.00 0.00 0.00 │ │
149+ ││axis 2 │ 0.00 0.00 0.00 │ │
150+ ││ │ 0.00 0.00 0.00 │ │
151+ ││ │ 0.00 0.00 0.00 │ │
152+ │├──────┼─────────────────────┤ │
153+ ││1 @ 1 │ 0.00 0.00 8.00 │ │
154+ ││axis 2 │ 0.00 0.00 9.00 │ │
155+ ││ │ 0.00 0.00 1.00e+1 │ │
156+ ││ │ 0.00 0.00 1.10e+1 │ │
157+ │└──────┴─────────────────────┘ │
158+ └───────────────────────────────────────────┘
159+ |}]
0 commit comments