Skip to content

Commit

Permalink
merge custom_u and custom_v
Browse files Browse the repository at this point in the history
  • Loading branch information
taolei87 committed Feb 11, 2020
1 parent 6f749e3 commit 79a8e34
Showing 1 changed file with 69 additions and 83 deletions.
152 changes: 69 additions & 83 deletions sru/sru_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,14 +222,12 @@ def __init__(self,
bidirectional=False,
n_proj=0,
use_tanh=0,
#is_input_normalized=False,
highway_bias=0,
has_skip_term=True,
layer_norm=False,
rescale=True,
v1=False,
custom_u=None,
custom_v=None):
custom_m=None):

super(SRUCell, self).__init__()
self.input_size = input_size
Expand All @@ -244,8 +242,7 @@ def __init__(self,
self.rescale = rescale
self.activation_type = 0
self.activation = 'none'
self.custom_u = custom_u
self.custom_v = custom_v
self.custom_m = custom_m
if use_tanh:
self.activation_type = 1
self.activation = 'tanh'
Expand All @@ -261,7 +258,7 @@ def __init__(self,
self.num_matrices = 4

# make parameters
if self.custom_u is None:
if self.custom_m is None:
if self.projection_size == 0:
self.weight = nn.Parameter(torch.Tensor(
input_size,
Expand All @@ -273,11 +270,9 @@ def __init__(self,
self.projection_size,
self.output_size * self.num_matrices
))

if self.custom_v is None:
self.weight_c = nn.Parameter(torch.Tensor(2 * self.output_size))

self.weight_c = nn.Parameter(torch.Tensor(2 * self.output_size))
self.bias = nn.Parameter(torch.Tensor(2 * self.output_size))

# scaling constant used in highway connections when rescale=True
self.register_buffer('scale_x', torch.FloatTensor([0]))

Expand All @@ -304,7 +299,7 @@ def reset_parameters(self):
scale_val = (1 + math.exp(bias_val) * 2)**0.5
self.scale_x.data[0] = scale_val

if self.custom_u is None:
if self.custom_m is None:
# initialize weights such that E[w_ij]=0 and Var[w_ij]=1/d
d = self.weight.size(0)
val_range = (3.0 / d)**0.5
Expand Down Expand Up @@ -333,21 +328,20 @@ def reset_parameters(self):
scale_val = (1 + math.exp(bias_val) * 2)**0.5
w[:, :, :, 3].mul_(scale_val)

if self.custom_v is None:
if not self.v1:
# intialize weight_c such that E[w]=0 and Var[w]=1
self.weight_c.data.uniform_(-3.0**0.5, 3.0**0.5)
if not self.v1:
# intialize weight_c such that E[w]=0 and Var[w]=1
self.weight_c.data.uniform_(-3.0**0.5, 3.0**0.5)

# rescale weight_c and the weight of sigmoid gates with a factor of sqrt(0.5)
if self.custom_u is None:
w[:, :, :, 1].mul_(0.5**0.5)
w[:, :, :, 2].mul_(0.5**0.5)
self.weight_c.data.mul_(0.5**0.5)
else:
self.weight_c.data.zero_()
self.weight_c.requires_grad = False
# rescale weight_c and the weight of sigmoid gates with a factor of sqrt(0.5)
if self.custom_m is None:
w[:, :, :, 1].mul_(0.5**0.5)
w[:, :, :, 2].mul_(0.5**0.5)
self.weight_c.data.mul_(0.5**0.5)
else:
self.weight_c.data.zero_()
self.weight_c.requires_grad = False

def forward(self, input, c0=None, mask_pad=None):
def forward(self, input, c0=None, mask_pad=None, **kwargs):
"""
This method computes `U`. In addition, it computes the remaining components
in `SRU_Compute_GPU` or `SRU_Compute_CPU` and return the results.
Expand All @@ -371,27 +365,33 @@ def forward(self, input, c0=None, mask_pad=None):
mask = self.get_dropout_mask_((batch_size, input.size(-1)), self.rnn_dropout)
input = input * mask.expand_as(input)

# compute U that's (length, batch_size, output_size * num_matrices)
if self.custom_u is not None:
U = self.custom_u(input)
else:
# compute U, V
# U is (length, batch_size, output_size * num_matrices)
# V is (output_size*2,) or (length, batch_size, output_size * 2) if provided
if self.custom_m is None:
U = self.compute_U(input)
if U.size(-1) != self.output_size * self.num_matrices:
raise ValueError("U must have a last dimension of {} but got {}.".format(
self.output_size * self.num_matrices,
U.size(-1)
))

# V is (length, batch_size, output_size * 2) if customized otherwise (output_size * 2,)
if self.custom_v is not None:
V = self.custom_v(input)
else:
V = self.weight_c
if V.size(-1) != self.output_size * 2:
raise ValueError("V must have a last dimension of {} but got {}.".format(
self.output_size * 2,
V.size(-1)
))
else:
ret = self.custom_m(input, c0=c0, mask_pad=mask_pad, **kwargs)
if isinstance(ret, tuple) or isinstance(ret, list):
if len(ret) > 2:
raise Exception("Custom module must return 1 or 2 tensors but got {}.".format(
len(ret)
))
U, V = ret[0], ret[1] + self.weight_c
else:
U, V = ret, self.weight_c

if U.size(-1) != self.output_size * self.num_matrices:
raise ValueError("U must have a last dimension of {} but got {}.".format(
self.output_size * self.num_matrices,
U.size(-1)
))
if V.size(-1) != self.output_size * 2:
raise ValueError("V must have a last dimension of {} but got {}.".format(
self.output_size * 2,
V.size(-1)
))

# get the scaling constant; scale_x is a scalar
scale_val = self.scale_x if self.rescale else None
Expand Down Expand Up @@ -456,10 +456,8 @@ def extra_repr(self):
s += ", has_skip_term={has_skip_term}"
if self.layer_norm:
s += ", layer_norm=True"
if self.custom_u is not None:
s += ",\n custom_u=" + str(self.custom_u)
if self.custom_v is not None:
s += ",\n custom_v=" + str(self.custom_v)
if self.custom_m is not None:
s += ",\n custom_m=" + str(self.custom_m)
return s.format(**self.__dict__)

def __repr__(self):
Expand All @@ -472,36 +470,31 @@ def __repr__(self):

class SRU(nn.Module):
"""
PyTorch SRU model. In effect, simply wraps an arbitrary number of
contiguous `SRUCell`s, and returns the matrix and hidden states ,
as well as final memory cell (`c_t`), from the last of these `SRUCell`s.
PyTorch SRU model. In effect, simply wraps an arbitrary number of contiguous `SRUCell`s, and
returns the matrix and hidden states , as well as final memory cell (`c_t`), from the last of
these `SRUCell`s.
Args:
input_size (int) : the number of dimensions in a single
input sequence element. For example, if the input sequence
is a sequence of word embeddings, `input_size` is the
dimensionality of a single word embedding, e.g. 300.
hidden_size (int) : the dimensionality of the hidden state
of the SRU cell.
input_size (int) : the number of dimensions in a single input sequence element. For example,
if the input sequence is a sequence of word embeddings, `input_size` is the dimensionality
of a single word embedding, e.g. 300.
hidden_size (int) : the dimensionality of the hidden state of the SRU cell.
num_layers (int) : number of `SRUCell`s to use in the model.
dropout (float) : a number between 0.0 and 1.0. The amount of dropout
applied to `g(c_t)` internally in each `SRUCell`.
rnn_dropout (float) : the amount of dropout applied to the input of
each `SRUCell`.
dropout (float) : a number between 0.0 and 1.0. The amount of dropout applied to `g(c_t)`
internally in each `SRUCell`.
rnn_dropout (float) : the amount of dropout applied to the input of each `SRUCell`.
use_tanh (bool) : use tanh activation
layer_norm (bool) : whether or not to use layer normalization on the output of each layer
bidirectional (bool) : whether or not to use bidirectional `SRUCell`s.
is_input_normalized (bool) : whether the input is normalized (e.g. batch norm / layer norm)
highway_bias (float) : initial bias of the highway gate, typicially <= 0
nn_rnn_compatible_return (bool) : set to True to change the layout of returned state to match
that of pytorch nn.RNN, ie (num_layers * num_directions, batch, hidden_size)
(this will be slower, but can make SRU a dropin replacement for nn.RNN and nn.GRU)
custom_u (nn.Module) : use a custom module to compute the U matrix given the input.
The module must take as input a tensor of shape (seq_len, batch_size, hidden_size) and
return a tensor of shape (seq_len, batch_size, hidden_size * 3)
custom_v (nn.Module) : use a custom module to compute the V matrix given the input.
The module must take as input a tensor of shape (seq_len, batch_size, hidden_size) and
return a tensor of shape (seq_len, batch_size, hidden_size * 2)
nn_rnn_compatible_return (bool) : set to True to change the layout of returned state to
match that of pytorch nn.RNN, ie (num_layers * num_directions, batch, hidden_size)
(this will be slower, but can make SRU a drop-in replacement for nn.RNN and nn.GRU)
custom_m (nn.Module or List[nn.Module]) : use a custom module to compute the U matrix (and V
matrix) given the input. The module must take as input a tensor of shape (seq_len,
batch_size, hidden_size).
It returns a tensor U of shape (seq_len, batch_size, hidden_size * 3), or one optional
tensor V of shape (seq_len, batch_size, hidden_size * 2).
"""

def __init__(self,
Expand All @@ -514,14 +507,12 @@ def __init__(self,
projection_size=0,
use_tanh=False,
layer_norm=False,
#is_input_normalized=False,
highway_bias=0,
has_skip_term=True,
rescale=False,
v1=False,
nn_rnn_compatible_return=False,
custom_u=None,
custom_v=None,
custom_m=None,
proj_input_to_hidden_first=False):

super(SRU, self).__init__()
Expand All @@ -547,11 +538,9 @@ def __init__(self,

for i in range(num_layers):
# get custom modules when provided
custom_u_i, custom_v_i = None, None
if custom_u is not None:
custom_u_i = custom_u[i] if isinstance(custom_u, list) else copy.deepcopy(custom_u)
if custom_v is not None:
custom_v_i = custom_v[i] if isinstance(custom_v, list) else copy.deepcopy(custom_v)
custom_m_i = None
if custom_m is not None:
custom_m_i = custom_m[i] if isinstance(custom_m, list) else copy.deepcopy(custom_m)
# create the i-th SRU layer
l = SRUCell(
first_layer_input_size if i == 0 else self.output_size,
Expand All @@ -561,14 +550,12 @@ def __init__(self,
bidirectional=bidirectional,
n_proj=projection_size,
use_tanh=use_tanh,
#is_input_normalized=is_input_normalized or (i > 0 and self.use_layer_norm),
layer_norm=layer_norm,
highway_bias=highway_bias,
has_skip_term=has_skip_term,
rescale=rescale,
v1=v1,
custom_u=custom_u_i,
custom_v=custom_v_i
custom_m=custom_m_i
)
self.rnn_lst.append(l)

Expand Down Expand Up @@ -667,5 +654,4 @@ def make_backward_compatible(self):
if not hasattr(self, 'input_to_hidden'):
self.input_to_hidden = None
for cell in self.rnn_lst:
cell.custom_u = None
cell.custom_v = None
cell.custom_m = None

0 comments on commit 79a8e34

Please sign in to comment.