In [1]:
from syft import nn
from syft import PhiTensor
from syft import GammaTensor
from syft import DataSubjectList
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


## Test out utils functions

In [2]:
from jax import numpy as jnp

In [3]:
np.log(jnp.ones(5))



array([0., 0., 0., 0., 0.])

In [4]:
np.log(np.e)

1.0

In [5]:
dsl = DataSubjectList(one_hot_lookup=np.array([0,1]), data_subjects_indexed=np.concatenate((np.zeros(10), np.ones(10))))

In [6]:
np.log(np.ones(10))

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [7]:
gt = GammaTensor(child=np.ones(10)*np.e,data_subjects=dsl ,min_val=1, max_val=5)

In [8]:
gt.log()

GammaTensor(child=array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]), data_subjects=<syft.core.adp.data_subject_list.DataSubjectList object at 0x7feeb3549ee0>, min_val=<lazyrepeatarray data: 0.0 -> shape: (10,)>, max_val=<lazyrepeatarray data: 1.6094379124341003 -> shape: (10,)>, is_linear=True, func=<function GammaTensor.log.<locals>._log at 0x7feeb35859d0>, id='1103955744', state={'908042406': GammaTensor(child=array([2.71828183, 2.71828183, 2.71828183, 2.71828183, 2.71828183,
       2.71828183, 2.71828183, 2.71828183, 2.71828183, 2.71828183]), data_subjects=<syft.core.adp.data_subject_list.DataSubjectList object at 0x7feeb3549ee0>, min_val=1, max_val=5, is_linear=True, func=<function no_op at 0x7feeb2878040>, id='908042406', state={})})

In [9]:
nn.utils.dp_log(gt)

GammaTensor(child=array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]), data_subjects=<syft.core.adp.data_subject_list.DataSubjectList object at 0x7feeb3549ee0>, min_val=<lazyrepeatarray data: 0.0 -> shape: (10,)>, max_val=<lazyrepeatarray data: 1.6094379124341003 -> shape: (10,)>, is_linear=True, func=<function GammaTensor.log.<locals>._log at 0x7feeb3585d30>, id='2045560941', state={'908042406': GammaTensor(child=array([2.71828183, 2.71828183, 2.71828183, 2.71828183, 2.71828183,
       2.71828183, 2.71828183, 2.71828183, 2.71828183, 2.71828183]), data_subjects=<syft.core.adp.data_subject_list.DataSubjectList object at 0x7feeb3549ee0>, min_val=1, max_val=5, is_linear=True, func=<function no_op at 0x7feeb2878040>, id='908042406', state={})})

In [10]:
nn.utils.dp_zeros(gt.shape, gt.data_subjects)

GammaTensor(child=array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), data_subjects=<syft.core.adp.data_subject_list.DataSubjectList object at 0x7feeb3587880>, min_val=<lazyrepeatarray data: 0.0 -> shape: (10,)>, max_val=<lazyrepeatarray data: 0.0 -> shape: (10,)>, is_linear=True, func=<function no_op at 0x7feeb2878040>, id='944352444', state={})

## DataSubjectList Combination tests


**Case 1: Non-overlapping DSL**

In [20]:
dsl1 = DataSubjectList(
    one_hot_lookup=np.arange(10),
    data_subjects_indexed=np.random.choice(np.arange(10), size=(10, 10))
)

In [27]:
w = np.arange(100, 110)
dsl2 = DataSubjectList(
    one_hot_lookup = w,
    data_subjects_indexed=np.random.choice(np.arange(len(w)), size=(10, 10))
)

In [28]:
dsl1.one_hot_lookup, dsl2.one_hot_lookup

(array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
 array([100, 101, 102, 103, 104, 105, 106, 107, 108, 109]))

In [29]:
dsl1.data_subjects_indexed

array([[1, 3, 8, 3, 9, 2, 9, 3, 6, 1],
       [6, 7, 6, 7, 5, 9, 0, 3, 7, 5],
       [7, 8, 4, 2, 6, 9, 8, 3, 9, 5],
       [6, 6, 5, 9, 3, 4, 5, 9, 8, 7],
       [1, 1, 4, 1, 5, 7, 1, 8, 7, 9],
       [1, 4, 0, 2, 1, 5, 1, 5, 7, 6],
       [3, 4, 8, 9, 4, 5, 4, 8, 1, 7],
       [2, 7, 4, 6, 0, 5, 7, 9, 4, 4],
       [2, 0, 5, 5, 1, 6, 4, 0, 3, 6],
       [0, 3, 3, 7, 6, 1, 4, 5, 0, 6]])

In [30]:
dsl2.data_subjects_indexed

