In [None]:
# type: ignore
import torch
import torch.nn as nn
import torch.nn.functional as F

from test_models import *

# Vanilla Version

## GRU
<center>
<img src="assets/GRUV.png" width="80%">
</center>

In [2]:
class GRU_Cell(nn.Module):
    def __init__(self,input_dim,hidden_dim):
        super(GRU_Cell, self).__init__()
        
        self.linear_r = nn.Linear(input_dim + hidden_dim, hidden_dim) # reset gate 
 
        self.linear_z = nn.Linear(input_dim + hidden_dim, hidden_dim) # update gate 
                
        self.linear_h = nn.Linear(input_dim + hidden_dim, hidden_dim) # hidden candidate state
        
        # input_dim + hidden_dim , because it takes the combine (1,2)
        
    def forward(self,x_t, h_prev):        

        combined = torch.cat((x_t,h_prev),dim=1)
        
        r_t = torch.sigmoid(self.linear_r(combined))
        
        z_t = torch.sigmoid(self.linear_z(combined))
        
        h_combine = torch.cat((x_t,r_t * h_prev),dim=1)
        h_candidate_t = torch.tanh(self.linear_h(h_combine))
        
        h_t = (1 - z_t) * h_prev + z_t * h_candidate_t
        return h_t

In [3]:
# compare nn.GRU , hand coded GRU -> Cells
gru_tests(nn.GRU,GRU_Cell)

B=64 , T=200 , input dimension=128, hidden dimemnsion=256
[Standard GRU] Time taken: 0.099236 seconds
[recoded GRU] Time taken: 0.074645 seconds
input :  torch.Size([64, 200, 128])
output :  torch.Size([64, 200, 256])
Shapes matched.


## LSTM

<center>
<img src="assets/LSTMV.png" width="80%">
</center>

In [4]:
class LSTM_Cell(nn.Module):
    def __init__(self,input_dim,hidden_dim):
        super(LSTM_Cell,self).__init__()
        self.linear_f = nn.Linear(input_dim + hidden_dim , hidden_dim)
        self.linear_i = nn.Linear(input_dim + hidden_dim , hidden_dim)
        self.linear_c_tilda = nn.Linear(input_dim + hidden_dim , hidden_dim)
        self.linear_o = nn.Linear(input_dim + hidden_dim , hidden_dim)
        
    def forward(self,x_t,h_prev,c_prev):
        combined = torch.cat((x_t,h_prev),dim=1)
        
        f_t = torch.sigmoid(self.linear_f(combined))
        i_t = torch.sigmoid(self.linear_i(combined))   
        c_tilda_t = torch.sigmoid(self.linear_c_tilda(combined))
        o_t = torch.sigmoid(self.linear_o(combined))
        
        c_t = f_t * (c_prev - 1) + i_t * c_tilda_t
        h_t = o_t * torch.tanh(c_t)
        
        return h_t,c_t

In [5]:
# compare nn.LSTM , hand coded LSTM -> Cells
lstm_tests(nn.LSTM,LSTM_Cell)

B=64 , T=200 , input dimension=128, hidden dimemnsion=256
[Standard LSTM] Time taken: 0.097613 seconds
[recoded LSTM] Time taken: 0.101432 seconds
input :  torch.Size([64, 200, 128])
output :  torch.Size([64, 200, 256])
Shapes matched.


# Pseudocode: Min Version

## minGRU: A Minimal GRU

<center>
<img src="assets/GRUV2.png" width="80%" />
</center>

In [6]:
# minGRU
class minGRU_Cell(nn.Module):
    def __init__(self,input_dim,hidden_dim):
        super(minGRU_Cell, self).__init__()
        
        self.linear_z = nn.Linear(input_dim,hidden_dim)
        self.linear_h = nn.Linear(input_dim,hidden_dim)
        
    def forward(self,x_t,h_prev):
        
        z_t = torch.sigmoid(self.linear_z(x_t))
        h_tilda_t = self.linear_h(x_t)
        h_t = (1 - z_t) * h_prev + z_t * h_tilda_t
        return h_t


In [7]:
# compare nn.GRU , MinGRU -> Cells
gru_tests(nn.GRU,minGRU_Cell)

B=64 , T=200 , input dimension=128, hidden dimemnsion=256
[Standard GRU] Time taken: 0.067931 seconds
[recoded GRU] Time taken: 0.030765 seconds
input :  torch.Size([64, 200, 128])
output :  torch.Size([64, 200, 256])
Shapes matched.


## Implementation with log

In [8]:
def g(x):
    return torch.where(x >= 0, x + 0.5, x.sigmoid())

In [9]:
# minGRU_log
class log_minGRU_Cell(nn.Module):
    def __init__(self,input_dim,hidden_dim):
        super(log_minGRU_Cell, self).__init__()
        
        self.linear_z = nn.Linear(input_dim,hidden_dim)
        self.linear_h = nn.Linear(input_dim,hidden_dim)
        
    def forward(self,x_t,h_prev):
        
        z_t = torch.sigmoid(self.linear_z(x_t))
        h_tilda_t = g(self.linear_h(x_t))
        h_t = (1 - z_t) * h_prev + z_t * h_tilda_t
        return h_t

