In [198]:
import names
from syft.core.common import UID
from sympy import symbols
from scipy import optimize
import sympy as sym
import numpy as np
import random
from sympy.solvers import solve

from functools import lru_cache

# ordered_symbols = list()
# for i in range(100):
#     ordered_symbols.append(symbols("s"+str(i)))

@lru_cache(maxsize=None)
def maximize_flattened_poly(flattened_poly, *rranges, force_all_searches=False, **s2i):

    search_fun = create_searchable_function(flattened_poly, s2i)
    
    return minimize_function(f=search_fun, rranges=rranges, force_all_searches=False)

def flatten_and_maximize_poly(poly, force_all_searches=False):

    i2s = list(poly.free_symbols)
    s2i = {s:i for i,s in enumerate(i2s)}
    
    # this code seems to make things slower - although there might be a memory improvement (i haven't checked)
#     flattened_poly = poly.copy().subs({k:v for k,v in zip(i2s, ordered_symbols[0:len(i2s)])})
#     flattened_s2i = {str(ordered_symbols[i]):i for s,i in s2i.items()}

    flattened_poly = poly
    flattened_s2i = {str(s):i for s,i in s2i.items()}    

    rranges = [(ssid2obj[i2s[i].name].min_val, ssid2obj[i2s[i].name].max_val) for i in range(len(s2i))]

    return maximize_flattened_poly(flattened_poly, *rranges, force_all_searches=force_all_searches, **flattened_s2i)

def create_lookup_tables_for_symbol(polynomial):

    index2symbol = [str(x) for x in polynomial.free_symbols]
    symbol2index = {sym: i for i, sym in enumerate(index2symbol)}

    return index2symbol, symbol2index

def create_searchable_function(f, symbol2index):

        # Tudor: Here you weren't using *params
        # Tudor: If I understand correctly, .subs returns
        def _run_specific_args(tuple_of_args: tuple):
            kwargs = {sym: tuple_of_args[i] for sym, i in symbol2index.items()}
            output =  f.subs(kwargs)
            return output

        return _run_specific_args

def minimize_function(f, rranges, constraints=[], force_all_searches=False):
    
    results = list()
    
    # Step 1: try simplicial
    shgo_results = optimize.shgo(f, rranges, sampling_method='simplicial', constraints=constraints)
    results.append(shgo_results)
    
    if not shgo_results.success or force_all_searches:
        # sometimes simplicial has trouble as a result of initialization
        # see: https://github.com/scipy/scipy/issues/10429 for details
#         if not force_all_searches:
#             print("Simplicial search didn't work... trying sobol")
        shgo_results = optimize.shgo(f, rranges, sampling_method='sobol', constraints=constraints)
        results.append(shgo_results)
        
    if not shgo_results.success:
        raise Exception("Search algorithm wasn't solvable... abort")
        
    return results

def max_lipschitz_wrt_entity(scalars, entity):
    result = max_lipschitz_via_jacobian(scalars, input_entity=entity)[0][-1]
    if isinstance(result, float):
        return -result
    else:
        return -float(result.fun)

