<a href="https://colab.research.google.com/github/P3109/Public/blob/main/Value%20Tables/make-value-tables.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Value tables of 8-bit floating point formats

This code allows the reader to experiment with properties of FP8 formats, and is produced in conjunction with the work of the IEEE P3109 working group on
floating point formats for machine learning, although is *not* an official output of that group.

The public outputs of of the group are available at https://github.com/P3109/Public and the interim report is at [PDF](https://github.com/P3109/Public/blob/main/Shared%20Reports/P3109%20WG%20Interim%20report.pdf)

The code here is low quality experimental code intended to allow quick experimentation across a range of formats.

Higher quality code is at https://github.com/awf/ml_dtypes_p3109/pull/1.

In [None]:
%pip install airium

In [None]:
from enum import Enum
from dataclasses import dataclass
import numpy as np
import airium
import pandas
from IPython.display import HTML

# Throw on overflow
np.seterr(over='raise')

def fstr(v):
  """
  Format special values as NaN,+/-Inf
  """
  if np.isfinite(v):
    return str(v)
  if np.isnan(v):
    return 'NaN'
  if v == np.inf:
    return "+Inf"
  if v == -np.inf:
    return "-Inf"
  raise ValueError(f"Bad {v=}")

## FormatInfo

A small dataclass holding format information.  


In [None]:
@dataclass
class FormatInfo:
  name: str
  precision: int
  emax: int
  significandBits: int
  expBits: int
  expBias: int
  all_bits_one_full: bool  # Set if all-bits-one exponent is all nonfinite
  has_nz: bool  # Set if format has negative 0. If false, assume 0x80 is NaN
  #               following e.g. LLVM Float8E5M2FNUZ
  has_infs : bool # Set if format has +/- infinity


def format_info_p3109(precision) -> FormatInfo:
  significandBits = precision - 1
  expBits = 8 - precision # precision includes leading 0/1

  # formula: emax = fix(2 ** (7-p) - 1)
  emax = int(np.fix(2 ** (7-precision) - 1))

  all_bits_one_full = precision <= 1

  # bias is derived from emax, as in IEEE 754
  if (not all_bits_one_full) and (expBits > 0):
    expBias = emax + 1
  else:
    expBias = emax

  has_infs = True
  has_nz = False
  return FormatInfo(f'p3109_p{precision}', precision, emax, significandBits, expBits, expBias, all_bits_one_full, has_nz, has_infs)

format_info_ocp_e5m2 = FormatInfo(name='ocp_e5m2', precision=3, emax=(1<<4)-1, significandBits=2, expBits=5, expBias=(1<<4)-1, all_bits_one_full=True, has_nz=True, has_infs=True)
format_info_ocp_e4m3 = FormatInfo(name='ocp_e4m3', precision=4, emax=(1<<3), significandBits=3, expBits=4, expBias=(1<<3)-1, all_bits_one_full=False, has_nz=True, has_infs=False)

other_formats = [format_info_ocp_e5m2, format_info_ocp_e4m3]
p3109_formats = [format_info_p3109(i) for i in range(1,9)]
p3109_and_ocp_formats = p3109_formats + other_formats

display(pandas.DataFrame(p3109_and_ocp_formats).sort_values(by='precision'))
print('precision: Number of significand bits (including implicit leading bit encoded by exponent>0)')
print('has_nz: Has negative zero (following e.g. LLVM Float8E5M2FNUZ)')
print('all_bits_one_full: An all-bits-one exponent corresponds to 2**(p-1) codes.',
      'If all of those codes are used for specials, e.g. Inf/NaN, this field is True.')

## FloatValue

A floating-point value in detail: bit fields, classification, printing


In [None]:
class FloatClass(Enum):
  NORMAL = 1     # A positive or negative normalized non-zero value
  SUBNORMAL = 2	 # A positive or negative subnormal value
  ZERO	= 3      # A positive or negative subnormal zero value
  INFINITE = 4	 # A positive or negative infinity
  NAN = 5        # A NaN


