In [34]:
from mxnet import nd
import mxnet as mx
from mxnet import autograd
from mxnet import gluon

import numpy as np
import time
import random

In [13]:
class _FWLayer(gluon.nn.Block):
    """ Basic Fast Weights Layer, performs operation, as mentioned in paper:
        hs+1(t + 1) = f([W h(t) + Cx(t)] + A(t)hs(t + 1))
    
    """
    
    def __init__(self, f, hidden_base, hidden_state, C, A, W, s, prefix=None, params=None):
        super(_FWLayer, self).__init__(prefix=prefix, params=params)
        self._hidden_size = hidden_state.shape
        self._f = f
        self._ht = hidden_base
        self._hs = hidden_state
        self._C = C
        self._A = A
        self._W = W
        self._s = s
        
    def forward(self, inputs):
        self._hs = self._f(nd.dot(self._W, self._ht) + nd.dot(self._C, inputs) + nd.dot(self._A, self._hs))
    
    def step(self, inputs):
        for _ in range(self._s):
            self.forward(inputs)
            print(self._hs)

In [41]:
mx.random.seed(1)
random.seed(1)

x = nd.random.normal(shape=(3, 1))
hs = nd.random.normal(shape=(10, 1))
ht = nd.zeros((10, 1))

C = nd.random.normal(shape=(10, 3))
A = nd.random.normal(shape=(10, 10))
W = nd.random.normal(shape=(10, 10))

def identity(x):
    return x

layer = _FWLayer(f=identity, hidden_base=ht, hidden_state=hs, C=C, A=A, W=W, s=10)
layer.initialize()

layer.step(x)


[[-6.0638256 ]
 [ 2.8693247 ]
 [-7.9947515 ]
 [-2.327079  ]
 [-1.3456746 ]
 [ 0.38783777]
 [ 6.924112  ]
 [ 2.0592256 ]
 [ 0.10559458]
 [ 2.0353913 ]]
<NDArray 10x1 @cpu(0)>

[[-40.43789   ]
 [-42.17067   ]
 [ -0.24119651]
 [ -3.998911  ]
 [-12.306088  ]
 [ 10.265751  ]
 [  4.2345486 ]
 [  5.065038  ]
 [ 11.713766  ]
 [ 12.851218  ]]
<NDArray 10x1 @cpu(0)>

[[ -12.179483 ]
 [-121.08273  ]
 [ -40.412746 ]
 [ -28.201056 ]
 [ -58.478764 ]
 [   1.6772883]
 [  40.252537 ]
 [  40.172512 ]
 [   2.7554612]
 [ -19.785248 ]]
<NDArray 10x1 @cpu(0)>

[[ 227.38136  ]
 [-109.64625  ]
 [   5.940396 ]
 [  23.423113 ]
 [  -2.6004677]
 [ 167.06367  ]
 [-186.27754  ]
 [ 215.13727  ]
 [-154.5283   ]
 [-105.13273  ]]
<NDArray 10x1 @cpu(0)>

[[ 704.95416]
 [ 350.39587]
 [ 362.30975]
 [ 644.53674]
 [ 274.13397]
 [ 209.13283]
 [-498.0943 ]
 [ 208.65565]
 [-269.19598]
 [-291.29495]]
<NDArray 10x1 @cpu(0)>

[[  478.74957]
 [ 1712.4463 ]
 [ -134.65001]
 [ 1297.274  ]
 [ 1330.845  ]
 [  305.22607]
 [ -360.92807]