From 5fe0c184ac1389d6fd7c9550631be6073a54817e Mon Sep 17 00:00:00 2001 From: ycycse Date: Mon, 4 Nov 2024 15:14:41 +0800 Subject: [PATCH] fix built-in model inference error --- .../org/apache/iotdb/ainode/it/AINodeBasicIT.java | 13 +++++++++++++ .../iotdb/ainode/model/built_in_model_factory.py | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeBasicIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeBasicIT.java index 17404a50e9dc0..0306b1a522559 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeBasicIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeBasicIT.java @@ -179,6 +179,8 @@ public void ModelOperationTest() { public void callInferenceTest() { String sql = "CALL INFERENCE(identity, \"select s0,s1,s2 from root.AI.data\")"; String sql2 = "CALL INFERENCE(identity, \"select s2,s0,s1 from root.AI.data\")"; + String sql3 = + "CALL INFERENCE(_NaiveForecaster, \"select s0 from root.AI.data\", predict_length=3)"; try (Connection connection = EnvFactory.getEnv().getConnection(); Statement statement = connection.createStatement()) { @@ -215,6 +217,17 @@ public void callInferenceTest() { } assertEquals(7, count); } + + try (ResultSet resultSet = statement.executeQuery(sql3)) { + ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); + checkHeader(resultSetMetaData, "output0,output1,output2"); + int count = 0; + while (resultSet.next()) { + count++; + } + assertEquals(3, count); + } + } catch (SQLException e) { fail(e.getMessage()); } diff --git a/iotdb-core/ainode/iotdb/ainode/model/built_in_model_factory.py b/iotdb-core/ainode/iotdb/ainode/model/built_in_model_factory.py index 3854ba4e367f4..82443012176b0 100644 --- a/iotdb-core/ainode/iotdb/ainode/model/built_in_model_factory.py +++ b/iotdb-core/ainode/iotdb/ainode/model/built_in_model_factory.py @@ -99,7 +99,7 @@ def fetch_built_in_model(model_id, inference_attributes): else: raise BuiltInModelNotSupportError(model_id) - return model, attributes + return model class Attribute(object):