@dataclass
class FloatValue:
  """
  A floating-point value in detail
  """
  ival: int           # Integer code point
  fval: float         # Value [Note 1]
  valstr: str         # Value as string, assuming all code points finite
  exp: int            # Raw exponent without bias
  expval: int         # Exponent, bias subtracted
  significand: int    # Significand (significandissa) as an integer
  fsignificand: float # Significand as a float in the range [0,2)
  signbit: int        # Sign bit: 1 => negative, 0 => positive
  signstr: str        # String representation of sign
  fclass: bool        # See enum FloatClass
  fi: FormatInfo      # Backlink to FormatInfo

  # [Note 1]
  # Values are assumed to be exactly round-trippable to python float == float64.
  # This is true for all <64bit formats known in 2023.

def decode_f8(i : int, fi : FormatInfo) -> FloatValue:
  signbit = 1 if i & 0x80 else 0
  sign = -1 if signbit else 1
  signstr = '-' if sign == -1 else '+'

  exp = (i & 0x7f) >> fi.significandBits
  significand = i & ((1 << fi.significandBits) - 1)

  isnormal = exp != 0
  if isnormal:
    expval = exp-fi.expBias
    fsignificand = 1.0 + significand * 2 ** -fi.significandBits
  else:
    expval = 1-fi.expBias
    fsignificand = significand * 2 ** -fi.significandBits

  # val: the raw value excluding specials
  val = sign * fsignificand * 2 ** expval

  # valstr: string representation of value in base 10
  # If the representation does not roundtrip to the value,
  # it is preceded by a "~" to indicate "approximately equal to"
  valstr = f'{val}'
  if len(valstr) > 14:
    valstr = f'{val:.8}'
  if float(valstr) != val:
    valstr = '~'+valstr

  fclass = None
  if val == 0:
    fclass = FloatClass.ZERO
  elif isnormal:
    fclass = FloatClass.NORMAL
  else:
    fclass = FloatClass.SUBNORMAL
  
  fval = val
  if fi.all_bits_one_full and fi.significandBits > 1:
    # all_bits_one exponent has NaNs where it doesn't have infs
    if expval == fi.emax + 1:
      if fi.has_infs and significand == 0:
        fclass = FloatClass.INFINITE
        fval = -np.inf if signbit else np.inf
      else:
        fclass = FloatClass.NAN
        fval = np.nan
  elif fi.has_infs:
    if i == 0xff:
      fclass = FloatClass.INFINITE
      fval = -np.inf
    elif i == 0x7f:
      fclass = FloatClass.INFINITE
      fval = np.inf
  else:
    if i == 0xff:
      fclass = FloatClass.NAN
      fval = np.nan
    elif i == 0x7f:
      fclass = FloatClass.NAN
      fval = np.nan

  if i == 0x80:
    if fi.has_nz:
      fval = -0.0
    else:
      fclass = FloatClass.NAN
      fval = np.nan
  
  # update valstr if a special value
  if fclass not in (FloatClass.ZERO, FloatClass.SUBNORMAL, FloatClass.NORMAL):
    valstr = str(fval)

  return FloatValue(i,fval,valstr,
                    exp,expval,significand,fsignificand,signbit,signstr,
                    fclass,fi)

for fi in (format_info_p3109(precision=3), format_info_ocp_e5m2):
  print(fi.name)
  for ival in (0x00, 0x01, 0x40, 0x80, 0x7e, 0x7f):
    print(decode_f8(ival, format_info_ocp_e5m2), sep='\n')

# Spot-check p3
fi = format_info_p3109(3)
dec = lambda ival: decode_f8(ival, fi).fval
fclass = lambda ival: decode_f8(ival, fi).fclass
assert dec(0x01) == 2.0 ** -17
assert dec(0x40) == 1.0
assert np.isnan(dec(0x80))
assert dec(0xff) == -np.inf
assert np.floor(np.log2(dec(0x7e))) == fi.emax

