In [1]:
import numpy as np
from scipy.stats import bernoulli
%matplotlib inline
import matplotlib.pyplot as plt
import warnings
warnings.simplefilter("ignore")

In [2]:
def Bi(xi, k):
    """A function to generate quantization levels.

    Args:
        xi (ndarray): parameters to quantize. shape=(d,).
        k (int): no. of quantization levels.

    Returns:
        ndarray: an array of quantization levels
    """

    # the below code is straight forward from the paper
    # the whole function can be replaces with
    # np.linspace(low=np.min(xi), high=np.max(xi), num=k)
    
    si = np.max(xi) - np.min(xi)
    arr = np.arange(k)*(si/(k-1))
    return np.min(xi) + arr

In [3]:
def encoding(xi, k, iters, d):
    """Encoding function.

    Args:
        xi (ndarray): parameters to quantize. shape=(d,).
        brs (ndarray): an array of quantization levels. shape=(k,).
        iters (int): no. of iterations.
        d (int): dimensions of the array.

    Returns:
        ndarray: an array of 1s and 0s with shape(iters, d)
        ndarray: an array of left and right quantizing points of each xi.
                 shape=(d, 2)
    """
    bi = Bi(xi, k) # generating array of quantization levels.

    # np.searchsorted finds the B(r)s of each xi.
    # i.e, left point of quantizing interval in which each xi belong to.
    br_ids = np.searchsorted(bi, xi, side='right')-1

    # stacking them to make them look like points.
    # (left point, right point)
    brs = np.vstack((br_ids, br_ids+1)).T

    # as the np.searchsorted outputs the last index of bi for np.max(xi),
    # ids+1 will be out of index.
    # outputing last index of bi as both left and right index will solve the problem.
    brs[brs==k] = k-1

    # replacing indices with quantization values.
    brs = bi[brs]

    # finds probabilities of the elements,
    # outputs 0 for np.max(xi).
    probs = np.where(
        (brs[:, 1] - brs[:, 0]) != 0,
        (xi - brs[:, 0])/(brs[:, 1]-brs[:, 0]),0)
    
    # outputs 1s and 0s based on the probabilities,
    # repeats the above line for iters no. of times.
    return bernoulli.rvs(probs, size=(iters, d)), brs


def decoding(brs, encs,):
    """Decoding function.

    Args:
        brs (ndarray): an array of quantization levels.
                        shape = (k,)
        encs (ndarray): an arrays of 0s and 1s (output of encoding function).
                        shape = (iters, d)

    Returns:
        ndarray: a decoded array. shape=(iters, d)
    """
    
    # for each iteration:
        #  replaces 1s with their corresponding B(r+1)s
        #  and 0s with corresponding B(r)s.
    # np.max(xi) prob is 0. so, this function output with last level...
    # value of Bi as the decoded value.
    return (np.where(encs, brs[:, 1], brs[:, 0]))

In [4]:
def sto_k(k, n=8, d=128, ITERS=1024):
    """Stochastic K-level Quantization

    Args:
        k (int): no. of Quantization levels.
        n (int, optional): no. of clients. Defaults to 2.
        d (int, optional): length of array for each user. Defaults to 64.
        ITERS (int, optional): no. of iterations. Defaults to 4096.

    Returns:
        int: Mean Squared Error.
    """

    # considering uniform distribution
    x = np.random.rand(n, d)
    
    total = 0
    for xi in x: # for each client:
        # xi.shape = (d,)

        encs, brs = encoding(xi=xi, k=k, iters=ITERS, d=d)
        yi = decoding(brs=brs, encs=encs.copy())

        # adding the encoded and decoded xi of each xi (in order to take their mean).
        total += yi

    # dividing by number of client (inorder to take their mean).
    x_hat_mean = total / n # shape = (iters, d)

    # calculating mean along dimensions of several users
    x_mean = np.mean(x, axis=0)

    # expected error: (using sec-1.2 in DME)
    # calculating norm for each iteration. output_shape=(iters,)
    # and then calculating mean of the iterations. output_type: scalar.
    err = np.mean(np.linalg.norm((x_hat_mean - x_mean), axis=1)**2)

    # lemma-5 in DME
    # calaculating the norm for each client, squaring and adding them.
    bnd = ((0.5 * d)/(n*(k-1))**2)*np.sum(np.linalg.norm(x, axis=1)**2)
    
    print("error:", err)
    print("up bound:", bnd)
    print("relation holds") if err <= bnd else print("FAILURE")
    return err
    

In [5]:
sto_k(k=2)

error: 2.5096941193824365
up bound: 349.289573033618
relation holds


2.5096941193824365

In [6]:
sto_k(k=4)

error: 0.2868429965841466
up bound: 39.14730825654012
relation holds


0.2868429965841466

In [7]:
sto_k(k=32)

error: 0.0025835141028086755
up bound: 0.35351666208062493
relation holds


0.0025835141028086755

In [15]:
sto_k(k=40)

error: 0.0016546350852397397
up bound: 0.21961607866470131
relation holds


0.0016546350852397397

In [16]:
sto_k(k=64)

error: 0.0006456990149832766
up bound: 0.09060712564285647
relation holds


0.0006456990149832766