TODO:
- Rewrite overbudgeted entities code to use vectorized code (replace rows with 0s)
- Figure out how to read/write to databases effectively?
    - Make it so the ledger remembers what it got from the DB, and what it got locally (make a diff object, and make it so that it can read that effectively)

Later:
- Give each DataSubject a unique integer ID
- Custom deltas? (please no)

In [101]:
import numpy as np
from scipy.optimize import minimize_scalar

class DataSubjectLedger:
    """for a particular data subject, this is the list
    of all mechanisms releasing informationo about this
    particular subject, stored in a vectorized form"""
    
    def __init__(self, default_cache_size=1e3):
        
        self.delta = 1e-6  # WARNING: CHANGING DELTA INVALIDATES THE CACHE
        self.reset()
        self.cache_constant2epsilon = list()
        self.increase_max_cache(int(default_cache_size))        

    def reset(self):
        self.sigmas = np.array([])
        self.l2_norms = np.array([])
        self.l2_norm_bounds = np.array([])
        self.Ls = np.array([])
        self.coeffs = np.array([])
        self.entity_ids = np.array([])
        self.entity2budget = np.array([])
        
    def batch_append(self, 
                     sigmas: np.ndarray, 
                     l2_norms: np.ndarray, 
                     l2_norm_bounds: np.ndarray, 
                     Ls: np.ndarray, 
                     coeffs: np.ndarray, 
                     entity_ids: np.ndarray):
        
        self.sigmas = np.concatenate([self.sigmas, sigmas])
        self.l2_norms = np.concatenate([self.l2_norms, l2_norms])        
        self.l2_norm_bounds = np.concatenate([self.l2_norm_bounds, l2_norm_bounds])        
        self.Ls = np.concatenate([self.Ls, Ls])        
        self.coeffs = np.concatenate([self.coeffs, coeffs])               
        self.entity_ids = np.concatenate([self.entity_ids, entity_ids])
        
    def increase_max_cache(self, new_size):
        new_entries = []
        current_size = len(self.cache_constant2epsilon)
        for i in range(new_size - current_size):
            alpha, eps = self.get_optimal_alpha_for_constant(i+1 + current_size)
            new_entries.append(eps)
        self.cache_constant2epsilon = np.concatenate([self.cache_constant2epsilon, np.array(new_entries)])
        # print(self.cache_constant2epsilon)
        
    def get_fake_rdp_func(self, constant):
        
        def func(alpha):
            return alpha * constant
        
        return func

    def get_alpha_search_function(self, rdp_compose_func):
            
        # if len(self.deltas) > 0:
            # delta = np.max(self.deltas)
        # else:
        log_delta = np.log(self.delta)
        
        def fun(alpha):  # the input is the RDP's \alpha
            
            if alpha <= 1:
                return np.inf
            else:
                alpha_minus_1 = alpha-1
                return np.maximum(rdp_compose_func(alpha) + np.log(alpha_minus_1/alpha)
                                  - (log_delta + np.log(alpha))/alpha_minus_1, 0)
        return fun    
    
    def get_optimal_alpha_for_constant(self, constant=3):
        
        f = self.get_fake_rdp_func(constant)
        f2 = self.get_alpha_search_function(rdp_compose_func=f)
        results = minimize_scalar(f2, method='Brent', bracket=(1,2), bounds=[1, np.inf])
        
        return results.x, results.fun

        
    def get_batch_rdp_constants(self, entity_ids_query, private=True):
        
        # get indices for all ledger rows corresponding to any of the entities in entity_ids_query
        indices_batch = np.where(np.in1d(self.entity_ids, entity_ids_query))[0]
        
        # use the indices to get a "batch" of the full ledger. this is the only part
        # of the ledger we care about (the entries corresponding to specific entities)
        batch_sigmas = self.sigmas.take(indices_batch)
        batch_Ls = self.Ls.take(indices_batch)
        batch_l2_norms = self.l2_norms.take(indices_batch)
        batch_l2_norm_bounds = self.l2_norm_bounds.take(indices_batch)
        batch_coeffs = self.coeffs.take(indices_batch)
        batch_entity_ids = self.entity_ids.take(indices_batch).astype(np.int64)
        
        squared_Ls = batch_Ls**2
        squared_sigma = batch_sigmas**2
        
        if private:
            squared_L2_norms = batch_l2_norms**2
            constant = (squared_Ls * squared_L2_norms / (2 * squared_sigma)) * batch_coeffs
            constant = np.bincount(batch_entity_ids, weights=constant).take(entity_ids_query)
            return constant
        else:
            squared_L2_norm_bounds = batch_l2_norm_bounds**2
            constant = (squared_Ls * squared_L2_norm_bounds / (2 * squared_sigma)) * batch_coeffs
            constant = np.bincount(batch_entity_ids, weights=constant).take(entity_ids_query)
            return constant
        
    def get_epsilon_spend(self, entity_ids_query):
        rdp_constants = self.get_batch_rdp_constants(entity_ids_query=entity_ids_query).astype(np.int64)
        rdp_constants_lookup = rdp_constants - 1
        try:
            eps_spend = self.cache_constant2epsilon.take(rdp_constants_lookup)
        except IndexError:
            self.increase_max_cache(int(max(rdp_constants_lookup) * 1.1))
            eps_spend = self.cache_constant2epsilon.take(rdp_constants_lookup)
        return eps_spend

