@@ -91,16 +91,34 @@ static uint16_t float_to_half_emulated(float f) {
9191 /* Too small - flush to zero */
9292 return sign << 15 ;
9393 }
94- /* Subnormal */
94+ /* Subnormal - with proper rounding */
9595 uint32_t shift = - new_exp + 1 ;
96- mantissa = (mantissa | 0x800000 ) >> shift ;
97- return (sign << 15 ) | (mantissa >> 13 );
96+ mantissa = (mantissa | 0x800000 );
97+ if (shift < 13 ) {
98+ /* Round before final shift */
99+ mantissa = (mantissa + (1 << (shift + 12 ))) >> (shift + 13 );
100+ } else {
101+ /* Shift is large, need to be careful with rounding */
102+ mantissa = mantissa >> shift ;
103+ mantissa = (mantissa + 0x1000 ) >> 13 ;
104+ }
105+ return (sign << 15 ) | mantissa ;
98106 } else if (new_exp >= 0x1F ) {
99107 /* Overflow to infinity */
100108 return (sign << 15 ) | (0x1F << 10 );
101109 } else {
102- /* Normal number */
103- return (sign << 15 ) | (new_exp << 10 ) | (mantissa >> 13 );
110+ /* Normal number - with proper rounding */
111+ uint32_t rounded_mantissa = (mantissa + 0x1000 ) >> 13 ;
112+ if (rounded_mantissa > 0x3FF ) {
113+ /* Rounding caused overflow in mantissa */
114+ new_exp ++ ;
115+ rounded_mantissa = 0 ;
116+ if (new_exp >= 0x1F ) {
117+ /* Overflow to infinity */
118+ return (sign << 15 ) | (0x1F << 10 );
119+ }
120+ }
121+ return (sign << 15 ) | (new_exp << 10 ) | rounded_mantissa ;
104122 }
105123}
106124
@@ -306,29 +324,15 @@ extern uint16_t uint4x32_to_bfloat16_uniform(uint4x32_t x) {
306324 float f = uint32_to_single_uniform (x .v [0 ]);
307325 uint32_t bits ;
308326 memcpy (& bits , & f , sizeof (float ));
309- return (uint16_t )(bits >> 16 );
327+ /* Add proper rounding for bfloat16 */
328+ return (uint16_t )((bits + 0x8000 ) >> 16 );
310329}
311330
312331/* Uint4x32 to float16 uniform - uses first 16 bits */
313332extern uint16_t uint4x32_to_half_uniform (uint4x32_t x ) {
314- /* Simplified conversion - proper fp16 would need more complex handling */
315- /* This creates a uniform distribution in [0, 1) for fp16 */
316- uint16_t raw = (uint16_t )(x .v [0 ] & 0xFFFF );
317- /* Map to [0, 1) range in fp16 format */
318- /* Sign bit = 0, exponent between -14 and 0, mantissa from raw bits */
319- if (raw == 0 ) return 0 ;
320-
321- /* Find the highest set bit */
322- int shift = 0 ;
323- uint16_t temp = raw ;
324- while (temp >>= 1 ) shift ++ ;
325-
326- /* Normalize mantissa */
327- uint16_t mantissa = (raw << (10 - shift )) & 0x3FF ;
328- /* Exponent for values in [0, 1) */
329- uint16_t exponent = 14 - shift ;
330-
331- return (exponent << 10 ) | mantissa ;
333+ /* Convert through float for consistent behavior */
334+ float f = (x .v [0 ] & 0xFFFF ) * (1.0f / 65536.0f );
335+ return FLOAT_TO_HALF (f );
332336}
333337
334338/* Uint4x32 to fp8 uniform - uses first 8 bits */
@@ -407,8 +411,9 @@ extern uint16x8_t uint4x32_to_bfloat16_uniform_vec(uint4x32_t x) {
407411 uint32_t bits1 , bits2 ;
408412 memcpy (& bits1 , & f1 , sizeof (float ));
409413 memcpy (& bits2 , & f2 , sizeof (float ));
410- result .v [i * 2 + 0 ] = (uint16_t )(bits1 >> 16 );
411- result .v [i * 2 + 1 ] = (uint16_t )(bits2 >> 16 );
414+ // Add proper rounding for bfloat16 (round to nearest even)
415+ result .v [i * 2 + 0 ] = (uint16_t )((bits1 + 0x8000 ) >> 16 );
416+ result .v [i * 2 + 1 ] = (uint16_t )((bits2 + 0x8000 ) >> 16 );
412417 }
413418 return result ;
414419}
0 commit comments