In [1]:
import numpy as np
from scipy.linalg import hadamard
from scipy.stats import bernoulli

In [2]:
def Bi_rot(zi_mins, zi_maxs, k):
    """Generates an array for Quantization for each iteration.

    Args:
        zi_mins (ndarray): np.min(zi) for each iteration. shape=(iters,1)
        zi_maxs (ndarray): np.max(zi) for each iteration. shape=(iters,1)
        k (int): no. of quantization levels.

    Returns:
        ndarray: arrays of quantization levels for each iteration.
    """
    si = zi_maxs - zi_mins # shape=(iters,)
    arr = np.arange(k)*(si/(k-1)) # shape=(iters, k)
    return zi_mins + arr # shape=(iters, k)

def encoder_rot(zi, bi, iters, d, k):
    """Encoding Function

    Args:
        zi (ndarray): preprocessed xi. shape=(iters, d)
        bi (ndarray): an array of quantization. shape=(iters, k)
        iters (int): no. of iterations.
        d (int): length of array with each clients.
        k (int): no. of quantization levels.

    Returns:
        ndarray: an array of 0s and 1s. shape=(iters, d)
    """


    # initialize array for storing indices.
    brs = np.zeros((iters, d, 2))
   
    for i in range(iters):
        # np.searchsorted finds the B(r)s of each zi.
        # i.e, left point of quantizing interval in which each zi belong to.
        ids = np.searchsorted(bi[i], zi[i], side='right')-1

        # stacking them to make them look like points.
        # (left point, right point)
        idid = np.vstack((ids, ids+1)).T

        # as the np.searchsorted outputs the last index of bi for np.max(zi),
        # ids+1 will be out of index.
        # outputing last index of bi as both left and right index, will solve the problem.
        idid[idid==k] = k-1

        # replacing indices with quantization values.
        brs[i] = bi[i][idid]
    
    # finds probabilities of the elements,
    # outputs 0 for np.max(xi).
    probs = np.where(
        ((brs[..., 1] - brs[..., 0]) != 0), 
        np.divide((zi - brs[..., 0]),(brs[..., 1]- brs[..., 0] + (1e-100))), 
        0
    )
    # outputs 1s and 0s based on the probabilities,
    return bernoulli.rvs(probs), brs

def decoder_rot(enc, brs):
    """Decoding Function

    Args:
        enc (ndarray): output of encoding function. shape=(iters, d)
        brs (ndarray): an array of quantization levels. shape=(iters, k)

    Returns:
        ndarray: a decoded array. shape=(iters, d)
    """

    # for each iteration:
        #  replaces 1s with their corresponding B(r+1)s
        #  and 0s with corresponding B(r)s.
    # np.max(xi) prob is 0. so, this function output with last level...
    # value of Bi as the decoded value.
    return np.where(enc,  brs[..., 1], brs[..., 0])

In [3]:
def sto_rot_k(n, d, k, ITERS=1024):
    """Stochastic K-level Rotation Quantization.

    Args:
        n (int): no. of clients.
        d (int): len of array each client has.
        k (int): no. of quantization levels.
        ITERS (int, optional): no. of iterations. Defaults to 1024.
    Returns:
        None.
        prints:
            lemma-7 for each iteration.
            lemma-6.

    """
    
    # generate a uniform distributin of X with the shown shape
    X = np.random.rand(n, ITERS, d)
    
    # generating Rotation Matrix.
    R = hadamard(d) @ np.diag(
        np.random.choice([1, -1], size=d, p=[0.5, 0.5])
    ) / np.sqrt(d)
    
    # verifying orthogonal property of generated Rotation Matrix.
    assert np.allclose(np.eye(d), R @ R.T), "R is not orthogonal."
    
    # intilizing Y.
    Y = np.zeros((ITERS, d))
    tot_expec = 0
    
    for xi in X:
        # xi.shape=(iters,d)
        # doing R@xi for each iteration.
        zi = np.transpose(R @ xi.T)

        # checking a property: norm(zi)==norm(xi).
        assert np.allclose(
            np.linalg.norm(zi),
            np.linalg.norm(xi)
        ), "l2 norm zi != l2 norm xi"

        # finding maxs, mins for each iteration.
        # both of them have shape=(iters,1)        
        zi_maxs = np.max(zi, axis=1, keepdims=True)
        zi_mins = np.min(zi, axis=1, keepdims=True)
        
        # proving lemma-7
        expec_zmin = np.mean(zi_mins**2)
        expec_zmax = np.mean(zi_maxs**2)
        # randomly choosing one xi.
        up_bnd_minmax = (np.linalg.norm(
            xi[np.random.randint(ITERS)])**2/d) * ((np.log(d**2) + 2))

        print("expected_val_sq_zmax:\t\t   ", expec_zmax)
        print("expected_val_sq_zmin:\t\t   ", expec_zmin)
        print("upper bound for sq_zmax or sq_zmin:", up_bnd_minmax)
        print("-----------------------------------")
        
        # defining an array of quantizations. shape=(iters,k)
        bi = Bi_rot(zi_maxs=zi_maxs, zi_mins=zi_mins, k=k)

        encs, brs = encoder_rot(zi=zi, bi=bi, k=k,
                            iters=ITERS, d=d)

        yi = decoder_rot(encs, brs)

        # adding the encoded and decoded xi of each xi (in order to take their mean).
        Y += yi
    
        tot_expec += expec_zmin+expec_zmax
    # dividing by number of client (inorder to take their mean).
    Z_hat_mean = Y/n 
    # doing inv(R)@zi for each iteration.
    x_hat_mean = np.transpose(np.linalg.inv(R) @ Z_hat_mean.T)

    # calculating mean along dimensions of several users
    x_mean = np.mean(X, axis=0)

    # expected error: (using sec-1.2 in DME)
    # calculating norm for each iteration. output_shape=(iters,)
    # and then calculating mean of the iterations. output_type: scalar.
    error_obs = np.mean((np.linalg.norm((x_mean - x_hat_mean), axis=1)**2))

    # lemma-6
    # calaculating the norm for each client, squaring and adding them.
    error_cal_bnd = (d/(2*((n*(k-1))**2))) * tot_expec
    print("\n\nexpected error:\t\t", error_obs)
    print("upper bound for error:\t", error_cal_bnd)

