-
Notifications
You must be signed in to change notification settings - Fork 74
/
JavaSemanticKernelChainsApproach.java
178 lines (160 loc) · 7.41 KB
/
JavaSemanticKernelChainsApproach.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
// Copyright (c) Microsoft. All rights reserved.
package com.microsoft.openai.samples.rag.ask.approaches.semantickernel;
import com.azure.ai.openai.OpenAIAsyncClient;
import com.microsoft.openai.samples.rag.approaches.ContentSource;
import com.microsoft.openai.samples.rag.approaches.RAGApproach;
import com.microsoft.openai.samples.rag.approaches.RAGOptions;
import com.microsoft.openai.samples.rag.approaches.RAGResponse;
import com.microsoft.openai.samples.rag.proxy.AzureAISearchProxy;
import com.microsoft.openai.samples.rag.proxy.OpenAIProxy;
import com.microsoft.openai.samples.rag.retrieval.semantickernel.AzureAISearchPlugin;
import com.microsoft.semantickernel.Kernel;
import com.microsoft.semantickernel.aiservices.openai.chatcompletion.OpenAIChatCompletion;
import com.microsoft.semantickernel.orchestration.FunctionResult;
import com.microsoft.semantickernel.plugin.KernelPluginFactory;
import com.microsoft.semantickernel.semanticfunctions.KernelFunctionArguments;
import com.microsoft.semantickernel.services.chatcompletion.ChatCompletionService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import java.io.OutputStream;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
/**
* Use Java Semantic Kernel framework with semantic and native functions chaining. It uses an
* imperative style for AI orchestration through sequentially call semantic kernel functions.
* InformationFinder.SearchFromQuestion native function and RAG.AnswerQuestion semantic function are called
* sequentially. Several Azure AI Search retrieval options are available: Text, Vector, Hybrid.
*/
@Component
public class JavaSemanticKernelChainsApproach implements RAGApproach<String, RAGResponse> {
private static final Logger LOGGER =
LoggerFactory.getLogger(JavaSemanticKernelChainsApproach.class);
private static final String PLAN_PROMPT =
"""
Take the input as a question and answer it finding any information needed
""";
private final AzureAISearchProxy azureAISearchProxy;
private final OpenAIProxy openAIProxy;
private final OpenAIAsyncClient openAIAsyncClient;
@Value("${openai.chatgpt.deployment}")
private String gptChatDeploymentModelId;
public JavaSemanticKernelChainsApproach(
AzureAISearchProxy azureAISearchProxy,
OpenAIAsyncClient openAIAsyncClient,
OpenAIProxy openAIProxy) {
this.azureAISearchProxy = azureAISearchProxy;
this.openAIAsyncClient = openAIAsyncClient;
this.openAIProxy = openAIProxy;
}
/**
* @param question
* @param options
* @return
*/
@Override
public RAGResponse run(String question, RAGOptions options) {
// Build semantic kernel context
Kernel semanticKernel = buildSemanticKernel(options);
// STEP 1: Retrieve relevant documents using user question. It reuses the
// AzureAISearchRetriever appraoch through the AzureAISearchPlugin native function.
FunctionResult<String> searchContext = semanticKernel
.getPlugin("InformationFinder")
.get("SearchFromQuestion")
.invokeAsync(semanticKernel)
.withArguments(
KernelFunctionArguments.builder()
.withInput(question)
.build()
)
.withResultType(String.class)
.block();
var sources = formSourcesList(searchContext.getResult());
// STEP 2: Build a SK context with the sources retrieved from the memory store and the user
// question.
var answerVariables =
KernelFunctionArguments.builder()
.withVariable("sources", searchContext.getResult())
.withVariable("input", question)
.build();
/**
* STEP 3: Get a reference of the semantic function [AnswerQuestion] of the [RAG] plugin
* (a.k.a. skill) from the SK skills registry and provide it with the pre-built context.
* Triggering Open AI to get an answerVariables.
*/
FunctionResult<String> answerExecutionContext = semanticKernel
.invokeAsync("RAG", "AnswerQuestion")
.withArguments(answerVariables)
.withResultType(String.class)
.block();
return new RAGResponse.Builder()
.prompt("Prompt is managed by Semantic Kernel")
.answer(answerExecutionContext.getResult())
.sources(sources)
.sourcesAsText(searchContext.getResult())
.question(question)
.build();
}
@Override
public void runStreaming(
String questionOrConversation, RAGOptions options, OutputStream outputStream) {
throw new IllegalStateException("Streaming not supported for this approach");
}
private List<ContentSource> formSourcesList(String result) {
if (result == null) {
return Collections.emptyList();
}
return Arrays.stream(result.split("\n"))
.map(
source -> {
String[] split = source.split(":", 2);
if (split.length >= 2) {
var sourceName = split[0].trim();
var sourceContent = split[1].trim();
return new ContentSource(sourceName, sourceContent);
} else {
return null;
}
})
.filter(Objects::nonNull)
.collect(Collectors.toList());
}
/**
* Build semantic kernel context with AnswerQuestion semantic function and
* InformationFinder.SearchFromQuestion native function. AnswerQuestion is imported from
* src/main/resources/semantickernel/Plugins. InformationFinder.SearchFromQuestion is implemented in a
* traditional Java class method: AzureAISearchPlugin.searchFromConversation
*
* @param options
* @return
*/
private Kernel buildSemanticKernel(RAGOptions options) {
return Kernel.builder()
.withAIService(
ChatCompletionService.class,
OpenAIChatCompletion.builder()
.withModelId(gptChatDeploymentModelId)
.withOpenAIAsyncClient(this.openAIAsyncClient)
.build()
)
.withPlugin(
KernelPluginFactory.createFromObject(
new AzureAISearchPlugin(this.azureAISearchProxy, this.openAIProxy, options),
"InformationFinder")
)
.withPlugin(
KernelPluginFactory.importPluginFromResourcesDirectory(
"semantickernel/Plugins",
"RAG",
"AnswerQuestion",
null,
String.class
)
)
.build();
}
}