In [1]:
import numpy as np
import pandas as pd
import math

import src.hdmm.workload as workload
import src.census_workloads as census
from src.workload_selection import workload_selection
import online_workloads as online_workloads

# Database

In [2]:
data_path = "migration_tworace.csv"
x_data = pd.read_csv(data_path, header=None).to_numpy().T[1]
n = x_data.shape[0]
n

86

# Workloads

In [3]:
W_name = ['identity', 'total', 'race1', 'race2', 'race3', 'custom', 'prefix_sum']
W_lst = [online_workloads.identity(n), online_workloads.total(n), online_workloads.race1(), online_workloads.race2(), online_workloads.race3(), online_workloads.custom(n), online_workloads.prefix_sum(n)]

def print_workload(workload_name, workload):
    print(f'---Workload: {workload_name}---')
    print(f'Shape: {workload.shape}')
    print(f'Workload: \n{workload}\n')
    
def print_workloads():
    for i in range(7):
        print_workload(W_name[i], W_lst[i])
        
print_workloads()


---Workload: identity---
Shape: (86, 86)
Workload: 
[[1. 0. 0. ... 0. 0. 0.]
 [0. 1. 0. ... 0. 0. 0.]
 [0. 0. 1. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 1. 0. 0.]
 [0. 0. 0. ... 0. 1. 0.]
 [0. 0. 0. ... 0. 0. 1.]]

---Workload: total---
Shape: (1, 86)
Workload: 
[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
  1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
  1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
  1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]