# Spot-check e5m2
fi = format_info_ocp_e5m2
assert dec(0x01) == 2.0 ** -16
assert dec(0x40) == 2.0
assert dec(0x80) == 0.0 and np.signbit(dec(0x80))
assert dec(0xfc) == -np.inf
assert np.isnan(dec(0x7f))
assert dec(0x7c) == np.inf
assert np.floor(np.log2(dec(0x7b))) == fi.emax
assert fclass(0x80) == FloatClass.ZERO
assert fclass(0x00) == FloatClass.ZERO

# Spot-check e4m3
fi = format_info_ocp_e4m3
assert dec(0x01) == 2.0 ** -9
assert dec(0x40) == 2.0
assert dec(0x80) == 0.0 and np.signbit(dec(0x80))
assert np.isnan(dec(0x7f))
assert np.floor(np.log2(dec(0x7e))) == fi.emax

## Enumerate all values for a given format

In [None]:
# Get the FormatInfo
fi = format_info_p3109(precision=4)
print(fi)

# Generate values as a list of FloatValues
values = [decode_f8(i, fi) for i in range(256)]

# Convert to dataframe (and drop the FormatInfo)
pandas.DataFrame(values).set_index('ival').drop(columns=["fi"])

## Printing

String formatting for binary16 and F8 values

In [None]:
import struct
def b16_str(val) -> tuple[str,str]:
  """
  Represent VAL in binary16.

  If val does not convert exactly to binary16,
  returns "<Not16:{val}>"
  """
  with np.errstate(over="ignore"):
    b16 = np.float16(val)

  if float(b16) != val and not np.isfinite(b16):
    # Finite, but not representable in float16
    return f'<Not16:{val}>',''
  b16_int = struct.unpack('!H',struct.pack('!e',b16))[0]

  # bitstr is of the form 0_00000_1100000000
  s = f'{b16_int:016b}'
  e_str = s[1:6]
  m_str = s[6:]
  bitstr = f'{s[0]}_{e_str}_{m_str}'

  # pow2str is of the form '+0b0.1100000000*2^-15', or '' for nonfinite values
  e = int(e_str,2) - 15
  m = int(m_str,2)
  leading_bit = 0 if e == -15 else 1
  signstr= '-' if s[0] == '1' else '+'
  if np.isfinite(b16):
    pow2str = f'{signstr}0b{leading_bit}.{m:010b}*2^{e}'
  else:
    pow2str = ''
  return bitstr,pow2str

for v in [3*2**-14, 3*2**-15, 3*2**-16, 3*2**-18]:
  print(b16_str(v))
print(b16_str(-np.inf))
print(b16_str(2 ** 16))
assert b16_str(3*2**-16) == ('0_00000_1100000000', '+0b0.1100000000*2^-15')

#### Render with underscores separating s_e_m

E.g `0_1011_110`.  For formats with zero significand bits or zero exponent bits, we use `0_1011110_` or `0__10111110`.

In [None]:
def str_bits_with_underscores(fv):
  # 0_1011110_
  if fv.fi.significandBits == 0:
    return f'{fv.signbit}_{fv.exp:0{fi.expBits}b}_'

  # 0__1011110
  if fv.fi.expBits == 0:
    return f'{fv.signbit}__{fv.significand:0{fi.significandBits}b}'

  # 0_101_1110
  return f'{fv.signbit}_{fv.exp:0{fi.expBits}b}_{fv.significand:0{fi.significandBits}b}'

fi = format_info_p3109(3)
assert str_bits_with_underscores(decode_f8(0x41, fi)) == '0_10000_01'

fi = format_info_p3109(1)
assert str_bits_with_underscores(decode_f8(0x41, fi)) == '0_1000001_'

fi = format_info_p3109(8)
assert str_bits_with_underscores(decode_f8(0x41, fi)) == '0__1000001'


