In [1]:
from syft import nn
from syft import PhiTensor
from syft import GammaTensor
from syft import DataSubjectList
import numpy as np
from jax import numpy as jnp

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# third party
import numpy as np
from typing import Union

# relative
# from ....common.serde.serializable import serializable
# from ...autodp.phi_tensor import PhiTensor
XavierInitialization = nn.initializations.XavierInitialization
Layer = nn.layers.base.Layer

# @serializable(recursive_serde=True)
class Linear(Layer):
    def __init__(self, n_out, n_in=None):
        self.n_out = n_out
        self.n_in = n_in
        self.out_shape = (None, n_out)

        self.W = None
        self.b = None
        self.dW = None
        self.db = None
        self.last_input = None
        self.init = XavierInitialization()

    def connect_to(self, prev_layer=None):
        if prev_layer is None:
            assert self.n_in is not None
            n_in = self.n_in
        else:
            assert len(prev_layer.out_shape) == 2
            n_in = prev_layer.out_shape[-1]

        self.W = self.init((n_in, self.n_out))
        self.b = np.zeros((self.n_out,))

    def forward(self, input: Union[PhiTensor, GammaTensor], *args, **kwargs):
        self.last_input = input
        return input.dot(self.W) + self.b

    def backward(self, pre_grad: PhiTensor, *args, **kwargs):
        self.dW = self.last_input.transpose().dot(pre_grad)  # Should this transpose just the last 2 indices?
        #         self.dW = self.last_input.swapaxes(-1, -2)
        self.db = pre_grad.mean(axis=0)
        if not self.first_layer:
            return pre_grad.dot(self.W.T)

    @property
    def params(self):
        return self.W, self.b

    @property
    def grads(self):
        return self.dW, self.db


In [3]:
from syft import lazyrepeatarray as lra

sh = (10, 10)
dsl = DataSubjectList(
        one_hot_lookup=np.arange(10),
        data_subjects_indexed=np.random.choice(np.arange(10), size=sh)
    )
smol_data = GammaTensor(
        child=jnp.array(np.random.rand(*sh) * 255),
        data_subjects=dsl,
        min_vals=lra(0, shape=sh),
        max_vals=lra(255, shape=sh))



In [4]:
smol_data.dot(smol_data.transpose()).state.keys()

dict_keys(['613830793', '1276741108'])

In [5]:
assert isinstance(smol_data.dot(np.ones((10,10))), GammaTensor)

In [6]:
lin = Linear(n_out=5)

In [7]:
lin.n_in = 10
lin.connect_to()

In [8]:
res = lin.forward(smol_data)

In [9]:
res.shape

(10, 5)

In [10]:
lin.W.shape

(10, 5)

In [11]:
res2 = lin.backward(res)

In [23]:
res2.data_subjects.data_subjects_indexed.shape

(10, 10)

In [12]:
lin.W

array([[ 0.49379314,  0.21496861, -0.45342833, -0.35527468, -0.4657647 ],
       [ 0.45522281, -0.22787395,  0.25857645, -0.23341571,  0.3831928 ],
       [ 0.31659098, -0.62084241,  0.10331025,  0.58979858,  0.41826245],
       [ 0.51796771, -0.51816149,  0.15696   ,  0.42490617, -0.54732432],
       [ 0.00547995,  0.32491967, -0.43517279,  0.11515994,  0.16317752],
       [ 0.38824337, -0.63169924,  0.54700098, -0.49441273, -0.58148438],
       [ 0.10088685,  0.04274169, -0.59989856,  0.40738686, -0.17396011],
       [-0.20324239, -0.1218242 ,  0.62339347,  0.26855927, -0.58264807],
       [ 0.32528151, -0.56007097, -0.45544377, -0.36854267, -0.08653191],
       [ 0.25307497,  0.20967178, -0.02153192,  0.46369819, -0.23672348]])

In [13]:
lin.b

array([0., 0., 0., 0., 0.])

In [14]:
a = np.ones((5,5))
b = a.mean(axis=1)

In [15]:
c = np.zeros_like(a)

In [16]:
d = c.reshape((*b.shape, -1))
d.shape

(5, 5)

In [17]:
b

array([1., 1., 1., 1., 1.])

In [18]:
b.shape

(5,)

In [19]:
b.reshape((1, *b.shape)).shape

(1, 5)

In [20]:
b_shape = np.expand_dims(b, 0).shape

In [21]:
d = c.reshape((-1, *b_shape))
d.shape

(5, 1, 5)

In [22]:
d

array([[[0., 0., 0., 0., 0.]],

       [[0., 0., 0., 0., 0.]],

       [[0., 0., 0., 0., 0.]],

       [[0., 0., 0., 0., 0.]],

       [[0., 0., 0., 0., 0.]]])