Skip to content

Commit

Permalink
fix bugs; better naming
Browse files Browse the repository at this point in the history
  • Loading branch information
taoleicn committed Oct 23, 2019
1 parent c23741d commit 2c757bd
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 53 deletions.
32 changes: 16 additions & 16 deletions classification/train_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,21 +62,21 @@ def forward(self, input):
return self.out(output)

def eval_model(niter, model, valid_x, valid_y):
model.eval()
N = len(valid_x)
criterion = nn.CrossEntropyLoss()
correct = 0.0
cnt = 0
total_loss = 0.0
for x, y in zip(valid_x, valid_y):
x, y = Variable(x, volatile=True), Variable(y)
output = model(x)
loss = criterion(output, y)
total_loss += loss.item()*x.size(1)
pred = output.data.max(1)[1]
correct += pred.eq(y.data).cpu().sum()
cnt += y.numel()
model.train()
with torch.no_grad():
model.eval()
N = len(valid_x)
criterion = nn.CrossEntropyLoss()
correct = 0.0
cnt = 0
total_loss = 0.0
for x, y in zip(valid_x, valid_y):
output = model(x)
loss = criterion(output, y)
total_loss += loss.item()*x.size(1)
pred = output.data.max(1)[1]
correct += pred.eq(y).sum().item()
cnt += y.numel()
model.train()
return 1.0-correct/cnt