def max_lipschitz_via_jacobian(scalars, input_entity=None, data_dependent=True, force_all_searches=False, try_hessian_shortcut=False):
        
        polys = [x.poly for x in scalars]
        input_scalars = set()
        for s in scalars:
            for i_s in s.input_scalars:
                input_scalars.add(i_s)
        
        # the numberator of the partial derivative
        out = sym.Matrix([x.poly for x in scalars])
        

        if input_entity is None:
            j = out.jacobian([x.poly for x in input_scalars])
        else:
            
            # In general it doesn't make sense to consider the max partial derivative over all inputs because we dont' want the 
            # Lipschiptz bound of the entire jacobian, we want the lipschitz bound with respect to entity "i" (see https://arxiv.org/abs/2008.11193).
            # For example, if I had a polynomial y = a + b**2 + c**3 + d**4 where each a,b,c,d variable was from a different entity, 
            # the fact taht d has a big derivative should change the max lipschitz bound of y with respect to "a". Thus, we're only interested
            # in searching for the maximum partial derivative with respect to the variables from the focus entity "i".
            
            # And if we're looking to compute the max parital derivative with respect to input scalars from only one entity, then
            # we select only the variables corresponding to that entity here. 
            relevant_scalars = list(filter(lambda s:s.entity == input_entity, input_scalars))
            relevant_inputs = [x.poly for x in relevant_scalars]
            j = out.jacobian(relevant_inputs)
            
            # for higher order functions - it's possible that some of the partial derivatives are conditioned
            # on data from the input entity. The philosophy of input DP is that when producing an epsilon
            # guarantee for entity[i] that we don't need to search over the possible range of data for that entity
            # but can instead use the data itself - this results in an epsilon for each entity which is private
            # but it also means the bound is tighter. So we could skip this step but it would in some cases
            # make the bound looser than it needs to be.
            if data_dependent:
                j = j.subs({x.poly:x.value for x in relevant_scalars})

        neg_l2_j = -(np.sum(np.square(j)))**0.5
        
        if(len(np.sum(j).free_symbols) == 0):
            result = -float(np.max(j))
            return [result], neg_l2_j
        
        if(try_hessian_shortcut):
            h = j.jacobian([x.poly for x in input_scalars])
            if(len(solve(np.sum(h**2), *[x.poly for x in input_scalars], dict=True)) == 0):
                print("The gradient is linear - solve through brute force search over edges of domain")

                i2s,s2i = create_lookup_tables_for_symbol(neg_l2_j)
                search_fun = create_searchable_function(f=neg_l2_j, symbol2index=s2i)        

                constant = 0.000001
                rranges = [(x.min_val, x.max_val, x.max_val - x.min_val) for x in input_scalars]
                skewed_results = optimize.brute(search_fun, rranges, finish=None, full_output=False)
                result_inputs = skewed_results + constant
                result_output = search_fun(result_inputs)
                return [float(result_output)], neg_l2_j
        
        return flatten_and_maximize_poly(neg_l2_j), neg_l2_j    
    

def get_mechanism_for_entity(scalars, entity, sigma=1.5):
    
    m_id = "ms_"
    for s in scalars:
        m_id += str(s.id).split(" ")[1][:-1]+"_"
    
    return iDPGaussianMechanism(sigma=sigma,
                                value=np.sqrt(np.sum(np.square(np.array([float(s.value) for s in scalars])))),
                                L=float(max_lipschitz_wrt_entity(scalars, entity=entity)),
                                entity=entity.unique_name,
                                name=m_id)

def get_all_entity_mechanisms(scalars, sigma:float=1.5):
    entities = set()
    for s in scalars:
        for i_s in s.input_scalars:
            entities.add(i_s.entity)
    return {e:[get_mechanism_for_entity(scalars=scalars,entity=e,sigma=sigma)] for e in entities}



def publish(scalars, acc, sigma: float = 1.5) -> float:

    acc_original = acc

    acc_temp = deepcopy(acc_original)

    ms = get_all_entity_mechanisms(scalars=scalars, sigma=sigma)
    acc_temp.append(ms)

    overbudgeted_entities = acc_temp.overbudgeted_entities

    # so that we don't modify the original polynomial
    # it might be fine to do so but just playing it safe
    if len(overbudgeted_entities) > 0:
        scalars = deepcopy(scalars)

    while len(overbudgeted_entities) > 0:

        input_scalars = set()
        for s in scalars:
            for i_s in s.input_scalars:
                input_scalars.add(i_s)

        for symbol in input_scalars:
            if symbol.entity in overbudgeted_entities:
                self.poly = self.poly.subs(symbol.poly, 0)
                break

        acc_temp = deepcopy(acc_original)

        # get mechanisms for new publish event
        ms = self.get_all_entity_mechanisms(sigma=sigma)
        acc_temp.append(ms)

        overbudgeted_entities = acc_temp.overbudgeted_entities

    output = [s.value + random.gauss(0, sigma) for s in scalars]

    acc_original.entity2ledger = deepcopy(acc_temp.entity2ledger)

    return output


class Scalar():
        
    def publish(self, acc, sigma: float = 1.5) -> float:
        return publish([self], acc=acc, sigma=sigma)
    
    
    def __str__(self) -> str:
        return "<"+str(type(self).__name__) + ": (" + str(self.min_val)+" < "+str(self.value)+" < " + str(self.max_val) + ")>"

    def __repr__(self) -> str:
        return str(self)

