In [120]:
import math
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from torch.nn.parameter import Parameter

class NeuralAccumulatorCell(nn.Module):
    
    # Feed forward but Weight decomposition
    
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim

        self.W_hat = Parameter(torch.Tensor(out_dim, in_dim))
        self.M_hat = Parameter(torch.Tensor(out_dim, in_dim))
        self.W = Parameter(torch.tanh(self.W_hat) * torch.sigmoid(self.M_hat))
        self.register_parameter('bias', None)

        init.kaiming_uniform_(self.W_hat, a=math.sqrt(5))
        init.kaiming_uniform_(self.M_hat, a=math.sqrt(5))
        
        #init.normal_(self.W_hat)
        #init.normal_(self.M_hat)
        
    def forward(self, input):
        return F.linear(input, self.W, self.bias)


class NAC(nn.Module):
    
    def __init__(self, dims):
        '''
        dims = [input_dim + hidden_dims + output_dims]
        '''
        super().__init__()
        self.num_layers = len(dims) - 1
        
        layers = nn.ModuleList()
        layers.extend([NeuralAccumulatorCell(dims[i],dims[i+1]) for i in range(self.num_layers)])
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        out = self.model(x)
        return out

class NeuralArithmeticLogicUnitCell(nn.Module):

    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.eps = 1e-10

        self.G = Parameter(torch.Tensor(out_dim, in_dim))
        self.W = Parameter(torch.Tensor(out_dim, in_dim))
        self.register_parameter('bias', None)
        self.nac = NeuralAccumulatorCell(in_dim, out_dim)

        init.kaiming_uniform_(self.G, a=math.sqrt(5))
        init.kaiming_uniform_(self.W, a=math.sqrt(5))

    def forward(self, input):
        
        a = self.nac(input)
        g = torch.sigmoid(F.linear(input, self.G, self.bias))
        add_sub = g * a
        log_input = torch.log(torch.abs(input) + self.eps)
        m = torch.exp(self.nac(log_input))
        # m = torch.exp(F.linear(log_input, self.W, self.bias))
        mul_div = (1 - g) * m
        y = add_sub + mul_div
        return y


class NALU(nn.Module):
    
    def __init__(self, dims):
        super().__init__()
        self.num_layers = len(dims) - 1
        layers = nn.ModuleList()
        layers.extend([NeuralArithmeticLogicUnitCell(dims[i],dims[i+1]) for i in range(self.num_layers)])
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        out = self.model(x)
        return out

class NALU_LSTMCell(nn.Module):

    def __init__(self, input_size, hidden_size, bias=True):
        super().__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        self.i2h = nn.Sequential(
            nn.Linear(input_size, 4 * hidden_size, bias=bias),
#             nn.BatchNorm1d(4 * hidden_size),
#             nn.LeakyReLU(0.2,inplace=True),
#             nn.Linear(4 *input_size, 4 * hidden_size, bias=bias),
        )
        self.h2h = nn.Sequential(
            nn.Linear(hidden_size, 4 * hidden_size, bias=bias),
#            nn.BatchNorm1d(4 * hidden_size),
#             nn.LeakyReLU(0.2,inplace=True),
#             nn.Linear(4 * hidden_size, 4 * hidden_size, bias=bias)
        )
        self.nalu_h = NALU([hidden_size, hidden_size])
        self.nalu_c = NALU([hidden_size, hidden_size])
        self.out = nn.Linear(hidden_size, input_size, bias=bias)
        self.apply(self.weight_init)

    def weight_init(self,m):

        std = 1.0 / math.sqrt(self.hidden_size)
        for name, w in m.named_parameters():
            w.data.uniform_(-std, std)

    def forward(self, x, hidden = None):
        
        if hidden is None:
            hidden = x.new_zeros(x.size(0), self.hidden_size, requires_grad=False)
            hidden = (hidden, hidden)
            
        h, c = hidden
        
        preact = self.i2h(x) + self.h2h(h)
        
        # First: apply nalu to replace activation func
        
        # self.nalu(preact)
        
        gates = preact[:, :3 * self.hidden_size].sigmoid()
        g_t = preact[:, 3 * self.hidden_size:].tanh()
        i_t = gates[:, :self.hidden_size] 
        f_t = gates[:, self.hidden_size:2 * self.hidden_size]
        o_t = gates[:, -self.hidden_size:]
        
        # Second: Apply it in the output and hidden layer

        c_t = (c*f_t) + (i_t*g_t)

        h_t = o_t * c_t.tanh()
        
        # return x, (h_t, c_t) # LSTM
        
        # return x + self.out(h_t), (h_t + h, c_t + c) # Residule LSTM

        return  x + self.out(h_t) , (self.nalu_h(h_t + h), self.nalu_c(c_t + c)) # Residule NALU

class NALU_LSTM(nn.Module):
    
    def __init__(self,input_size, hidden_sizes, bidirectional = False):
        
        super().__init__()
        self.bidirectional = bidirectional
        
        if self.bidirectional:
            self.num_dir = 2
        else:
            self.num_dir = 1
            
        self.input_size = input_size
        self.L = len(hidden_sizes)
        self.layers = nn.ModuleList()
        self.layers.extend([NALU_LSTMCell(input_size,i) for i in hidden_sizes])
        self.c0 = nn.ParameterList([nn.Parameter(torch.randn(self.num_dir, 1,i)) for i in hidden_sizes])
        self.h0 = nn.ParameterList([nn.Parameter(torch.randn(self.num_dir, 1,i)) for i in hidden_sizes])
        
        
    def forward(self, input):
        
        '''
        input_shape = B, S, input_size
        output_shape = B, num_dir, L, S, input_size
        hidden, cells = S * (B, hidden_dim)
        '''
        
        B,S = input.shape[:-1]
        
        outputs = torch.zeros(B, self.num_dir, self.L+1, S, self.input_size)
        outputs[:,:,0,:,:] = input.unsqueeze(1).expand_as(outputs[:,:,0,:,:])
        hiddens = []
        cells = []
    
        for i, layer in enumerate(self.layers):
            f_h, f_c = self.h0[i][0].repeat(B,1), self.c0[i][0].repeat(B,1)
            if self.bidirectional:
                i_h, i_c = self.h0[i][1].repeat(B,1), self.c0[i][1].repeat(B,1)
            for j in range(S):
                f_out, (f_h,f_c) = layer(outputs[:,0,i,j,:].clone(), (f_h,f_c))
                outputs[:,0,i+1,j,:] = f_out

            if self.bidirectional:
                for j in reversed(range(S)):
                    i_out, (i_h,i_c) = layer(outputs[:,1,i,j,:].clone(), (i_h,i_c))
                    outputs[:,1,i+1,j,:] = i_out
                hiddens.append((torch.stack([f_h,i_h])))
                cells.append(torch.stack([f_c,i_c]))
            else:
                hiddens.append(f_h)
                cells.append(f_c)
                
        
        return outputs[:,:,1:,:,:].contiguous(), (hiddens, cells)