In [1]:
from pytorch_toolz.functools import Parallel, Reduce, Sequential
from pytorch_toolz.operator import Apply
from pytorch_toolz.itertools import Accumulate
from torch import nn
from torch.nn.parameter import Parameter
from itertools import islice
import torch

**LSTMCell.forward** takes arguments in the wrong order; to use it with reduce arguments need to be reversed:

In [2]:
lstm_swap_args = Sequential(Apply(lambda *args: reversed(args)), nn.LSTMCell(8, 8), unpack=True)

A very simple RNN: reduce the incoming sequence with a single cell. To make sure that the initial cell/hidden states are learnable, pass them as nn.Parameter-s

In [3]:
rnn = Accumulate(
    lstm_swap_args,
    nn.ParameterList([
        Parameter(torch.zeros(8)),
        Parameter(torch.zeros(8))
    ])
)

Test with a single training step

In [4]:
opt = torch.optim.Adam(rnn.parameters())

seq = torch.randn(10, 8)
targ = torch.randn(10, 8)

h_all, c_all = zip(*rnn(seq))
output = torch.stack(h_all[1:])
loss = nn.functional.mse_loss(output, targ)
loss.backward()
opt.step()

Check that the initial cell and hidden states have changed

In [5]:
rnn.initial[0]

Parameter containing:
tensor([-0.0010,  0.0010,  0.0010, -0.0010, -0.0010,  0.0010, -0.0010,  0.0010],
       requires_grad=True)