From 5ba19845afed558b13815d7359f9378694df0ad3 Mon Sep 17 00:00:00 2001 From: Heklis Date: Tue, 21 May 2019 22:35:38 +0800 Subject: [PATCH] Update base_model.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修改了当样本数是batch_size整数倍时的处理操作 --- kashgari/tasks/classification/base_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/kashgari/tasks/classification/base_model.py b/kashgari/tasks/classification/base_model.py index 60489ec1..55137bf1 100644 --- a/kashgari/tasks/classification/base_model.py +++ b/kashgari/tasks/classification/base_model.py @@ -184,6 +184,7 @@ def get_data_generator(self, target_x.append(x[start_index: end_index]) target_y = y_data[start_index: end_index] if len(target_x[0]) == 0: + target_x.pop() for x in x_data: target_x.append(x[0: batch_size]) target_y = y_data[0: batch_size]