In [None]:
def str_tablerow(fv, show_b16_info=True):
  """
  Create a string of the form
    0x41 0_10000_01 = +0b1.01*2^0   = 1.25
  optionally adding binary16 info
    0x41 0_10000_01 = +0b1.01*2^0   = 0_01111_0100000000 +0b1.0100000000*2^0 = 1.25
  """
  text = []

  # 0x45 0_1000_101
  text.append(f'0x{fv.ival:02x} {str_bits_with_underscores(fv)}')
  
  finite_nonzero = np.isfinite(fv.fval) and fv.fval != 0
  
  #  = +0b1.101*2^-7 =
  if finite_nonzero:
    b = '0' if fv.fclass == FloatClass.SUBNORMAL else '1'
    binary_pow2 = f'{fv.signstr}0b{b}.{fv.significand:0{fi.significandBits}b}*2^{fv.expval:<3}'
    text.append(binary_pow2)

  if show_b16_info and finite_nonzero:
    b16_binary_str,b16_bscistr = b16_str(fv.fval)
    text.append(f'{b16_binary_str} {b16_bscistr}')

  # 1.125
  text.append(fv.valstr)

  # Return tuple
  return " = ".join(text)

fi = format_info_p3109(3)
for i in (0x00, 0x01, 0x07, 0x21, 0x40, 0x41, 0x7e, 0x7f, 0x80, 0x81,0xfe,0xff):
  print(str_tablerow(decode_f8(i, fi), show_b16_info=True))


## Enumerate all values for a given precision

In [None]:
fi = format_info_p3109(precision=4)
values = [decode_f8(i, fi) for i in range(256)]

# Copy from FloatValue above
print("""
  ival: int           # Integer code point
  fval: float         # Value [Note 1]
  valstr: str         # Value as string, assuming all code points finite
  exp: int            # Raw exponent without bias
  expval: int         # Exponent, bias subtracted
  significand: int    # Significand (mantissa) as an integer
  fsignificand: float # Significand as a float in the range [0,2)
  signbit: int        # Sign bit: 1 => negative, 0 => positive
  signstr: str        # String representation of sign
  fclass: FloatClass  # See FloatClass
""")
pandas.DataFrame(values).drop(columns=['fi']).set_index('ival')

In [None]:

import collections
def nt(**kwargs):
    fields = kwargs.keys()
    ty = collections.namedtuple('NT', fields)
    return ty(*(kwargs[f] for f in fields))

def collect_stats(fi : FormatInfo):
  # Generate all values
  values = [decode_f8(i, fi) for i in range(256)]
  df = pandas.DataFrame(values)
  
  # Extract format information parameters
  E=fi.expBits
  M=fi.significandBits

  # Compute statistics: lt1,ge1
  fval = df['fval']
  total_01 = fval.between(0, 1, inclusive='neither').sum()
  total_1Inf = fval.between(1, np.inf, inclusive='left').sum()

  # Compute statistics: maxFinite,minFinite
  finite_vals = fval[np.isfinite(fval)]
  maxFinite = finite_vals.loc[finite_vals.idxmax()]
  minFinite = finite_vals.loc[finite_vals.idxmin()]

  # Compute statistics: maxNormal,minNormal
  normal_vals = fval[(df['fclass'] == FloatClass.NORMAL) & (fval>0)]
  maxNormal = normal_vals.loc[normal_vals.idxmax()] if normal_vals.any() else np.nan
  minNormal = normal_vals.loc[normal_vals.idxmin()] if normal_vals.any() else np.nan

  # Compute statistics: minSubnormal
  pos_subnormal = fval[(df['fclass'] == FloatClass.SUBNORMAL) & (fval>0)]
  maxSubnormal = pos_subnormal.loc[pos_subnormal.idxmax()] if pos_subnormal.any() else np.nan
  minSubnormal = pos_subnormal.loc[pos_subnormal.idxmin()] if pos_subnormal.any() else np.nan

  # Compute roundtrips: rt16, rt32
  with np.errstate(over='ignore'):
    rt16 = (np.float64(np.float16(fval)) == np.float64(fval)) | ~np.isfinite(fval)
    rt32 = (np.float64(np.float32(fval)) == np.float64(fval)) | ~np.isfinite(fval)

  rt16 = rt16.all()
  rt32 = rt32.all()
  assert rt32 # If not, we should include rt32 in the table
  
  # Assemble tuple
  return nt(name=fi.name,P=fi.precision,E=E,M=M,
            lt1=total_01,ge1=total_1Inf,
            rt16=rt16,
            maxFinite=maxFinite,minFinite=minFinite,
            maxNormal=maxNormal,minNormal=minNormal,
            minSubnormal=minSubnormal,maxSubnormal=maxSubnormal)