def train_model(epoch, model, optimizer,
Expand All @@ -95,7 +95,6 @@ def train_model(epoch, model, optimizer,
niter += 1
cnt += 1
model.zero_grad()
x, y = Variable(x), Variable(y)
output = model(x)
loss = criterion(output, y)
loss.backward()
Expand Down Expand Up @@ -178,6 +177,7 @@ def main(args):
filter(need_grad, model.parameters()),
lr = args.lr
)
print (model)

best_valid = 1e+8
test_err = 1e+8
Expand Down
12 changes: 6 additions & 6 deletions sru/sru_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ __global__ void sru_cuda_forward_kernel(
const auto u0 = *up;
const auto u1 = *(up + 1);
const auto u2 = *(up + 2);
const auto x_val = *xp;
const auto x_val = (skip_type) ? (*xp) : (scalar_t)0.f;
const auto g1 = sigmoidf(u1 + wc1*cur + bias1);
const auto g2 = sigmoidf(u2 + wc2*cur + bias2);
cur = (cur-u0)*g1 + u0;
Expand Down Expand Up @@ -271,7 +271,7 @@ __global__ void sru_cuda_bi_forward_kernel(
const auto u0 = *up;
const auto u1 = *(up + 1);
const auto u2 = *(up + 2);
const auto x_val = *xp;
const auto x_val = (skip_type) ? (*xp) : (scalar_t)0.f;
const auto g1 = sigmoidf(u1 + wc1*cur + bias1);
const auto g2 = sigmoidf(u2 + wc2*cur + bias2);
cur = (cur-u0)*g1 + u0;
Expand Down Expand Up @@ -459,7 +459,7 @@ void sru_cuda_forward(
h.data<scalar_t>(),
c.data<scalar_t>(),
U.data<scalar_t>(),
x.data<scalar_t>(),
x.numel() ? x.data<scalar_t>() : NULL,
weight_c.data<scalar_t>(),
bias.data<scalar_t>(),
c_init.data<scalar_t>(),
Expand Down Expand Up @@ -501,7 +501,7 @@ void sru_cuda_bi_forward(
h.data<scalar_t>(),
c.data<scalar_t>(),
U.data<scalar_t>(),
x.data<scalar_t>(),
x.numel() ? x.data<scalar_t>() : NULL,
weight_c.data<scalar_t>(),
bias.data<scalar_t>(),
c_init.data<scalar_t>(),
Expand Down Expand Up @@ -552,7 +552,7 @@ void sru_cuda_backward(
grad_bias.data<scalar_t>(),
grad_init.data<scalar_t>(),
U.data<scalar_t>(),
x.data<scalar_t>(),
x.numel() ? x.data<scalar_t>() : NULL,
weight_c.data<scalar_t>(),
bias.data<scalar_t>(),
c_init.data<scalar_t>(),
Expand Down Expand Up @@ -606,7 +606,7 @@ void sru_cuda_bi_backward(
grad_bias.data<scalar_t>(),
grad_init.data<scalar_t>(),
U.data<scalar_t>(),
x.data<scalar_t>(),
x.numel() ? x.data<scalar_t>() : NULL,
weight_c.data<scalar_t>(),
bias.data<scalar_t>(),
c_init.data<scalar_t>(),
Expand Down
59 changes: 28 additions & 31 deletions sru/sru_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ def sru_compute_cpu(u, x, weight_c, bias, init=None, mask_c=None):
u0 = u[:, :, di, :, 0].chunk(length)
u1 = (u[:, :, di, :, 1] + fb).chunk(length)
u2 = (u[:, :, di, :, 2] + rb).chunk(length)
xp = x_prime[:, :, di, :].chunk(length)
if x_prime is not None:
xp = x_prime[:, :, di, :].chunk(length)
for t in time_seq:
forget_t = (u1[t] + c_prev*fw).sigmoid()
reset_t = (u2[t] + c_prev*rw).sigmoid()
Expand Down Expand Up @@ -181,28 +182,24 @@ class SRUCell(nn.Module):
as per `LSTMCell`, `GRUCell` and `RNNCell` in PyTorch.
Args:
n_in (int) : the number of dimensions in a single
input_size (int) : the number of dimensions in a single
input sequence element. For example, if the input sequence
is a sequence of word embeddings, `input_size` is the
dimensionality of a single word embedding, e.g. 300.
n_out (int) : the dimensionality of the hidden state
hidden_size (int) : the dimensionality of the hidden state
of this cell.
dropout (float) : a number between 0.0 and 1.0. The amount of dropout
applied to `g(c_t)` internally in this cell.
rnn_dropout (float) : the amount of dropout applied to the input of
this cell.
use_tanh (bool) : use tanh activation
use_relu (bool) : use relu activation
use_selu (bool) : use selu activation
weight_norm (bool) : whether applyweight normalization on self.weight
is_input_normalized (bool) : whether the input is normalized (e.g. batch norm / layer norm)
bidirectional (bool) : whether or not to employ a bidirectional cell.
index (int) : index of this cell when multiple layers are stacked in SRU()
"""

def __init__(self,
n_in,
n_out,
input_size,
hidden_size,
dropout=0,
rnn_dropout=0,
bidirectional=False,
Expand All @@ -216,15 +213,15 @@ def __init__(self,
v1=False):

super(SRUCell, self).__init__()
self.input_size = n_in
self.hidden_size = n_out # hidden size per direction
self.output_size = n_out * 2 if bidirectional else n_out
self.input_size = input_size
self.hidden_size = hidden_size # hidden size per direction
self.output_size = hidden_size * 2 if bidirectional else hidden_size
self.rnn_dropout = rnn_dropout
self.dropout = dropout
self.bidirectional = bidirectional
#self.is_input_normalized = is_input_normalized
self.highway_bias = highway_bias
self.has_skip_term = has_skip_term
self.highway_bias = highway_bias
self.v1 = v1
self.rescale = rescale
self.activation_type = 0
Expand All @@ -235,9 +232,9 @@ def __init__(self,

# projection dimension
self.projection_size = 0
if n_proj > 0 and n_proj < n_in and n_proj < n_out:
if n_proj > 0 and n_proj < input_size and n_proj < self.output_size:
self.projection_size = n_proj

# number of sub-matrices used in SRU
self.num_matrices = 3
if has_skip_term and self.input_size != self.output_size:
Expand All @@ -257,12 +254,12 @@ def __init__(self,
))
self.weight_c = nn.Parameter(torch.Tensor(self.output_size * 2))
self.bias = nn.Parameter(torch.Tensor(self.output_size * 2))

# scaling constant used in highway connections when rescale=True
self.register_buffer('scale_x', torch.FloatTensor([0]))

if layer_norm:
self.layer_norm = nn.LayerNorm(n_in)
self.layer_norm = nn.LayerNorm(input_size)
else:
self.layer_norm = None
self.reset_parameters()
Expand All @@ -279,15 +276,15 @@ def reset_parameters(self):
val_range = (3.0 / d)**0.5
self.weight.data.uniform_(-val_range, val_range)
if self.projection_size > 0:
val_range_2 = (3.0 / self.weight_proj.size(0))**0.5
self.weight_proj.data.uniform_(-val_range_2, val_range_2)
val_range = (3.0 / self.weight_proj.size(0))**0.5
self.weight_proj.data.uniform_(-val_range, val_range)

# initialize bias
self.bias.data.zero_()
bias_val, output_size = self.highway_bias, self.output_size
self.bias.data[output_size:].zero_().add_(bias_val)

# projection matrix as a tensor of size:
# projection matrix as a tensor of size:
# (input_size, bidirection, hidden_size, num_matrices)
w = self.weight.data.view(d, -1, self.hidden_size, self.num_matrices)
if not self.v1:
Expand All @@ -314,7 +311,7 @@ def reset_parameters(self):
#self.weight_c.data.mul_(0.25)

self.scale_x.data[0] = 1
if not self.rescale:
if not (self.rescale and self.has_skip_term):
return
# scalar used to properly scale the highway output
scale_val = (1 + math.exp(bias_val) * 2)**0.5
Expand All @@ -331,7 +328,7 @@ def forward(self, input, c0=None, mask_pad=None):
if input.dim() != 2 and input.dim() != 3:
raise ValueError("Input must be 2 or 3 dimensional")

n_in, n_out = self.input_size, self.hidden_size
input_size, hidden_size = self.input_size, self.hidden_size
batch_size = input.size(-2)
if c0 is None:
c0 = Variable(input.data.new(
Expand All @@ -345,7 +342,7 @@ def forward(self, input, c0=None, mask_pad=None):

# apply dropout for multiplication
if self.training and (self.rnn_dropout > 0):
mask = self.get_dropout_mask_((batch_size, n_in), self.rnn_dropout)
mask = self.get_dropout_mask_((batch_size, input_size), self.rnn_dropout)
input = input * mask.expand_as(input)

# compute U that's (length, batch_size, output_size, num_matrices)
Expand All @@ -359,7 +356,7 @@ def forward(self, input, c0=None, mask_pad=None):
if input.is_cuda:
SRU_Compute = _lazy_load_cuda_kernel()(
self.activation_type,
n_out,
hidden_size,
self.bidirectional,
self.has_skip_term,
scale_val,
Expand All @@ -368,7 +365,7 @@ def forward(self, input, c0=None, mask_pad=None):
else:
SRU_Compute = SRU_CPU_class(
self.activation_type,
n_out,
hidden_size,
self.bidirectional,
self.has_skip_term,
scale_val,
Expand All @@ -383,13 +380,13 @@ def forward(self, input, c0=None, mask_pad=None):
h, c = SRU_Compute(U, residual, self.weight_c, self.bias, c0)

return h, c

def compute_U(self, input):
"""
SRU performs grouped matrix multiplication to transform
the input (length, batch_size, input_size) into a tensor
SRU performs grouped matrix multiplication to transform
the input (length, batch_size, input_size) into a tensor
U of size (length, batch_size, output_size, num_matrices)
"""
"""
# collapse (length, batch_size) into one dimension if necessary
x = input if input.dim() == 2 else input.contiguous().view(-1, self.input_size)
if self.projection_size > 0:
Expand Down Expand Up @@ -489,8 +486,8 @@ def __init__(self,

for i in range(num_layers):
l = SRUCell(
n_in=self.input_size if i == 0 else self.output_size,
n_out=self.hidden_size,
self.input_size if i == 0 else self.output_size,
self.hidden_size,
dropout=dropout if i + 1 != num_layers else 0,
rnn_dropout=rnn_dropout,
bidirectional=bidirectional,
Expand Down

0 comments on commit 2c757bd

Please sign in to comment.