Permalink
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
436 lines (325 sloc) 14.1 KB
import math
import torch
from torch.nn import Module, Parameter
import torch.nn.functional as F
from torch.autograd import Variable
def clip_grad(v, min, max):
v_tmp = v.expand_as(v)
v_tmp.register_hook(lambda g: g.clamp(min, max))
return v_tmp
class RNNCellBase(Module):
def __repr__(self):
s = '{name}({input_size}, {hidden_size}'
if 'bias' in self.__dict__ and self.bias is not True:
s += ', bias={bias}'
if 'nonlinearity' in self.__dict__ and self.nonlinearity != "tanh":
s += ', nonlinearity={nonlinearity}'
s += ')'
return s.format(name=self.__class__.__name__, **self.__dict__)
class RNNCell(RNNCellBase):
def __init__(self, input_size, hidden_size, bias=True, grad_clip=None):
super(RNNCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.grad_clip = grad_clip
self.weight_ih = Parameter(torch.Tensor(hidden_size, input_size))
self.weight_hh = Parameter(torch.Tensor(hidden_size, hidden_size))
if bias:
self.bias = Parameter(torch.Tensor(hidden_size))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
def forward(self, input, h):
output = F.linear(input, self.weight_ih, self.bias) + F.linear(h, self.weight_hh)
if self.grad_clip:
output = clip_grad(output, -self.grad_clip, self.grad_clip) # avoid explosive gradient
output = F.relu(output)
return output
class GRUCell(RNNCellBase):
def __init__(self, input_size, hidden_size, bias=True, grad_clip=None):
super(GRUCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.grad_clip = grad_clip
self.weight_ih = Parameter(torch.Tensor(3 * hidden_size, input_size))
self.weight_hh_rz = Parameter(torch.Tensor(2 * hidden_size, hidden_size))
self.weight_hh = Parameter(torch.Tensor(hidden_size, hidden_size))
if bias:
self.bias = Parameter(torch.Tensor(3 * hidden_size))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
def forward(self, input, h):
ih = F.linear(input, self.weight_ih, self.bias)
hh_rz = F.linear(h, self.weight_hh_rz)
if self.grad_clip:
ih = clip_grad(ih, -self.grad_clip, self.grad_clip)
hh_rz = clip_grad(hh_rz, -self.grad_clip, self.grad_clip)
r = F.sigmoid(ih[:, :self.hidden_size] + hh_rz[:, :self.hidden_size])
i = F.sigmoid(ih[:, self.hidden_size: self.hidden_size * 2] + hh_rz[:, self.hidden_size:])
hhr = F.linear(h * r, self.weight_hh)
if self.grad_clip:
hhr = clip_grad(hhr, -self.grad_clip, self.grad_clip)
n = F.relu(ih[:, self.hidden_size * 2:] + hhr)
h = (1 - i) * n + i * h
return h
class LSTMCell(RNNCellBase):
def __init__(self, input_size, hidden_size, bias=True, grad_clip=None):
super(LSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.grad_clip = grad_clip
self.weight_ih = Parameter(torch.Tensor(4 * hidden_size, input_size))
self.weight_hh = Parameter(torch.Tensor(4 * hidden_size, hidden_size))
if bias:
self.bias = Parameter(torch.Tensor(4 * hidden_size))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
def forward(self, input, hx):
h, c = hx
pre = F.linear(input, self.weight_ih, self.bias) \
+ F.linear(h, self.weight_hh)
if self.grad_clip:
pre = clip_grad(pre, -self.grad_clip, self.grad_clip)
i = F.sigmoid(pre[:, :self.hidden_size])
f = F.sigmoid(pre[:, self.hidden_size: self.hidden_size * 2])
g = F.tanh(pre[:, self.hidden_size * 2: self.hidden_size * 3])
o = F.sigmoid(pre[:, self.hidden_size * 3:])
c = f * c + i * g
h = o * F.tanh(c)
return h, c
def cumax(logits, dim=-1):
return torch.cumsum(F.softmax(logits, dim), dim=dim)
class LSTMONCell(RNNCellBase):
'''
Shen & Tan et al. ORDERED NEURONS: INTEGRATING TREE STRUCTURES INTO RECURRENT NEURAL NETWORKS
'''
def __init__(self, input_size, hidden_size, bias=True, grad_clip=None):
super(LSTMONCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.grad_clip = grad_clip
self.weight_ih = Parameter(torch.Tensor(6 * hidden_size, input_size))
self.weight_hh = Parameter(torch.Tensor(6 * hidden_size, hidden_size))
if bias:
self.bias = Parameter(torch.Tensor(6 * hidden_size))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
def forward(self, input, hx):
h, c = hx
pre = F.linear(input, self.weight_ih, self.bias) \
+ F.linear(h, self.weight_hh)
if self.grad_clip:
pre = clip_grad(pre, -self.grad_clip, self.grad_clip)
i = F.sigmoid(pre[:, :self.hidden_size])
f = F.sigmoid(pre[:, self.hidden_size: self.hidden_size * 2])
g = F.tanh(pre[:, self.hidden_size * 2: self.hidden_size * 3])
o = F.sigmoid(pre[:, self.hidden_size * 3: self.hidden_size * 4])
ff = cumax(pre[:, self.hidden_size * 4: self.hidden_size * 5])
ii = 1 - cumax(pre[:, self.hidden_size * 5: self.hidden_size * 6])
w = ff * ii
f = f * w + (ff - w)
i = i * w + (ii - w)
c = f * c + i * g
h = o * F.tanh(c)
return h, c
class LSTMPCell(RNNCellBase):
def __init__(self, input_size, hidden_size, recurrent_size, bias=True, grad_clip=None):
super(LSTMPCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.recurrent_size = recurrent_size
self.grad_clip = grad_clip
self.weight_ih = Parameter(torch.Tensor(4 * hidden_size, input_size))
self.weight_hh = Parameter(torch.Tensor(4 * hidden_size, recurrent_size))
self.weight_rec = Parameter(torch.Tensor(recurrent_size, hidden_size))
if bias:
self.bias = Parameter(torch.Tensor(4 * hidden_size))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
def forward(self, input, hx):
h, c = hx
pre = F.linear(input, self.weight_ih, self.bias) \
+ F.linear(h, self.weight_hh)
if self.grad_clip:
pre = clip_grad(pre, -self.grad_clip, self.grad_clip)
i = F.sigmoid(pre[:, :self.hidden_size])
f = F.sigmoid(pre[:, self.hidden_size: self.hidden_size * 2])
g = F.tanh(pre[:, self.hidden_size * 2: self.hidden_size * 3])
o = F.sigmoid(pre[:, self.hidden_size * 3:])
c = f * c + i * g
h = o * F.tanh(c)
h = F.linear(h, self.weight_rec)
return h, c
class MGRUCell(RNNCellBase):
'''Minimal GRU
Reference:
Ravanelli et al. [Improving speech recognition by revising gated recurrent units](https://arxiv.org/abs/1710.00641).
'''
def __init__(self, input_size, hidden_size, bias=True, grad_clip=None):
super(MGRUCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.grad_clip = grad_clip
self.weight_ih = Parameter(torch.Tensor(2 * hidden_size, input_size))
self.weight_hh = Parameter(torch.Tensor(2 * hidden_size, hidden_size))
if bias:
self.bias = Parameter(torch.Tensor(2 * hidden_size))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
def forward(self, input, h):
ih = F.linear(input, self.weight_ih, self.bias)
hh = F.linear(h, self.weight_hh)
if self.grad_clip:
ih = clip_grad(ih, -self.grad_clip, self.grad_clip)
hh = clip_grad(hh, -self.grad_clip, self.grad_clip)
z = F.sigmoid(ih[:, :self.hidden_size] + hh[:, :self.hidden_size])
n = F.relu(ih[:, self.hidden_size:] + hh[:, self.hidden_size:])
h = (1 - z) * n + z * h
return h
class IndRNNCell(RNNCellBase):
'''
References:
Li et al. [Independently Recurrent Neural Network (IndRNN): Building A Longer and Deeper RNN](https://arxiv.org/abs/1803.04831).
'''
def __init__(self, input_size, hidden_size, bias=True, grad_clip=None):
super(IndRNNCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.grad_clip = grad_clip
self.weight_ih = Parameter(torch.Tensor(hidden_size, input_size))
self.weight_hh = Parameter(torch.Tensor(hidden_size))
if bias:
self.bias = Parameter(torch.Tensor(hidden_size))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
def forward(self, input, h):
output = F.linear(input, self.weight_ih, self.bias) + h * self.weight_hh
if self.grad_clip:
output = clip_grad(output, -self.grad_clip, self.grad_clip) # avoid explosive gradient
output = F.relu(output)
return output
class RNNBase(Module):
def __init__(self, mode, input_size, hidden_size, recurrent_size=None, num_layers=1, bias=True,
return_sequences=True, grad_clip=None):
super(RNNBase, self).__init__()
self.mode = mode
self.input_size = input_size
self.hidden_size = hidden_size
self.recurrent_size = recurrent_size
self.num_layers = num_layers
self.bias = bias
self.return_sequences = return_sequences
self.grad_clip = grad_clip
mode2cell = {'RNN': RNNCell,
'IndRNN': IndRNNCell,
'GRU': GRUCell,
'MGRU': GRUCell,
'LSTM': LSTMCell,
'LSTMON': LSTMONCell,
'LSTMP': LSTMPCell}
Cell = mode2cell[mode]
kwargs = {'input_size': input_size,
'hidden_size': hidden_size,
'bias': bias,
'grad_clip': grad_clip}
if self.mode == 'LSTMP':
kwargs['recurrent_size'] = recurrent_size
self.cell0= Cell(**kwargs)
for i in range(1, num_layers):
kwargs['input_size'] = recurrent_size if self.mode == 'LSTMP' else hidden_size
cell = Cell(**kwargs)
setattr(self, 'cell{}'.format(i), cell)
def forward(self, input, initial_states=None):
if initial_states is None:
zeros = Variable(torch.zeros(input.size(0), self.hidden_size))
if self.mode == 'LSTM' or self.mode == 'LSTMON':
initial_states = [(zeros, zeros), ] * self.num_layers
elif self.mode == 'LSTMP':
zeros_h = Variable(torch.zeros(input.size(0), self.recurrent_size))
initial_states = [(zeros_h, zeros), ] * self.num_layers
else:
initial_states = [zeros] * self.num_layers
assert len(initial_states) == self.num_layers
states = initial_states
outputs = []
time_steps = input.size(1)
for t in range(time_steps):
x = input[:, t, :]
for l in range(self.num_layers):
hx = getattr(self, 'cell{}'.format(l))(x, states[l])
states[l] = hx
if self.mode.startswith('LSTM'):
x = hx[0]
else:
x = hx
outputs.append(hx)
if self.return_sequences:
if self.mode.startswith('LSTM'):
hs, cs = zip(*outputs)
h = torch.stack(hs).transpose(0, 1)
c = torch.stack(cs).transpose(0, 1)
output = (h, c)
else:
output = torch.stack(outputs).transpose(0, 1)
else:
output = outputs[-1]
return output
class RNN(RNNBase):
def __init__(self, *args, **kwargs):
super(RNN, self).__init__('RNN', *args, **kwargs)
class GRU(RNNBase):
def __init__(self, *args, **kwargs):
super(GRU, self).__init__('GRU', *args, **kwargs)
class MGRU(RNNBase):
def __init__(self, *args, **kwargs):
super(MGRU, self).__init__('MGRU', *args, **kwargs)
class LSTM(RNNBase):
def __init__(self, *args, **kwargs):
super(LSTM, self).__init__('LSTM', *args, **kwargs)
class LSTMON(RNNBase):
def __init__(self, *args, **kwargs):
super(LSTMON, self).__init__('LSTMON', *args, **kwargs)
class LSTMP(RNNBase):
def __init__(self, *args, **kwargs):
super(LSTMP, self).__init__('LSTMP', *args, **kwargs)
class IndRNN(RNNBase):
'''
References:
Li et al. [Independently Recurrent Neural Network (IndRNN): Building A Longer and Deeper RNN](https://arxiv.org/abs/1803.04831).
'''
def __init__(self, *args, **kwargs):
super(IndRNN, self).__init__('IndRNN', *args, **kwargs)