class IntermediateScalar(Scalar):
    
    def __init__(self, poly, id=None):
        self.poly = poly
        self._gamma = None
        self.id = id if id else UID()        
    
    def __rmul__(self, other: "Scalar") -> "Scalar":
        return self * other

    def __radd__(self, other: "Scalar") -> "Scalar":
        return self + other
    
    @property
    def input_scalars(self):
        phi_scalars = list()
        for ssid in self.input_polys:
            phi_scalars.append(ssid2obj[str(ssid)])
        return phi_scalars
    
    @property
    def input_entities(self):
        return list(set([x.entity for x in self.input_scalars]))
    
    @property
    def input_polys(self):
        return self.poly.free_symbols
    
    @property
    def max_val(self):
        return -flatten_and_maximize_poly(-self.poly)[-1].fun
    @property
    def min_val(self):
        return flatten_and_maximize_poly(self.poly)[-1].fun

    @property
    def value(self):
        return self.poly.subs({obj.poly:obj.value for obj in self.input_scalars})
    
class IntermediatePhiScalar(IntermediateScalar):

    def __init__(self, poly, entity):
        super().__init__(poly=poly)
        self.entity = entity
    
    def max_lipschitz_wrt_entity(self, *args, **kwargs):
        return self.gamma.max_lipschitz_wrt_entity(*args, **kwargs)
    
    @property
    def max_lipschitz(self):
        return self.gamma.max_lipschitz
    
    def __mul__(self, other: "Scalar") -> "Scalar":

        if isinstance(other, IntermediateGammaScalar):          
            return self.gamma * other
        
        if not isinstance(other, IntermediatePhiScalar):
            return IntermediatePhiScalar(poly=self.poly * other, entity=self.entity)
            
        # if other is referencing the same individual
        if self.entity == other.entity:
            return IntermediatePhiScalar(poly=self.poly * other.poly, entity=self.entity)

        return self.gamma * other.gamma

    def __add__(self, other: "Scalar") -> "Scalar":

        if isinstance(other, IntermediateGammaScalar):
            return self.gamma + other
        
        # if other is a public value
        if not isinstance(other, Scalar):
            return IntermediatePhiScalar(poly=self.poly + other, entity=self.entity)
        
        # if other is referencing the same individual
        if self.entity == other.entity:
            return IntermediatePhiScalar(poly=self.poly + other.poly, entity=self.entity)
        
        return self.gamma + other.gamma
    
    
    def __sub__(self, other: "Scalar") -> "Scalar":

        if isinstance(other, IntermediateGammaScalar):
            return self.gamma - other
        
        # if other is a public value
        if not isinstance(other, IntermediatePhiScalar):
            return IntermediatePhiScalar(poly=self.poly - other, entity=self.entity)

        # if other is referencing the same individual
        if self.entity == other.entity:
            return IntermediatePhiScalar(poly=self.poly - other.poly, entity=self.entity)

        return self.gamma - other.gamma
    
    @property
    def gamma(self):
        
        if self._gamma is None:
            self._gamma = GammaScalar(min_val=self.min_val,
                               value=self.value,
                               max_val=self.max_val,
                               entity=self.entity)
        return self._gamma
    
class OriginScalar(Scalar):
    """A scalar which stores the root polynomial values. When this is a superclass of
    PhiScalar it represents data that was loaded in by a data owner. When this is a superclass
    of GammaScalar this represents the node at which point data from mulitple entities was combined."""
    
    def __init__(self, min_val, value, max_val, entity=None, id=None):

        self.id = id if id else UID()
        self._value = value
        self._min_val = min_val
        self._max_val = max_val
        self.entity = entity if entity is not None else Entity()
        
    @property
    def value(self):
        return self._value
        
    @property
    def max_val(self):
        return self._max_val
    
    @property
    def min_val(self):
        return self._min_val
    
class PhiScalar(OriginScalar, IntermediatePhiScalar):
    """A scalar over data from a single entity"""
    
    def __init__(self, min_val, value, max_val, entity=None, id=None, ssid=None):
        super().__init__(min_val=min_val, value=value, max_val=max_val, entity=entity,id=id)
        
        # the scalar string identifier (SSID) - because we're using polynomial libraries
        # we need to be able to reference this object in string form. the library doesn't
        # know how to process things that aren't strings
        if ssid is None:
            ssid = str(self.id).split(" ")[1][:-1]# + "_" + str(self.entity.id).split(" ")[1][:-1]
            
        self.ssid = ssid
        
        IntermediatePhiScalar.__init__(self, poly=symbols(self.ssid), entity=self.entity)
        
        ssid2obj[self.ssid] = self
    
    
