[link](https://github.com/google-deepmind/deepmind-research/blob/master/gated_linear_networks/colabs/dendritic_gated_network.ipynb)

In [1]:
import numpy as np
from sklearn import datasets
from sklearn import preprocessing
from sklearn import model_selection

In [2]:
features, targets = datasets.load_breast_cancer(return_X_y=True)

In [3]:
features.shape

(569, 30)

In [4]:
x_train, x_test, y_train, y_test = model_selection.train_test_split(features, targets, test_size=0.2, random_state=0)

In [5]:
n_features = x_train.shape[-1]
n_features

30

In [6]:
# Input features are centered and scaled to unit variance:
feature_encoder = preprocessing.StandardScaler()
x_train = feature_encoder.fit_transform(x_train)
x_test = feature_encoder.transform(x_test)

In [7]:
x_train.shape[0], x_test.shape[0]

(455, 114)

In [8]:
# number of neurons per layer, the last element must be 1
n_neurons = np.array([100, 10, 1])
n_branches = 20  # number of dendritic brancher per neuron

In [9]:
n_neurons

array([100,  10,   1])

In [10]:
n_inputs = np.hstack([n_features + 1, n_neurons[:-1] + 1])

In [11]:
n_inputs

array([ 31, 101,  11])

In [12]:
dgn_weights = [np.zeros((n_neuron, n_branches, n_input)) for n_neuron, n_input in zip(n_neurons, n_inputs)]

In [13]:
dgn_weights[0].shape

(100, 20, 31)

In [14]:
dgn_weights[1].shape

(10, 20, 101)

In [15]:
dgn_weights[2].shape

(1, 20, 11)

In [16]:
dgn_hyperplanes = [np.random.normal(0, 1, size=(n_neuron, n_branches, n_features + 1)) for n_neuron in n_neurons]

In [17]:
dgn_hyperplanes[0].shape

(100, 20, 31)

In [18]:
dgn_hyperplanes = [
    h_ / np.linalg.norm(h_[:, :, :-1], axis=(1, 2))[:, None, None]
    for h_ in dgn_hyperplanes]

In [19]:
dgn_hyperplanes[0].shape

(100, 20, 31)

In [20]:
i = 0
for i, x_i in enumerate(x_train):
    if i == 2:
        break

for w, h in zip(dgn_weights, dgn_hyperplanes):
    break

In [21]:
target = 1

In [22]:
side_info = np.hstack([1., x_i])
side_info.shape

(31,)

In [23]:
r_in = np.hstack([1., x_i])
r_in

array([ 1.        ,  0.574121  , -1.03333557,  0.51394098,  0.40858627,
       -0.10616078, -0.36301886, -0.41799048, -0.08844569, -0.27182044,
       -0.57522132, -0.57672579, -1.05784511, -0.53856037, -0.38708923,
       -1.07211882, -0.72057496, -0.42362791, -0.49218988, -0.67484362,
       -0.80147288,  0.29761532, -0.97781783,  0.26213665,  0.11388819,
       -0.52472419, -0.52086645, -0.18298917, -0.02371948, -0.20050207,
       -0.75144254])

In [24]:
h.shape

(100, 20, 31)

In [25]:
side_info.shape

(31,)

In [26]:
h.dot(side_info).shape

(100, 20)

In [27]:
gate_values = np.heaviside(h.dot(side_info), 0).astype(bool)
gate_values

array([[False,  True,  True, ..., False,  True,  True],
       [ True, False, False, ...,  True, False,  True],
       [False,  True,  True, ..., False, False,  True],
       ...,
       [False,  True, False, ..., False,  True, False],
       [ True, False, False, ..., False, False,  True],
       [ True, False, False, ..., False, False, False]])

In [28]:
w.shape

(100, 20, 31)

In [29]:
gate_values.shape

(100, 20)

In [30]:
effective_weights = gate_values.dot(w).sum(axis=1)
effective_weights

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]])

In [48]:
gate_values.dot(w).shape

(100, 100, 31)

In [31]:
effective_weights.shape

(100, 31)

In [32]:
r_in.shape

(31,)

In [33]:
r_out = effective_weights.dot(r_in)

In [34]:
r_out.shape

(100,)

In [35]:
r_out[:, None].shape

(100, 1)

In [36]:
r_in[None].shape

(1, 31)

In [37]:
grad = (r_out[:, None] - target) * r_in[None]
grad

array([[-1.        , -0.574121  ,  1.03333557, ...,  0.02371948,
         0.20050207,  0.75144254],
       [-1.        , -0.574121  ,  1.03333557, ...,  0.02371948,
         0.20050207,  0.75144254],
       [-1.        , -0.574121  ,  1.03333557, ...,  0.02371948,
         0.20050207,  0.75144254],
       ...,
       [-1.        , -0.574121  ,  1.03333557, ...,  0.02371948,
         0.20050207,  0.75144254],
       [-1.        , -0.574121  ,  1.03333557, ...,  0.02371948,
         0.20050207,  0.75144254],
       [-1.        , -0.574121  ,  1.03333557, ...,  0.02371948,
         0.20050207,  0.75144254]])

In [38]:
grad.shape

(100, 31)

In [39]:
learning_rate = 0.001

In [40]:
gate_values[:, :, None].shape

(100, 20, 1)

In [41]:
gate_values.shape

(100, 20)

In [42]:
grad[:, None].shape

(100, 1, 31)

In [43]:
a = gate_values[:, :, None] * grad[:, None]

In [44]:
a.shape

(100, 20, 31)

In [45]:
w.shape

(100, 20, 31)

In [46]:
w -= learning_rate * gate_values[:, :, None] * grad[:, None]

In [47]:
r_out.shape

(100,)