In [1]:
import torch
from torch import nn
import sys
sys.path.append('library/src')
import rnnbuilder as rb
from rnnbuilder import custom

In [2]:
from rnnbuilder.nn import ReLU, Conv2d, Linear, Tanh, Sigmoid


from functools import reduce
import operator
def prod(iterable):
    return reduce(operator.mul, iterable, 1)

class HadamardModule(rb.custom.CustomModule):
    def get_out_shape(self, in_shapes):
        return in_shapes[0]

    def forward(self, inputs, _):
        return prod(inputs), ()

Hadamard = custom.register_recurrent(module_class=HadamardModule, flatten_input=False, single_step=False, unroll_full_state=False)


In [None]:
hidden_size = 32
n = rb.Network()

n.i_i = n.input.apply(Linear(hidden_size))
n.i_f = n.input.apply(Linear(hidden_size))
n.i_o = n.input.apply(Linear(hidden_size))
n.i_g = n.input.apply(Linear(hidden_size))

n.output = rb.Placeholder()
n.h_i = n.output.apply(Linear(hidden_size))
n.h_f = n.output.apply(Linear(hidden_size))
n.h_o = n.output.apply(Linear(hidden_size))
n.h_g = n.output.apply(Linear(hidden_size))

n.i = n.i_i.sum(n.h_i).apply(Sigmoid())
n.f = n.i_f.sum(n.h_f).apply(Sigmoid())
n.o = n.i_o.sum(n.h_o).apply(Sigmoid())
n.g = n.i_g.sum(n.h_g).apply(Tanh())

n.c = rb.Placeholder()
n.c_1 = n.f.append(n.c).apply(Hadamard())
n.c_2 = n.i.append(n.g).apply(Hadamard())

n.c = n.c_1.sum(n.c_2)
n.tan_c = n.c.apply(Tanh())
n.output = n.o.append(n.tan_c).apply(Hadamard())




In [10]:
hidden_size = 32
n = rb.Network()

n.output = rb.Placeholder()
n.h_and_i = n.input.stack(n.output)

n.i = n.h_and_i.apply(Linear(hidden_size), Sigmoid())
n.f = n.h_and_i.apply(Linear(hidden_size), Sigmoid())
n.o = n.h_and_i.apply(Linear(hidden_size), Sigmoid())
n.g = n.h_and_i.apply(Linear(hidden_size), Tanh())

n.c = rb.Placeholder()
n.c_1 = n.f.append(n.c).apply(Hadamard())
n.c_2 = n.i.append(n.g).apply(Hadamard())

n.c = n.c_1.sum(n.c_2)
n.tan_c = n.c.apply(Tanh())
n.output = n.o.append(n.tan_c).apply(Hadamard())

In [4]:
seq = 8
batch = 32
inp_shape = (100,)
example = torch.rand((seq,batch)+inp_shape)

In [11]:
model = n.make_model(inp_shape)
lstm = torch.nn.LSTM(inp_shape[0], hidden_size)

In [6]:
with torch.no_grad():
    model.inner.layers.i_i.inner.weight.data,\
    model.inner.layers.i_f.inner.weight.data,\
    model.inner.layers.i_g.inner.weight.data,\
    model.inner.layers.i_o.inner.weight.data = lstm.weight_ih_l0.chunk(4, 0)

    model.inner.layers.i_i.inner.bias.data,\
    model.inner.layers.i_f.inner.bias.data,\
    model.inner.layers.i_g.inner.bias.data,\
    model.inner.layers.i_o.inner.bias.data = lstm.bias_ih_l0.chunk(4)

    model.inner.layers.c0.layers.h_i.inner.weight.data,\
    model.inner.layers.c0.layers.h_f.inner.weight.data,\
    model.inner.layers.c0.layers.h_g.inner.weight.data,\
    model.inner.layers.c0.layers.h_o.inner.weight.data = lstm.weight_hh_l0.chunk(4, 0)

    model.inner.layers.c0.layers.h_i.inner.bias.data,\
    model.inner.layers.c0.layers.h_f.inner.bias.data,\
    model.inner.layers.c0.layers.h_g.inner.bias.data,\
    model.inner.layers.c0.layers.h_o.inner.bias.data = lstm.bias_hh_l0.chunk(4)

In [28]:
out1, _ = model(example)

In [29]:
out2, _ = lstm(example)

In [38]:
torch.isclose(out1, out2, atol=1e-7).all()

tensor(True)

In [37]:
out1-out2

tensor([[[ 0.0000e+00,  1.4901e-08,  1.4901e-08,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-1.4901e-08, -1.4901e-08,  0.0000e+00,  ...,  4.6566e-10,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00, -1.4901e-08,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         ...,
         [ 7.4506e-09,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  3.7253e-09],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]],

        [[ 0.0000e+00,  0.0000e+00,  5.9605e-08,  ...,  3.7253e-09,
           4.4703e-08,  3.7253e-09],
         [ 7.4506e-09,  0.0000e+00,  4.4703e-08,  ...,  2.9802e-08,
           3.7253e-09, -7.4506e-09],
         [-3.7253e-09, -1.4901e-08, -1.4901e-08,  ...,  1.6298e-09,
           2.2352e-08, -2.2352e-08],
         ...,
         [ 1.4901e-08,  0

In [7]:
for name, p in lstm.named_parameters():
    print(name, p.shape)

weight_ih_l0 torch.Size([128, 100])
weight_hh_l0 torch.Size([128, 32])
bias_ih_l0 torch.Size([128])
bias_hh_l0 torch.Size([128])


In [8]:
for name, p in model.named_parameters():
    print(name, p.shape)

inner.layers.c0.layers.i.mlist.0.inner.weight torch.Size([32, 132])
inner.layers.c0.layers.i.mlist.0.inner.bias torch.Size([32])
inner.layers.c0.layers.f.mlist.0.inner.weight torch.Size([32, 132])
inner.layers.c0.layers.f.mlist.0.inner.bias torch.Size([32])
inner.layers.c0.layers.o.mlist.0.inner.weight torch.Size([32, 132])
inner.layers.c0.layers.o.mlist.0.inner.bias torch.Size([32])
inner.layers.c0.layers.g.mlist.0.inner.weight torch.Size([32, 132])
inner.layers.c0.layers.g.mlist.0.inner.bias torch.Size([32])


In [10]:
torch.cat((lstm.weight_ih_l0, lstm.weight_hh_l0), dim=-1).chunk(4)[0]

torch.Size([32, 132])

In [27]:

with torch.no_grad():
    model.inner.layers.c0.layers.i.mlist[0].inner.weight.data,\
    model.inner.layers.c0.layers.f.mlist[0].inner.weight.data,\
    model.inner.layers.c0.layers.g.mlist[0].inner.weight.data,\
    model.inner.layers.c0.layers.o.mlist[0].inner.weight.data = torch.cat((lstm.weight_hh_l0, lstm.weight_ih_l0), dim=-1).chunk(4)

    model.inner.layers.c0.layers.i.mlist[0].inner.bias.data,\
    model.inner.layers.c0.layers.f.mlist[0].inner.bias.data,\
    model.inner.layers.c0.layers.g.mlist[0].inner.bias.data,\
    model.inner.layers.c0.layers.o.mlist[0].inner.bias.data = (lstm.bias_ih_l0 + lstm.bias_hh_l0).chunk(4)