Skip to content

Commit c3e7b6d

Browse files
committed
Experiment to test random number discrepancy on threefry4x32_demo and test_threefry4x32, by Claude Opus
1 parent 230b65d commit c3e7b6d

File tree

1 file changed

+31
-26
lines changed

1 file changed

+31
-26
lines changed

arrayjit/lib/builtins.c

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -91,16 +91,34 @@ static uint16_t float_to_half_emulated(float f) {
9191
/* Too small - flush to zero */
9292
return sign << 15;
9393
}
94-
/* Subnormal */
94+
/* Subnormal - with proper rounding */
9595
uint32_t shift = -new_exp + 1;
96-
mantissa = (mantissa | 0x800000) >> shift;
97-
return (sign << 15) | (mantissa >> 13);
96+
mantissa = (mantissa | 0x800000);
97+
if (shift < 13) {
98+
/* Round before final shift */
99+
mantissa = (mantissa + (1 << (shift + 12))) >> (shift + 13);
100+
} else {
101+
/* Shift is large, need to be careful with rounding */
102+
mantissa = mantissa >> shift;
103+
mantissa = (mantissa + 0x1000) >> 13;
104+
}
105+
return (sign << 15) | mantissa;
98106
} else if (new_exp >= 0x1F) {
99107
/* Overflow to infinity */
100108
return (sign << 15) | (0x1F << 10);
101109
} else {
102-
/* Normal number */
103-
return (sign << 15) | (new_exp << 10) | (mantissa >> 13);
110+
/* Normal number - with proper rounding */
111+
uint32_t rounded_mantissa = (mantissa + 0x1000) >> 13;
112+
if (rounded_mantissa > 0x3FF) {
113+
/* Rounding caused overflow in mantissa */
114+
new_exp++;
115+
rounded_mantissa = 0;
116+
if (new_exp >= 0x1F) {
117+
/* Overflow to infinity */
118+
return (sign << 15) | (0x1F << 10);
119+
}
120+
}
121+
return (sign << 15) | (new_exp << 10) | rounded_mantissa;
104122
}
105123
}
106124

@@ -306,29 +324,15 @@ extern uint16_t uint4x32_to_bfloat16_uniform(uint4x32_t x) {
306324
float f = uint32_to_single_uniform(x.v[0]);
307325
uint32_t bits;
308326
memcpy(&bits, &f, sizeof(float));
309-
return (uint16_t)(bits >> 16);
327+
/* Add proper rounding for bfloat16 */
328+
return (uint16_t)((bits + 0x8000) >> 16);
310329
}
311330

312331
/* Uint4x32 to float16 uniform - uses first 16 bits */
313332
extern uint16_t uint4x32_to_half_uniform(uint4x32_t x) {
314-
/* Simplified conversion - proper fp16 would need more complex handling */
315-
/* This creates a uniform distribution in [0, 1) for fp16 */
316-
uint16_t raw = (uint16_t)(x.v[0] & 0xFFFF);
317-
/* Map to [0, 1) range in fp16 format */
318-
/* Sign bit = 0, exponent between -14 and 0, mantissa from raw bits */
319-
if (raw == 0) return 0;
320-
321-
/* Find the highest set bit */
322-
int shift = 0;
323-
uint16_t temp = raw;
324-
while (temp >>= 1) shift++;
325-
326-
/* Normalize mantissa */
327-
uint16_t mantissa = (raw << (10 - shift)) & 0x3FF;
328-
/* Exponent for values in [0, 1) */
329-
uint16_t exponent = 14 - shift;
330-
331-
return (exponent << 10) | mantissa;
333+
/* Convert through float for consistent behavior */
334+
float f = (x.v[0] & 0xFFFF) * (1.0f / 65536.0f);
335+
return FLOAT_TO_HALF(f);
332336
}
333337

334338
/* Uint4x32 to fp8 uniform - uses first 8 bits */
@@ -407,8 +411,9 @@ extern uint16x8_t uint4x32_to_bfloat16_uniform_vec(uint4x32_t x) {
407411
uint32_t bits1, bits2;
408412
memcpy(&bits1, &f1, sizeof(float));
409413
memcpy(&bits2, &f2, sizeof(float));
410-
result.v[i*2 + 0] = (uint16_t)(bits1 >> 16);
411-
result.v[i*2 + 1] = (uint16_t)(bits2 >> 16);
414+
// Add proper rounding for bfloat16 (round to nearest even)
415+
result.v[i*2 + 0] = (uint16_t)((bits1 + 0x8000) >> 16);
416+
result.v[i*2 + 1] = (uint16_t)((bits2 + 0x8000) >> 16);
412417
}
413418
return result;
414419
}

0 commit comments

Comments
 (0)