stats = [collect_stats(fi) for fi in p3109_and_ocp_formats]
df = pandas.DataFrame(stats)
df

## Render values as exact fractions * 2^e


In [None]:
import fractions

def pow2str(v):
  """
  Render finite values as string (p/q) x2^e 
  """
  if not np.isfinite(v):
    return str(v)
  
  s = np.sign(v)
  x = np.abs(v)
  e = int(np.floor(np.log2(x)))
  mant = fractions.Fraction(x * 2**-e)
  return ('-' if s < 0 else '') + f'{mant}*2^{e:d}'

df_pow2 = df.copy()
for field in ("maxFinite","minFinite","maxNormal","minNormal","minSubnormal","maxSubnormal"):
  df_pow2[field] = df_pow2[field].map(pow2str)
df_pow2

## Generate LaTeX table of min/max values

In [None]:
df_pow2[["name","minSubnormal","maxSubnormal","minNormal","maxNormal","maxFinite"]]


In [None]:
tbl3 = df[["name","minSubnormal","maxSubnormal","minNormal","maxNormal","maxFinite"]]
header=["Format","minSubnormal","maxSubnormal","minNormal","maxNormal","maxFinite"]
latex = tbl3.to_latex(header=header, index=False, float_format=pow2str)

import re
latex = re.sub('([\d./]+)\*2\^([\d-]+)','$\\\\binaryfloat{\\1}{\\2}$', latex)
latex = latex.replace('p3109_p', 'p')
latex = latex.replace('ocp_', 'ocp\_')
print('% File: tbl-extremalvalues.tex')
print('% Generated from https://github.com/P3109/Public/blob/main/Value%20Tables/make-value-tables.ipynb')
print(latex)

## Value Tables: LaTeX

In [None]:
def table_style(fv):
  if fv.fclass == FloatClass.SUBNORMAL:
     return "subnormal"
  
  if fv.fclass == FloatClass.NORMAL:
     return "normal"

  if fv.fclass == FloatClass.ZERO and not fv.signbit:
     return "normal"

  # Everyting else is special
  return "special"

assert table_style(decode_f8(0x80, format_info_ocp_e5m2)) == "special"
assert table_style(decode_f8(0x01, format_info_ocp_e5m2)) == "subnormal"
assert table_style(decode_f8(0x7f, format_info_ocp_e5m2)) == "special"


def mktbl(fi : FormatInfo, file):
  # Make tables
  cols = 4
  rows = 256//cols

  for i in range(0,rows):
    out_row = []
    for n in range(i,256,rows):
      fv = decode_f8(n, fi)
      text = str_tablerow(fv, show_b16_info=False)
      style = table_style(fv)

      # 2^-7  -> \pow{-7}
      text = re.sub(r'\*2\^([-0-9]+)', r'\\pow{\1}', text)
      # 1.234E7  -> \e{1.234}{7}
      text = re.sub(r'([0-9.+-]+)[eE]([0-9]+)', r'\\e{\1}{\2}', text)
      # 1.234E-7  -> \e{1.234}{-7}
      text = re.sub(r'([0-9.+-]+)[eE]-([0-9]+)', r'\\e{\1}{\\neg{\2}}', text)
      # -1.234 EOL  -> \f{-1.234}
      text = re.sub(r'([0-9.+-]+)$', r'\\f{\1}', text)
      #\f{-1.234} -> \f{\neg{1.234}}
      text = re.sub(r'\\(f|pow)\{-([0-9.]+)\}', r'\\\1{\\neg{\2}}', text)
      
      # ^0x44  -> 0x44~
      text = re.sub('^(0x[0-9a-f]+) ', '\\1 = ', text)
      text = text.replace('_', '\_')
      text = text.replace('= ~', '\\approx ')
      out_row += [f'\\{style}{{{text}}}']
    print(*out_row, sep="&\n", end=r'\\'+"\n", file=file)

