In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
#export
import sys

sys.path.insert(0, '/'.join(sys.path[0].split('/')[:-1] + ['scripts']))
from resnet import *

In [3]:
#export
def init_2d_weight(shape, leak=1.):
    # default to he init
    assert len(shape) == 2
    fan = shape[0]
    gain_sq = 2.0 / (1 + leak**2)
    return torch.randn(*shape) * (gain_sq / fan)**0.5

In [4]:
#export
class LSTMCell(nn.Module):
    def __init__(self, i_dim, h_dim):
        super().__init__()
        self.i_dim, self.h_dim = i_dim, h_dim
        self.Ui = nn.Parameter(init_2d_weight((i_dim, h_dim)))
        self.Uf = nn.Parameter(init_2d_weight((i_dim, h_dim)))
        self.Uo = nn.Parameter(init_2d_weight((i_dim, h_dim)))
        self.Ug = nn.Parameter(init_2d_weight((i_dim, h_dim)))
        self.Wi = nn.Parameter(init_2d_weight((h_dim, h_dim)))
        self.Wf = nn.Parameter(init_2d_weight((h_dim, h_dim)))
        self.Wo = nn.Parameter(init_2d_weight((h_dim, h_dim)))
        self.Wg = nn.Parameter(init_2d_weight((h_dim, h_dim)))

    def forward(self, x, state):
        h, c = state
        
        i   = (x @ self.Ui + h @ self.Wi).sigmoid()
        f   = (x @ self.Uf + h @ self.Wf).sigmoid()
        o   = (x @ self.Uo + h @ self.Wo).sigmoid()
        c_t = (x @ self.Ug + h @ self.Wg).tanh()
        
        c = (f*c + i*c_t).sigmoid()
        h = c.tanh() * o
        return h, (h, c)
    
    def __repr__(self):
        return f'LSTM({self.i_dim}, {self.h_dim})'

In [5]:
#export
class LSTMLayer(nn.Module):
    def __init__(self, i_dim, h_dim):
        super().__init__()
        self.cell = LSTMCell(i_dim, h_dim)

    def forward(self, inps, state):
        outputs = []
        for inp in inps.unbind(1):
            out, state = self.cell(inp, state)
            outputs.append(out)
        return torch.stack(outputs, 1), state
    
    def __repr__(self):
        return f'{self.cell}'

In [6]:
#export
class FastLSTMCell(nn.Module):
    def __init__(self, i_dim, h_dim):
        super().__init__()
        self.i_dim, self.h_dim = i_dim, h_dim
        # also adds a small bias
        self.x_gates = nn.Linear(i_dim, 4*h_dim)
        self.h_gates = nn.Linear(i_dim, 4*h_dim)
    
    def forward(self, x, state):
        h, c = state
        gates = (self.x_gates(x) + self.h_gates(h)).chunk(4, 1)
        
        i   = gates[0].sigmoid()
        f   = gates[1].sigmoid()
        o   = gates[2].sigmoid()
        c_t = gates[3].tanh()
        
        c = f*c + i*c_t
        h = o * c.tanh()
        return h, (h, c)

    def __repr__(self):
        return f'LSTM({self.i_dim}, {self.h_dim})'

In [7]:
#export
class FastLSTMLayer(nn.Module):
    def __init__(self, i_dim, h_dim):
        super().__init__()
        self.cell = FastLSTMCell(i_dim, h_dim)

    def forward(self, inps, state):
        outputs = []
        for inp in inps.unbind(1):
            out, state = self.cell(inp, state)
            outputs.append(out)
        return torch.stack(outputs, 1), state
    
    def __repr__(self):
        return f'{self.cell}'

In [8]:
fastlstm = FastLSTMLayer(1024, 1024)
lstm = LSTMLayer(1024, 1024)

In [9]:
x = torch.randn(128, 100, 1024)
h = torch.zeros(128, 1024)
c = torch.zeros(128, 1024)

y, (h_out, c_out) = fastlstm(x, (h, c))
print('FastLSTM shapes:')
print(y.shape)
print(h_out.shape)
print(c_out.shape)
y, (h_out, c_out) = lstm(x, (h, c))
print('\nLSTM shapes:')
print(y.shape)
print(h_out.shape)
print(c_out.shape)

FastLSTM shapes:
torch.Size([128, 100, 1024])
torch.Size([128, 1024])
torch.Size([128, 1024])

LSTM shapes:
torch.Size([128, 100, 1024])
torch.Size([128, 1024])
torch.Size([128, 1024])


In [10]:
%time _ = fastlstm(x, (h, c))
%time _ = lstm(x, (h, c))

CPU times: user 5.34 s, sys: 61 ms, total: 5.4 s
Wall time: 902 ms
CPU times: user 5.46 s, sys: 79.1 ms, total: 5.53 s
Wall time: 925 ms


In [11]:
print(lstm)
print(fastlstm)

LSTM(1024, 1024)
LSTM(1024, 1024)
