diff --git a/example/warpctc/lstm.py b/example/warpctc/lstm.py index 97fda6b9c9d4..32ba2455e11d 100644 --- a/example/warpctc/lstm.py +++ b/example/warpctc/lstm.py @@ -72,7 +72,7 @@ def lstm_unroll(num_lstm_layer, seq_len, hidden_concat = mx.sym.Concat(*hidden_all, dim=0) pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=11) - label = mx.sym.Reshape(data=label, target_shape=(0,)) + label = mx.sym.Reshape(data=label, shape=(-1,)) label = mx.sym.Cast(data = label, dtype = 'int32') sm = mx.sym.WarpCTC(data=pred, label=label, label_length = num_label, input_length = seq_len) return sm