In [10]:
# compare nn.GRU , log-space minGRU -> Cells
gru_tests(nn.GRU,log_minGRU_Cell)

B=64 , T=200 , input dimension=128, hidden dimemnsion=256
[Standard GRU] Time taken: 0.057156 seconds
[recoded GRU] Time taken: 0.053628 seconds
input :  torch.Size([64, 200, 128])
output :  torch.Size([64, 200, 256])
Shapes matched.


## Parallel implementation 

In [None]:
# parallel 
def log_g(x):
    return torch.where(x >= 0,(x + 0.5).log(),-F.softplus(-x))


def parallel_scan_log(log_coeffs, log_values):
    # log_coeffs: (batch_size, seq_len, input_size)
    # log_values: (batch_size, seq_len + 1, input_size)
    a_star = F.pad(torch.cumsum(log_coeffs, dim=1), (0, 0, 1, 0))
    log_h0_plus_b_star = torch.logcumsumexp(
    log_values - a_star, dim=1)
    log_h = a_star + log_h0_plus_b_star
    return torch.exp(log_h)[:, 1:]

# minGRU_log parallel
class parallel_log_minGRU(nn.Module):
    def __init__(self,input_dim,hidden_dim):
        super(parallel_log_minGRU, self).__init__()
        
        self.linear_z = nn.Linear(input_dim,hidden_dim)
        self.linear_h = nn.Linear(input_dim,hidden_dim)    
    
    def forward(self,x,h_prev):
        
        log_z = -F.softplus(-self.linear_z(x))
        log_coeffs = -F.softplus(self.linear_z(x))
        log_h = log_g(h_prev).unsqueeze(1)
        log_h_tilda = log_g(self.linear_h(x))
        h_t = parallel_scan_log(log_coeffs,torch.cat([log_h,log_z + log_h_tilda],dim=1))
        return h_t


## minLSTM: A Minimal LSTM

<center>
<img src="assets/LSTMV2.png" width="80%" />
</center>

In [16]:
# minLSTM
class minLSTM_Cell(nn.Module):
    def __init__(self,input_dim,output_dim):
        super(minLSTM_Cell,self).__init__()
        
        self.linear_f = nn.Linear(input_dim,output_dim)
        self.linear_i = nn.Linear(input_dim,output_dim)
        self.linear_h = nn.Linear(input_dim,output_dim)
        
    def forward(self,x_t,h_0):
        f_t = torch.sigmoid(self.linear_f(x_t))
        i_t = torch.sigmoid(self.linear_i(x_t))
        
        h_tilda_t = self.linear_h(x_t)
        
        f_prime_t = f_t / (f_t + i_t)
        i_prime_t = i_t / (f_t + i_t)
        
        h_t = f_prime_t * h_0 + i_prime_t * h_tilda_t
        
        return h_t


## log implementation

In [30]:
# log_space minLSTM
class log_minLSTM_Cell(nn.Module):
    def __init__(self,input_dim,output_dim):
        super(log_minLSTM_Cell,self).__init__()
        
        self.linear_f = nn.Linear(input_dim,output_dim)
        self.linear_i = nn.Linear(input_dim,output_dim)
        self.linear_h = nn.Linear(input_dim,output_dim)
        
    def forward(self,x_t,h_0):
        f_t = torch.sigmoid(self.linear_f(x_t))
        i_t = torch.sigmoid(self.linear_i(x_t))
        
        h_tilda_t = g(self.linear_h(x_t))
        
        f_prime_t = f_t / (f_t + i_t)
        i_prime_t = i_t / (f_t + i_t)
        
        h_t = f_prime_t * h_0 + i_prime_t * h_tilda_t
        
        return h_t


## parallel implementation

In [31]:
# log_space minLSTM
class parallel_log_minLSTM(nn.Module):
    def __init__(self,input_dim,output_dim):
        super(parallel_log_minLSTM,self).__init__()
        
        self.linear_f = nn.Linear(input_dim,output_dim)
        self.linear_i = nn.Linear(input_dim,output_dim)
        self.linear_h = nn.Linear(input_dim,output_dim)
        
    def forward(self,x,h_0):
        diff = F.softplus(-self.linear_f(x)) / -F.softplus(-self.linear_i(x))
        log_f = -F.softplus(diff)
        log_i = -F.softplus(-diff)
        log_h_0 = torch.log(h_0).unsqueeze(1)
        log_tilde_h = log_g(self.linear_h(x))
        h = parallel_scan_log(log_f,torch.cat([log_h_0, log_i + log_tilde_h], dim=1))
        return h



In [32]:
B, T, input_dim, hidden_dim = 64,200, 128,256

# Random sequence input (batch_size, seq_len, input_dim)
x = torch.randn(B,T, input_dim)
# Initial hidden state (batch_size, hidden_dim)
h0 = torch.randn(B, hidden_dim)

# Initialize model
parallel_log_minLSTM = parallel_log_minLSTM(input_dim, hidden_dim)
log_minLSTM = log_minLSTM_Cell(input_dim, hidden_dim)


p_l_out = parallel_log_minLSTM(x,h0)

In [33]:
h = h0
outputs = []
for t in range(T):
    h = log_minLSTM(x[:, t, :], h)
    outputs.append(h.unsqueeze(1))
l_out = torch.cat(outputs, dim=1)