In [2]:
import math
import torch as t
import lltm_cpp

In [3]:
class LLLTMFunction(t.autograd.Function):
    @staticmethod
    def forward(ctx, input, weights, bias, old_h, old_cell):
        outputs = lltm_cpp.forward(input, weights, bias, old_h, old_cell)
        new_h, new_cell = outputs[:2]
        variables = outputs[1:] + [weights]
        ctx.save_for_backward(*variables)
        return new_h, new_cell
    
    @staticmethod
    def backward(ctx, grad_h, grad_cell):
        outputs = lltm_cpp.backward(
            grad_h.contiguous(),
            grad_cell.contiguous(),
            *ctx.saved_tensors
        )
        d_old_h, d_input, d_weights, d_bias, d_old_cell = outputs
        return d_input, d_weights, d_bias, d_old_h, d_old_cell

In [4]:
class LLTM(t.nn.Module):
    def __init__(self, input_features, state_size):
        super().__init__()
        self.input_features = input_features
        self.state_size = state_size
        self.weights = t.nn.Parameter(
            t.empty(3 * state_size, input_features + state_size)
        )
        self.bias = t.nn.Parameter(t.empty(3 * state_size))
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.state_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, +stdv)

    def forward(self, input, state):
        return LLLTMFunction.apply(input, self.weights, self.bias, *state)

In [5]:
batch_size = 16
input_features = 32
state_size = 128

In [6]:
X = t.randn(batch_size, input_features)
h = t.randn(batch_size, state_size)
C = t.randn(batch_size, state_size)

In [7]:
rnn = LLTM(input_features, state_size)

In [8]:
new_h, new_C = rnn(X, (h, C))
(new_h.sum() + new_C.sum()).backward()