In [63]:
from __future__ import annotations
import typing
from dataclasses import dataclass
from contextlib import contextmanager
import time

In [64]:
import numpy as np

In [65]:
@contextmanager
def localize_globals(*exceptions: str, restore_values: bool = True):
    exceptions: typing.Set[str] = set(exceptions)

    old_globals: typing.Dict[str, typing.Any] = dict(globals())
    allowed: typing.Set[str] = set(old_globals.keys())
    allowed.update(exceptions)

    yield None

    new_globals: typing.Dict[str, typing.Any] = globals()

    for name in tuple(new_globals.keys()):
        if name not in allowed:
            del new_globals[name]
    
    if not restore_values:
        return
    
    new_globals.update(
        {k: v for k, v in old_globals.items() if k not in exceptions}
    )

In [66]:
class measure_time:
    _start: float
    _end: float
    
    def __init__(self):
        self._start = 0
        self._end = 0
    
    def __enter__(self) -> measure_time:
        self._start = time.perf_counter()
        return self
    
    def __exit__(self, exc_type, exc_value, traceback) -> None:
        self._end = time.perf_counter()
    
    @property
    def elapsed(self) -> float:
        return self._end - self._start
    
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}({self.elapsed:.2e})"

In [67]:
np.random.seed(42)

with localize_globals("dataset_64"):
    n: typing.Final[int] = 10 ** 6
    random_scaled = lambda scale, loc=0, size=n: np.random.normal(loc=loc, size=size)
    
    dataset_64: np.ndarray = np.hstack([
        random_scaled(1e-9),
        random_scaled(1e0),
        random_scaled(1e3),
    ])

In [68]:
np.random.choice(dataset_64, size=10)

array([-0.85033265, -0.66929662, -0.28918369,  0.8509891 , -0.71009343,
       -0.71079159, -0.46595553, -1.74863657, -0.07533348,  0.08785799])

In [69]:
dataset: np.ndarray = dataset_64.astype(np.float16)

In [70]:
np.random.choice(dataset, size=10)

array([ 0.7656 ,  0.586  , -0.696  ,  0.0932 ,  0.0483 , -0.519  ,
       -0.3389 , -1.314  ,  1.54   ,  0.01839], dtype=float16)

In [71]:
true_sum: float = np.sum(dataset_64)
true_sum

-1655.027614386756

In [72]:
def check_summing_method(do_sum: typing.Callable[[np.ndarray], float], name: str = "?") -> None:
    timer = measure_time()
    
    with timer:
        method_sum: float = np.float64(do_sum(dataset))
    
    error: float = np.abs(true_sum - method_sum)
    
    print(f"{name}: {error:.2e} (true sum: {true_sum}, method sum: {method_sum}) (done in {timer.elapsed:f}s)")


In [73]:
def do_sum_linear(arr: np.ndarray) -> float:
    result: float = np.float16(0)
    
    for item in arr:
        result += item
    
    return result

check_summing_method(do_sum_linear, "linear")

linear: 3.39e+02 (true sum: -1655.027614386756, method sum: -1316.0) (done in 0.382734s)


In [74]:
def do_sum_tree(arr: np.ndarray) -> float:
    if len(arr) == 0:
        return np.float16(0)
    
    if len(arr) == 1:
        return arr[0]
    
    return do_sum_tree(arr[::2]) + do_sum_tree(arr[1::2])

check_summing_method(do_sum_tree, "tree")

tree: 2.97e+00 (true sum: -1655.027614386756, method sum: -1658.0) (done in 2.704604s)


In [75]:
def do_kahan_sum(arr: np.ndarray) -> float:
    result: float = np.float16(0)
    compensation: float = np.float16(0)
    
    for item in arr:
        y: float = item - compensation
        t: float = result + y
        compensation = (t - result) - y
        result = t
    
    return result

check_summing_method(do_kahan_sum, "kahan")

kahan: 1.03e+00 (true sum: -1655.027614386756, method sum: -1654.0) (done in 0.842458s)
