In [36]:
import math
import numpy as np
import torch

from dahuffman import HuffmanCodec

class ECUQCompressor():
    def __init__(self, device, quantum_num, tolerance=0.1):
        self.device = device
        self.quantum_num = quantum_num
        self.bit = math.ceil(np.log2(quantum_num + 1))
        self.tolerance = tolerance

    def quantize(self, x, Q):
        pass

    def double_binary_search_num_quantization_levels(self, x: torch.Tensor):
        device = x.device
        low = 2**self.bit
        high = torch.inf
        p = -1
        x_max = torch.max(x)
        x_min = torch.min(x)

        while low <= high:
            if high == torch.inf:
                p += 1
                mid = 2**self.bit + 2**p
            else:
                mid = (low + high) / 2

            K = int(mid)
            delta = (x_max - x_min) / mid 
            print(f"low: {low} high: {high} mid: {mid} p: {p} b: {self.bit} K: {K} Delta: {delta}")

            Q = x_min + (torch.linspace(0, K-1, K, device=self.device) + 0.5) * delta
            x_hat_Q = torch.argmin(torch.abs(x - Q.view([-1,1])), 0)

            p_Q = torch.bincount(x_hat_Q) / x.numel()

            min_real = torch.finfo(x.dtype).min
            H_p_Q = - torch.sum(p_Q * torch.clamp(torch.log2(p_Q), min=min_real))
            print(f"Current entropy: {H_p_Q}")

            if H_p_Q > self.bit:
                high = mid - 1
            elif H_p_Q < self.bit - self.tolerance:
                low = mid + 1
            else:
                return x_hat_Q, delta
        print("ECUQ cannot find solution within the given bit budget")
        return x_hat_Q, delta


    def compress(self, x: torch.Tensor, name=""):
        shape = x.size()
        x = x.flatten()

        x_max = torch.max(x)
        x_min = torch.min(x)
        K = 2 ** self.bit
        delta = (x_max - x_min) / K

        # Find quantization values Q
        Q = x_min + (torch.linspace(0, K-1, K, device=self.device) + 0.5) * delta
        # Equivalent to
        # Q = torch.tensor([x_min + (k + 0.5) * delta for k in range(K)])

        # Quantize x
        x_hat_Q = torch.argmin(torch.abs(x - Q.view([-1,1])), 0)

        # Find empirical density of quantized x
        p_Q = torch.bincount(x_hat_Q) / x.numel()
        print(f"Current empirical density: {p_Q}")

        # Find entropy of quantized x
        min_real = torch.finfo(x.dtype).min
        H_p_Q = - torch.sum(p_Q * torch.clamp(torch.log2(p_Q), min=min_real))
        print(f"Current entropy: {H_p_Q}")

        if H_p_Q < self.bit - self.tolerance:
            x_hat_Q, delta = self.double_binary_search_num_quantization_levels(x)

        bin_count = torch.bincount(x_hat_Q)
        freq_table = {k:v for k, v in enumerate(bin_count)}
        codec = HuffmanCodec.from_frequencies(freq_table)
        encoded = codec.encode(x_hat_Q.tolist())
        # codec.print_code_table()
        print(f"Encoded size bits {len(encoded) * 8} bit budget {x_hat_Q.numel() * self.bit} ratio {len(encoded) * 8 / (x_hat_Q.numel() * self.bit)}")
        return encoded, (shape, x_min, delta, codec)


    def decompress(self, tensor_compressed: bytes, ctx):
        shape, x_min, delta, codec = ctx
        x = torch.tensor(codec.decode(tensor_compressed), device=self.device)

        tensor_decompressed = x_min + (x + 0.5) * delta
        tensor_decompressed = tensor_decompressed.view(shape)
        return tensor_decompressed
    
    def calculate_size(self, numel):
        res_bits = numel * (1 + np.ceil(np.log2(self.quantum_num + 1)))
        self._compressed_size  = 32 + res_bits # in bits
        return self._compressed_size
    
    


In [37]:
def find_NMSE(x1, x2):
    return torch.norm(x1 - x2)

In [38]:
numel = 1000
device = "cuda"

# x = torch.randint(0, 100, [5, 5], dtype=torch.float, device=device)
# x[0,0] = 0
# x[0,1] = 100

# m = torch.distributions.Dirichlet(torch.tensor([0.5, 0.5], device=device))
# x = m.sample([numel])[:,0]

m = torch.distributions.LogNormal(0, 1)
x = m.sample([numel, numel]).to(device=device)

# m = torch.distributions.Normal(0, 1)
# x = m.sample([numel]).to(device=device)

# m = torch.distributions.Normal(0, 0.1)
# x = m.sample([numel]).to(device=device)

x.shape

torch.Size([1000, 1000])

