# GRU
source: https://arxiv.org/abs/1412.3555

In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
#export
import sys
from os.path import join

sys.path.insert(0, '/'.join(sys.path[0].split('/')[:-1] + ['scripts']))
from lstm import *

In [3]:
#export
class GRUCell(nn.Module):
    def __init__(self, i, h):
        '''GRU cell.
            i: input data dimension
            h: number of hidden units in the gru cell
        '''
        super().__init__()
        self.i, self.h = i, h
        self.Wz = nn.Parameter(init_2d_weight((i, h)))
        self.Wr = nn.Parameter(init_2d_weight((i, h)))
        self.Wh = nn.Parameter(init_2d_weight((i, h)))
        self.Uz = nn.Parameter(init_2d_weight((h, h)))
        self.Ur = nn.Parameter(init_2d_weight((h, h)))
        self.Uh = nn.Parameter(init_2d_weight((h, h)))
        self.bz = nn.Parameter(torch.zeros(h))
        self.br = nn.Parameter(torch.zeros(h))
        
    def forward(self, x, h):
        z =   (x @ self.Wz + h @ self.Uz).sigmoid()
        r =   (x @ self.Wr + h @ self.Ur).sigmoid()
        h_t = (x @ self.Wh + r * h @ self.Uh).tanh()
        h = weighted_sum(h, h_t, z)
        return h

In [4]:
#export
class GRULayer(nn.Module):
    def __init__(self, i, h):
        '''Wrapper for passing different input timestamps into GRU cell
            i: input data dimension
            h: number of hidden units in the lstm cell
        '''
        super().__init__()
        self.cell = GRUCell(i, h)
        
    def forward(self, inps, h):
        outputs = []
        for inp in inps.unbind(1):
            h = self.cell(inp, h)
            outputs.append(h)
        return torch.stack(outputs, 1), h

# Tests

In [5]:
gru = GRULayer(1024, 1024)

In [6]:
x = torch.randn(128, 100, 1024)
h = torch.zeros(128, 1024)

y, h_out = gru(x, h)
print('GRU shapes:')
print(y.shape)
print(h_out.shape)

GRU shapes:
torch.Size([128, 100, 1024])
torch.Size([128, 1024])
