In [None]:
'''
limitation
if input_size != hidden_size, GRUD would not works.
dropout and bidirectinoal are not working
bias True/False is not checked
'''

In [1]:
import torch # GRUD
import math
import warnings
import itertools
import numbers
import numpy as np
import torch.utils.data as utils

In [15]:
class GRUD(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=1, x_mean=0,\
                 bias=True, batch_first=False, bidirectional=False, dropout_type='mloss', dropout=0):
        super(GRUD, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.zeros = torch.autograd.Variable(torch.zeros(input_size))
        self.x_mean = torch.autograd.Variable(torch.tensor(x_mean))
        self.bias = bias
        self.batch_first = batch_first
        self.dropout_type = dropout_type
        self.dropout = dropout
        self.bidirectional = bidirectional
        num_directions = 2 if bidirectional else 1
        
        if not isinstance(dropout, numbers.Number) or not 0 <= dropout <= 1 or \
                isinstance(dropout, bool):
            raise ValueError("dropout should be a number in range [0, 1] "
                             "representing the probability of an element being "
                             "zeroed")
        if dropout > 0 and num_layers == 1:
            warnings.warn("dropout option adds dropout after all but last "
                          "recurrent layer, so non-zero dropout expects "
                          "num_layers greater than 1, but got dropout={} and "
                          "num_layers={}".format(dropout, num_layers))
        
        ################################
        gate_size = 1 # not used
        ################################
        
        self._all_weights = []
        for layer in range(num_layers):
            for direction in range(num_directions):
                # not used
                layer_input_size = input_size if layer == 0 else hidden_size * num_directions
                '''
                w_ih = Parameter(torch.Tensor(gate_size, layer_input_size))
                w_hh = Parameter(torch.Tensor(gate_size, hidden_size))
                b_ih = Parameter(torch.Tensor(gate_size))
                b_hh = Parameter(torch.Tensor(gate_size))
                layer_params = (w_ih, w_hh, b_ih, b_hh)
                '''
                # decay rates gamma
                w_dg_x = torch.nn.Parameter(torch.Tensor(input_size))
                w_dg_h = torch.nn.Parameter(torch.Tensor(hidden_size))
                                
                # z
                w_xz = torch.nn.Parameter(torch.Tensor(input_size))
                w_hz = torch.nn.Parameter(torch.Tensor(hidden_size))
                w_mz = torch.nn.Parameter(torch.Tensor(input_size))
                
                # r
                w_xr = torch.nn.Parameter(torch.Tensor(input_size))
                w_hr = torch.nn.Parameter(torch.Tensor(hidden_size))
                w_mr = torch.nn.Parameter(torch.Tensor(input_size))
                
                # h_tilde
                w_xh = torch.nn.Parameter(torch.Tensor(input_size))
                w_hh = torch.nn.Parameter(torch.Tensor(hidden_size))
                w_mh = torch.nn.Parameter(torch.Tensor(input_size))
                
                # y (output)
                w_hy = torch.nn.Parameter(torch.Tensor(output_size, hidden_size))
                
                # bias
                b_dg_x = torch.nn.Parameter(torch.Tensor(hidden_size))
                b_dg_h = torch.nn.Parameter(torch.Tensor(hidden_size))
                b_z = torch.nn.Parameter(torch.Tensor(hidden_size))
                b_r = torch.nn.Parameter(torch.Tensor(hidden_size))
                b_h = torch.nn.Parameter(torch.Tensor(hidden_size))
                b_y = torch.nn.Parameter(torch.Tensor(output_size))
                
                layer_params = (w_dg_x, w_dg_h,\
                                w_xz, w_hz, w_mz,\
                                w_xr, w_hr, w_mr,\
                                w_xh, w_hh, w_mh,\
                                w_hy,\
                                b_dg_x, b_dg_h, b_z, b_r, b_h, b_y)
                
                suffix = '_reverse' if direction == 1 else ''
                param_names = ['weight_dg_x_l{}{}', 'weight_dg_h_l{}{}',\
                               'weight_xz_l{}{}', 'weight_hz_l{}{}','weight_mz_l{}{}',\
                               'weight_xr_l{}{}', 'weight_hr_l{}{}','weight_mr_l{}{}',\
                               'weight_xh_l{}{}', 'weight_hh_l{}{}','weight_mh_l{}{}',\
                               'weight_hy']
                if bias:
                    param_names += ['bias_dg_x_l{}{}', 'bias_dg_h_l{}{}',\
                                    'bias_z_l{}{}',\
                                    'bias_r_l{}{}',\
                                    'bias_h_l{}{}',\
                                    'bias_y']
                param_names = [x.format(layer, suffix) for x in param_names]

                for name, param in zip(param_names, layer_params):
                    setattr(self, name, param)
                self._all_weights.append(param_names)
        
        self.flatten_parameters()
        self.reset_parameters()
        
    def flatten_parameters(self):
        """
        Resets parameter data pointer so that they can use faster code paths.
        Right now, this works only if the module is on the GPU and cuDNN is enabled.
        Otherwise, it's a no-op.
        """
        any_param = next(self.parameters()).data
        if not any_param.is_cuda or not torch.backends.cudnn.is_acceptable(any_param):
            return

        # If any parameters alias, we fall back to the slower, copying code path. This is
        # a sufficient check, because overlapping parameter buffers that don't completely
        # alias would break the assumptions of the uniqueness check in
        # Module.named_parameters().
        all_weights = self._flat_weights
        unique_data_ptrs = set(p.data_ptr() for p in all_weights)
        if len(unique_data_ptrs) != len(all_weights):
            return

        with torch.cuda.device_of(any_param):
            import torch.backends.cudnn.rnn as rnn

            # NB: This is a temporary hack while we still don't have Tensor
            # bindings for ATen functions
            with torch.no_grad():
                # NB: this is an INPLACE function on all_weights, that's why the
                # no_grad() is necessary.
                torch._cudnn_rnn_flatten_weight(
                    all_weights, (4 if self.bias else 2),
                    self.input_size, rnn.get_cudnn_mode(self.mode), self.hidden_size, self.num_layers,
                    self.batch_first, bool(self.bidirectional))

    def _apply(self, fn):
        ret = super(GRUD, self)._apply(fn)
        self.flatten_parameters()
        return ret

    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            torch.nn.init.uniform_(weight, -stdv, stdv)

    def check_forward_args(self, input, hidden, batch_sizes):
        is_input_packed = batch_sizes is not None
        expected_input_dim = 2 if is_input_packed else 3
        if input.dim() != expected_input_dim:
            raise RuntimeError(
                'input must have {} dimensions, got {}'.format(
                    expected_input_dim, input.dim()))
        if self.input_size != input.size(-1):
            raise RuntimeError(
                'input.size(-1) must be equal to input_size. Expected {}, got {}'.format(
                    self.input_size, input.size(-1)))

        if is_input_packed:
            mini_batch = int(batch_sizes[0])
        else:
            mini_batch = input.size(0) if self.batch_first else input.size(1)

        num_directions = 2 if self.bidirectional else 1
        expected_hidden_size = (self.num_layers * num_directions,
                                mini_batch, self.hidden_size)
        
        def check_hidden_size(hx, expected_hidden_size, msg='Expected hidden size {}, got {}'):
            if tuple(hx.size()) != expected_hidden_size:
                raise RuntimeError(msg.format(expected_hidden_size, tuple(hx.size())))

        if self.mode == 'LSTM':
            check_hidden_size(hidden[0], expected_hidden_size,
                              'Expected hidden[0] size {}, got {}')
            check_hidden_size(hidden[1], expected_hidden_size,
                              'Expected hidden[1] size {}, got {}')
        else:
            check_hidden_size(hidden, expected_hidden_size)
    
    def extra_repr(self):
        s = '{input_size}, {hidden_size}'
        if self.num_layers != 1:
            s += ', num_layers={num_layers}'
        if self.bias is not True:
            s += ', bias={bias}'
        if self.batch_first is not False:
            s += ', batch_first={batch_first}'
        if self.dropout != 0:
            s += ', dropout={dropout}'
        if self.bidirectional is not False:
            s += ', bidirectional={bidirectional}'
        return s.format(**self.__dict__)
    
    
    def __setstate__(self, d):
        super(GRUD, self).__setstate__(d)
        if 'all_weights' in d:
            self._all_weights = d['all_weights']
        if isinstance(self._all_weights[0][0], str):
            return
        num_layers = self.num_layers
        num_directions = 2 if self.bidirectional else 1
        self._all_weights = []
        for layer in range(num_layers):
            for direction in range(num_directions):
                suffix = '_reverse' if direction == 1 else ''
                weights = ['weight_dg_x_l{}{}', 'weight_dg_h_l{}{}',\
                           'weight_xz_l{}{}', 'weight_hz_l{}{}','weight_mz_l{}{}',\
                           'weight_xr_l{}{}', 'weight_hr_l{}{}','weight_mr_l{}{}',\
                           'weight_xh_l{}{}', 'weight_hh_l{}{}','weight_mh_l{}{}',\
                           'weight_hy',\
                           'bias_dg_x_l{}{}', 'bias_dg_h_l{}{}',\
                           'bias_z_l{}{}', 'bias_r_l{}{}', 'bias_h_l{}{}','bias_y']
                weights = [x.format(layer, suffix) for x in weights]
                if self.bias:
                    self._all_weights += [weights]
                else:
                    self._all_weights += [weights[:2]]

    @property
    def _flat_weights(self):
        return list(self._parameters.values())

    @property
    def all_weights(self):
        return [[getattr(self, weight) for weight in weights] for weights in self._all_weights]
    
    def forward(self, input):
        # input.size = (3, 33,49) : num_input or num_hidden, num_layer or step
        X = torch.squeeze(input[0]) # .size = (33,49)
        Mask = torch.squeeze(input[1]) # .size = (33,49)
        Delta = torch.squeeze(input[2]) # .size = (33,49)
        Hidden_State = torch.autograd.Variable(torch.zeros(input_size))
        
        step_size = X.size(1) # 49
        #print('step size : ', step_size)
        
        output = None
        h = Hidden_State
        for layer in range(num_layers):
            
            x = torch.squeeze(X[:,layer:layer+1])
            m = torch.squeeze(Mask[:,layer:layer+1])
            d = torch.squeeze(Delta[:,layer:layer+1])
            
            # decay rates gamma
            w_dg_x = getattr(self, 'weight_dg_x_l' + str(layer))
            w_dg_h = getattr(self, 'weight_dg_h_l' + str(layer))
                                
            #z
            w_xz = getattr(self, 'weight_xz_l' + str(layer))
            w_hz = getattr(self, 'weight_hz_l' + str(layer))
            w_mz = getattr(self, 'weight_mz_l' + str(layer))
                
            # r
            w_xr = getattr(self, 'weight_xr_l' + str(layer))
            w_hr = getattr(self, 'weight_hr_l' + str(layer))
            w_mr = getattr(self, 'weight_mr_l' + str(layer))
                
            # h_tilde
            w_xh = getattr(self, 'weight_xh_l' + str(layer))
            w_hh = getattr(self, 'weight_hh_l' + str(layer))
            w_mh = getattr(self, 'weight_mh_l' + str(layer))
                
            # bias
            b_dg_x = getattr(self, 'bias_dg_x_l' + str(layer))
            b_dg_h = getattr(self, 'bias_dg_h_l' + str(layer))
            b_z = getattr(self, 'bias_z_l' + str(layer))
            b_r = getattr(self, 'bias_r_l' + str(layer))
            b_h = getattr(self, 'bias_h_l' + str(layer))
            
            #(4)
            gamma_x = torch.exp(-torch.max(self.zeros, (w_dg_x * x + b_dg_x)))
            gamma_h = torch.exp(-torch.max(self.zeros, (w_dg_h * h + b_dg_h)))

            #(5)
            x = m * x + (1 - m) * (gamma_x * x + (1 - gamma_x) * self.x_mean)
            
            #(6)
            if self.dropout == 0:
                h = gamma_h * h

                z = torch.nn.functional.sigmoid((w_xz*x + w_hz*h + w_mz*m + b_z))
                r = torch.nn.functional.sigmoid((w_xr*x + w_hr*h + w_mr*m + b_r))
                h_tilde = torch.nn.functional.tanh((w_xh*x + w_hh*(r * h) + w_mh*m + b_h))

                h = (1 - z) * h + z * h_tilde
                
            elif self.dropout_type == 'Moon':
                '''
                RNNDROP: a novel dropout for rnn in asr(2015)
                '''
                h = gamma_h * h

                z = torch.nn.functional.sigmoid((w_xz*x + w_hz*h + w_mz*m + b_z))
                r = torch.nn.functional.sigmoid((w_xr*x + w_hr*h + w_mr*m + b_r))
                
                h_tilde = torch.nn.functional.tanh((w_xh*x + w_hh*(r * h) + w_mh*m + b_h))
                
                h = (1 - z) * h + z * h_tilde
                dropout = torch.nn.Dropout(p=self.dropout)
                h = dropout(h)
                
            elif self.dropout_type == 'Gal':
                '''
                A Theoretically grounded application of dropout in recurrent neural networks(2015)
                '''
                dropout = torch.nn.Dropout(p=self.dropout)
                h = dropout(h)
                
                print('Gal')
                print(h)
                
                h = gamma_h * h
                
                z = torch.nn.functional.sigmoid((w_xz*x + w_hz*h + w_mz*m + b_z))
                r = torch.nn.functional.sigmoid((w_xr*x + w_hr*h + w_mr*m + b_r))
                h_tilde = torch.nn.functional.tanh((w_xh*x + w_hh*(r * h) + w_mh*m + b_h))
                
                h = (1 - z) * h + z * h_tilde
                
            elif self.dropout_type == 'mloss':
                '''
                recurrent dropout without memory loss arXiv 1603.05118
                g = h_tilde, p = the probability to not drop a neuron
                '''
                
                h = gamma_h * h
                
                z = torch.nn.functional.sigmoid((w_xz*x + w_hz*h + w_mz*m + b_z))
                r = torch.nn.functional.sigmoid((w_xr*x + w_hr*h + w_mr*m + b_r))
                h_tilde = torch.nn.functional.tanh((w_xh*x + w_hh*(r * h) + w_mh*m + b_h))
                
                dropout = torch.nn.Dropout(p=self.dropout)
                h_tilde = dropout(h_tilde)
                
                h = (1 - z)* h + z*h_tilde
                
            else:
                h = gamma_h * h

                z = torch.nn.functional.sigmoid((w_xz*x + w_hz*h + w_mz*m + b_z))
                r = torch.nn.functional.sigmoid((w_xr*x + w_hr*h + w_mr*m + b_r))
                h_tilde = torch.nn.functional.tanh((w_xh*x + w_hh*(r * h) + w_mh*m + b_h))

                h = (1 - z) * h + z * h_tilde
            
            
        w_hy = getattr(self, 'weight_hy')
        b_y = getattr(self, 'bias_y')

        output = torch.matmul(w_hy, h) + b_y

        return output