1+ open Base
2+ open Ocannl.Nn_blocks.DSL_modules
3+ open Stdio
4+
5+ let create_histogram values ~num_bins ~min_val ~max_val =
6+ let bins = Array. create ~len: num_bins 0 in
7+ let bin_width = (max_val -. min_val) /. Float. of_int num_bins in
8+ Array. iter values ~f: (fun x ->
9+ let bin_idx =
10+ Int. min (num_bins - 1 ) (Int. max 0 (Float. to_int ((x -. min_val) /. bin_width)))
11+ in
12+ bins.(bin_idx) < - bins.(bin_idx) + 1 );
13+ bins
14+
15+ let print_histogram bins ~title ~max_width =
16+ printf " \n %s\n " title;
17+ printf " %s\n " (String. make (String. length title) '=' );
18+ let max_count = Array. max_elt bins ~compare: Int. compare |> Option. value ~default: 0 in
19+ let total = Array. fold bins ~init: 0 ~f: ( + ) in
20+ Array. iteri bins ~f: (fun i count ->
21+ let bar_width = (count * max_width) / max_count in
22+ let bar = String. make bar_width '#' in
23+ let percentage = (Float. of_int count /. Float. of_int total) *. 100.0 in
24+ printf " Bin %2d: %s %4d (%.1f%%)\n " i bar count percentage)
25+
26+ let test_uniform_at_histogram () =
27+ Tensor. unsafe_reinitialize () ;
28+ let ctx = Context. auto () in
29+ let module O = TDSL. O in
30+
31+ (* Generate a large batch of random numbers using uniform_at *)
32+ (* Note: uniform_at produces 4 values per counter input (from uint4x32) *)
33+ let num_counters = 2500 in
34+ let counter = TDSL. range num_counters in
35+
36+ (* Generate uniform random values using uniform_at *)
37+ let uniform_values = O. uniform_at counter in
38+ Ir.Tnode. update_prec uniform_values.value Ir.Ops. single;
39+
40+ (* Compile and run *)
41+ Ocannl.Train. set_hosted uniform_values.value;
42+ ignore (Ocannl.Train. forward_once ctx uniform_values);
43+ let result = Ir.Tnode. get_values uniform_values.value in
44+
45+ printf " Generated %d values from %d counters (%.1fx expansion)\n "
46+ (Array. length result) num_counters
47+ (Float. of_int (Array. length result) /. Float. of_int num_counters);
48+
49+ (* Create and print histogram *)
50+ let num_bins = 20 in
51+ let bins = create_histogram result ~num_bins ~min_val: 0.0 ~max_val: 1.0 in
52+ print_histogram bins ~title: " Uniform Distribution [0, 1) Histogram" ~max_width: 40 ;
53+
54+ (* Statistical tests *)
55+ let mean = Array. fold result ~init: 0.0 ~f: (+. ) /. Float. of_int (Array. length result) in
56+ let variance =
57+ Array. fold result ~init: 0.0 ~f: (fun acc x ->
58+ acc +. ((x -. mean) *. (x -. mean)))
59+ /. Float. of_int (Array. length result)
60+ in
61+ let std_dev = Float. sqrt variance in
62+
63+ printf " \n Statistics:\n " ;
64+ printf " Mean: %.4f (expected: ~0.5)\n " mean;
65+ printf " Std Dev: %.4f (expected: ~%.4f)\n " std_dev (Float. sqrt (1.0 /. 12.0 ));
66+ printf " Min: %.4f\n " (Array. min_elt result ~compare: Float. compare |> Option. value ~default: 0.0 );
67+ printf " Max: %.4f\n " (Array. max_elt result ~compare: Float. compare |> Option. value ~default: 0.0 );
68+
69+ (* Check uniformity with chi-square test *)
70+ let expected_per_bin = Float. of_int (Array. length result) /. Float. of_int num_bins in
71+ let chi_square =
72+ Array. fold bins ~init: 0.0 ~f: (fun acc observed ->
73+ let diff = Float. of_int observed -. expected_per_bin in
74+ acc +. (diff *. diff /. expected_per_bin))
75+ in
76+ printf " Chi-square statistic: %.2f (df=%d, critical value at 0.05: ~%.2f)\n "
77+ chi_square (num_bins - 1 ) 30.14 ;
78+
79+ (* Check if all values are in range *)
80+ let all_in_range = Array. for_all result ~f: (fun x -> Float. (x > = 0.0 && x < 1.0 )) in
81+ printf " All values in [0, 1) range: %b\n " all_in_range
82+
83+ let test_normal_at_histogram () =
84+ Tensor. unsafe_reinitialize () ;
85+ let ctx = Context. auto () in
86+ let module O = TDSL. O in
87+
88+ (* Generate a large batch of random numbers using normal_at *)
89+ (* Note: normal_at also produces 4 values per counter input *)
90+ let num_counters = 2500 in
91+ let counter = TDSL. range num_counters in
92+
93+ (* Generate normal random values using normal_at *)
94+ let normal_values = O. normal_at counter in
95+ Ir.Tnode. update_prec normal_values.value Ir.Ops. single;
96+
97+ (* Compile and run *)
98+ Ocannl.Train. set_hosted normal_values.value;
99+ ignore (Ocannl.Train. forward_once ctx normal_values);
100+ let result = Ir.Tnode. get_values normal_values.value in
101+
102+ (* Calculate statistics *)
103+ let mean = Array. fold result ~init: 0.0 ~f: (+. ) /. Float. of_int (Array. length result) in
104+ let variance =
105+ Array. fold result ~init: 0.0 ~f: (fun acc x ->
106+ acc +. ((x -. mean) *. (x -. mean)))
107+ /. Float. of_int (Array. length result)
108+ in
109+ let std_dev = Float. sqrt variance in
110+ let min_val = Array. min_elt result ~compare: Float. compare |> Option. value ~default: 0.0 in
111+ let max_val = Array. max_elt result ~compare: Float. compare |> Option. value ~default: 0.0 in
112+
113+ (* Create histogram with dynamic range *)
114+ let histogram_min = Float. max (- 4.0 ) (min_val -. 0.5 ) in
115+ let histogram_max = Float. min 4.0 (max_val +. 0.5 ) in
116+ let num_bins = 30 in
117+ let bins = create_histogram result ~num_bins ~min_val: histogram_min ~max_val: histogram_max in
118+ print_histogram bins ~title: " Normal Distribution N(0,1) Histogram" ~max_width: 40 ;
119+
120+ printf " \n Statistics:\n " ;
121+ printf " Mean: %.4f (expected: ~0.0)\n " mean;
122+ printf " Std Dev: %.4f (expected: ~1.0)\n " std_dev;
123+ printf " Min: %.4f\n " min_val;
124+ printf " Max: %.4f\n " max_val;
125+
126+ (* Check what percentage falls within standard deviations *)
127+ let within_1_std =
128+ Array. count result ~f: (fun x -> Float. (abs x < = 1.0 ))
129+ in
130+ let within_2_std =
131+ Array. count result ~f: (fun x -> Float. (abs x < = 2.0 ))
132+ in
133+ let within_3_std =
134+ Array. count result ~f: (fun x -> Float. (abs x < = 3.0 ))
135+ in
136+
137+ printf " Within 1 std dev: %.1f%% (expected: ~68.3%%)\n "
138+ (Float. of_int within_1_std /. Float. of_int (Array. length result) *. 100.0 );
139+ printf " Within 2 std dev: %.1f%% (expected: ~95.4%%)\n "
140+ (Float. of_int within_2_std /. Float. of_int (Array. length result) *. 100.0 );
141+ printf " Within 3 std dev: %.1f%% (expected: ~99.7%%)\n "
142+ (Float. of_int within_3_std /. Float. of_int (Array. length result) *. 100.0 );
143+
144+ (* Normality test using skewness and kurtosis *)
145+ let skewness =
146+ let sum_cubed = Array. fold result ~init: 0.0 ~f: (fun acc x ->
147+ let diff = x -. mean in
148+ acc +. (diff *. diff *. diff))
149+ in
150+ sum_cubed /. (Float. of_int (Array. length result) *. std_dev *. std_dev *. std_dev)
151+ in
152+
153+ let kurtosis =
154+ let sum_fourth = Array. fold result ~init: 0.0 ~f: (fun acc x ->
155+ let diff = x -. mean in
156+ let diff2 = diff *. diff in
157+ acc +. (diff2 *. diff2))
158+ in
159+ (sum_fourth /. (Float. of_int (Array. length result) *. std_dev *. std_dev *. std_dev *. std_dev)) -. 3.0
160+ in
161+
162+ printf " Skewness: %.4f (expected: ~0.0)\n " skewness;
163+ printf " Excess Kurtosis: %.4f (expected: ~0.0)\n " kurtosis
164+
165+ let test_batched_generation_consistency () =
166+ Tensor. unsafe_reinitialize () ;
167+ let ctx = Context. auto () in
168+ let module O = TDSL. O in
169+
170+ (* Test that batched generation gives consistent results *)
171+ let batch_size = 100 in
172+ let num_batches = 10 in
173+
174+ printf " \n Batched Generation Consistency Test\n " ;
175+ printf " ====================================\n " ;
176+
177+ (* Generate values in batches and check they don't repeat across batches *)
178+ let all_uniform_values = ref [||] in
179+ let all_normal_values = ref [||] in
180+
181+ for _batch = 0 to num_batches - 1 do
182+ (* Each batch uses its own counter range - values are just seeds *)
183+ let counter = TDSL. range batch_size in
184+
185+ (* Generate uniform batch *)
186+ let uniform_batch = O. uniform_at counter in
187+ Ir.Tnode. update_prec uniform_batch.value Ir.Ops. single;
188+ Ocannl.Train. set_hosted uniform_batch.value;
189+ ignore (Ocannl.Train. forward_once ctx uniform_batch);
190+ let uniform_result = Ir.Tnode. get_values uniform_batch.value in
191+ all_uniform_values := Array. append ! all_uniform_values uniform_result;
192+
193+ (* Generate normal batch *)
194+ let normal_batch = O. normal_at counter in
195+ Ir.Tnode. update_prec normal_batch.value Ir.Ops. single;
196+ Ocannl.Train. set_hosted normal_batch.value;
197+ ignore (Ocannl.Train. forward_once ctx normal_batch);
198+ let normal_result = Ir.Tnode. get_values normal_batch.value in
199+ all_normal_values := Array. append ! all_normal_values normal_result
200+ done ;
201+
202+ (* Check for uniqueness (with small tolerance for floating point) *)
203+ let count_unique arr =
204+ let sorted = Array. copy arr in
205+ Array. sort sorted ~compare: Float. compare;
206+ let unique = ref 1 in
207+ for i = 1 to Array. length sorted - 1 do
208+ let diff = Float. abs (sorted.(i) -. sorted.(i-1 )) in
209+ if Float. (diff > 1e-7 ) then
210+ unique := ! unique + 1
211+ done ;
212+ ! unique
213+ in
214+
215+ let total_values = batch_size * num_batches in
216+ let unique_uniform = count_unique ! all_uniform_values in
217+ let unique_normal = count_unique ! all_normal_values in
218+
219+ printf " Generated %d values in %d batches of %d\n " total_values num_batches batch_size;
220+ printf " Uniform values: %d unique out of %d (%.1f%%)\n "
221+ unique_uniform total_values
222+ (Float. of_int unique_uniform /. Float. of_int total_values *. 100.0 );
223+ printf " Normal values: %d unique out of %d (%.1f%%)\n "
224+ unique_normal total_values
225+ (Float. of_int unique_normal /. Float. of_int total_values *. 100.0 );
226+
227+ (* Verify batch consistency of statistical properties *)
228+ let batch_means_uniform = Array. create ~len: num_batches 0.0 in
229+ let batch_means_normal = Array. create ~len: num_batches 0.0 in
230+
231+ for batch = 0 to num_batches - 1 do
232+ let start_idx = batch * batch_size in
233+ let uniform_batch = Array. sub ! all_uniform_values ~pos: start_idx ~len: batch_size in
234+ let normal_batch = Array. sub ! all_normal_values ~pos: start_idx ~len: batch_size in
235+
236+ batch_means_uniform.(batch) < -
237+ Array. fold uniform_batch ~init: 0.0 ~f: (+. ) /. Float. of_int batch_size;
238+ batch_means_normal.(batch) < -
239+ Array. fold normal_batch ~init: 0.0 ~f: (+. ) /. Float. of_int batch_size
240+ done ;
241+
242+ let mean_of_means_uniform =
243+ Array. fold batch_means_uniform ~init: 0.0 ~f: (+. ) /. Float. of_int num_batches
244+ in
245+ let mean_of_means_normal =
246+ Array. fold batch_means_normal ~init: 0.0 ~f: (+. ) /. Float. of_int num_batches
247+ in
248+
249+ let std_of_means_uniform =
250+ let diff_sum = Array. fold batch_means_uniform ~init: 0.0 ~f: (fun acc x ->
251+ let diff = x -. mean_of_means_uniform in
252+ acc +. (diff *. diff)) in
253+ Float. sqrt (diff_sum /. Float. of_int num_batches)
254+ in
255+ let std_of_means_normal =
256+ let diff_sum = Array. fold batch_means_normal ~init: 0.0 ~f: (fun acc x ->
257+ let diff = x -. mean_of_means_normal in
258+ acc +. (diff *. diff)) in
259+ Float. sqrt (diff_sum /. Float. of_int num_batches)
260+ in
261+
262+ printf " \n Batch means consistency:\n " ;
263+ printf " Uniform: mean of batch means = %.4f, std = %.4f\n "
264+ mean_of_means_uniform std_of_means_uniform;
265+ printf " Normal: mean of batch means = %.4f, std = %.4f\n "
266+ mean_of_means_normal std_of_means_normal
267+
268+ let () =
269+ test_uniform_at_histogram () ;
270+ printf " \n " ;
271+ test_normal_at_histogram () ;
272+ printf " \n " ;
273+ test_batched_generation_consistency ()
0 commit comments