diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 7cb65ae22ba2..a2578ea469a0 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -106,3 +106,4 @@ List of Contributors * [Max Kuhn](https://github.com/topepo) * [Yuqi Li](https://github.com/ziyeqinghan) * [Depeng Liang](https://github.com/Ldpe2G) +* [Kiko Qiu](https://github.com/kikoqiu) diff --git a/python/mxnet/model.py b/python/mxnet/model.py index b5ed84f2330f..013c86f8d2cc 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -607,9 +607,6 @@ def predict(self, X, num_batch=None, return_data=False, reset=True): i = 0 for batch in X: - if num_batch is not None and i == num_batch: - break - i += 1 _load_data(batch, data_arrays) self._pred_exec.forward(is_train=False) @@ -624,6 +621,9 @@ def predict(self, X, num_batch=None, return_data=False, reset=True): data_list[j].append(x[0:real_size].asnumpy()) for j, x in enumerate(batch.label): label_list[j].append(x[0:real_size].asnumpy()) + i += 1 + if num_batch is not None and i == num_batch: + break outputs = [np.concatenate(x) for x in output_list] if len(outputs) == 1: