Skip to content

Commit 2304c99

Browse files
lukstaficlaude
andcommitted
Fix normal distribution test to be deterministic across machines
Box-Muller transformation uses transcendental functions (log, cos) that produce slightly different floating-point results across CPU architectures and math libraries. Changed from printing exact histogram values to printing only PASS/FAIL results for statistical property checks with defined tolerances. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent a4ffb68 commit 2304c99

File tree

2 files changed

+58
-69
lines changed

2 files changed

+58
-69
lines changed

test/operations/test_random_histograms.expected

Lines changed: 14 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -34,49 +34,20 @@ Statistics:
3434
All values in [0, 1) range: true
3535

3636

37-
Normal Distribution N(0,1) Histogram
38-
====================================
39-
Bin 0: 0 (0.0%)
40-
Bin 1: 1 (0.0%)
41-
Bin 2: 3 (0.0%)
42-
Bin 3: 10 (0.1%)
43-
Bin 4: 22 (0.2%)
44-
Bin 5: # 43 (0.4%)
45-
Bin 6: ## 83 (0.8%)
46-
Bin 7: #### 143 (1.4%)
47-
Bin 8: ####### 236 (2.4%)
48-
Bin 9: ############ 381 (3.8%)
49-
Bin 10: ################## 548 (5.5%)
50-
Bin 11: ####################### 716 (7.2%)
51-
Bin 12: ########################## 799 (8.0%)
52-
Bin 13: ############################ 881 (8.8%)
53-
Bin 14: ######################################## 1216 (12.2%)
54-
Bin 15: ###################################### 1168 (11.7%)
55-
Bin 16: ########################## 815 (8.2%)
56-
Bin 17: ########################## 812 (8.1%)
57-
Bin 18: ###################### 698 (7.0%)
58-
Bin 19: ################ 514 (5.1%)
59-
Bin 20: ########### 349 (3.5%)
60-
Bin 21: ####### 226 (2.3%)
61-
Bin 22: ##### 158 (1.6%)
62-
Bin 23: ### 96 (1.0%)
63-
Bin 24: # 48 (0.5%)
64-
Bin 25: 18 (0.2%)
65-
Bin 26: 8 (0.1%)
66-
Bin 27: 3 (0.0%)
67-
Bin 28: 3 (0.0%)
68-
Bin 29: 2 (0.0%)
69-
70-
Statistics:
71-
Mean: -0.0073 (expected: ~0.0)
72-
Std Dev: 1.0022 (expected: ~1.0)
73-
Min: -3.5704
74-
Max: 4.0677
75-
Within 1 std dev: 67.6% (expected: ~68.3%)
76-
Within 2 std dev: 95.4% (expected: ~95.4%)
77-
Within 3 std dev: 99.8% (expected: ~99.7%)
78-
Skewness: 0.0449 (expected: ~0.0)
79-
Excess Kurtosis: -0.0068 (expected: ~0.0)
37+
Normal Distribution N(0,1) Statistical Test
38+
============================================
39+
Generated 10000 values
40+
Mean (expected: ~0.0, tolerance: 0.10): PASS
41+
Std Dev (expected: ~1.0, tolerance: 0.10): PASS
42+
Within 1 std dev %% (expected: ~68.3, tolerance: 3.00): PASS
43+
Within 2 std dev %% (expected: ~95.4, tolerance: 2.00): PASS
44+
Within 3 std dev %% (expected: ~99.7, tolerance: 1.00): PASS
45+
Skewness (expected: ~0.0, tolerance: 0.15): PASS
46+
Excess Kurtosis (expected: ~0.0, tolerance: 0.15): PASS
47+
Min (should be < -3.0): PASS
48+
Max (should be > 3.0): PASS
49+
50+
Overall: ALL TESTS PASSED
8051

8152

8253
Batched Generation Consistency Test

test/operations/test_random_histograms.ml

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -97,39 +97,24 @@ let test_normal_at_histogram () =
9797
let result = Ir.Tnode.get_values normal_values.value in
9898

9999
(* Calculate statistics *)
100-
let mean = Array.fold result ~init:0.0 ~f:( +. ) /. Float.of_int (Array.length result) in
100+
let n = Array.length result in
101+
let mean = Array.fold result ~init:0.0 ~f:( +. ) /. Float.of_int n in
101102
let variance =
102103
Array.fold result ~init:0.0 ~f:(fun acc x -> acc +. ((x -. mean) *. (x -. mean)))
103-
/. Float.of_int (Array.length result)
104+
/. Float.of_int n
104105
in
105106
let std_dev = Float.sqrt variance in
106107
let min_val = Array.min_elt result ~compare:Float.compare |> Option.value ~default:0.0 in
107108
let max_val = Array.max_elt result ~compare:Float.compare |> Option.value ~default:0.0 in
108109

