In [10]:
import torch
from torch import nn
import sys
sys.path.append('library/src')
import rnnbuilder as rb
from rnnbuilder import custom

In [11]:
from rnnbuilder.nn import ReLU, Conv2d, Linear, Tanh, Sigmoid


from functools import reduce
import operator
def prod(iterable):
    return reduce(operator.mul, iterable, 1)

class HadamardModule(rb.custom.CustomModule):
    def get_out_shape(self, in_shapes):
        return in_shapes[0]

    def forward(self, inputs, _):
        return prod(inputs), ()

Hadamard = custom.register_recurrent(module_class=HadamardModule, flatten_input=False, single_step=False, unroll_full_state=False)




In [12]:
hidden_size = 32
n = rb.Network()

n.output = rb.Placeholder()
n.h_and_i = n.input.stack(n.output)

n.i = n.h_and_i.apply(Linear(hidden_size), Sigmoid())
n.f = n.h_and_i.apply(Linear(hidden_size), Sigmoid())
n.o = n.h_and_i.apply(Linear(hidden_size), Sigmoid())
n.g = n.h_and_i.apply(Linear(hidden_size), Tanh())

n.c = rb.Placeholder()
n.c_1 = n.f.append(n.c).apply(Hadamard())
n.c_2 = n.i.append(n.g).apply(Hadamard())

n.c = n.c_1.sum(n.c_2)
n.tan_c = n.c.apply(Tanh())
n.output = n.o.append(n.tan_c).apply(Hadamard())

In [31]:
seq = 8
batch = 32
inp_shape = (100,)
example = torch.rand((seq,batch)+inp_shape)

In [32]:
model = n.make_model(inp_shape)
lstm = torch.nn.LSTM(inp_shape[0], hidden_size)

In [33]:
with torch.no_grad():
    model.inner.layers.c0.layers.i.mlist[0].inner.weight.data,\
    model.inner.layers.c0.layers.f.mlist[0].inner.weight.data,\
    model.inner.layers.c0.layers.g.mlist[0].inner.weight.data,\
    model.inner.layers.c0.layers.o.mlist[0].inner.weight.data = torch.cat((lstm.weight_hh_l0, lstm.weight_ih_l0), dim=-1).chunk(4)

    model.inner.layers.c0.layers.i.mlist[0].inner.bias.data,\
    model.inner.layers.c0.layers.f.mlist[0].inner.bias.data,\
    model.inner.layers.c0.layers.g.mlist[0].inner.bias.data,\
    model.inner.layers.c0.layers.o.mlist[0].inner.bias.data = (lstm.bias_ih_l0 + lstm.bias_hh_l0).chunk(4)

In [34]:
out1, _ = model(example)

In [35]:
out2, _ = lstm(example)

In [36]:
torch.isclose(out1, out2, atol=1e-7).all()

tensor(False)

In [9]:
model.inner.layers.c0

NestedNetworkModule(
  (layers): ModuleDict(
    (i): SequentialModule(
      (mlist): ModuleList(
        (0): StatelessWrapper(
          (inner): Linear(in_features=132, out_features=32, bias=True)
        )
        (1): StatelessWrapper(
          (inner): Sigmoid()
        )
      )
    )
    (f): SequentialModule(
      (mlist): ModuleList(
        (0): StatelessWrapper(
          (inner): Linear(in_features=132, out_features=32, bias=True)
        )
        (1): StatelessWrapper(
          (inner): Sigmoid()
        )
      )
    )
    (o): SequentialModule(
      (mlist): ModuleList(
        (0): StatelessWrapper(
          (inner): Linear(in_features=132, out_features=32, bias=True)
        )
        (1): StatelessWrapper(
          (inner): Sigmoid()
        )
      )
    )
    (g): SequentialModule(
      (mlist): ModuleList(
        (0): StatelessWrapper(
          (inner): Linear(in_features=132, out_features=32, bias=True)
        )
        (1): StatelessWrapper(
        

In [7]:
for name, p in lstm.named_parameters():
    print(name, p.shape)

weight_ih_l0 torch.Size([128, 100])
weight_hh_l0 torch.Size([128, 32])
bias_ih_l0 torch.Size([128])
bias_hh_l0 torch.Size([128])


In [8]:
for name, p in model.named_parameters():
    print(name, p.shape)

inner.layers.c0.layers.i.mlist.0.inner.weight torch.Size([32, 132])
inner.layers.c0.layers.i.mlist.0.inner.bias torch.Size([32])
inner.layers.c0.layers.f.mlist.0.inner.weight torch.Size([32, 132])
inner.layers.c0.layers.f.mlist.0.inner.bias torch.Size([32])
inner.layers.c0.layers.o.mlist.0.inner.weight torch.Size([32, 132])
inner.layers.c0.layers.o.mlist.0.inner.bias torch.Size([32])
inner.layers.c0.layers.g.mlist.0.inner.weight torch.Size([32, 132])
inner.layers.c0.layers.g.mlist.0.inner.bias torch.Size([32])


In [10]:
torch.cat((lstm.weight_ih_l0, lstm.weight_hh_l0), dim=-1).chunk(4)[0]

torch.Size([32, 132])