class IntermediateGammaScalar(IntermediateScalar):
    """"""
    
    def __add__(self, other):
        if isinstance(other, Scalar):
            if isinstance(other, IntermediatePhiScalar):
                other = other.gamma
            return IntermediateGammaScalar(poly=self.poly + other.poly)
        return IntermediateGammaScalar(poly=self.poly + other)
    
    def __sub__(self, other):
        if isinstance(other, Scalar):
            if isinstance(other, IntermediatePhiScalar):
                other = other.gamma
            return IntermediateGammaScalar(poly=self.poly - other.poly)
        return IntermediateGammaScalar(poly=self.poly - other)    
    
    def __mul__(self, other):
        if isinstance(other, Scalar):
            if isinstance(other, IntermediatePhiScalar):
                other = other.gamma
            return IntermediateGammaScalar(poly=self.poly * other.poly)
        return IntermediateGammaScalar(poly=self.poly * other)    
    
    def max_lipschitz_via_explicit_search(self, force_all_searches=False):

        r1 = np.array([x.poly for x in self.input_scalars])

        r2_diffs = np.array([GammaScalar(x.min_val,x.value,x.max_val, entity=x.entity).poly for x in self.input_scalars])
        r2 = r1 + r2_diffs

        fr1 = self.poly
        fr2 = self.poly.copy().subs({x[0]:x[1] for x in list(zip(r1, r2))})

        left = np.sum(np.square(fr1 - fr2)) ** 0.5
        right = np.sum(np.square(r1 - r2)) ** 0.5

        C = -left / right

        i2s, s2i = create_lookup_tables_for_symbol(C)
        search_fun = create_searchable_function(C, s2i)

        r1r2diff_zip = list(zip(r1, r2_diffs))

        s2range = {}
        for _input_scalar, _additive_counterpart in r1r2diff_zip:

            input_scalar = ssid2obj[_input_scalar.name]
            additive_counterpart = ssid2obj[_additive_counterpart.name]

            s2range[input_scalar.ssid] = (input_scalar.min_val, input_scalar.max_val)
            s2range[additive_counterpart.ssid] = (input_scalar.min_val, input_scalar.max_val)

        rranges = list()
        for index,symbol in enumerate(i2s):
            rranges.append(s2range[symbol])

        r2_indices_list = list()
        min_max_list = list()
        for r2_val in r2:
            r2_syms = [ssid2obj[x.name] for x in r2_val.free_symbols]
            r2_indices = [s2i[x.ssid] for x in r2_syms]

            r2_indices_list.append(r2_indices)
            min_max_list.append((r2_syms[0].min_val, r2_syms[0].max_val))

        functions = list()
        for i in range(2):
            f1 = lambda x,i=i: x[r2_indices_list[i][0]]+x[r2_indices_list[i][1]] + min_max_list[i][0]
            f2 = lambda x,i=i: -(x[r2_indices_list[i][0]]+x[r2_indices_list[i][1]]) + min_max_list[i][1]

            functions.append(f1)
            functions.append(f2)

        constraints = [{'type':'ineq', 'fun':f} for f in functions]

        def non_negative_additive_terms(symbol_vector):
            out = 0
            for index in [s2i[x.name] for x in r2_diffs]:
                out += (symbol_vector[index]**2)
            # theres a small bit of rounding error from this constraint - this should
            # only be used as a double check or as a backup!!!
            return out**0.5 - 1/2**16 

        constraints.append({'type':'ineq', 'fun':non_negative_additive_terms})
        results = minimize_function(f=search_fun, rranges=rranges, constraints=constraints, force_all_searches=force_all_searches)
        
        return results, C

    def max_lipschitz_via_jacobian(self, input_entity=None, data_dependent=True, force_all_searches=False, try_hessian_shortcut=False):
        return max_lipschitz_via_jacobian(scalars=[self], input_entity=input_entity, data_dependent=data_dependent, force_all_searches=force_all_searches, try_hessian_shortcut=try_hessian_shortcut)  
    
    @property
    def max_lipschitz(self):
        result = self.max_lipschitz_via_jacobian()[0][-1]
        if isinstance(result, float):
            return -result
        else:
            return -float(result.fun)
    
    def max_lipschitz_wrt_entity(self, entity):
        result = self.max_lipschitz_via_jacobian(input_entity=entity)[0][-1]
        if isinstance(result, float):
            return -result
        else:
            return -float(result.fun)
    
