diff --git a/be/src/exprs/aggregate/aggregate_function_ai_agg.cpp b/be/src/exprs/aggregate/aggregate_function_ai_agg.cpp index 44cbff4301be49..5b7e9efcb67dc8 100644 --- a/be/src/exprs/aggregate/aggregate_function_ai_agg.cpp +++ b/be/src/exprs/aggregate/aggregate_function_ai_agg.cpp @@ -21,7 +21,6 @@ #include "exprs/aggregate/helpers.h" namespace doris { -QueryContext* AggregateFunctionAIAggData::_ctx = nullptr; void register_aggregate_function_ai_agg(AggregateFunctionSimpleFactory& factory) { factory.register_function_both("ai_agg", diff --git a/be/src/exprs/aggregate/aggregate_function_ai_agg.h b/be/src/exprs/aggregate/aggregate_function_ai_agg.h index ae58216b451422..fd532c49d742b5 100644 --- a/be/src/exprs/aggregate/aggregate_function_ai_agg.h +++ b/be/src/exprs/aggregate/aggregate_function_ai_agg.h @@ -146,7 +146,7 @@ class AggregateFunctionAIAggData { } } - static void set_query_context(QueryContext* context) { _ctx = context; } + void set_query_context(QueryContext* context) { _ctx = context; } const std::string& get_task() const { return _task; } @@ -197,7 +197,7 @@ class AggregateFunctionAIAggData { process_current_context(); } - static size_t get_ai_context_window_size() { + size_t get_ai_context_window_size() const { DORIS_CHECK(_ctx); return static_cast(_ctx->query_options().ai_context_window_size); @@ -247,7 +247,7 @@ class AggregateFunctionAIAggData { inited = !data.empty(); } - static QueryContext* _ctx; + QueryContext* _ctx = nullptr; AIResource _ai_config; std::shared_ptr _ai_adapter; std::string _task; @@ -264,7 +264,7 @@ class AggregateFunctionAIAgg final void set_query_context(QueryContext* context) override { if (context) { - AggregateFunctionAIAggData::set_query_context(context); + _ctx = context; } } @@ -274,6 +274,11 @@ class AggregateFunctionAIAgg final bool is_blockable() const override { return true; } + void create(AggregateDataPtr __restrict place) const override { + new (place) AggregateFunctionAIAggData; + data(place).set_query_context(_ctx); + } + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena&) const override { data(place).prepare( @@ -303,7 +308,10 @@ class AggregateFunctionAIAgg final } } - void reset(AggregateDataPtr place) const override { data(place).reset(); } + void reset(AggregateDataPtr place) const override { + data(place).reset(); + data(place).set_query_context(_ctx); + } void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena&) const override { @@ -317,6 +325,7 @@ class AggregateFunctionAIAgg final void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, Arena&) const override { data(place).read(buf); + data(place).set_query_context(_ctx); } void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { @@ -324,6 +333,9 @@ class AggregateFunctionAIAgg final DCHECK(!result.empty()) << "AI returns an empty result"; assert_cast(to).insert_data(result.data(), result.size()); } + +private: + QueryContext* _ctx = nullptr; }; } // namespace doris diff --git a/be/test/ai/aggregate_function_ai_agg_test.cpp b/be/test/ai/aggregate_function_ai_agg_test.cpp index a5ebbd8fb79b52..4fb9c7969d0273 100644 --- a/be/test/ai/aggregate_function_ai_agg_test.cpp +++ b/be/test/ai/aggregate_function_ai_agg_test.cpp @@ -66,7 +66,7 @@ class AggregateFunctionAIAggTest : public ::testing::Test { _agg_function->set_query_context(_query_ctx.get()); } - void TearDown() override { AggregateFunctionAIAggData::_ctx = nullptr; } + void TearDown() override {} protected: std::unique_ptr _runtime_state; @@ -424,6 +424,56 @@ TEST_F(AggregateFunctionAIAggTest, ai_context_window_size_session_variable_test) _agg_function->destroy(place); } +TEST_F(AggregateFunctionAIAggTest, query_context_is_isolated_between_function_instances_test) { + TQueryOptions first_query_options = create_fake_query_options(); + first_query_options.__set_ai_context_window_size(8); + auto first_query_ctx = + MockQueryContext::create(TUniqueId(), ExecEnv::GetInstance(), first_query_options); + first_query_ctx->set_mock_ai_resource(); + + TQueryOptions second_query_options = create_fake_query_options(); + second_query_options.__set_ai_context_window_size(1024); + auto second_query_ctx = + MockQueryContext::create(TUniqueId(), ExecEnv::GetInstance(), second_query_options); + second_query_ctx->set_mock_ai_resource(); + + AggregateFunctionSimpleFactory factory; + register_aggregate_function_ai_agg(factory); + auto first_agg_function = factory.get("ai_agg", _data_types, nullptr, false, -1); + auto second_agg_function = factory.get("ai_agg", _data_types, nullptr, false, -1); + ASSERT_TRUE(first_agg_function != nullptr); + ASSERT_TRUE(second_agg_function != nullptr); + + first_agg_function->set_query_context(first_query_ctx.get()); + second_agg_function->set_query_context(second_query_ctx.get()); + + auto resource_col = ColumnString::create(); + auto text_col = ColumnString::create(); + auto task_col = ColumnString::create(); + + resource_col->insert_data("mock_resource", 13); + text_col->insert_data("abcd", 4); + task_col->insert_data("summarize", 9); + + resource_col->insert_data("mock_resource", 13); + text_col->insert_data("efgh", 4); + task_col->insert_data("summarize", 9); + + std::unique_ptr memory(new char[first_agg_function->size_of_data()]); + AggregateDataPtr place = memory.get(); + first_agg_function->create(place); + + const IColumn* columns[3] = {resource_col.get(), text_col.get(), task_col.get()}; + first_agg_function->add(place, columns, 0, _arena); + first_agg_function->add(place, columns, 1, _arena); + + const auto& data = *reinterpret_cast(place); + std::string actual(reinterpret_cast(data.data.data()), data.data.size()); + EXPECT_EQ(actual, "this is a mock response\nefgh"); + + first_agg_function->destroy(place); +} + TEST_F(AggregateFunctionAIAggTest, gemini_endpoint_normalize_to_generate_content_test) { AIResource resource; resource.provider_type = "GEMINI";