In [1]:
import qtorch  
import qtorch.quant

If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


In [11]:
import heapq
import torch

def q(x, man=3):
    return qtorch.quant.float_quantize(x, 8, man, rounding="nearest")

def sorted_add(xs, man):
    xs = xs.clone()

    xs = list(xs)
    # sorted by abs value
    xs.sort(key=lambda x: abs(x))
    acc = 0
    for x in xs:
        acc = q(acc + x, man)
    return acc.item()

def normal_add(xs, man):
    xs = xs.clone()
    xs = list(xs)
    acc = 0
    for x in xs:
        acc = q(acc + x, man)
    return acc.item()

def bitree_add(xs, man):
    xs = xs.clone()
    xs = list(xs)
    while len(xs) > 1:
        new_xs = []
        for i in range(0, len(xs), 2):
            if i + 1 < len(xs):
                new_xs.append(q(xs[i] + xs[i + 1], man))
            else:
                new_xs.append(xs[i])
        xs = new_xs
    return xs[0]


def insert_add(xs, man):
    xs = xs.clone()
    xs = list(xs)
    xs_heap = [(abs(x), x) for x in xs]
    heapq.heapify(xs_heap)

    while len(xs_heap) > 1:
        _, a = heapq.heappop(xs_heap)
        _, b = heapq.heappop(xs_heap)
        r = q(a + b, man)
        heapq.heappush(xs_heap, (abs(r), r))
    return xs_heap[0][1]



In [12]:
x = torch.randn(1000)

In [13]:
print(x.sum())

tensor(38.3111)


In [14]:
normal_add(x, 4)

54.0

In [15]:
sorted_add(x, 4)

38.0

In [16]:
insert_add(x, 10)

tensor(38.3125)

In [17]:
N_EXP = 100
fl_result = {}
for fl in [6, 7, 8, 9, 10]:
    mse_sorted_acc = 0
    mse_normal_acc = 0
    mse_bitree_acc = 0
    mse_insert_acc = 0
    for i in range(N_EXP):
        x = torch.randn(500)
        full_result = x.sum()
        sorted_result = sorted_add(x, fl)
        normal_result = normal_add(x, fl)
        bitree_result = bitree_add(x, fl)
        insert_result = insert_add(x, fl)
        mse_sorted_acc += (full_result - sorted_result) ** 2
        mse_normal_acc += (full_result - normal_result) ** 2
        mse_bitree_acc += (full_result - bitree_result) ** 2
        mse_insert_acc += (full_result - insert_result) ** 2
    fl_result[fl] = {
        "sorted": mse_sorted_acc / N_EXP,
        "normal": mse_normal_acc / N_EXP,
        "bitree": mse_bitree_acc / N_EXP,
        "insert": mse_insert_acc / N_EXP
    }
    
    

In [18]:
fl_result

{6: {'sorted': tensor(0.3556),
  'normal': tensor(1.6778),
  'bitree': tensor(0.1554),
  'insert': tensor(0.0814)},
 7: {'sorted': tensor(0.1553),
  'normal': tensor(0.3856),
  'bitree': tensor(0.0380),
  'insert': tensor(0.0196)},
 8: {'sorted': tensor(0.0385),
  'normal': tensor(0.1051),
  'bitree': tensor(0.0084),
  'insert': tensor(0.0052)},
 9: {'sorted': tensor(0.0065),
  'normal': tensor(0.0242),
  'bitree': tensor(0.0018),
  'insert': tensor(0.0011)},
 10: {'sorted': tensor(0.0023),
  'normal': tensor(0.0047),
  'bitree': tensor(0.0005),
  'insert': tensor(0.0002)}}

In [19]:
for fl, result in fl_result.items():
    print(fl, sum(result)/ len(result))


TypeError: unsupported operand type(s) for +: 'int' and 'str'