class GammaScalar(OriginScalar, IntermediateGammaScalar):
    """A scalar over data from multiple entities"""
    
    def __init__(self, min_val, value, max_val, entity=None, id=None, ssid=None):
        super().__init__(min_val=min_val, value=value, max_val=max_val, entity=entity, id=id)
        
        # the scalar string identifier (SSID) - because we're using polynomial libraries
        # we need to be able to reference this object in string form. the library doesn't
        # know how to process things that aren't strings
        if ssid is None:
            ssid = str(self.id).split(" ")[1][:-1] + "_" + str(self.entity.id).split(" ")[1][:-1]
            
        self.ssid = ssid
        
        IntermediateGammaScalar.__init__(self, poly=symbols(self.ssid))
        
        ssid2obj[self.ssid] = self
        


In [199]:
from syft.core.adp.adversarial_accountant import AdversarialAccountant
from syft.core.adp.entity import Entity
from copy import deepcopy
from syft.core.adp.idp_gaussian_mechanism import iDPGaussianMechanism

In [200]:
# stdlib
from typing import Dict as TypeDict
from typing import KeysView as TypeKeysView
from typing import List as TypeList
from typing import Set as TypeSet

# third party
from autodp.autodp_core import Mechanism
from autodp.transformer_zoo import Composition


class AdversarialAccountant:
    def __init__(self, max_budget: float = 10, delta: float = 1e-6) -> None:
        self.entity2ledger: TypeDict[Entity, Mechanism] = {}
        self.max_budget = max_budget
        self.delta = delta

    def append(self, entity2mechanisms: TypeDict[str, TypeList[Mechanism]]) -> None:
        for key, ms in entity2mechanisms.items():
            if key not in self.entity2ledger.keys():
                self.entity2ledger[key] = list()
            for m in ms:
                self.entity2ledger[key].append(m)

    def get_eps_for_entity(self, entity: Entity) -> Scalar:
        # compose them with the transformation: compose.
        compose = Composition()
        mechanisms = self.entity2ledger[entity]
        composed_mech = compose(mechanisms, [1] * len(mechanisms))

        # Query for eps given delta
        return PhiScalar(
            value=composed_mech.get_approxDP(self.delta),
            min_val=0,
            max_val=self.max_budget,
            entity=entity,
        )

    def has_budget(self, entity_name: str) -> bool:
        eps = self.get_eps_for_entity(entity_name)
        if eps.value is not None:
            return eps.value < self.max_budget

    @property
    def entities(self) -> TypeKeysView[str]:
        return self.entity2ledger.keys()

    @property
    def overbudgeted_entities(self) -> TypeSet[str]:
        entities = set()

        for ent in self.entities:
            if not self.has_budget(ent):
                entities.add(ent)

        return entities

    def print_ledger(self, delta: float = 1e-6) -> None:
        for entity, mechanisms in self.entity2ledger.items():
            print(str(entity) + "\t" + str(self.get_eps_for_entity(entity)._value))


In [227]:


x = PhiScalar(0,0.01,1)
y = PhiScalar(0,0.02,1)
z = PhiScalar(0,0.02,1)

o = x*x + y*y + z
z = o * o * o

In [244]:
for k in sym.class_registry.all_classes:
    if (isinstance(z.poly, k)):
        print(k)

<class 'sympy.core.basic.Basic'>
<class 'sympy.core.expr.Expr'>
<class 'sympy.core.power.Pow'>


In [250]:
from p

NameError: name 'function' is not defined

In [232]:
isinstance(z.poly, sym.Symbol)

False

In [230]:
type(z.poly)

sympy.core.power.Pow

In [203]:
z.max_lipschitz_via_explicit_search()

