Skip to content

Commit

Permalink
fix concat tuple
Browse files Browse the repository at this point in the history
  • Loading branch information
ShigekiKarita committed Jul 31, 2015
1 parent bc9e541 commit 73fd609
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
sizes = OptimizationSizes(
epoch_size=1000,
train_size=1,
eval_size=8,
eval_size=16,
mini_batch_size=1
)
model = GravesPredictionNet(nhidden=100)
Expand Down
8 changes: 4 additions & 4 deletions src/gravesnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def split_args(m, y, t_x, t_e):
return gps + t_x, (y_e, t_e)


def concat_losses(p, e, t_e):
def concat_losses(p, e):
loss_x = -F.sum(F.log(sum_axis(p))) / numpy.float32(p.data.shape[0])
loss_e = F.sigmoid_cross_entropy(*e)
return loss_x + loss_e
Expand All @@ -32,7 +32,7 @@ def concat_losses(p, e, t_e):
def loss_func(m, y, t_x, t_e):
x, e = split_args(m, y, t_x, t_e)
p = gaussian_mixture_2d_ref(*x)
return concat_losses(p, e, t_e)
return concat_losses(p, e)


# TODO: implement nice gaussian function to plot
Expand Down Expand Up @@ -104,9 +104,9 @@ def forward_one_step(self, hidden_state, lstm_cells, x_data, t_x_data, t_e_data,

gps, y_e, hidden_state, lstm_cells = self.bottle_neck(hidden_state, lstm_cells, x_data, train)
t_x = split_axis_by_widths(t_x, [1, 1])
gi, e = gps + t_x, (y_e, t_e)
gi, e = (gps + tuple(t_x)), (y_e, t_e)
p = gaussian_mixture_2d_ref(*gi)
loss = concat_losses(p, e, t_e)
loss = concat_losses(p, e)

return hidden_state, lstm_cells, loss

Expand Down
5 changes: 3 additions & 2 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def evaluate(context, model, lstm_cells: chainer.Variable,

for t in range(len(es[i]) - 1):
ci, cx, ce = create_inout(context, x, e, t, mean, stddev)
hidden_state, _, loss = model.forward_one_step(hidden_state, lstm_cells, ci, cx, ce, train=False)
hidden_state, lstm_cells, loss = model.forward_one_step(hidden_state, lstm_cells, ci, cx, ce, train=False)
total += loss.data.reshape(())

set_volatile(lstm_cells, False)
Expand Down Expand Up @@ -153,7 +153,8 @@ def optimize(model, sizes: OptimizationSizes, data_dir: str):
loss_point_train / sizes.eval,
loss_seq_train / sizes.eval))
sys.stdout.flush()
loss_point, loss_seq = evaluate(context, model, lstm_cells, sizes, txs, tes, mean, stddev)
lstm_copy = lstm_cells.copy()
loss_point, loss_seq = evaluate(context, model, lstm_copy, sizes, txs, tes, mean, stddev)
print('\ttest: [loss/point: {:.6f}, loss/seq: {:.6f}]'.format(loss_point, loss_seq))
sys.stdout.flush()
loss_point_train = 0.0
Expand Down

0 comments on commit 73fd609

Please sign in to comment.