In [39]:
compressor = ECUQCompressor(device, 4)
print(f"Bit: {compressor.bit}")

x_compressed, ctx = compressor.compress(x)
# x_compressed[:10]

x_decompressed = compressor.decompress(x_compressed, ctx)

assert x.shape == x_decompressed.shape

print(x)
print(x_decompressed)
print(find_NMSE(x, x_decompressed))


Bit: 3
Current empirical density: tensor([9.9430e-01, 5.0570e-03, 5.0400e-04, 1.0200e-04, 1.8000e-05, 1.4000e-05,
        6.0000e-06, 3.0000e-06], device='cuda:0')
Current entropy: 0.0543198436498642
low: 8 high: inf mid: 9 p: 0 b: 3 K: 9 Delta: 11.056159019470215
Current entropy: 0.07240406423807144
low: 10 high: inf mid: 10 p: 1 b: 3 K: 10 Delta: 9.950543403625488
Current entropy: 0.09232079237699509
low: 11 high: inf mid: 12 p: 2 b: 3 K: 12 Delta: 8.292119979858398
Current entropy: 0.1373068243265152
low: 13 high: inf mid: 16 p: 3 b: 3 K: 16 Delta: 6.219089508056641
Current entropy: 0.24158763885498047
low: 17 high: inf mid: 24 p: 4 b: 3 K: 24 Delta: 4.146059989929199
Current entropy: 0.4752896726131439
low: 25 high: inf mid: 40 p: 5 b: 3 K: 40 Delta: 2.487635850906372
Current entropy: 0.9372392892837524
low: 41 high: inf mid: 72 p: 6 b: 3 K: 72 Delta: 1.3820198774337769
Current entropy: 1.6614880561828613
low: 73 high: inf mid: 136 p: 7 b: 3 K: 136 Delta: 0.7316575646400452
Current

In [41]:
import sys

sys.path.insert(1, "/home/qfyan/FedScale")

from fedscale.utils.compressor.lfl import LFLCompressor
from fedscale.utils.compressor.qsgd import QSGDCompressor
from fedscale.utils.compressor.qsgd_bucket import QSGDBucketCompressor
from fedscale.utils.compressor.eden_wrapper import EDENCompressor

print(x)
c2 = EDENCompressor(3, "cuda")
x_c2, ctx_c2 = c2.compress(x)
x_c2_d = c2.decompress(x_c2, ctx_c2)
print(x_c2_d)
print(find_NMSE(x, x_c2_d))


c3 = QSGDBucketCompressor(4)
x_c3, ctx_c3 = c3.compress(x)
x_c3_d = c3.decompress(x_c3, ctx_c3)
print(x_c3_d)
print(find_NMSE(x, x_c3_d))

c5 = QSGDCompressor(4)
x_c5, ctx_c5 = c5.compress(x)
x_c5_d = c5.decompress(x_c5, ctx_c5)
print(x_c5_d)
print(find_NMSE(x, x_c5_d))

c4 = LFLCompressor(4)
x_c4, ctx_c4 = c4.compress(x)
x_c4_d = c4.decompress(x_c4, ctx_c4)
print(x_c4_d)
print(find_NMSE(x, x_c4_d))


tensor([[0.4555, 0.7558, 0.4367,  ..., 2.3310, 3.3497, 0.4724],
        [0.4093, 4.7010, 0.8158,  ..., 0.7639, 0.1884, 1.7461],
        [1.5182, 0.5217, 0.5123,  ..., 8.7076, 0.6022, 0.6836],
        ...,
        [2.7874, 1.6192, 0.0992,  ..., 0.6495, 0.8002, 1.9223],
        [1.8472, 0.7047, 0.1636,  ..., 1.2500, 0.1844, 3.3420],
        [0.9197, 1.0456, 0.1886,  ..., 0.2873, 3.7068, 0.2299]],
       device='cuda:0')
2
tensor([[ 1.0826e+00,  1.7376e+00,  7.7652e-01,  ...,  2.1328e+00,
          2.1752e+00,  1.1264e+00],
        [ 2.3750e-01,  4.4754e+00,  9.4954e-01,  ...,  5.9449e-01,
          1.4436e+00,  2.6600e+00],
        [ 2.7480e+00,  9.4431e-02,  1.0424e+00,  ...,  8.3671e+00,
          1.0610e+00, -9.5120e-01],
        ...,
        [ 1.3314e+00,  9.1829e-01,  1.0100e-03,  ...,  1.2476e+00,
          2.5730e-01,  3.3508e+00],
        [ 2.3726e+00,  1.7598e+00,  1.1429e+00,  ...,  1.8251e+00,
         -6.2444e-01,  3.5601e+00],
        [ 8.5903e-01, -2.5779e-01,  1.7625e-01, 

In [53]:
if device == "cuda":
    torch.cuda.empty_cache()