109-
(* Create histogram with dynamic range *)
110-
let histogram_min = Float.max (-4.0) (min_val -. 0.5) in
111-
let histogram_max = Float.min 4.0 (max_val +. 0.5) in
112-
let num_bins = 30 in
113-
let bins = create_histogram result ~num_bins ~min_val:histogram_min ~max_val:histogram_max in
114-
print_histogram bins ~title:"Normal Distribution N(0,1) Histogram" ~max_width:40;
115-
116-
printf "\nStatistics:\n";
117-
printf " Mean: %.4f (expected: ~0.0)\n" mean;
118-
printf " Std Dev: %.4f (expected: ~1.0)\n" std_dev;
119-
printf " Min: %.4f\n" min_val;
120-
printf " Max: %.4f\n" max_val;
121-
122110
(* Check what percentage falls within standard deviations *)
123111
let within_1_std = Array.count result ~f:(fun x -> Float.(abs x <= 1.0)) in
124112
let within_2_std = Array.count result ~f:(fun x -> Float.(abs x <= 2.0)) in
125113
let within_3_std = Array.count result ~f:(fun x -> Float.(abs x <= 3.0)) in
126114

127-
printf " Within 1 std dev: %.1f%% (expected: ~68.3%%)\n"
128-
(Float.of_int within_1_std /. Float.of_int (Array.length result) *. 100.0);
129-
printf " Within 2 std dev: %.1f%% (expected: ~95.4%%)\n"
130-
(Float.of_int within_2_std /. Float.of_int (Array.length result) *. 100.0);
131-
printf " Within 3 std dev: %.1f%% (expected: ~99.7%%)\n"
132-
(Float.of_int within_3_std /. Float.of_int (Array.length result) *. 100.0);
115+
let pct_1_std = Float.of_int within_1_std /. Float.of_int n *. 100.0 in
116+
let pct_2_std = Float.of_int within_2_std /. Float.of_int n *. 100.0 in
117+
let pct_3_std = Float.of_int within_3_std /. Float.of_int n *. 100.0 in
133118

134119
(* Normality test using skewness and kurtosis *)
135120
let skewness =
@@ -138,7 +123,7 @@ let test_normal_at_histogram () =
138123
let diff = x -. mean in
139124
acc +. (diff *. diff *. diff))
140125
in
141-
sum_cubed /. (Float.of_int (Array.length result) *. std_dev *. std_dev *. std_dev)
126+
sum_cubed /. (Float.of_int n *. std_dev *. std_dev *. std_dev)
142127
in
143128

144129
let kurtosis =
@@ -148,12 +133,45 @@ let test_normal_at_histogram () =
148133
let diff2 = diff *. diff in
149134
acc +. (diff2 *. diff2))
150135
in
151-
(sum_fourth /. (Float.of_int (Array.length result) *. std_dev *. std_dev *. std_dev *. std_dev))
152-
-. 3.0
136+
(sum_fourth /. (Float.of_int n *. std_dev *. std_dev *. std_dev *. std_dev)) -. 3.0
137+
in
138+
139+
(* Note: Box-Muller transformation uses transcendental functions (log, cos) which may
140+
produce slightly different results across different CPU architectures and math libraries.
141+
We only verify statistical properties are within acceptable bounds, not exact values. *)
142+
printf "\nNormal Distribution N(0,1) Statistical Test\n";
143+
printf "============================================\n";
144+
printf "Generated %d values\n" n;
145+
146+
(* Verify statistical properties - only print PASS/FAIL to avoid machine-specific output *)
147+
let check name value expected tolerance =
148+
let passed = Float.(abs (value -. expected) <= tolerance) in
149+
printf " %s (expected: ~%.1f, tolerance: %.2f): %s\n" name expected tolerance
150+
(if passed then "PASS" else Printf.sprintf "FAIL (got %.4f)" value);
151+
passed
152+
in
153+
154+
let check_bound name value bound is_lower =
155+
let passed = if is_lower then Float.(value < bound) else Float.(value > bound) in
156+
let op = if is_lower then "<" else ">" in
157+
printf " %s (should be %s %.1f): %s\n" name op bound
158+
(if passed then "PASS" else Printf.sprintf "FAIL (got %.4f)" value);
159+
passed
160+
in
161+
162+
let all_passed =
163+
check "Mean" mean 0.0 0.1
164+
&& check "Std Dev" std_dev 1.0 0.1
165+
&& check "Within 1 std dev %%" pct_1_std 68.3 3.0
166+
&& check "Within 2 std dev %%" pct_2_std 95.4 2.0
167+
&& check "Within 3 std dev %%" pct_3_std 99.7 1.0
168+
&& check "Skewness" skewness 0.0 0.15
169+
&& check "Excess Kurtosis" kurtosis 0.0 0.15
170+
&& check_bound "Min" min_val (-3.0) true
171+
&& check_bound "Max" max_val 3.0 false
153172
in
154173

155-
printf " Skewness: %.4f (expected: ~0.0)\n" skewness;
156-
printf " Excess Kurtosis: %.4f (expected: ~0.0)\n" kurtosis
174+
printf "\nOverall: %s\n" (if all_passed then "ALL TESTS PASSED" else "SOME TESTS FAILED")
157175

158176
let test_batched_generation_consistency () =
159177
Tensor.unsafe_reinitialize ();

0 commit comments

Comments
 (0)