In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from gru_v import *
from test_utils import *

# GRU comparision

## Vanilla GRU
<center>
<img src="assets/GRUV.png" width="80%">
</center>

In [5]:
# compare nn.GRU , hand coded GRU -> Cells
gru_tests(nn.GRU,GRU_Cell)

B=64 , T=200 , input dimension=128, hidden dimemnsion=256
[GRU] Time taken: 0.046936 seconds
[recoded GRU_Cell] Time taken: 0.051591 seconds
input :  torch.Size([64, 200, 128])
output :  torch.Size([64, 200, 256])
Shapes matched.


## minGRU: A Minimal GRU

<center>
<img src="assets/GRUV2.png" width="80%" />
</center>

In [6]:
# compare nn.GRU , MinGRU -> Cells
gru_tests(nn.GRU,minGRU_Cell)

B=64 , T=200 , input dimension=128, hidden dimemnsion=256
[GRU] Time taken: 0.048377 seconds
[recoded minGRU_Cell] Time taken: 0.021233 seconds
input :  torch.Size([64, 200, 128])
output :  torch.Size([64, 200, 256])
Shapes matched.


## Log-space minGRU

In [7]:
# minGRU_log

def g(x):
    return torch.where(x >= 0, x + 0.5, x.sigmoid())

class log_minGRU_Cell(nn.Module):
    def __init__(self,input_dim,hidden_dim):
        super(log_minGRU_Cell, self).__init__()
        
        self.linear_z = nn.Linear(input_dim,hidden_dim)
        self.linear_h = nn.Linear(input_dim,hidden_dim)
        
    def forward(self,x_t,h_prev):
        
        z_t = torch.sigmoid(self.linear_z(x_t))
        h_tilda_t = g(self.linear_h(x_t))
        h_t = (1 - z_t) * h_prev + z_t * h_tilda_t
        return h_t

In [17]:
# compare nn.GRU , log-space minGRU -> Cells
gru_tests(nn.GRU,log_minGRU_Cell)

B=64 , T=200 , input dimension=128, hidden dimemnsion=256
[GRU] Time taken: 0.045435 seconds
[recoded log_minGRU_Cell] Time taken: 0.031606 seconds
input :  torch.Size([64, 200, 128])
output :  torch.Size([64, 200, 256])
Shapes matched.


## Prallel minGRU

In [12]:
# parallel 
def log_g(x):
    return torch.where(x >= 0,(x + 0.5).log(),-F.softplus(-x))


def parallel_scan_log(log_coeffs, log_values):
    # log_coeffs: (batch_size, seq_len, input_size)
    # log_values: (batch_size, seq_len + 1, input_size)
    a_star = F.pad(torch.cumsum(log_coeffs, dim=1), (0, 0, 1, 0))
    log_h0_plus_b_star = torch.logcumsumexp(
    log_values - a_star, dim=1)
    log_h = a_star + log_h0_plus_b_star
    return torch.exp(log_h)[:, 1:]

# minGRU_log parallel
class parallel_log_minGRU(nn.Module):
    def __init__(self,input_dim,hidden_dim):
        super(parallel_log_minGRU, self).__init__()
        
        self.linear_z = nn.Linear(input_dim,hidden_dim)
        self.linear_h = nn.Linear(input_dim,hidden_dim)    
    
    def forward(self,x,h_prev):
        
        log_z = -F.softplus(-self.linear_z(x))
        log_coeffs = -F.softplus(self.linear_z(x))
        log_h = log_g(h_prev).unsqueeze(1)
        log_h_tilda = log_g(self.linear_h(x))
        h_t = parallel_scan_log(log_coeffs,torch.cat([log_h,log_z + log_h_tilda],dim=1))
        return h_t

In [19]:
B, T, input_dim, hidden_dim = 64,200, 128,256

# Random sequence input (batch_size, seq_len, input_dim)
x = torch.randn(B,T, input_dim)
# Initial hidden state (batch_size, hidden_dim)
h0 = torch.randn(B, hidden_dim)

parallel_minGRU = parallel_log_minGRU(input_dim,hidden_dim);parallel_minGRU

parallel_log_minGRU(
  (linear_z): Linear(in_features=128, out_features=256, bias=True)
  (linear_h): Linear(in_features=128, out_features=256, bias=True)
)

In [30]:
v_gru = nn.GRUCell(input_dim,hidden_dim); v_gru

GRUCell(128, 256)

In [21]:
parallel_minGRU(x,h0).shape

torch.Size([64, 200, 256])

In [48]:
h = h0
outputs = []
for t in range(T):
    h = v_gru(x[:, t, :], h)
    outputs.append(h.unsqueeze(1))
out = torch.cat(outputs, dim=1)