array([[4, 9, 5, 6, 2, 0, 9, 2, 1, 1],
       [1, 3, 1, 6, 5, 6, 6, 2, 9, 7],
       [5, 0, 0, 3, 9, 9, 0, 6, 4, 9],
       [9, 5, 0, 0, 0, 5, 0, 1, 5, 2],
       [5, 1, 6, 1, 9, 0, 8, 8, 6, 5],
       [8, 9, 2, 7, 8, 1, 6, 7, 3, 0],
       [6, 3, 6, 4, 7, 5, 0, 4, 1, 5],
       [0, 5, 1, 7, 9, 6, 5, 9, 6, 6],
       [6, 7, 1, 7, 5, 0, 7, 9, 2, 7],
       [7, 4, 2, 8, 7, 3, 6, 8, 6, 0]])

In [31]:
dsl3 = DataSubjectList.combine(dsl1, dsl2)

In [32]:
dsl3.one_hot_lookup

array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9, 100, 101, 102,
       103, 104, 105, 106, 107, 108, 109])

**Case 2: Fully overlapping DSL**

In [33]:
dsl4 = DataSubjectList.combine(dsl1, dsl1)

In [34]:
dsl4 == dsl1

True

In [36]:
(dsl4.data_subjects_indexed == dsl1.data_subjects_indexed).all()

True

**Case 3: Partially overlapping DSL**

In [20]:
dsl1 = DataSubjectList(
    one_hot_lookup=np.arange(10),
    data_subjects_indexed=np.random.choice(np.arange(10), size=(10, 10))
)

In [38]:
w = np.arange(7, 17)
dsl5 = DataSubjectList(
    one_hot_lookup = w,
    data_subjects_indexed=np.random.choice(np.arange(len(w)), size=(10, 10))
)

In [39]:
dsl1.one_hot_lookup, dsl5.one_hot_lookup

(array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
 array([ 7,  8,  9, 10, 11, 12, 13, 14, 15, 16]))

In [43]:
dsl5.data_subjects_indexed

array([[8, 7, 8, 4, 7, 7, 8, 5, 7, 5],
       [3, 6, 5, 5, 2, 6, 3, 5, 9, 6],
       [0, 6, 8, 7, 6, 9, 3, 5, 9, 4],
       [7, 3, 8, 4, 0, 0, 9, 4, 4, 7],
       [5, 9, 9, 0, 1, 8, 6, 9, 0, 6],
       [9, 4, 8, 1, 5, 6, 8, 1, 5, 2],
       [1, 8, 9, 7, 4, 1, 5, 5, 1, 7],
       [2, 6, 0, 4, 2, 5, 8, 1, 0, 4],
       [0, 6, 7, 3, 8, 1, 9, 4, 3, 7],
       [8, 1, 8, 4, 4, 9, 9, 3, 6, 7]])

In [40]:
dsl6 = DataSubjectList.combine(dsl1, dsl5)

In [41]:
dsl6.one_hot_lookup

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16])

In [42]:
dsl6.data_subjects_indexed

array([[[ 1,  3,  8,  3,  9,  2,  9,  3,  6,  1],
        [ 6,  7,  6,  7,  5,  9,  0,  3,  7,  5],
        [ 7,  8,  4,  2,  6,  9,  8,  3,  9,  5],
        [ 6,  6,  5,  9,  3,  4,  5,  9,  8,  7],
        [ 1,  1,  4,  1,  5,  7,  1,  8,  7,  9],
        [ 1,  4,  0,  2,  1,  5,  1,  5,  7,  6],
        [ 3,  4,  8,  9,  4,  5,  4,  8,  1,  7],
        [ 2,  7,  4,  6,  0,  5,  7,  9,  4,  4],
        [ 2,  0,  5,  5,  1,  6,  4,  0,  3,  6],
        [ 0,  3,  3,  7,  6,  1,  4,  5,  0,  6]],

       [[15, 14, 15, 11, 14, 14, 15, 12, 14, 12],
        [10, 13, 12, 12,  9, 13, 10, 12, 16, 13],
        [ 7, 13, 15, 14, 13, 16, 10, 12, 16, 11],
        [14, 10, 15, 11,  7,  7, 16, 11, 11, 14],
        [12, 16, 16,  7,  8, 15, 13, 16,  7, 13],
        [16, 11, 15,  8, 12, 13, 15,  8, 12,  9],
        [ 8, 15, 16, 14, 11,  8, 12, 12,  8, 14],
        [ 9, 13,  7, 11,  9, 12, 15,  8,  7, 11],
        [ 7, 13, 14, 10, 15,  8, 16, 11, 10, 14],
        [15,  8, 15, 11, 11, 16, 16, 10, 13, 14]

In [46]:
(dsl6.data_subjects_indexed[0] == dsl1.data_subjects_indexed).all()

True