In [None]:
class LSTM(nn.Module):
    
    def __init__(self, options,inp_dim):
        super(LSTM, self).__init__()
        
        # Reading parameters
        self.input_dim=inp_dim
        self.lstm_lay=list(map(int, options['lstm_lay'].split(',')))
        self.lstm_drop=list(map(float, options['lstm_drop'].split(','))) 
        self.lstm_use_batchnorm=list(map(strtobool, options['lstm_use_batchnorm'].split(',')))
        self.lstm_use_laynorm=list(map(strtobool, options['lstm_use_laynorm'].split(','))) 
        self.lstm_use_laynorm_inp=strtobool(options['lstm_use_laynorm_inp'])
        self.lstm_use_batchnorm_inp=strtobool(options['lstm_use_batchnorm_inp'])
        self.lstm_act=options['lstm_act'].split(',')
        self.lstm_orthinit=strtobool(options['lstm_orthinit'])

        self.bidir=strtobool(options['lstm_bidir'])
        self.use_cuda=strtobool(options['use_cuda'])
        self.to_do=options['to_do']
        
        if self.to_do=='train':
            self.test_flag=False
        else:
            self.test_flag=True
        
        
        # List initialization
        self.wfx  = nn.ModuleList([]) # Forget
        self.ufh  = nn.ModuleList([]) # Forget
        
        self.wix  = nn.ModuleList([]) # Input
        self.uih  = nn.ModuleList([]) # Input  
        
        self.wox  = nn.ModuleList([]) # Output
        self.uoh  = nn.ModuleList([]) # Output  
        
        self.wcx  = nn.ModuleList([]) # Cell state
        self.uch = nn.ModuleList([])  # Cell state
        
        self.ln  = nn.ModuleList([]) # Layer Norm
        self.bn_wfx  = nn.ModuleList([]) # Batch Norm
        self.bn_wix  = nn.ModuleList([]) # Batch Norm
        self.bn_wox  = nn.ModuleList([]) # Batch Norm
        self.bn_wcx = nn.ModuleList([]) # Batch Norm
        
        self.act  = nn.ModuleList([]) # Activations
       
  
        # Input layer normalization
        if self.lstm_use_laynorm_inp:
            self.ln0=LayerNorm(self.input_dim)
          
        # Input batch normalization    
        if self.lstm_use_batchnorm_inp:
            self.bn0=nn.BatchNorm1d(self.input_dim,momentum=0.05)
           
        self.N_lstm_lay=len(self.lstm_lay)
             
        current_input=self.input_dim
        
        # Initialization of hidden layers
        
        for i in range(self.N_lstm_lay):
             
             # Activations
             self.act.append(act_fun(self.lstm_act[i]))
            
             add_bias=True
             
             if self.lstm_use_laynorm[i] or self.lstm_use_batchnorm[i]:
                add_bias=False
             
                  
             # Feed-forward connections
            self.wfx.append(nn.Linear(current_input, self.lstm_lay[i],bias=add_bias))
            self.wix.append(nn.Linear(current_input, self.lstm_lay[i],bias=add_bias))
            self.wox.append(nn.Linear(current_input, self.lstm_lay[i],bias=add_bias))
            self.wcx.append(nn.Linear(current_input, self.lstm_lay[i],bias=add_bias))
            
             # Recurrent connections
            self.ufh.append(nn.Linear(self.lstm_lay[i], self.lstm_lay[i],bias=False))
            self.uih.append(nn.Linear(self.lstm_lay[i], self.lstm_lay[i],bias=False))
            self.uoh.append(nn.Linear(self.lstm_lay[i], self.lstm_lay[i],bias=False))
            self.uch.append(nn.Linear(self.lstm_lay[i], self.lstm_lay[i],bias=False))
             
            if self.lstm_orthinit:
                nn.init.orthogonal_(self.ufh[i].weight)
                nn.init.orthogonal_(self.uih[i].weight)
                nn.init.orthogonal_(self.uoh[i].weight)
                nn.init.orthogonal_(self.uch[i].weight)
            
             
             # batch norm initialization
            self.bn_wfx.append(nn.BatchNorm1d(self.lstm_lay[i],momentum=0.05))
            self.bn_wix.append(nn.BatchNorm1d(self.lstm_lay[i],momentum=0.05))
            self.bn_wox.append(nn.BatchNorm1d(self.lstm_lay[i],momentum=0.05))
            self.bn_wcx.append(nn.BatchNorm1d(self.lstm_lay[i],momentum=0.05))
                
            self.ln.append(LayerNorm(self.lstm_lay[i]))
                
            if self.bidir:
                current_input=2*self.lstm_lay[i]
            else:
                current_input=self.lstm_lay[i]
                 
        self.out_dim=self.lstm_lay[i]+self.bidir*self.lstm_lay[i]
            
             
        
    def forward(self, x):

        # Applying Layer/Batch Norm
        if bool(self.lstm_use_laynorm_inp):
            x=self.ln0((x))
        
        if bool(self.lstm_use_batchnorm_inp):
            x_bn=self.bn0(x.view(x.shape[0]*x.shape[1],x.shape[2]))
            x=x_bn.view(x.shape[0],x.shape[1],x.shape[2])

          
        for i in range(self.N_lstm_lay):
            
            # Initial state and concatenation
            if self.bidir:
                h_init = torch.zeros(2*x.shape[1], self.lstm_lay[i])
                x=torch.cat([x,flip(x,0)],1)
            else:
                h_init = torch.zeros(x.shape[1],self.lstm_lay[i])
        
               
            # Drop mask initilization (same mask for all time steps)            
            if self.test_flag==False:
                drop_mask=torch.bernoulli(torch.Tensor(h_init.shape[0],h_init.shape[1]).fill_(1-self.lstm_drop[i]))
            else:
                drop_mask=torch.FloatTensor([1-self.lstm_drop[i]])
                
            if self.use_cuda:
                h_init=h_init.cuda()
                drop_mask=drop_mask.cuda()
               
                 
            # Feed-forward affine transformations (all steps in parallel)
            wfx_out=self.wfx[i](x)
            wix_out=self.wix[i](x)
            wox_out=self.wox[i](x)
            wcx_out=self.wcx[i](x)
            
            # Apply batch norm if needed (all steos in parallel)
            if self.lstm_use_batchnorm[i]:

                wfx_out_bn=self.bn_wfx[i](wfx_out.view(wfx_out.shape[0]*wfx_out.shape[1],wfx_out.shape[2]))
                wfx_out=wfx_out_bn.view(wfx_out.shape[0],wfx_out.shape[1],wfx_out.shape[2])
         
                wix_out_bn=self.bn_wix[i](wix_out.view(wix_out.shape[0]*wix_out.shape[1],wix_out.shape[2]))
                wix_out=wix_out_bn.view(wix_out.shape[0],wix_out.shape[1],wix_out.shape[2])
   
                wox_out_bn=self.bn_wox[i](wox_out.view(wox_out.shape[0]*wox_out.shape[1],wox_out.shape[2]))
                wox_out=wox_out_bn.view(wox_out.shape[0],wox_out.shape[1],wox_out.shape[2])

                wcx_out_bn=self.bn_wcx[i](wcx_out.view(wcx_out.shape[0]*wcx_out.shape[1],wcx_out.shape[2]))
                wcx_out=wcx_out_bn.view(wcx_out.shape[0],wcx_out.shape[1],wcx_out.shape[2]) 
            
            
            # Processing time steps
            hiddens = []
            ct=h_init
            ht=h_init
            
            for k in range(x.shape[0]):
                
                # LSTM equations
                ft=torch.sigmoid(wfx_out[k]+self.ufh[i](ht))
                it=torch.sigmoid(wix_out[k]+self.uih[i](ht))
                ot=torch.sigmoid(wox_out[k]+self.uoh[i](ht))
                ct=it*self.act[i](wcx_out[k]+self.uch[i](ht))*drop_mask+ft*ct
                ht=ot*self.act[i](ct)
                
                if self.lstm_use_laynorm[i]:
                    ht=self.ln[i](ht)
                    
                hiddens.append(ht)
                
            # Stacking hidden states
            h=torch.stack(hiddens)
            
            # Bidirectional concatenations
            if self.bidir:
                h_f=h[:,0:int(x.shape[1]/2)]
                h_b=flip(h[:,int(x.shape[1]/2):x.shape[1]].contiguous(),0)
                h=torch.cat([h_f,h_b],2)
                
            # Setup x for the next hidden layer
            x=h

              
        return x