# Reqs

SPEED:
- automatic testing of fast serde of 10M tensor
- automatic testing of fast .private() creation
- direct converstion between primes created into numpy array and use here.
- automatic testing of fast converstion between Phi and Gamma
- testing of compatibility with jax?

CORRECTNESS:
- automatic testing that polynomial evaluation is correct
- automatic testing that derivative calculation is correct (or perhaps figure out a way to use jax's jacobian function?)
- automatic testing that DP guarantees seem to work (min/max extremes and inner samples)

In [1]:
%load_ext jupyterflame

In [2]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from numpy.typing import ArrayLike
from typing import List
from typing import Union
import numpy as np
from primesieve import primes
import syft as sy
import time
import pyarrow

In [27]:
class PrimeFactory:
    """IMPORTANT: it's very important that two tensors be able to tell that
    they are indeed referencing the EXACT same PrimeFactory. At present this is done
    by ensuring that it is literally the same python object. In the future, we will
    probaby need to formalize this. However, the main way this could go wrong is if we
    created some alternate way for checking to see if two prime factories 'sortof looked
    the same' but which in fact weren't the EXACT same object. This could lead to
    security leaks wherein two tensors think two different symbols in fact are the
    same symbol."""

    def __init__(self) -> None:
        self.exp = 2
        self._prime_number_cache: np.ndarray = primes(10**self.exp)

    def __getitem__(self, indices):
        
        if isinstance(indices, slice):
            max_index = indices.stop
            
        elif isinstance(indices, int):
            max_index = indices

        while max_index > len(self._prime_number_cache) - 1:
            self.exp += 1
            self._prime_number_cache = primes(10**self.exp)

        return np.array(self._prime_number_cache.__getitem__(indices))
    
pf = PrimeFactory()
sy.pf = pf

In [28]:
# from multiprocessing import Pool

# out = 5

# def f(x):
#     return x*x + out

# if __name__ == '__main__':
#     with Pool(5) as p:
#         print(p.map(f, [1, 2, 3]))

In [29]:
# import multiprocessing as mp

In [30]:
# x = {"3":3,"4":4}

In [31]:
# mp.Value('lookup', x)

In [32]:
class DataSubjectRegistry():
    """This registry needs to run on the Domain server with the accountant. It is assumed that there
    is only one globally maintained list of entities, mapped to primes."""
    
    def __init__(self, prime_factory=None):
        
        if prime_factory is None:
            prime_factory = sy.pf
        
        # tries to make it as efficient as possible to fetch arrays of primes
        self.prime_factory = prime_factory
        
        # maps indices in the prime_factory to the data subject strings uniquely identifying them
        self.prime_index2data_subjects = list()
        self.data_subjects_set = set()
        
    def reset(self):
        self.prime_index2data_subjects = list()
        self.data_subjects_set = set()
        
    def __getitem__(self, data_subject_strings):

        if isinstance(data_subject_strings, str):
            data_subject_strings = [data_subject_strings]
        
        if isinstance(data_subject_strings, list):
            if len(self.prime_index2data_subjects) == 0:
                unique_data_subject_strings = set(data_subject_strings)
                if len(unique_data_subject_strings) == len(data_subject_strings):
                    self.prime_index2data_subjects = data_subject_strings
                    self.data_subjects_set = unique_data_subject_strings
                    return self.prime_factory[0:len(data_subject_strings)]
                else:
                    self.prime_index2data_subjects = list(unique_data_subject_strings)
                    self.data_subjects_set = unique_data_subject_strings
                    
                    relevant_primes = self.prime_factory[0:len(data_subject_strings)]
                    
                    
            
            return self.prime_factory[starting_len:starting_len+len(data_subject_strings)]
        
        raise Exception("don't know how to handle this input type")

In [196]:
class lazyrepeatarray:
    """when data is repeated along one or more dimensions, store it using this lazyrepeatarray so that 
    you can save on RAM and CPU when computing with it."""
    
    def __init__(self, data, shape):
        """
        data: the raw data values without repeats
        shape: the shape of 'data' if repeats were included
        """
        
        
        # NOTE: all additional arguments are assumed to be broadcast if dims are shorter
        # than that of data. Example: if data.shape == (2,3,4) and min_vals.shape == (2,3), 
        # then it's assumed that the full min_vals.shape is actually (2,3,4) where the last
        # dim is simply copied. Example2: if data.shape == (2,3,4) and min_vals.shape == (2,1,4),
        # then the middle dimension is supposed to be copied to be min_vals.shape == (2,3,4) if
        # necessary. This is just to keep the memory footprint (and computation) as small as
        # possible.        
        
        
        if isinstance(data, (bool, int, float)):
            data = np.array(data)
        
        self.data = data
        self.shape = shape
    
    def __sub__(self, other):
        """
        THIS MIGHT LOOK LIKE COPY-PASTED CODE!
        Don't touch it. It's going to get more complicated.
        """        
        if self.shape != other.shape:
            raise Exception("cannot subtract tensors with different shapes")
        
        if self.data.shape == other.data.shape:
            return lazyrepeatarray(data = self.data - other.data, shape=self.shape)
        
        raise Exception("not sure how to do this yet")
    
    def __mul__(self, other):
        """
        THIS MIGHT LOOK LIKE COPY-PASTED CODE!
        Don't touch it. It's going to get more complicated.
        """
        if self.shape != other.shape:
            raise Exception("cannot subtract tensors with different shapes")
        
        if self.data.shape == other.data.shape:
            return lazyrepeatarray(data = self.data * other.data, shape=self.shape)
        
        raise Exception("not sure how to do this yet")
    
    def __pow__(self, exponent):
        if exponent == 2:
            return self * self
        raise Exception("not sure how to do this yet")
    
    def simple_assets_for_serde(self):
        return [self.data, self.shape]
    
    @staticmethod
    def deserialize_from_simple_assets(assets):
        return lazyrepeatarray(data=assets[0], 
                              shape=assets[1])
    
    @property
    def size(self):
        return np.prod(self.shape)
        
    def sum(self, axis=None):
        if axis is None:
            if self.data.size == 1:
                return np.array(self.data * self.size).flatten()
            else:
                raise Exception("not sure how to do this yet")
        else:
            raise Exception("not sure how to do this yet")

x = lazyrepeatarray(2, shape=(1000000,4,4))

In [197]:
class array:
    
    def __init__(self, data: ArrayLike):
        
        if isinstance(data, list):
            data = np.array(data)
        
        self.child = data
    
    @property
    def shape(self):
        return self.child.shape
    
    def private(self, 
                min_vals: ArrayLike,
                max_vals: ArrayLike,
                data_subjects: Union[str,List[str]]):
        
        ####################################################################
        # Step 1: Ensure data_subjects are ready to used to create phi array
        ####################################################################
        
        # if it's a string - convert to list of length 1
        if isinstance(data_subjects, str):
            data_subjects = [data_subjects]

        # for every data subject, get its affiliated prime
        data_subjects = dsr[data_subjects]
        
        ######################################################################
        # Step 2: Store min_vals and max_vals as lazyarray objects if possible
        ######################################################################
        
        if isinstance(min_vals, (bool, int, float)):
            min_vals = np.array(min_vals)
            
        if isinstance(max_vals, (bool, int, float)):
            max_vals = np.array(max_vals)            
        
        if min_vals.shape != self.child.shape:
            min_vals = lazyrepeatarray(min_vals, self.child.shape)
            
        if max_vals.shape != self.child.shape:
            max_vals = lazyrepeatarray(max_vals, self.child.shape)            
        
        # entire tensor refers to one entity
        if len(data_subjects) == 1:
            return pharray(data=self.child,
                           min_vals=min_vals,
                           max_vals=max_vals,
                           data_subject_prime=data_subjects)
        
        # each row corresponds to a unique entity
        elif len(data_subjects) == len(self.child):
            return rowarray(rows=self.child,
                            min_vals=min_vals,
                            max_vals=max_vals,
                            data_subjects=data_subjects,
                            row_type=pharray)
        else:
            raise Exception("not sure how to initialize")
        
sy.array = array

In [198]:
class pharray:
    
    def __init__(self, 
                 data: np.ndarray, 
                 min_vals: np.ndarray,
                 max_vals: np.ndarray,
                 data_subject_prime: np.ndarray,
                 ignore_minmax_checks=True):
        
        if isinstance(min_vals, (bool, int, float)):
            min_vals = np.array(min_vals)
            
        if isinstance(max_vals, (bool, int, float)):
            max_vals = np.array(max_vals)            
            
        if isinstance(data_subject_prime, (bool, int, float)):
            data_subject_prime = np.array(data_subject_prime)                        
        
        if not ignore_minmax_checks:
            assert (data >= min_vals).all()
            assert (data <= max_vals).all()
        
        self.data = data
        self.minv = min_vals
        self.maxv = max_vals
        self.sub = data_subject_prime
        
    @property
    def gamma(self):
        shape = self.data.shape
        size = self.data.size
        term = input2scalarprime.copy()
        return gmarray(input2value=self.data,
                       input2minval=self.minv,
                       input2maxval=self.maxv,
                       input2subjectprime=x.sub)        

In [199]:
class lazyprimearray:
    
    def __init__(self, start, stop, shape):
        """ an array made of primes which is lazily evaluated
        
        start: the prime index this array begins with
        end: the prime inddex this array ends with
        shape: the shape of the array
        """
        self.start = start
        self.stop = stop
        self.shape = shape
        self._data_cache = None
        
    def reshape(self, *new_shape):
        if np.prod(self.shape) == np.prod(new_shape):
            
            result = lazyprimearray(start=self.start,
                                    stop=self.stop,
                                    shape=new_shape)
            return result
        else:
            raise Exception("New shape not compatible")
    
    

In [200]:
        
class rowarray:
    
    def __init__(self,
                 rows,
                 min_vals,
                 max_vals,
                 data_subjects,
                 row_type=pharray):
        
        self.rows = rows
        self.minv = min_vals
        self.maxv = max_vals
        self.subs = data_subjects
        self.row_type = row_type
    
    def sum(self, axis=None):
        return self.gamma.sum(axis=axis)
    
    def serialize(self):
        assets = [self.rows, 
                  self.minv.simple_assets_for_serde(), 
                  self.maxv.simple_assets_for_serde(), 
                  self.subs, 
                  self.row_type]
        return pyarrow.serialize(assets).to_buffer()
    
    @staticmethod
    def deserialize(blob):
        assets = pyarrow.deserialize(blob)
        rows = assets[0]
        minv = lazyrepeatarray.deserialize_from_simple_assets(assets[1])
        maxv = lazyrepeatarray.deserialize_from_simple_assets(assets[2])
        subs = assets[3]
        row_type = assets[4]
        return rowarray(rows=rows,
                        min_vals=minv,
                        max_vals=maxv,
                        data_subjects=subs,
                        row_type=row_type)
    
    @property
    def shape(self):
        return self.rows.shape
    
    @property
    def gamma(self):
        if self.row_type == pharray:
            shape = self.rows.shape
            size = self.rows.size
            return gmarray(input2value=self.rows,
                           input2minval=self.minv,
                           input2maxval=self.maxv,
                           input2subjectprime=self.subs,
                           shape=self.shape,
                           is_linear=True)
        else:
            raise Exception("Sorry don't know how to convert this to gamma yet.")
            
            
class gmarray:
    
    def __init__(self, 
                 input2value,                 
                 input2minval,
                 input2maxval,                 
                 input2subjectprime,
                 shape,
                 is_linear,
                 input2scalarprime=None,          
                 input2scalarprime_id=None,
                 value_cache=None,
                 minval_cache=None,
                 maxval_cache=None,
                 term=None, 
                 coeff=None, 
                 bias=None):

        # REPLACING SCALAR MANAGER ARE THE FOLLOWING NDARRAY LOOKUP TABLES
        self.input2value = input2value
        self.input2minval = input2minval
        self.input2maxval = input2maxval 
        self.shape = shape
        
        if input2subjectprime.shape == self.input2value.shape:
            input2subjectprime = input2subjectprime.reshape(input2subjectprime.shape + [1])
        
        # if an integer, it's assumed to be elementwise
        self.input2subjectprime = input2subjectprime 
        
        # None == elementwise, unique primes for freshly created gammatensor, starting at 1
        self.input2scalarprime = input2scalarprime
        
        if input2scalarprime_id is None:
            # given no i2s id, ASSUME we're initializing this tensor for the first time!
            # which means all the caches are just copies of the data
            input2scalarprime_id = sy.core.common.UID()
            value_cache=input2value
            minval_cache=input2minval
            maxval_cache=input2maxval
            is_linear=True
            
        self.input2scalarprime_id = input2scalarprime_id
        
        self.value_cache = value_cache
        self.minval_cache = minval_cache
        self.maxval_cache = maxval_cache
        
        self.is_linear=is_linear
        
        # tensor of polynomial terms - primes representing variables
        # None == elementwise, unique primes for freshly created gammatensor, starting at 1
        self._term = term
        
        # a tensor of coefficients - the floats which multiply by variables in polys
        # None == np.ones_like(term)
        self._coeff = coeff 
        
        # a tensor of bias terms - scalars which are added to polys
        # None == np.zeros_like(term)
        self._bias = bias
    
    def serialize(self):
        assets = list()
        assets.append(self.input2value)
        assets.append(self.input2minval.simple_assets_for_serde())
        assets.append(self.input2maxval.simple_assets_for_serde())
        assets.append(self.shape)
        assets.append(self.input2subjectprime)
        assets.append(self.input2scalarprime)
#         assets.append(self.input2scalarprime_id)        
        assets.append(self.value_cache)        
        assets.append(self.minval_cache)        
        assets.append(self.maxval_cache)        
        assets.append(self._term)                
        assets.append(self._coeff)                
        assets.append(self._bias)    
        return pyarrow.serialize(assets).to_buffer()        
        
    @property
    def size(self):
        return np.prod(self.shape)
    
    @property
    def term(self):
        if self._term is None:
            self._term = lazyprimearray(start=0, stop=np.prod(self.shape), shape=list(self.shape)+[1])
        return self._term
    
    
    @property
    def coeff(self):
        if self._coeff is None:
            self._coeff = lazyrepeatarray(data=1, shape=self.shape)
        return self._coeff
    
    @property
    def bias(self):
        if self._bias is None:
            self._bias = lazyrepeatarray(data=0, shape=self.shape)
        return self._bias
        
    def sum(self, axis=None):
        if axis is None:
            return gmarray(input2value = self.input2value,
                            input2minval = self.input2minval,
                            input2maxval = self.input2maxval,
                            input2subjectprime = self.input2subjectprime,
                            shape=(),
                            is_linear=self.is_linear,
                            input2scalarprime = self.input2scalarprime,
                            input2scalarprime_id = self.input2scalarprime_id,
                            value_cache=self.value_cache.sum(),
                            minval_cache=self.minval_cache.sum(),
                            maxval_cache=self.maxval_cache.sum(),
                            term=self.term.reshape(1,self.size),
                            coeff = None if self._coeff is None else self.coeff.reshape(1, self.size),
                            bias = None if self._bias is None else self.bias.sum())

        else:
            raise Exception("Not sure how to run this yet")
    
    def deriv(self, inputs, input_mask=None):

        assert inputs.shape == self.input2value.shape
        
        # if someone doesn't pass in a mask we assume they
        # want to use all the inputs they're passing in
        if input_mask is None:
            input_mask = np.zeros_like(inputs)
        else:
            ""
            # if they do pass in a mask then 1s correspond
            # to data passed in and 0s to values from self.input2value
        
        assert inputs.shape == self.input2value.shape
        
        if self.is_linear:
            
            # TODO: lazyarray should know how to find the max coeff very
            # efficient instead of needing to hardcode this here
            if self._coeff is None:
                return np.ones(self.shape)
        
        raise Exception("Ooops... can't compute max deriv of this yet...")
            
    def max_deriv(self, inputs, input_mask=None):
        
        assert inputs.shape == self.input2value.shape
        
        # if someone doesn't pass in a mask we assume they
        # want to use all the inputs they're passing in
        if input_mask is None:
            input_mask = np.zeros_like(inputs)
        else:
            ""
            # if they do pass in a mask then 1s correspond
            # to data passed in and 0s to values from self.input2value
        
        if self.is_linear:
            
            # TODO: lazyarray should know how to find the max coeff very
            # efficient instead of needing to hardcode this here
            if self._coeff is None:
                return np.ones(self.shape)
        
        raise Exception("Ooops... can't compute max deriv of this yet...")
        
#     def max_deriv_wrt_entity(self, entity_prime):

In [201]:
# %%timeit
data_subjects = list(map(lambda x:str(x),range(0,5)))
data = sy.array(np.random.rand(5,1))

In [202]:
dsr = DataSubjectRegistry()
x = data.private(min_vals=0, 
                 max_vals=1, 
                 data_subjects=data_subjects)

In [203]:
out = x.sum()

In [204]:
out.deriv(data)

array(1.)

In [205]:
out.max_deriv(data)

array(1.)

In [206]:
from syft.core.adp.idp_gaussian_mechanism import iDPGaussianMechanism

In [211]:
sigma = 3
squared_l2_norm = np.sum(out.input2value**2)
squared_l2_norm_upper_bound = ((out.input2maxval - out.input2minval)**2).sum()
L = out.max_deriv(data)

In [213]:
# def max_deriv_wrt_entity(self, entity_prime):

In [None]:
entity_prime = 

In [210]:
m = iDPGaussianMechanism(sigma=sigma,
                         squared_l2_norm=squared_l2_norm,
                         squared_l2_norm_upper_bound=squared_l2_norm_upper_bound,
                         L=L,
                         entity_name="dunno")

In [26]:
import jax
import jax.numpy as jnp
def f(x):
    return jnp.sum(x**2) # identical to numpy syntax
grad_f = jax.grad(f) # compute the gradient function
x = jnp.array([0., 1., 2.]) # use JAX arrays!
print('x: ', x)
print('f(x): ', f(x))
print('grad_f(x):', grad_f(x))

x:  [0. 1. 2.]
f(x):  5.0
grad_f(x): [0. 2. 4.]
