@@ -97,39 +97,24 @@ let test_normal_at_histogram () =
9797 let result = Ir.Tnode. get_values normal_values.value in
9898
9999 (* Calculate statistics *)
100- let mean = Array. fold result ~init: 0.0 ~f: ( +. ) /. Float. of_int (Array. length result) in
100+ let n = Array. length result in
101+ let mean = Array. fold result ~init: 0.0 ~f: ( +. ) /. Float. of_int n in
101102 let variance =
102103 Array. fold result ~init: 0.0 ~f: (fun acc x -> acc +. ((x -. mean) *. (x -. mean)))
103- /. Float. of_int ( Array. length result)
104+ /. Float. of_int n
104105 in
105106 let std_dev = Float. sqrt variance in
106107 let min_val = Array. min_elt result ~compare: Float. compare |> Option. value ~default: 0.0 in
107108 let max_val = Array. max_elt result ~compare: Float. compare |> Option. value ~default: 0.0 in
108109
109- (* Create histogram with dynamic range *)
110- let histogram_min = Float. max (- 4.0 ) (min_val -. 0.5 ) in
111- let histogram_max = Float. min 4.0 (max_val +. 0.5 ) in
112- let num_bins = 30 in
113- let bins = create_histogram result ~num_bins ~min_val: histogram_min ~max_val: histogram_max in
114- print_histogram bins ~title: " Normal Distribution N(0,1) Histogram" ~max_width: 40 ;
115-
116- printf " \n Statistics:\n " ;
117- printf " Mean: %.4f (expected: ~0.0)\n " mean;
118- printf " Std Dev: %.4f (expected: ~1.0)\n " std_dev;
119- printf " Min: %.4f\n " min_val;
120- printf " Max: %.4f\n " max_val;
121-
122110 (* Check what percentage falls within standard deviations *)
123111 let within_1_std = Array. count result ~f: (fun x -> Float. (abs x < = 1.0 )) in
124112 let within_2_std = Array. count result ~f: (fun x -> Float. (abs x < = 2.0 )) in
125113 let within_3_std = Array. count result ~f: (fun x -> Float. (abs x < = 3.0 )) in
126114
127- printf " Within 1 std dev: %.1f%% (expected: ~68.3%%)\n "
128- (Float. of_int within_1_std /. Float. of_int (Array. length result) *. 100.0 );
129- printf " Within 2 std dev: %.1f%% (expected: ~95.4%%)\n "
130- (Float. of_int within_2_std /. Float. of_int (Array. length result) *. 100.0 );
131- printf " Within 3 std dev: %.1f%% (expected: ~99.7%%)\n "
132- (Float. of_int within_3_std /. Float. of_int (Array. length result) *. 100.0 );
115+ let pct_1_std = Float. of_int within_1_std /. Float. of_int n *. 100.0 in
116+ let pct_2_std = Float. of_int within_2_std /. Float. of_int n *. 100.0 in
117+ let pct_3_std = Float. of_int within_3_std /. Float. of_int n *. 100.0 in
133118
134119 (* Normality test using skewness and kurtosis *)
135120 let skewness =
@@ -138,7 +123,7 @@ let test_normal_at_histogram () =
138123 let diff = x -. mean in
139124 acc +. (diff *. diff *. diff))
140125 in
141- sum_cubed /. (Float. of_int ( Array. length result) *. std_dev *. std_dev *. std_dev)
126+ sum_cubed /. (Float. of_int n *. std_dev *. std_dev *. std_dev)
142127 in
143128
144129 let kurtosis =
@@ -148,12 +133,45 @@ let test_normal_at_histogram () =
148133 let diff2 = diff *. diff in
149134 acc +. (diff2 *. diff2))
150135 in
151- (sum_fourth /. (Float. of_int (Array. length result) *. std_dev *. std_dev *. std_dev *. std_dev))
152- -. 3.0
136+ (sum_fourth /. (Float. of_int n *. std_dev *. std_dev *. std_dev *. std_dev)) -. 3.0
137+ in
138+
139+ (* Note: Box-Muller transformation uses transcendental functions (log, cos) which may
140+ produce slightly different results across different CPU architectures and math libraries.
141+ We only verify statistical properties are within acceptable bounds, not exact values. *)
142+ printf " \n Normal Distribution N(0,1) Statistical Test\n " ;
143+ printf " ============================================\n " ;
144+ printf " Generated %d values\n " n;
145+
146+ (* Verify statistical properties - only print PASS/FAIL to avoid machine-specific output *)
147+ let check name value expected tolerance =
148+ let passed = Float. (abs (value -. expected) < = tolerance) in
149+ printf " %s (expected: ~%.1f, tolerance: %.2f): %s\n " name expected tolerance
150+ (if passed then " PASS" else Printf. sprintf " FAIL (got %.4f)" value);
151+ passed
152+ in
153+
154+ let check_bound name value bound is_lower =
155+ let passed = if is_lower then Float. (value < bound) else Float. (value > bound) in
156+ let op = if is_lower then " <" else " >" in
157+ printf " %s (should be %s %.1f): %s\n " name op bound
158+ (if passed then " PASS" else Printf. sprintf " FAIL (got %.4f)" value);
159+ passed
160+ in
161+
162+ let all_passed =
163+ check " Mean" mean 0.0 0.1
164+ && check " Std Dev" std_dev 1.0 0.1
165+ && check " Within 1 std dev %%" pct_1_std 68.3 3.0
166+ && check " Within 2 std dev %%" pct_2_std 95.4 2.0
167+ && check " Within 3 std dev %%" pct_3_std 99.7 1.0
168+ && check " Skewness" skewness 0.0 0.15
169+ && check " Excess Kurtosis" kurtosis 0.0 0.15
170+ && check_bound " Min" min_val (- 3.0 ) true
171+ && check_bound " Max" max_val 3.0 false
153172 in
154173
155- printf " Skewness: %.4f (expected: ~0.0)\n " skewness;
156- printf " Excess Kurtosis: %.4f (expected: ~0.0)\n " kurtosis
174+ printf " \n Overall: %s\n " (if all_passed then " ALL TESTS PASSED" else " SOME TESTS FAILED" )
157175
158176let test_batched_generation_consistency () =
159177 Tensor. unsafe_reinitialize () ;
0 commit comments