---Workload: race1---
Shape: (7, 64)
Workload: 
[[0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0.

In [4]:
len(online_workloads.identity(n))

86

In [5]:
x_data

array([412, 333, 285, 231, 202, 174, 160, 142, 146, 149, 145, 181, 174,
       190, 213, 287, 372, 499, 619, 715, 785, 821, 822, 816, 799, 742,
       717, 697, 658, 593, 564, 519, 447, 403, 388, 365, 336, 306, 311,
       289, 261, 231, 213, 196, 194, 170, 175, 168, 149, 142, 131, 119,
       112, 118, 114, 116, 112, 114, 106, 111, 109, 112, 113, 109, 104,
       108, 108,  94,  91,  81,  81,  72,  68,  63,  56,  46,  41,  38,
        34,  28,  23,  22,  18,  18,  16,  41])

In [6]:
# Todo: Add code to make epsilon adjustable proportionally. 
# Todo: Fix divide by zero error (<ipython-input-51-46063bd03d8a>:37: RuntimeWarning: divide by zero encountered in double_scalars delta = 1 / (n * math.log(n, np.e)))

def pmw_split(workload, x, analyst_labels, T=5, eps=0.01, k=5, show_messages=True, to_return='pd', show_plot=False, show_failure_step=True):
    """
    Implement Private Multiplicative Weights Mechanism (PMW) on a workload of
    linear queries. 

    Algorithm Parameters: 
    - workload = workload of queries (M x k numpy array)
    - x = true database (M x 1 numpy array)
    - T = update threshold
    - eps = privacy budget
    - k = number of update steps PER ANALYST
    - analyst_labels = list of analyst names corresponding to each query in the workload
    
    Output Controls: 
    - show_messages argument determines whether the function will print information such as 
    error scale, threshold, update steps used, etc.
    - to_return argument determines what the function will return. 
        - if 'pd', pmw() returns pandas df with test data for each 
        query in the workload(showing query, d_t_hat, updated, algo_ans, real_ans, 
        abs_error, rel_error). 
        - if 'update_count', pmw() returns the update count for the total
        amount of queries.
    - show_plot - T/F whether the function will display a plot
    - show_failure_step - T/F whether function prints what step failure mode is reached
    """ 
    
    update_steps = {}
    for analyst in list(set(analyst_labels)): 
        update_steps[analyst] = k # each analyst starts with k update steps
    
    # initialize constants
    m = x.size  # database len
    n = x.sum()
    eta = (math.log(m, np.e) ** (1 / 4)) / (math.sqrt(n))
    delta = 1 / (n * math.log(n, np.e))
    x_norm = x / np.sum(x)
    
    # initialize synthetic databases at time 0 (prior to any queries)
    x_t = np.ones(m) / m
    y_t = np.ones(m) / m

    # initialize tracker lists to construct pandas dataframe at the end 
    x_list = [x_t] # create a list of x_t synthetic database at every time step
    update_list = []
    update_count = 0
    pmw_answers = []
    update_times = [] # record times that database is updated
    d_t_hat_list = []
    
    def lazy_round():
        """
        "Lazy Round" of querying using the stored synthetic database, x_t, in list x_list.
        p
        We call this the lazy round because it is contrasted with the updated step where we update the 
        sythetic database and answer the query using the real database.
        """
        update_list.append('no')
        pmw_answers.append(np.dot(query, x_list[time]))
        x_list.append(x_list[time].round(3))
    
    # inititate first instance of SVT with half the budget and k updates; will be reset in the main loop
    SVTtrigger = False 
    SVTepsilon1 = ((eps/2)/2)
    SVTepsilon2 = ((eps/2)/2)
    rho = np.random.laplace(loc=0, scale=(1/SVTepsilon1), size=1)[0]
    
    for time, query in enumerate(workload):
        
        analyst = analyst_labels[time]
        
        # Do one round of sparse vector technique 
        
        # Compute noisy answer by adding Laplacian noise
        a_t = np.random.laplace(loc=0, scale=(2*k/SVTepsilon2), size=1)[0]
    
        a_t_hat = (np.dot(query, x_norm)*n ) + a_t
        
        # Difference between noisy and maintained histogram answer
        d_t_hat = a_t_hat - (n*np.dot(query, x_list[time]))
        
        # Lazy round: use synthetic base to answer the query
        if (abs(d_t_hat) <= T + rho):
            d_t_hat_list.append(d_t_hat)
            lazy_round()
            continue

        # update round: update histogram and return noisy answer
        else:
            #make a new noisy query answer using some of the leftover budget
            a_t = np.random.laplace(loc=0, scale=(2*k/eps), size=1)[0]
            a_t_hat = (np.dot(query, x_norm)*n ) + a_t
            d_t_hat = a_t_hat - (n*np.dot(query, x_list[time]))
            d_t_hat_list.append(d_t_hat)
            update_times.append(time)
            
            # step a
            if d_t_hat < 0:
                r_t = query
            else:
                r_t = np.ones(m) - query
            for i, v in enumerate(y_t):
                y_t[i] = x_list[time][i] * math.exp((d_t_hat/(2*n)) * query[i]) * 20 # 20 is the learning rate
            
            # step b
            x_t = y_t / np.sum(y_t)
            update_count = update_list.count('yes')
            
            # if threshold for num updates is reached, just do a lazy round (synthetic database) answer
            if update_steps[analyst] == 0: 
                if show_failure_step:
                    print(f'Failure mode reached at query number {time}: {query}')
                lazy_round()
                
            # if there are still update steps that the analyst can use, 
            # 1. update the synthetic database
            # 2. answer the query using the noisy answer from the database itself 
            else: 
                x_list.append(x_t.round(3))
                update_list.append('yes') # increment number of updates counter
                pmw_answers.append(a_t_hat / np.sum(x))
                update_steps[analyst] -= 1 # use one of analyst's update steps

    update_count = update_list.count('yes')      

    # calculate error
    real_ans = np.matmul(workload, x_norm)
    abs_error = np.abs(pmw_answers - real_ans)
    rel_error = np.abs(abs_error / np.where(real_ans == 0, 0.000001,
                                                real_ans))
    
    if show_messages:
        np.set_printoptions(suppress=True)
        """Print inputes/outputs to analyze each query"""
        print(f'Original database: {x}\n')
        print(f'Normalized database: {x_norm}\n')
        print(f'Updated Database = {x_t}\n')
        print(f'Update Count = {update_count}\n')
        print(f'{T=}\n')
        print(f'Error Scale Query Answer= {2*((2*k/eps)**2)}\n')
        print(f'Error Scale SVT= {2*((2*k/SVTepsilon2)**2)}\n')
        print(f'Update Parameter Scale = {eta}\n')
        print(f'{delta=}\n')
        
    if show_plot: 
        plt.title('Error across queries:')
        rel_line, = plt.plot(rel_error, label='Relative Error')
        abs_line, = plt.plot(abs_error, label='Absolute Error')
        for xc in update_times:
            plt.axvline(x=xc, color='red', label='Update Times', linestyle='dashed')
        plt.legend(handles=[abs_line,rel_line])
        plt.xticks(range(0, len(workload), round(len(workload)/5)))
    
    if to_return == "pd":
        # hacky fix: remove the first synthetic database to keep length of lists consistent with the
        # other lists that comprise of the pandas dataframe
        x_list.pop(0).tolist() 
        d = {
            'queries': workload.tolist(), 
            'synthetic database (after query)': x_list,
            'algo_ans': pmw_answers,
            'real_ans': real_ans.tolist(),
            'updated': update_list,
            'abs_error': abs_error,               
            'rel_error': rel_error,
            'analyst': analyst_labels,
            'd_t_hat': d_t_hat_list, 

             }
        test_data = pd.DataFrame(data=d)
        test_data = test_data.round(3)
        return test_data
    
    # return dictionary of absolute errors
    if to_return == "error":
        d = {'analyst': analyst_labels,
             'abs_error': abs_error,               
             'rel_error': rel_error,}
        data = pd.DataFrame(data=d)
        data = data.round(3)
        
        analyst_error = {}
        for analyst in list(set(analyst_labels)):
            analyst_error[analyst] = data[data.analyst==analyst]['abs_error'].sum()
        return analyst_error
    
x_example = np.array([1, 2, 3, 4, 5])
pmw_split(np.vstack((online_workloads.identity(5), online_workloads.identity(5))), x_example, ['Alice'] * 10, eps=0.01)



Failure mode reached at query number 5: [1. 0. 0. 0. 0.]
Failure mode reached at query number 6: [0. 1. 0. 0. 0.]
Failure mode reached at query number 7: [0. 0. 1. 0. 0.]
Failure mode reached at query number 8: [0. 0. 0. 1. 0.]
Failure mode reached at query number 9: [0. 0. 0. 0. 1.]
Original database: [1 2 3 4 5]

Normalized database: [0.06666667 0.13333333 0.2        0.26666667 0.33333333]

Updated Database = [0. 0. 1. 0. 0.]

Update Count = 5

T=5

Error Scale Query Answer= 2000000.0

Error Scale SVT= 32000000.0

Update Parameter Scale = 0.29081910083756185

delta=0.024617958204590337



Unnamed: 0,queries,synthetic database (after query),algo_ans,real_ans,updated,abs_error,rel_error,analyst,d_t_hat
0,"[1.0, 0.0, 0.0, 0.0, 0.0]","[0.007, 0.248, 0.248, 0.248, 0.248]",-6.841,0.067,yes,6.907,103.61,Alice,-105.61
1,"[0.0, 1.0, 0.0, 0.0, 0.0]","[0.009, 0.0, 0.33, 0.33, 0.33]",-66.87,0.133,yes,67.004,502.527,Alice,-1006.774
2,"[0.0, 0.0, 1.0, 0.0, 0.0]","[0.0, 0.0, 1.0, 0.0, 0.0]",70.016,0.2,yes,69.816,349.078,Alice,1045.283
3,"[0.0, 0.0, 0.0, 1.0, 0.0]","[0.0, 0.0, 1.0, 0.0, 0.0]",-98.53,0.267,yes,98.796,370.486,Alice,-1477.944
4,"[0.0, 0.0, 0.0, 0.0, 1.0]","[0.0, 0.0, 1.0, 0.0, 0.0]",-14.395,0.333,yes,14.729,44.186,Alice,-215.929
5,"[1.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 1.0, 0.0, 0.0]",0.0,0.067,no,0.067,1.0,Alice,958.767
6,"[0.0, 1.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 1.0, 0.0, 0.0]",0.0,0.133,no,0.133,1.0,Alice,372.748
7,"[0.0, 0.0, 1.0, 0.0, 0.0]","[0.0, 0.0, 1.0, 0.0, 0.0]",1.0,0.2,no,0.8,4.0,Alice,-2022.99
8,"[0.0, 0.0, 0.0, 1.0, 0.0]","[0.0, 0.0, 1.0, 0.0, 0.0]",0.0,0.267,no,0.267,1.0,Alice,-101.752
9,"[0.0, 0.0, 0.0, 0.0, 1.0]","[0.0, 0.0, 1.0, 0.0, 0.0]",0.0,0.333,no,0.333,1.0,Alice,927.228


It seems to update too much during the update steps (step 0-5). Using the example database of [1, 2, 3, 4, 5] and a workload of two stacked 5x5 identity matrices, during the first 5 queries consisting of the first identity matrix, the database updates the first few vectors way too much. 

In [216]:
a = 5
a == None

False

In [181]:
# pmw_split
def pmw2(workload, x, analyst_labels, T, eps=0.01, k=None, 
         show_messages=False, to_return='pd', show_plot=False, show_failure_step=False):
    """
    Implement Private Multiplicative Weights Mechanism (PMW) on a workload of
    linear queries. 

    Algorithm Parameters: 
    - workload = workload of queries (M x k numpy array)
    - x = true database (M x 1 numpy array)
    - T = update threshold
    - eps = privacy budget
    - k = number of update steps PER ANALYST
    - analyst_labels = list of analyst names corresponding to each query in the workload
    
    Output Controls: 
    - show_messages argument determines whether the function will print information such as 
    error scale, threshold, update steps used, etc.
    - to_return argument determines what the function will return. 
        - if 'pd', pmw() returns pandas df with test data for each 
        query in the workload(showing query, d_t_hat, updated, algo_ans, real_ans, 
        abs_error, rel_error). 
        - if 'update_count', pmw() returns the update count for the total
        amount of queries.
    - show_plot - T/F whether the function will display a plot
    - show_failure_step - T/F whether function prints what step failure mode is reached
    """ 
    
    update_steps = {}
    for analyst in list(set(analyst_labels)): 
        update_steps[analyst] = k # each analyst starts with k update steps
    
    # initialize constants
    m = x.size  # database len
    n = x.sum()
    eta = (math.log(m, np.e) ** (1 / 4)) / (math.sqrt(n))
    delta = 1 / (n * math.log(n, np.e))
    x_norm = x / np.sum(x)
    
    # initialize synthetic databases at time 0 (prior to any queries)
    x_t = np.ones(m) / m
    y_t = np.ones(m) / m

    # initialize tracker lists to construct pandas dataframe at the end 
    x_list = [x_t] # create a list of x_t synthetic database at every time step
    update_list = []
    update_count = 0
    pmw_answers = []
    update_times = [] # record times that database is updated
    d_t_hat_list = []
    
    def lazy_round():
        """
        "Lazy Round" of querying using the stored synthetic database, x_t, in list x_list.
        
        We call this the lazy round because it is contrasted with the updated step where we update the 
        sythetic database and answer the query using the real database.
        """
        update_list.append('no')
        answer = np.dot(query, x_list[time])
        if answer < 0:
            pmw_answers.append(0)
        else: 
            pmw_answers.append(answer)
        x_list.append(x_list[time].round(3))
    
    # inititate first instance of SVT with half the budget and k updates; will be reset in the main loop
    SVTtrigger = False 
    SVTepsilon1 = ((eps/2)/2)
    SVTepsilon2 = ((eps/2)/2)
    rho = np.random.laplace(loc=0, scale=(1/SVTepsilon1), size=1)[0]
    
    for time, query in enumerate(workload):
        
        analyst = analyst_labels[time]
        
        # Do one round of sparse vector technique; compute noisy answer by adding Laplacian noise
        A_t = np.random.laplace(loc=0, scale=(k/SVTepsilon2), size=1)[0]
        a_t_hat = (np.dot(query, x_norm)*n ) + A_t
        d_t_hat = a_t_hat - (n*np.dot(query, x_list[time]))
        
        # LAZY ROUND: QUERY USING THE SYNTHETIC DATABASE
        if (abs(d_t_hat) <= T + rho):
            d_t_hat_list.append(d_t_hat)
            lazy_round()

        # UPDATE ROUND: UPDATE SYNTHETIC DATABASE AND RETURN NOISY ANSWER, A_T-HAT
        else:
            # noise
            A_t = np.random.laplace(loc=0, scale=(2*k/eps), size=1)[0]
            
            # noisy answer
            a_t_hat = (np.dot(query, x_norm)*n ) + A_t
            d_t_hat = a_t_hat - (n*np.dot(query, x_list[time]))
            d_t_hat_list.append(d_t_hat)
            update_times.append(time)
            
            # step a
            if d_t_hat < 0:
                r_t = query
            else:
                r_t = np.ones(m) - query
            for i in range(len(y_t)):
                y_t[i] = x_list[time][i] * math.exp((d_t_hat/(2*n)) * query[i]) * 20 # 20 is the learning rate
            
            # step b
            x_t = y_t / np.sum(y_t)
            update_count = update_list.count('yes')
            
            # if threshold for num updates is reached, just do a lazy round (synthetic database) answer
            if update_steps[analyst] == 0: 
                if show_failure_step:
                    print(f'Failure mode reached at query number {time}: {query}')
                lazy_round()
                
            # if there are still update steps that the analyst can use, 
            # 1. update the synthetic database
            # 2. answer the query using the noisy answer from the database itself 
            else: 
                x_list.append(x_t.round(3))
                update_list.append('yes') # increment number of updates counter
                answer = a_t_hat / np.sum(x)
                
                if answer < 0:
                    pmw_answers.append(0)
                else: 
                    pmw_answers.append(answer)
                
                update_steps[analyst] -= 1 # use one of analyst's update steps

    update_count = update_list.count('yes')      

    # calculate error
    real_ans = np.matmul(workload, x_norm)
    abs_error = np.abs(pmw_answers - real_ans)
    rel_error = np.abs(abs_error / np.where(real_ans == 0, 0.000001,
                                                real_ans))
    
    if show_messages:
        np.set_printoptions(suppress=True)
        """Print inputes/outputs to analyze each query"""
        print(f'Original database: {x}\n')
        print(f'Normalized database: {x_norm}\n')
        print(f'Updated Database = {x_t}\n')
        print(f'Update Count = {update_count}\n')
        print(f'{T=}\n')
        print(f'Error Scale Query Answer= {2*((2*k/eps)**2)}\n')
        print(f'Error Scale SVT= {2*((2*k/SVTepsilon2)**2)}\n')
        print(f'Update Parameter Scale = {eta}\n')
        print(f'{delta=}\n')
        
    if show_plot: 
        plt.title('Error across queries:')
        rel_line, = plt.plot(rel_error, label='Relative Error')
        abs_line, = plt.plot(abs_error, label='Absolute Error')
        for xc in update_times:
            plt.axvline(x=xc, color='red', label='Update Times', linestyle='dashed')
        plt.legend(handles=[abs_line,rel_line])
        plt.xticks(range(0, len(workload), round(len(workload)/5)))
    
    if to_return == "pd":
        # hacky fix: remove the first synthetic database to keep length of lists consistent with the
        # other lists that comprise of the pandas dataframe
        x_list.pop(0).tolist() 
        d = {
            'algo_ans': pmw_answers,
            'real_ans': real_ans.tolist(),
            'queries': workload.tolist(), 
            'updated': update_list,
            'abs_error': abs_error,               
            'rel_error': rel_error,
            'synthetic database': x_list,
            'analyst': analyst_labels,
            'd_t_hat': d_t_hat_list, 

             }
        test_data = pd.DataFrame(data=d)
        test_data = test_data.round(3)
        return test_data
    
    if to_return == "error":
        d = {'analyst': analyst_labels,
             'abs_error': abs_error,               
             'rel_error': rel_error,}
        data = pd.DataFrame(data=d)
        data = data.round(3)
        
        analyst_error = {}
        for analyst in list(sorted(analyst_labels)):
            analyst_error[analyst] = data[data.analyst==analyst]['abs_error'].sum()
        return analyst_error
    
x_example = np.array([1000, 2000, 3000, 4000, 5000])

pmw2(np.vstack((online_workloads.identity(5), online_workloads.identity(5))), 
     x_example, ['A'] * 5 + ['B'] * 5, eps=1, T=40, k = 5, to_return='error_vec')


In [213]:
# pmw_independent: write pmw for one person. 
# create wrapper function called pmw_independent() that takes in the workloads and workload labels. Run PMW for each analyst, separate their workloads based on analysts. 

def pmw_independent(w, input_x, analyst_labels, input_T, input_eps=0.01, input_k=5):
    """
    Wrapper function that calls pmw2() to simulate PMW for each independent person. 
    
    Takes a stream of workloads and analyst labels and separates them into distinct workloads for each analyst. Runs
    pmw2() on that particular workload for each analyst. Returns a dictionary o 
    """
    indices = {} # k: analyst, v: row indices of queries in the workloads
    for i, analyst in enumerate(analyst_labels):
        if analyst not in indices.keys(): 
            indices[analyst] = []
        indices[analyst].append(i)

    workloads = {} # k: analyst, v: the analyst's workload
    for analyst in indices.keys():
        workloads[analyst] = w[indices[analyst], :]
    print(workloads)

    all_analyst_error_dic = {}
    
    for analyst in workloads.keys():
        single_analyst_error = pmw2(workload=workloads[analyst], 
                                    x=input_x, 
                                    T=input_T, 
                                    k = input_k,
                                    analyst_labels=[analyst]*len(workloads[analyst]), 
                                    to_return="error",
                                    show_messages=False)
        all_analyst_error_dic.update(single_analyst_error)
    return all_analyst_error_dic
             
        
pmw_independent(np.vstack((online_workloads.identity(5), 
                           online_workloads.identity(5))), 
                input_x=x_example, 
                input_T=40, 
                input_eps=1, 
                analyst_labels=['A'] * 2 + ['B'] * 6 + ['A'] * 2, 
                )
    

{'A': array([[1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1.]]), 'B': array([[0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1.],
       [1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.]])}


{'A': 0.344, 'B': 0.371}

In [221]:
# pmw_naive
# write version of PMW where analysts can run out of privacy budget if they use too much of others' budgets
# this version everyone shares the same number of update steps 
def pmw_naive(workload, x, analyst_labels, T, eps=0.01, total_k=None, 
         show_messages=False, to_return='pd', show_plot=False, show_failure_step=False):
    """
    Implement Private Multiplicative Weights Mechanism (PMW) on a workload of
    linear queries where analysts can run out of privacy budget if they use too much of others'. 
    
    In other words, 

    Algorithm Parameters: 
    - workload = workload of queries (M x k numpy array)
    - x = true database (M x 1 numpy array)
    - T = update threshold
    - eps = privacy budget
    - total_k = total number of update steps alloted for the entire group
    - analyst_labels = list of analyst names corresponding to each query in the workload
    
    Output Controls: 
    - show_messages argument determines whether the function will print information such as 
    error scale, threshold, update steps used, etc.
    - to_return argument determines what the function will return. 
        - if 'pd', pmw() returns pandas df with test data for each 
        query in the workload(showing query, d_t_hat, updated, algo_ans, real_ans, 
        abs_error, rel_error). 
        - if 'update_count', pmw() returns the update count for the total
        amount of queries.
    - show_plot - T/F whether the function will display a plot
    - show_failure_step - T/F whether function prints what step failure mode is reached
    """ 
    
    # initialize constants
    m = x.size  # database len
    n = x.sum()
    eta = (math.log(m, np.e) ** (1 / 4)) / (math.sqrt(n))
    delta = 1 / (n * math.log(n, np.e))
    x_norm = x / np.sum(x)
    
    # initialize synthetic databases at time 0 (prior to any queries)
    x_t = np.ones(m) / m
    y_t = np.ones(m) / m

    # initialize tracker lists to construct pandas dataframe at the end 
    x_list = [x_t] # create a list of x_t synthetic database at every time step
    update_list = []
    update_count = 0
    pmw_answers = []
    update_times = [] # record times that database is updated
    d_t_hat_list = []
    
    # initialize total_k, the total number of update steps if not default
    if total_k == None:
        total_k = round(n * math.log(math.sqrt(m)))
        print(f'{total_k=}')
    
    def lazy_round():
        """
        "Lazy Round" of querying using the stored synthetic database, x_t, in list x_list.
        
        We call this the lazy round because it is contrasted with the updated step where we update the 
        sythetic database and answer the query using the real database.
        """
        update_list.append('no')
        answer = np.dot(query, x_list[time])
        if answer < 0:
            pmw_answers.append(0)
        else: 
            pmw_answers.append(answer)
        x_list.append(x_list[time].round(3))
    
    # inititate first instance of SVT with half the budget and k updates; will be reset in the main loop
    SVTtrigger = False 
    SVTepsilon1 = ((eps/2)/2)
    SVTepsilon2 = ((eps/2)/2)
    rho = np.random.laplace(loc=0, scale=(1/SVTepsilon1), size=1)[0]
    
    for time, query in enumerate(workload):
        
        analyst = analyst_labels[time]
        
        # Do one round of sparse vector technique; compute noisy answer by adding Laplacian noise
        A_t = np.random.laplace(loc=0, scale=(total_k/SVTepsilon2), size=1)[0]
        a_t_hat = (np.dot(query, x_norm)*n ) + A_t
        d_t_hat = a_t_hat - (n*np.dot(query, x_list[time]))
        
        # LAZY ROUND: QUERY USING THE SYNTHETIC DATABASE
        if (abs(d_t_hat) <= T + rho):
            d_t_hat_list.append(d_t_hat)
            lazy_round()

        # UPDATE ROUND: UPDATE SYNTHETIC DATABASE AND RETURN NOISY ANSWER, A_T-HAT
        else:
            # noise
            A_t = np.random.laplace(loc=0, scale=(2*total_k/eps), size=1)[0]
            
            # noisy answer
            a_t_hat = (np.dot(query, x_norm)*n ) + A_t
            d_t_hat = a_t_hat - (n*np.dot(query, x_list[time]))
            d_t_hat_list.append(d_t_hat)
            update_times.append(time)
            
            # step a
            if d_t_hat < 0:
                r_t = query
            else:
                r_t = np.ones(m) - query
            for i in range(len(y_t)):
                y_t[i] = x_list[time][i] * math.exp((d_t_hat/(2*n)) * query[i]) * 20 # 20 is the learning rate
            
            # step b
            x_t = y_t / np.sum(y_t)
            update_count = update_list.count('yes')
            
            # if threshold for num updates is reached, just do a lazy round (synthetic database) answer
            if total_k == 0: 
                if show_failure_step:
                    print(f'Failure mode reached at query number {time}: {query}')
                lazy_round()
                
            # if there are still update steps that the analyst can use, 
            # 1. update the synthetic database
            # 2. answer the query using the noisy answer from the database itself 
            else: 
                x_list.append(x_t.round(3))
                update_list.append('yes') # increment number of updates counter
                answer = a_t_hat / np.sum(x)
                
                if answer < 0:
                    pmw_answers.append(0)
                else: 
                    pmw_answers.append(answer)
                
                total_k -= 1 # use one of the total update steps

    update_count = update_list.count('yes')      

    # calculate error
    real_ans = np.matmul(workload, x_norm)
    abs_error = np.abs(pmw_answers - real_ans)
    rel_error = np.abs(abs_error / np.where(real_ans == 0, 0.000001,
                                                real_ans))
    
    if show_messages:
        np.set_printoptions(suppress=True)
        """Print inputes/outputs to analyze each query"""
        print(f'Original database: {x}\n')
        print(f'Normalized database: {x_norm}\n')
        print(f'Updated Database = {x_t}\n')
        print(f'Update Count = {update_count}\n')
        print(f'{T=}\n')
        print(f'Error Scale Query Answer= {2*((2*k/eps)**2)}\n')
        print(f'Error Scale SVT= {2*((2*k/SVTepsilon2)**2)}\n')
        print(f'Update Parameter Scale = {eta}\n')
        print(f'{delta=}\n')
        
    if show_plot: 
        plt.title('Error across queries:')
        rel_line, = plt.plot(rel_error, label='Relative Error')
        abs_line, = plt.plot(abs_error, label='Absolute Error')
        for xc in update_times:
            plt.axvline(x=xc, color='red', label='Update Times', linestyle='dashed')
        plt.legend(handles=[abs_line,rel_line])
        plt.xticks(range(0, len(workload), round(len(workload)/5)))
    
    if to_return == "pd":
        # hacky fix: remove the first synthetic database to keep length of lists consistent with the
        # other lists that comprise of the pandas dataframe
        x_list.pop(0).tolist() 
        d = {
            'algo_ans': pmw_answers,
            'real_ans': real_ans.tolist(),
            'queries': workload.tolist(), 
            'updated': update_list,
            'abs_error': abs_error,               
            'rel_error': rel_error,
            'synthetic database': x_list,
            'analyst': analyst_labels,
            'd_t_hat': d_t_hat_list, 

             }
        test_data = pd.DataFrame(data=d)
        test_data = test_data.round(3)
        return test_data
    
    if to_return == "error":
        d = {'analyst': analyst_labels,
             'abs_error': abs_error,               
             'rel_error': rel_error,}
        data = pd.DataFrame(data=d)
        data = data.round(3)
        
        analyst_error = {}
        for analyst in list(sorted(analyst_labels)):
            analyst_error[analyst] = data[data.analyst==analyst]['abs_error'].sum()
        return analyst_error
    
x_example = np.array([1000, 2000, 3000, 4000, 5000])

pmw_naive(np.vstack((online_workloads.identity(5), online_workloads.identity(5))), 
     x_example, ['A'] * 5 + ['B'] * 5, eps=1, T=40, total_k = 5)


total_k=12071


Unnamed: 0,algo_ans,real_ans,queries,updated,abs_error,rel_error,synthetic database,analyst,d_t_hat
0,0.0,0.067,"[1.0, 0.0, 0.0, 0.0, 0.0]",yes,0.067,1.0,"[0.082, 0.229, 0.229, 0.229, 0.229]",A,-30789.864
1,0.162,0.133,"[0.0, 1.0, 0.0, 0.0, 0.0]",yes,0.029,0.217,"[0.083, 0.224, 0.231, 0.231, 0.231]",A,-1000.901
2,0.0,0.2,"[0.0, 0.0, 1.0, 0.0, 0.0]",yes,0.2,1.0,"[0.098, 0.265, 0.09, 0.273, 0.273]",A,-33444.208
3,4.911,0.267,"[0.0, 0.0, 0.0, 1.0, 0.0]",yes,4.644,17.416,"[0.028, 0.076, 0.026, 0.793, 0.078]",A,69569.955
4,0.78,0.333,"[0.0, 0.0, 0.0, 0.0, 1.0]",yes,0.447,1.34,"[0.027, 0.074, 0.025, 0.767, 0.107]",A,10528.852
5,0.057,0.067,"[1.0, 0.0, 0.0, 0.0, 0.0]",yes,0.01,0.152,"[0.027, 0.074, 0.025, 0.767, 0.107]",B,442.58
6,0.735,0.133,"[0.0, 1.0, 0.0, 0.0, 0.0]",yes,0.601,4.51,"[0.026, 0.1, 0.024, 0.745, 0.104]",B,9910.453
7,0.697,0.2,"[0.0, 0.0, 1.0, 0.0, 0.0]",yes,0.497,2.485,"[0.026, 0.099, 0.033, 0.739, 0.103]",B,10095.999
8,0.109,0.267,"[0.0, 0.0, 0.0, 1.0, 0.0]",yes,0.158,0.592,"[0.032, 0.124, 0.041, 0.674, 0.129]",B,-9451.76
9,0.0,0.333,"[0.0, 0.0, 0.0, 0.0, 1.0]",yes,0.333,1.0,"[0.034, 0.131, 0.043, 0.712, 0.08]",B,-16075.884


PMW independent

Initialize an instance of PMW for each analyst with alpha =T and their share of the privacy budget
Analysts don’t share anything - PB or Synthetic database
Each analyst is only allowed to query their instance of PMW

Naive PMW

Initialize an instance of PMW with alpha = T and the whole privacy budget 
All analysts sharing everything - privacy budget and synthetic database
Allow every analyst to query that instance of PMW

Split PMW

Initialize a single instance of PMW for each analyst with alpha = T and the entire privacy budget
Split the update steps proportionally to each analyst based on their share of the privacy budget
There exists cases where some analysts have more privacy budgets than others - i.e. alice owns 50 percent of the data
The difference between split and PMW - is that in Split, everyone shares a synthetic database
Inference steps are infamously non-monotonic
Allow any analyst to answer from the PMW instance and only allow them to cause an update step if they own any unused update steps
