In [1]:
import torch as t
import math
import torch.nn.functional as F

In [3]:
class LLTM(t.nn.Module):
    def __init__(self, inputs, state_size):
        super().__init__()
        self.inputs = inputs
        self.state_size = state_size
        self.weights = t.nn.Parameter(t.empty(3 * state_size, inputs+state_size))
        self.bias = t.nn.Parameter(t.empty(3 * state_size))

    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.state_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, +stdv)

    def forward(self, input, state):
        old_h, old_cell = state
        X = t.cat([old_h, input], dim=1)

        gate_weights = F.linear(X, self.weights, self.bias)
        gates = gate_weights.chunk(3, dim=1)

        input_gate = t.sigmoid(gates[0])
        output_gate = t.sigmoid(gates[1])
        candidate_cell = F.elu(gates[2])

        new_cell = old_cell + candidate_cell * input_gate
        new_h = t.tanh(new_cell) * output_gate

        return new_h, new_cell

In [4]:
batch_size = 16
input_features = 32
state_size = 128

In [6]:
X = t.randn(batch_size, input_features)
h = t.randn(batch_size, state_size)
C = t.randn(batch_size, state_size)

In [9]:
rnn = LLTM(input_features, state_size)

In [10]:
new_h, new_C = rnn(X, (h, C))
(new_h.sum() + new_C.sum()).backward()

KeyboardInterrupt: 