In [None]:
import torch

device = torch.device("cuda")

## LSTM Definitions

### Vanilla

In [None]:
class LSTMCell(torch.nn.Module):
    def __init__(self, input_size, hidden_size):
        super(LSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.weight_ih = torch.nn.Parameter(torch.randn(4 * hidden_size, input_size))
        self.weight_hh = torch.nn.Parameter(torch.randn(4 * hidden_size, hidden_size))
        self.bias_ih = torch.nn.Parameter(torch.randn(4 * hidden_size))
        self.bias_hh = torch.nn.Parameter(torch.randn(4 * hidden_size))

    def forward(self, input, state):
        # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
        hx, cx = state
        gates = (torch.mm(input, self.weight_ih.t()) + self.bias_ih +
                 torch.mm(hx, self.weight_hh.t()) + self.bias_hh)
        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)

        ingate = torch.sigmoid(ingate)
        forgetgate = torch.sigmoid(forgetgate)
        cellgate = torch.tanh(cellgate)
        outgate = torch.sigmoid(outgate)

        cy = (forgetgate * cx) + (ingate * cellgate)
        hy = outgate * torch.tanh(cy)

        return hy, (hy, cy)
    
    
class LSTM(torch.nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.cell = LSTMCell(input_size, hidden_size)
        
    def forward(self, x):
        state = (
            torch.zeros((x.size(1), self.cell.hidden_size), device=x.device), 
            torch.zeros((x.size(1), self.cell.hidden_size), device=x.device)
        )
        xs = x.unbind(0)
        y = []
        for t in range(len(xs)):
            hy, state = self.cell(xs[t], state)
            y += [hy]
        return torch.stack(y), state

### Hard

In [None]:
class HardLSTMCell(torch.nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.weight_ih = torch.nn.Parameter(torch.randn(4 * hidden_size, input_size))
        self.weight_hh = torch.nn.Parameter(torch.randn(4 * hidden_size, hidden_size))
        self.bias_ih = torch.nn.Parameter(torch.randn(4 * hidden_size))
        self.bias_hh = torch.nn.Parameter(torch.randn(4 * hidden_size))

    def forward(self, input, state):
        # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
        hx, cx = state
        gates = (torch.mm(input, self.weight_ih.t()) + self.bias_ih +
                 torch.mm(hx, self.weight_hh.t()) + self.bias_hh)
        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)

        ingate = torch.clamp(0.2*ingate + 0.5, min=0., max=1.)
        forgetgate = torch.clamp(0.2*forgetgate + 0.5, min=0., max=1.)
        cellgate = torch.nn.functional.hardtanh_(cellgate)
        outgate = torch.clamp(0.2*outgate + 0.5, min=0., max=1.)

        cy = (forgetgate * cx) + (ingate * cellgate)
        hy = outgate * torch.nn.functional.hardtanh_(cy)

        return hy, (hy, cy)
    
    
class HardLSTM(torch.nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.cell = HardLSTMCell(input_size, hidden_size)
        
    def forward(self, x):
        state = (
            torch.zeros((x.size(1), self.cell.hidden_size), device=x.device), 
            torch.zeros((x.size(1), self.cell.hidden_size), device=x.device)
        )
        xs = x.unbind(0)
        y = []
        for t in range(len(xs)):
            hy, state = self.cell(xs[t], state)
            y += [hy]
        return torch.stack(y), state

## Profiling

In [None]:
batch = 128
input_size = 1024
hidden_size = 1024

seq_len = 256

### Vanilla

#### LSTMCell

In [None]:
cell_hand = LSTMCell(input_size, hidden_size).to(device)
cell_jit = torch.jit.script(cell_hand)
cell = torch.nn.LSTMCell(input_size, hidden_size).to(device)

In [None]:
x = torch.empty((batch, input_size)).normal_().to(device)
state = (torch.empty((batch, hidden_size)).normal_().to(device), torch.empty((batch, hidden_size)).normal_().to(device))

In [None]:
cell_jit.graph;
cell_jit.graph_for(x, state);

In [None]:
%timeit -n 100 cell_hand(x, state)

In [None]:
%timeit -n 100 cell_jit(x, state)

In [None]:
%timeit -n 100 cell(x, state)

#### LSTM

In [None]:
lstm_hand = LSTM(input_size, hidden_size).to(device)
lstm_jit = torch.jit.script(lstm_hand)
lstm = torch.nn.LSTM(input_size, hidden_size).to(device)

In [None]:
x = torch.empty((seq_len, batch, hidden_size)).normal_().to(device)

In [None]:
%timeit -n 100 lstm_hand(x)

In [None]:
%timeit -n 100 lstm_jit(x)

In [None]:
%timeit -n 100 lstm(x)

### Hard

#### HardLSTMCell

In [None]:
cell_hand = HardLSTMCell(input_size, hidden_size).to(device)
cell_jit = torch.jit.script(cell_hand)

In [None]:
x = torch.empty((batch, input_size)).normal_().to(device)
state = (torch.empty((batch, hidden_size)).normal_().to(device), torch.empty((batch, hidden_size)).normal_().to(device))

In [None]:
%timeit -n 100 cell_hand(x, state)

In [None]:
%timeit -n 100 cell_jit(x, state)

#### HardLSTM

In [None]:
lstm_hand = HardLSTM(input_size, hidden_size).to(device)
lstm_jit = torch.jit.script(lstm_hand)

In [None]:
x = torch.empty((seq_len, batch, hidden_size)).normal_().to(device)

In [None]:
%timeit -n 100 lstm_hand(x)

In [None]:
%timeit -n 100 lstm_jit(x)