import pathlib
dir = pathlib.Path("latex")
dir.mkdir(parents=True, exist_ok=True)

for fi in p3109_and_ocp_formats:
  filename = f"value-table-{fi.name}.tex"
  print(f'Saving to {dir / filename}')
  with open(dir / filename, "w") as f:
    print('% Autogenerated from make-value-tables.ipynb at https://github.com/P3109/Public', file=f)
    mktbl(fi, f)


## Generate HTML Value Tables

In [None]:
def mktbl(fi : FormatInfo):
  # Make tables
  cols = 4
  rows = 256//cols

  style = f'''
  body {{
  }}
  table {{
    width:100%;
    margin: 0pt;
    font-family: monospace;
    font-size: tiny;
    border-collapse: collapse;
  }}

  tr.blankrow {{
    height: 4ex;
    vertical-align: top;
  }}
  
  td {{
    text-align: left;
    border: solid 2px #ccc;
    width: {98/cols}%;
  }}
  
  .special {{
    color: #874723;
  }}
  
  .subnormal {{
    color: #012187;
  }}
  
  .normal {{
  }}
  
  pre {{
    margin: 1pt 1pt 13pt 13pt;
    display: inline;
  }}
'''

  title = f"FP8 Value Table, {fi.name}"
  a = airium.Airium()
  a('<!DOCTYPE html>')
  with a.html():
    with a.head():
        # a.meta('http-equiv="refresh" content="1"') # for rapid testing
        a.meta(charset="utf-8")
        a.title(_t=title)
        a.style(_t=style)

    with a.body():
        a.h3(_t=title)

        with a.table():
          for i in range(0,rows):
            trklass='blankrow' if i > 0 and i%16 == 0 else ''
            with a.tr(klass=trklass):
              for n in range(i,256,rows):
                fv = decode_f8(n, fi)
                text = str_tablerow(fv, show_b16_info=False)
                a.td(klass=table_style(fv)).pre(_t=text)

  return str(a)

import pathlib
dir = pathlib.Path("html")
dir.mkdir(parents=True, exist_ok=True)

a = airium.Airium()
a('<!DOCTYPE html>')
a('<!-- Autogenerated from make-value-tables.ipynb -->')
a.head().title(_t='F8 value tables')
with a.body():
  a.h2(_t='F8 value tables')
  with a.ol():
    for fi in p3109_and_ocp_formats:
      html_str = mktbl(fi)
      filename = f"value-table-{fi.name}.html"
      print(f'Saving to {dir / filename}')
      a.li().a(href=filename).code(_t=fi.name)
      with open(dir / filename, "w") as f:
        f.write(html_str)

index_filename = dir / 'index.html'
print(f'Saving {index_filename}')
with open(index_filename, "w") as index:
  index.write(str(a))


In [None]:
HTML(mktbl(format_info_p3109(1)))

In [None]:
HTML(mktbl(format_info_p3109(2)))

In [None]:
HTML(mktbl(format_info_p3109(3)))

In [None]:
HTML(mktbl(format_info_p3109(4)))

In [None]:
HTML(mktbl(format_info_p3109(5)))

In [None]:
HTML(mktbl(format_info_p3109(6)))

In [None]:
HTML(mktbl(format_info_p3109(7)))

In [None]:
HTML(mktbl(format_info_p3109(8)))

## Format tables: OCP formats

In [None]:
HTML(mktbl(format_info_ocp_e5m2))

In [None]:
HTML(mktbl(format_info_ocp_e4m3))

## Misc

In [None]:
# Demonstrating use of "log(0)" to create infs
import torch
torch.log(torch.rand((10,)) > 0.5)