([     fun: -46.47853095774777
      funl: array([-46.47853096, -46.14593642, -46.14593642, -45.72380853])
   message: 'Optimization terminated successfully.'
      nfev: 395
       nit: 2
     nlfev: 366
     nlhev: 0
     nljev: 35
   success: True
         x: array([0.62855774, 0.62855775, 1.        , 1.        , 0.37144225,
         0.37144226])
        xl: array([[0.62855774, 0.62855775, 1.        , 1.        , 0.37144225,
          0.37144226],
         [0.64833262, 0.5       , 1.        , 1.        , 0.5       ,
          0.35166738],
         [0.5       , 0.64833265, 1.        , 1.        , 0.35166735,
          0.5       ],
         [0.5       , 0.5       , 1.        , 1.        , 0.5       ,
          0.5       ]])],
 -(4ae41313c3a941f4bc818d040bf5d1c6_7b7ecdd5de5d4f7da6215a8a97db5ab8**2 + 79584fa604334b038341109712bb3f32_f196c0224a6f4cc389c7f122ac2c9e8e**2 + dcad1295b5b345ab88856e2b91cc5146_16cf56ce29ef4408993ee71a0e51b9dd**2)**(-0.5)*(((69f93b3cedbf4087a607b592c7a0b711_f196

In [204]:
z.max_lipschitz_via_jacobian()

([     fun: -46.76537180435969
      funl: array([-46.7653718])
   message: 'Optimization terminated successfully.'
      nfev: 14
       nit: 2
     nlfev: 5
     nlhev: 0
     nljev: 1
   success: True
         x: array([1., 1., 1.])
        xl: array([[1., 1., 1.]])],
 -5.19615242270663*((69f93b3cedbf4087a607b592c7a0b711_f196c0224a6f4cc389c7f122ac2c9e8e + 7f150aa4dc774de59837714580f3656e_7b7ecdd5de5d4f7da6215a8a97db5ab8 + c73d8bc28f6b4c5881f976b8a657e61f_16cf56ce29ef4408993ee71a0e51b9dd)**4)**0.5)

In [192]:
%%timeit

acc = AdversarialAccountant(max_budget=10)
z.publish(acc=acc, sigma=0.2)
z2.publish(acc=acc, sigma=0.2)

13.9 ms ± 81.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [181]:
%%timeit
acc = AdversarialAccountant(max_budget=10)

publish([z,z2], acc=acc, sigma=0.2)

11 ms ± 124 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [193]:
from syft.lib.autograd.value import to_values
from syft.lib.autograd.value import grad

def make_entities(n=100):
    ents = list()
    for i in range(n):
        ents.append(Entity(name=names.get_full_name().replace(" ", "_")))
    return ents


def private(input_data, min_val, max_val, entities=None, is_discrete=False):

    self = input_data
    if entities is None:
        flat_data = self.flatten()
        entities = make_entities(n=len(flat_data))

        scalars = list()
        for i in flat_data:
            value = max(min(float(i), max_val), min_val)
            s = Scalar(
                PhiScalar=value,
                min_val=min_val,
                max_val=max_val,
                entity=entities[len(scalars)],
#                 is_discrete=is_discrete
            )
            scalars.append(s)

        return to_values(np.array(scalars)).reshape(input_data.shape)

    elif isinstance(entities, list):
        if len(entities) == len(self):
            output_rows = list()
            for row_i, row in enumerate(self):
                row_of_entries = list()
                for item in row.flatten():
                    s = PhiScalar(
                        value=item,
                        min_val=min_val,
                        max_val=max_val,
                        entity=entities[row_i],
#                         is_discrete=is_discrete
                    )
                    row_of_entries.append(s)
                output_rows.append(np.array(row_of_entries).reshape(row.shape))
            return to_values(np.array(output_rows)).reshape(self.shape)
        else:
            print(len(entities))
            print(len(self))
            raise Exception("len(entities) must equal len(self)")


class Tensor(np.ndarray):
    def __new__(cls, input_array, min_val=None, max_val=None, entities=None, info=None, is_discrete=False):

        is_private = False

        if min_val is not None and max_val is not None:
            input_array = private(
                input_array, min_val=min_val, max_val=max_val, entities=entities, is_discrete=is_discrete
            )
            is_private = True
        else:
            input_array = to_values(input_array)

        obj = np.asarray(input_array).view(cls)
        obj.info = info
        obj.is_private = is_private

        return obj

    def __array_finalize__(self, obj):
        if obj is None:
            return
        self.info = getattr(obj, "info", None)
        self.is_private = getattr(obj, "is_private", None)

    def __array_wrap__(self, out_arr, context=None):
        output = out_arr.view(Tensor)

        is_private = False
        if context is not None:
            for arg in context[1]:
                if hasattr(arg, "is_private") and arg.is_private:
                    is_private = True
        output.is_private = is_private

        return output

    def backward(self):
        if self.shape == ():
            return grad(self.flatten()[0])
        else:
            raise Exception("Can only call .backward() on single-value tensor.")

    @property
    def grad(self):
        grads = list()
        for val in self.flatten().tolist():
            grads.append(val._grad)
        return Tensor(grads).reshape(self.shape)

    def slow_publish(self, **kwargs):
        grads = list()
        for val in self.flatten().tolist():
            grads.append(val.value.publish(**kwargs))
        return np.array(grads).reshape(self.shape)
    
    def publish(self, **kwargs):
        grads = list()
        for val in self.flatten().tolist():
            grads.append(val.value)
        grads = publish(scalars=grads, **kwargs)
        return np.array(grads).reshape(self.shape)    

    @property
    def value(self):
        values = list()
        for val in self.flatten().tolist():
            if hasattr(val.value, "value"):
                values.append(val.value.value)
            else:
                values.append(val.value)
        return np.array(values).reshape(self.shape)

    def private(self, min_val, max_val, entities=None, is_discrete=False):
        if self.is_private:
            raise Exception("Cannot call .private() on tensor which is already private")

        return Tensor(self.value, min_val=min_val, max_val=max_val, entities=entities, is_discrete=is_discrete)


In [194]:
acc = AdversarialAccountant(max_budget=3000000)

entities = [Entity(unique_name="Tudor"), Entity(unique_name="Madhava"), Entity(unique_name="Kritika"), Entity(unique_name="George")]

x = Tensor(np.array([[1,1],[1,0],[0,1],[0,0]])).private(min_val=0, max_val=1, entities=entities, is_discrete=True)
y = Tensor(np.array([[1],[1],[0],[0]])).private(min_val=0, max_val=1, entities=entities, is_discrete=False)

_weights = Tensor(np.random.uniform(size=(2,1)))

In [197]:
weights = _weights + 0
acc = AdversarialAccountant(max_budget=3000000)

for i in range(10):
    batch_loss = 0

    pred = x.dot(weights)
    loss = np.mean(np.square(y-pred))
    loss.backward()

    weight_grad = (weights.grad * 0.5)
    weight_grad = weight_grad.publish(acc=acc, sigma=0.1)

    weights = weights - weight_grad
    batch_loss += loss.value

    acc.print_ledger()
#     print(weights)

<Entity:George>	7.316032539740162
<Entity:Tudor>	7.316032539740162
<Entity:Madhava>	7.316032539740162
<Entity:Kritika>	7.316032539740162
<Entity:George>	9.520798620725053
<Entity:Tudor>	9.520798620725053
<Entity:Madhava>	9.520798620725053
<Entity:Kritika>	9.520798620725053
<Entity:George>	9.726133952597301
<Entity:Tudor>	9.726133952597301
<Entity:Madhava>	9.726133952597301
<Entity:Kritika>	9.726133952597301
<Entity:George>	10.043728493829402
<Entity:Tudor>	10.043728493829402
<Entity:Madhava>	10.043728493829402
<Entity:Kritika>	10.043728493829402
<Entity:George>	10.39935836365433
<Entity:Tudor>	10.39935836365433
<Entity:Madhava>	10.39935836365433
<Entity:Kritika>	10.39935836365433
<Entity:George>	10.413935084796428
<Entity:Tudor>	10.413935084796428
<Entity:Madhava>	10.413935084796428
<Entity:Kritika>	10.413935084796428
<Entity:George>	10.464816287127624
<Entity:Tudor>	10.464816287127624
<Entity:Madhava>	10.464816287127624
<Entity:Kritika>	10.464816287127624
<Entity:George>	11.4092937457

564 ms ± 4.86 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
