In [12]:
from copy import deepcopy

import torch

from main import seed_it
from client import EncodeMaskedSparser, EncodeSparser, EncodeQuantizer
from server import DecodeMaskedSparser, DecodeSparser, DecodeQuantizer

In [13]:
seed_it(42)

# compression error

In [None]:
ITER = 100
SHAPE = (256, 64, 32, 32)

QU_encode = EncodeQuantizer(bit=3)
QU_decode = DecodeQuantizer(bit=3)
SP_encode = EncodeSparser(ratio=25 / 32)
SP_decode = DecodeSparser(ratio=25 / 32)
MS_encode = EncodeMaskedSparser(bit=2, ratio=31 / 32)
MS_decode = DecodeMaskedSparser(bit=2, ratio=31 / 32)

c1, c2, c3 = 0, 0, 0
for i in range(ITER):
    t = torch.relu(torch.randn(SHAPE))
    t1 = deepcopy(t)
    t2 = deepcopy(t)
    t3 = deepcopy(t)

    x1 = torch.zeros_like(t1)
    x2 = torch.zeros_like(t2)
    x3 = torch.zeros_like(t3)

    QU_decode.decode(x1, *QU_encode.encode(t1))
    SP_decode.decode(x2, *SP_encode.encode(t2))
    MS_decode.decode(x3, *MS_encode.encode(t3))

    c1 += (x1 - t1).norm(2).item()
    c2 += (x2 - t2).norm(2).item()
    c3 += (x3 - t3).norm(2).item()

print(f"The avg of quantization error: {c1 / ITER}")
print(f"The avg of sparsification error: {c2 / ITER}")
print(f"The avg of mask sparsification error: {c3 / ITER}")


# overhead test

In [None]:
QU_encode = EncodeQuantizer(bit=2)
QU_decode = DecodeQuantizer(bit=2)
SP_encode = EncodeSparser(ratio=0.96)
SP_decode = DecodeSparser(ratio=0.96)
MS_encode = EncodeMaskedSparser(bit=2, ratio=0.99)
MS_decode = DecodeMaskedSparser(bit=2, ratio=0.99)

t = torch.abs(torch.randn(256, 64, 64, 64))
t1 = deepcopy(t)
t2 = deepcopy(t)
t3 = deepcopy(t)
x1 = torch.zeros_like(t1)
x2 = torch.zeros_like(t2)
x3 = torch.zeros_like(t3)

In [None]:
%timeit -r 10 -n 1 tmp = QU_encode.encode(t1)
tmp = QU_encode.encode(t1)
%timeit -r 10 -n 1 QU_decode.decode(x1, *tmp)

In [None]:
%timeit -r 10 -n 1 tmp = SP_encode.encode(t2)
tmp = SP_encode.encode(t2)
%timeit -r 10 -n 1 SP_decode.decode(x2, *tmp)

In [None]:
%timeit -r 10 -n 1 tmp = MS_encode.encode(t3)
tmp = MS_encode.encode(t3)
%timeit -r 10 -n 1 MS_decode.decode(x3, *tmp)