-
Notifications
You must be signed in to change notification settings - Fork 335
/
llm_helper.py
156 lines (144 loc) · 6.07 KB
/
llm_helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from openai import AzureOpenAI
from typing import List, Union, cast
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from semantic_kernel.connectors.ai.open_ai import AzureChatCompletion
from semantic_kernel.connectors.ai.open_ai.prompt_execution_settings.azure_chat_prompt_execution_settings import (
AzureChatPromptExecutionSettings,
)
from .env_helper import EnvHelper
class LLMHelper:
def __init__(self):
self.env_helper: EnvHelper = EnvHelper()
self.auth_type_keys = self.env_helper.is_auth_type_keys()
self.token_provider = self.env_helper.AZURE_TOKEN_PROVIDER
if self.auth_type_keys:
self.openai_client = AzureOpenAI(
azure_endpoint=self.env_helper.AZURE_OPENAI_ENDPOINT,
api_version=self.env_helper.AZURE_OPENAI_API_VERSION,
api_key=self.env_helper.OPENAI_API_KEY,
)
else:
self.openai_client = AzureOpenAI(
azure_endpoint=self.env_helper.AZURE_OPENAI_ENDPOINT,
api_version=self.env_helper.AZURE_OPENAI_API_VERSION,
azure_ad_token_provider=self.token_provider,
)
self.llm_model = self.env_helper.AZURE_OPENAI_MODEL
self.llm_max_tokens = (
int(self.env_helper.AZURE_OPENAI_MAX_TOKENS)
if self.env_helper.AZURE_OPENAI_MAX_TOKENS != ""
else None
)
self.embedding_model = self.env_helper.AZURE_OPENAI_EMBEDDING_MODEL
def get_llm(self):
if self.auth_type_keys:
return AzureChatOpenAI(
deployment_name=self.llm_model,
temperature=0,
max_tokens=self.llm_max_tokens,
openai_api_version=self.openai_client._api_version,
azure_endpoint=self.env_helper.AZURE_OPENAI_ENDPOINT,
api_key=self.env_helper.OPENAI_API_KEY,
)
else:
return AzureChatOpenAI(
deployment_name=self.llm_model,
temperature=0,
max_tokens=self.llm_max_tokens,
openai_api_version=self.openai_client._api_version,
azure_endpoint=self.env_helper.AZURE_OPENAI_ENDPOINT,
azure_ad_token_provider=self.token_provider,
)
# TODO: This needs to have a custom callback to stream back to the UI
def get_streaming_llm(self):
if self.auth_type_keys:
return AzureChatOpenAI(
azure_endpoint=self.env_helper.AZURE_OPENAI_ENDPOINT,
api_key=self.env_helper.OPENAI_API_KEY,
streaming=True,
callbacks=[StreamingStdOutCallbackHandler],
deployment_name=self.llm_model,
temperature=0,
max_tokens=self.llm_max_tokens,
openai_api_version=self.openai_client._api_version,
)
else:
return AzureChatOpenAI(
azure_endpoint=self.env_helper.AZURE_OPENAI_ENDPOINT,
api_key=self.env_helper.OPENAI_API_KEY,
streaming=True,
callbacks=[StreamingStdOutCallbackHandler],
deployment_name=self.llm_model,
temperature=0,
max_tokens=self.llm_max_tokens,
openai_api_version=self.openai_client._api_version,
azure_ad_token_provider=self.token_provider,
)
def get_embedding_model(self):
if self.auth_type_keys:
return AzureOpenAIEmbeddings(
azure_endpoint=self.env_helper.AZURE_OPENAI_ENDPOINT,
api_key=self.env_helper.OPENAI_API_KEY,
azure_deployment=self.embedding_model,
chunk_size=1,
)
else:
return AzureOpenAIEmbeddings(
azure_endpoint=self.env_helper.AZURE_OPENAI_ENDPOINT,
azure_deployment=self.embedding_model,
chunk_size=1,
azure_ad_token_provider=self.token_provider,
)
def generate_embeddings(self, input: Union[str, list[int]]) -> List[float]:
return (
self.openai_client.embeddings.create(
input=[input], model=self.embedding_model
)
.data[0]
.embedding
)
def get_chat_completion_with_functions(
self, messages: list[dict], functions: list[dict], function_call: str = "auto"
):
return self.openai_client.chat.completions.create(
model=self.llm_model,
messages=messages,
functions=functions,
function_call=function_call,
)
def get_chat_completion(
self, messages: list[dict], model: str | None = None, **kwargs
):
return self.openai_client.chat.completions.create(
model=model or self.llm_model,
messages=messages,
max_tokens=self.llm_max_tokens,
**kwargs
)
def get_sk_chat_completion_service(self, service_id: str):
if self.auth_type_keys:
return AzureChatCompletion(
service_id=service_id,
deployment_name=self.llm_model,
endpoint=self.env_helper.AZURE_OPENAI_ENDPOINT,
api_version=self.env_helper.AZURE_OPENAI_API_VERSION,
api_key=self.env_helper.OPENAI_API_KEY,
)
else:
return AzureChatCompletion(
service_id=service_id,
deployment_name=self.llm_model,
endpoint=self.env_helper.AZURE_OPENAI_ENDPOINT,
api_version=self.env_helper.AZURE_OPENAI_API_VERSION,
ad_token_provider=self.token_provider,
)
def get_sk_service_settings(self, service: AzureChatCompletion):
return cast(
AzureChatPromptExecutionSettings,
service.instantiate_prompt_execution_settings(
service_id=service.service_id,
temperature=0,
max_tokens=self.llm_max_tokens,
),
)