In [4]:
sto_rot_k(n=8, d=64, k=8)

expected_val_sq_zmax:		    2.5332678589430406
expected_val_sq_zmin:		    1.257900176823522
upper bound for sq_zmax or sq_zmin: 3.460655945824374
-----------------------------------
expected_val_sq_zmax:		    2.585813068669737
expected_val_sq_zmin:		    1.2857760605312913
upper bound for sq_zmax or sq_zmin: 2.98019293282247
-----------------------------------
expected_val_sq_zmax:		    2.55935309100999
expected_val_sq_zmin:		    1.2726398941002066
upper bound for sq_zmax or sq_zmin: 3.414375283423076
-----------------------------------
expected_val_sq_zmax:		    2.55812561169597
expected_val_sq_zmin:		    1.2824979575426032
upper bound for sq_zmax or sq_zmin: 3.5594658554073257
-----------------------------------
expected_val_sq_zmax:		    2.5299509312005037
expected_val_sq_zmin:		    1.2713682834027284
upper bound for sq_zmax or sq_zmin: 2.9188680637044153
-----------------------------------
expected_val_sq_zmax:		    2.5354856753067248
expected_val_sq_zmin:		    1.2875441302677282
upp

In [5]:
sto_rot_k(n=8, d=512, k=16)

expected_val_sq_zmax:		    3.2334682963620818
expected_val_sq_zmin:		    3.919838746205841
upper bound for sq_zmax or sq_zmin: 5.098583500048262
-----------------------------------
expected_val_sq_zmax:		    3.2944685783449943
expected_val_sq_zmin:		    3.9248930256852774
upper bound for sq_zmax or sq_zmin: 4.998301313388171
-----------------------------------
expected_val_sq_zmax:		    3.2230210479675105
expected_val_sq_zmin:		    3.9312510657027353
upper bound for sq_zmax or sq_zmin: 4.664292489862427
-----------------------------------
expected_val_sq_zmax:		    3.238911547460232
expected_val_sq_zmin:		    3.929153693515752
upper bound for sq_zmax or sq_zmin: 4.851210401205554
-----------------------------------
expected_val_sq_zmax:		    3.249904814369916
expected_val_sq_zmin:		    3.895851444498983
upper bound for sq_zmax or sq_zmin: 5.07293966717923
-----------------------------------
expected_val_sq_zmax:		    3.24074886063252
expected_val_sq_zmin:		    3.958363407917499
upper b

In [6]:
sto_rot_k(n=8, d=512, k=8)

expected_val_sq_zmax:		    2.6787695684507975
expected_val_sq_zmin:		    3.367220101360331
upper bound for sq_zmax or sq_zmin: 4.633329955683656
-----------------------------------
expected_val_sq_zmax:		    2.689317608233672
expected_val_sq_zmin:		    3.3951115640614633
upper bound for sq_zmax or sq_zmin: 4.598873054911518
-----------------------------------
expected_val_sq_zmax:		    2.6724444779973764
expected_val_sq_zmin:		    3.335616140169188
upper bound for sq_zmax or sq_zmin: 4.8731645287395935
-----------------------------------
expected_val_sq_zmax:		    2.706671437428022
expected_val_sq_zmin:		    3.353581742607641
upper bound for sq_zmax or sq_zmin: 4.464973758616667
-----------------------------------
expected_val_sq_zmax:		    2.6890020010534195
expected_val_sq_zmin:		    3.336865767476686
upper bound for sq_zmax or sq_zmin: 4.991619477856375
-----------------------------------
expected_val_sq_zmax:		    2.665710303422432
expected_val_sq_zmin:		    3.3559820420891944
uppe