In [102]:
ledger = DataSubjectLedger()

In [103]:
ledger.reset()
n = int(1e5)

In [104]:
ledger.batch_append(sigmas=np.ones(n),
                    l2_norms=np.ones(n)*10,
                    l2_norm_bounds=np.ones(n)*40,
                    Ls=np.ones(n)*5,
                    coeffs=np.ones(n),
                    entity_ids=np.arange(n))

In [105]:
query = np.arange(n)

In [106]:
%%time
eps = ledger.get_epsilon_spend(entity_ids_query=query)

CPU times: user 126 ms, sys: 28 µs, total: 126 ms
Wall time: 124 ms


In [107]:
eps = ledger.get_epsilon_spend(entity_ids_query=query)

In [108]:
ledger.cache_constant2epsilon[-100:]

array([1536.02310579, 1537.12684825, 1538.23055018, 1539.33421162,
       1540.43783261, 1541.54141321, 1542.64495347, 1543.74845342,
       1544.85191313, 1545.95533262, 1547.05871196, 1548.16205119,
       1549.26535035, 1550.36860949, 1551.47182866, 1552.5750079 ,
       1553.67814727, 1554.7812468 , 1555.88430654, 1556.98732655,
       1558.09030686, 1559.19324752, 1560.29614858, 1561.39901007,
       1562.50183206, 1563.60461459, 1564.70735769, 1565.81006141,
       1566.91272581, 1568.01535092, 1569.11793679, 1570.22048347,
       1571.322991  , 1572.42545942, 1573.52788878, 1574.63027913,
       1575.7326305 , 1576.83494295, 1577.93721652, 1579.03945125,
       1580.14164719, 1581.24380438, 1582.34592286, 1583.44800268,
       1584.55004389, 1585.65204652, 1586.75401062, 1587.85593624,
       1588.95782341, 1590.05967219, 1591.16148261, 1592.26325472,
       1593.36498856, 1594.46668417, 1595.56834161, 1596.6699609 ,
       1597.7715421 , 1598.87308524, 1599.97459038, 1601.07605

In [82]:
len(ledger.cache_constant2epsilon)

1373

In [31]:
eps.any()

False

In [None]:
len(ledger.sigmas)/1e8

In [None]:
%%time
eps = ledger.get_epsilon_spend(entity_ids_query=query)

In [None]:
eps[0:20]