In [1]:
import numpy as np

In [289]:
# for fun, extract individual components of fp16
# can view with hex() or bin()
def extract_bits_f16(x: np.float16):
    raw = b.view(np.uint16)
    # highest bit
    sign = raw >> (5 + 10)
    # next 5 bits
    exponent = np.uint16(raw << 1) # need explicit truncation
    exponent = (exponent >> (1 + 10)) - 15
    # next 10 bits
    mantissa = np.uint16(raw << (1 + 5))
    mantissa = mantissa >> (1 + 5)
    return (sign, exponent, mantissa)

b = np.float16(5.0)
(sign, exponent, mantissa) = extract_bits_f16(b)
print(sign, exponent, hex(mantissa))

# Mantissa = 01_0000_0000 (in binary)
#          = (1 + 0*1/2 + 1*(1/2)^2)
# Note that for this "normal" floating point number, first bit of mantissa is 1 implicitly
#
# Check: 5 = 2^2 * mantissa = 4 * (5/4)

0 2 0x100


In [291]:
def split_single(x: np.float32):
    assert type(x) == np.float32
    
    hi = x.astype(np.float16)
    lo = (x - hi.astype(np.float32))    # remainder
    lo = lo * np.float32(2**10)         # while still in fp32, avoid exponent underflow
    lo = lo.astype(np.float16)
    return (hi, lo)

def combine_halves(hi, lo):
    return hi + lo * (2**(-10))

In [292]:
x1 = np.float32(1.2345678901)
x2 = np.float32(-23000.456789012)
x3 = x1 * x2

In [296]:
(hi_1, lo_1) = split_single(x1)
(hi_2, lo_2) = split_single(x2)

x1_r = combine_halves(hi_1, lo_1)
print((x1 - x1_r) / x1)

x2_r = combine_halves(hi_2, lo_2)
print((x2 - x2_r) / x2)

0.0
-0.0


In [297]:
# multiply/accumulate at f32 precision
hi_3 = np.float32(hi_1) * hi_2
lo_3 = np.float32(hi_1) * lo_2 + np.float32(lo_1) * hi_2
x3_r = combine_halves(hi_3, lo_3)
print((x3 - x3_rr) / x3)

-0.0
