# LSTM
source: https://arxiv.org/abs/1909.09586

In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
#export
import sys

sys.path.insert(0, '/'.join(sys.path[0].split('/')[:-1] + ['scripts']))
from resnet import *

In [3]:
#export
class LSTMCell(nn.Module):
    def __init__(self, i, h):
        '''LSTM cell (naive implementation).
            i: input data dimension
            h: number of hidden units in the lstm cell
        '''
        super().__init__()
        self.i, self.h = i, h
        self.Ui = nn.Parameter(init_2d_weight((i, h)))
        self.Uf = nn.Parameter(init_2d_weight((i, h)))
        self.Uo = nn.Parameter(init_2d_weight((i, h)))
        self.Ug = nn.Parameter(init_2d_weight((i, h)))
        self.Wi = nn.Parameter(init_2d_weight((h, h)))
        self.Wf = nn.Parameter(init_2d_weight((h, h)))
        self.Wo = nn.Parameter(init_2d_weight((h, h)))
        self.Wg = nn.Parameter(init_2d_weight((h, h)))

    def forward(self, x, state):
        h, c = state
        
        i   = (x @ self.Ui + h @ self.Wi).sigmoid()
        f   = (x @ self.Uf + h @ self.Wf).sigmoid()
        o   = (x @ self.Uo + h @ self.Wo).sigmoid()
        c_t = (x @ self.Ug + h @ self.Wg).tanh()
        
        c = (f*c + i*c_t).sigmoid()
        h = c.tanh() * o
        return h, (h, c)
    
    def __repr__(self):
        return f'LSTM({self.i}, {self.h})'

In [4]:
#export
class LSTMLayer(nn.Module):
    def __init__(self, i, h):
        '''Wrapper for passing different input timestamps into LSTM cell
            i: input data dimension
            h: number of hidden units in the lstm cell
        '''
        super().__init__()
        self.cell = LSTMCell(i, h)

    def forward(self, inps, state):
        outputs = []
        for inp in inps.unbind(1):
            out, state = self.cell(inp, state)
            outputs.append(out)
        return torch.stack(outputs, 1), state
    
    def __repr__(self): return f'{self.cell}'

In [5]:
#export
class FastLSTMCell(nn.Module):
    def __init__(self, i, h):
        '''LSTM cell (fast implementation using linear layers).
            i: input data dimension
            h: number of hidden units in the lstm cell
        '''
        super().__init__()
        self.i, self.h = i, h
        # also adds a small bias
        self.x_gates = nn.Linear(i, 4*h)
        self.h_gates = nn.Linear(i, 4*h)
    
    def forward(self, x, state):
        h, c = state
        gates = (self.x_gates(x) + self.h_gates(h)).chunk(4, 1)
        
        i   = gates[0].sigmoid()
        f   = gates[1].sigmoid()
        o   = gates[2].sigmoid()
        c_t = gates[3].tanh()
        
        c = f*c + i*c_t
        h = o * c.tanh()
        return h, (h, c)

    def __repr__(self): return f'LSTM({self.i}, {self.h})'

In [6]:
#export
class FastLSTMLayer(nn.Module):
    def __init__(self, i, h):
        '''Wrapper for passing different input timestamps into FastLSTM cell
            i: input data dimension
            h: number of hidden units in the lstm cell
        '''        
        super().__init__()
        self.cell = FastLSTMCell(i, h)

    def forward(self, inps, state):
        outputs = []
        for inp in inps.unbind(1):
            out, state = self.cell(inp, state)
            outputs.append(out)
        return torch.stack(outputs, 1), state
    
    def __repr__(self): return f'{self.cell}'

# Tests

In [7]:
fastlstm = FastLSTMLayer(1024, 1024)
lstm = LSTMLayer(1024, 1024)

In [8]:
x = torch.randn(128, 100, 1024)
h = torch.zeros(128, 1024)
c = torch.zeros(128, 1024)

y, (h_out, c_out) = fastlstm(x, (h, c))
print('FastLSTM shapes:')
print(y.shape)
print(h_out.shape)
print(c_out.shape)
y, (h_out, c_out) = lstm(x, (h, c))
print('\nLSTM shapes:')
print(y.shape)
print(h_out.shape)
print(c_out.shape)

FastLSTM shapes:
torch.Size([128, 100, 1024])
torch.Size([128, 1024])
torch.Size([128, 1024])

LSTM shapes:
torch.Size([128, 100, 1024])
torch.Size([128, 1024])
torch.Size([128, 1024])


In [9]:
%time _ = fastlstm(x, (h, c))
%time _ = lstm(x, (h, c))

CPU times: user 6.09 s, sys: 91.1 ms, total: 6.18 s
Wall time: 1.04 s
CPU times: user 6.36 s, sys: 154 ms, total: 6.51 s
Wall time: 1.09 s


In [10]:
print(lstm)
print(fastlstm)

LSTM(1024, 1024)
LSTM(1024, 1024)
