Skip to content

feat: 增加Gemini大模型支持 #439

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from setting.models_provider.impl.xf_model_provider.xf_model_provider import XunFeiModelProvider
from setting.models_provider.impl.zhipu_model_provider.zhipu_model_provider import ZhiPuModelProvider
from setting.models_provider.impl.deepseek_model_provider.deepseek_model_provider import DeepSeekModelProvider
from setting.models_provider.impl.gemini_model_provider.gemini_model_provider import GeminiModelProvider


class ModelProvideConstants(Enum):
Expand All @@ -29,3 +30,4 @@ class ModelProvideConstants(Enum):
model_zhipu_provider = ZhiPuModelProvider()
model_xf_provider = XunFeiModelProvider()
model_deepseek_provider = DeepSeekModelProvider()
model_gemini_provider = GeminiModelProvider()
Original file line number Diff line number Diff line change
Expand Up @@ -21,43 +21,6 @@
from setting.models_provider.impl.azure_model_provider.model.azure_chat_model import AzureChatModel
from smartdoc.conf import PROJECT_DIR

"""
class AzureLLMModelCredential(BaseForm, BaseModelCredential):

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
model_type_list = AzureModelProvider().get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')

for key in ['api_base', 'api_key', 'deployment_name']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = AzureModelProvider().get_model(model_type, model_name, model_credential)
model.invoke([HumanMessage(content='你好')])
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, '校验失败,请检查参数是否正确')
else:
return False

return True

def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}

api_base = forms.TextInputField('API 版本 (api_version)', required=True)

api_key = forms.PasswordInputField("API Key(API 密钥)", required=True)

deployment_name = forms.TextInputField("部署名(deployment_name)", required=True)
"""


class DefaultAzureLLMModelCredential(BaseForm, BaseModelCredential):

Expand Down Expand Up @@ -97,8 +60,6 @@ def encryption_dict(self, model: Dict[str, object]):
deployment_name = forms.TextInputField("部署名 (deployment_name)", required=True)


# azure_llm_model_credential: AzureLLMModelCredential = AzureLLMModelCredential()

base_azure_llm_model_credential = DefaultAzureLLMModelCredential()

model_dict = {
Expand All @@ -114,7 +75,6 @@ def get_dialogue_number(self):
return 3

def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> AzureChatModel:
model_info: ModelInfo = model_dict.get(model_name)
azure_chat_open_ai = AzureChatModel(
azure_endpoint=model_credential.get('api_base'),
openai_api_version=model_credential.get('api_version', '2024-02-15-preview'),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,15 @@

class AzureChatModel(AzureChatOpenAI):
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
try:
return super().get_num_tokens_from_messages(messages)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])

def get_num_tokens(self, text: str) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
try:
return super().get_num_tokens(text)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project :MaxKB
@File :__init__.py.py
@Author :Brian Yang
@Date :5/13/24 7:40 AM
"""
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project :MaxKB
@File :gemini_model_provider.py
@Author :Brian Yang
@Date :5/13/24 7:47 AM
"""
import os
from typing import Dict

from langchain.schema import HumanMessage

from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \
ModelInfo, ModelTypeConst, ValidCode
from setting.models_provider.impl.gemini_model_provider.model.gemini_chat_model import GeminiChatModel
from smartdoc.conf import PROJECT_DIR


class GeminiLLMModelCredential(BaseForm, BaseModelCredential):

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
model_type_list = GeminiModelProvider().get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')

for key in ['api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = GeminiModelProvider().get_model(model_type, model_name, model_credential)
model.invoke([HumanMessage(content='你好')])
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True

def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}

api_key = forms.PasswordInputField('API Key', required=True)


gemini_llm_model_credential = GeminiLLMModelCredential()

model_dict = {
'gemini-1.0-pro': ModelInfo('gemini-1.0-pro', '最新的Gemini 1.0 Pro模型,随Google更新而更新',
ModelTypeConst.LLM,
gemini_llm_model_credential,
),
'gemini-1.0-pro-vision': ModelInfo('gemini-1.0-pro-vision', '最新的Gemini 1.0 Pro Vision模型,随Google更新而更新',
ModelTypeConst.LLM,
gemini_llm_model_credential,
),
}


class GeminiModelProvider(IModelProvider):

def get_dialogue_number(self):
return 3

def get_model(self, model_type, model_name, model_credential: Dict[str, object],
**model_kwargs) -> GeminiChatModel:
gemini_chat = GeminiChatModel(
model=model_name,
google_api_key=model_credential.get('api_key')
)
return gemini_chat

def get_model_credential(self, model_type, model_name):
if model_name in model_dict:
return model_dict.get(model_name).model_credential
return gemini_llm_model_credential

def get_model_provide_info(self):
return ModelProvideInfo(provider='model_gemini_provider', name='Gemini', icon=get_file_content(
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'gemini_model_provider', 'icon',
'gemini_icon_svg')))

def get_model_list(self, model_type: str):
if model_type is None:
raise AppApiException(500, '模型类型不能为空')
return [model_dict.get(key).to_dict() for key in
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]

def get_model_type_list(self):
return [{'key': "大语言模型", 'value': "LLM"}]
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
<svg width="100%" height="100%" viewBox="0 0 28 28" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M14 28C14 26.0633 13.6267 24.2433 12.88 22.54C12.1567 20.8367 11.165 19.355 9.905 18.095C8.645 16.835 7.16333 15.8433 5.46 15.12C3.75667 14.3733 1.93667 14 0 14C1.93667 14 3.75667 13.6383 5.46 12.915C7.16333 12.1683 8.645 11.165 9.905 9.905C11.165 8.645 12.1567 7.16333 12.88 5.46C13.6267 3.75667 14 1.93667 14 0C14 1.93667 14.3617 3.75667 15.085 5.46C15.8317 7.16333 16.835 8.645 18.095 9.905C19.355 11.165 20.8367 12.1683 22.54 12.915C24.2433 13.6383 26.0633 14 28 14C26.0633 14 24.2433 14.3733 22.54 15.12C20.8367 15.8433 19.355 16.835 18.095 18.095C16.835 19.355 15.8317 20.8367 15.085 22.54C14.3617 24.2433 14 26.0633 14 28Z" fill="url(#paint0_radial_16771_53212)"/>
<defs>
<radialGradient id="paint0_radial_16771_53212" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(2.77876 11.3795) rotate(18.6832) scale(29.8025 238.737)">
<stop offset="0.0671246" stop-color="#9168C0"/>
<stop offset="0.342551" stop-color="#5684D1"/>
<stop offset="0.672076" stop-color="#1BA1E3"/>
</radialGradient>
</defs>
</svg>
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project :MaxKB
@File :gemini_chat_model.py
@Author :Brian Yang
@Date :5/13/24 7:40 AM
"""
from typing import List

from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_google_genai import ChatGoogleGenerativeAI

from common.config.tokenizer_manage_config import TokenizerManage


class GeminiChatModel(ChatGoogleGenerativeAI):
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
try:
return super().get_num_tokens_from_messages(messages)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])

def get_num_tokens(self, text: str) -> int:
try:
return super().get_num_tokens(text)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ zhipuai = "^2.0.1"
httpx = "^0.27.0"
httpx-sse = "^0.4.0"
websocket-client = "^1.7.0"
langchain-google-genai = "^1.0.3"

[build-system]
requires = ["poetry-core"]
Expand Down