Skip to content

Commit a75880c

Browse files
committed
add flask example file
1 parent 267b9f2 commit a75880c

File tree

7 files changed

+271
-1
lines changed

7 files changed

+271
-1
lines changed

examples/__init__.py

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Alipay.com Inc.
4+
Copyright (c) 2004-2023 All Rights Reserved.
5+
------------------------------------------------------
6+
File Name : __init__.py.py
7+
Author : fuhui.phe
8+
Create Time : 2023/11/9 14:21
9+
Description : description what the main function of this file
10+
Change Activity:
11+
version0 : 2023/11/9 14:21 by fuhui.phe init
12+
"""

examples/flask/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# -*- coding: utf-8 -*-

examples/flask/data_insert.py

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# -*- coding: utf-8 -*-
2+
import json
3+
import requests
4+
5+
6+
def run():
7+
url = 'http://127.0.0.1:5000/modelcache'
8+
type = 'insert'
9+
scope = {"model": "CODEGPT-1109"}
10+
chat_info = [{"query": [{"role": "system", "content": "你是一个python助手"}, {"role": "user", "content": "hello"}],
11+
"answer": "你好,我是智能助手,请问有什么能帮您!"}]
12+
data = {'type': type, 'scope': scope, 'chat_info': chat_info}
13+
headers = {"Content-Type": "application/json"}
14+
res = requests.post(url, headers=headers, json=json.dumps(data))
15+
res_text = res.text
16+
print('res_text: {}'.format(res_text))
17+
18+
19+
if __name__ == '__main__':
20+
run()

examples/flask/data_query.py

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# -*- coding: utf-8 -*-
2+
import json
3+
import requests
4+
5+
6+
def run():
7+
url = 'http://127.0.0.1:5000/modelcache'
8+
type = 'query'
9+
scope = {"model": "CODEGPT-1109"}
10+
query = [{"role": "system", "content": "你是一个python助手"}, {"role": "user", "content": "hello"}]
11+
data = {'type': type, 'scope': scope, 'query': query}
12+
13+
headers = {"Content-Type": "application/json"}
14+
res = requests.post(url, headers=headers, json=json.dumps(data))
15+
res_text = res.text
16+
print('res_text: {}'.format(res_text))
17+
18+
19+
if __name__ == '__main__':
20+
run()

examples/flask/data_query_long.py

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# -*- coding: utf-8 -*-
2+
import json
3+
import requests
4+
5+
6+
def run():
7+
url = 'http://127.0.0.1:5000/modelcache'
8+
type = 'query'
9+
scope = {"model": "CODEGPT-1109"}
10+
system_conten = """
11+
<|role_start|>system<|role_end|>你是python助手, 你必须提供中立的、无害的答案帮助用户解决代码相关的问题,在回答用户问题过程中,你必须遵守如下准则:
12+
以用户选择的语言(如中文、英语)进行理解和交流
13+
回答应该是信息丰富的、直观的、合乎逻辑的和可操作的
14+
不泄漏模型的架构和内部实现细节
15+
不收集、存储或共享用户的个人信息或敏感信息,不使用未经许可的数据集,遵守数据集的许可协议和规定,并且不能改变数据集的原始内容
16+
不能生成涉及诽谤、歧视、侵犯知识产权等的内容,不违反法律和道德规范
17+
不能通过生成内容引起身体或精神上的伤害,例如,不包含暴力、恐怖、色情等内容
18+
不能使用或生成不准确、误导或伪造的信息,不能改变数据集的原始内容
19+
努力消除或减少内容中的偏见和歧视,包括种族、性别、性取向、宗教和政治观点等方面的偏见
20+
回答不会伤害人类、损害社会、危害环境和生态系统等方面
21+
<|end|><|role_start|>human<|role_end|>Analyze data from a survey and create visualizations to present the results.
22+
<|end|><|role_start|>bot<|role_end|>Sure thing! Let me just check if I can access the survey data.
23+
<|end|><|role_start|>human<|role_end|>What kind of visualizations can you create?
24+
<|end|><|role_start|>bot<|role_end|>I can create various types of visualizations such as bar charts, line graphs, scatter plots, pie charts, and more. I can also customize the visualizations according to your requirements.
25+
<|end|>
26+
"""
27+
user_content = """xP3(Crosslingual Public Pool of Prompts)是一个多语言指令数据集,由46种语言的16种不同的自然语言任务组成。数据集中的每个实例都有两个组件:“inputs”和“targets”。“inputs”是一种自然语言的任务描述。“targets”是正确遵循“inputs”指令的文本结果。xP3中的原始数据来自三个来源:英语指令数据集P3, P3中的4个英语未见任务(例如,翻译,程序合成)和30个多语言NLP数据集。作者通过从PromptSource中采样人工编写的任务模板,然后填充模板,将不同的NLP任务转换为统一的形式化,构建了xP3数据集。Unnatural Instructions是一个包含大约24万个实例的指令数据集,使用InstructGPT构建。数据集中的每个实例都有四个组件: INSTRUCTION,INPUT, CONSTRAINTS,OUTPUT。“INSTRUCTION”是用自然语言对教学任务的描述。“INPUT”是自然语言中的参数,用于实例化指令任务。“CONSTRAINTS”是任务输出空间的限制。“OUTPUT”是在给定输入参数和约束条件下正确执行指令的文本序列。
28+
"""
29+
30+
query = [{"role": "system", "content": system_conten}, {"role": "user", "content": user_content}]
31+
data = {'type': type, 'scope': scope, 'query': query}
32+
33+
headers = {"Content-Type": "application/json"}
34+
res = requests.post(url, headers=headers, json=json.dumps(data))
35+
res_text = res.text
36+
print('res_text: {}'.format(res_text))
37+
38+
39+
if __name__ == '__main__':
40+
run()

flask4modelcache.py

+178
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# -*- coding: utf-8 -*-
2+
import json
3+
from flask import Flask, request
4+
import logging
5+
from datetime import datetime
6+
import configparser
7+
import time
8+
import json
9+
from modelcache import cache
10+
from modelcache.adapter import adapter
11+
from modelcache.manager import CacheBase, VectorBase, get_data_manager
12+
from modelcache.similarity_evaluation.distance import SearchDistanceEvaluation
13+
from modelcache.processor.pre import query_multi_splicing
14+
from modelcache.processor.pre import insert_multi_splicing
15+
from concurrent.futures import ThreadPoolExecutor
16+
from modelcache.utils.model_filter import model_blacklist_filter
17+
from modelcache.embedding import Data2VecAudio
18+
# from modelcache.maya_embedding_service.maya_embedding_service import get_cache_embedding_text2vec
19+
20+
21+
# 创建一个Flask实例
22+
app = Flask(__name__)
23+
24+
25+
def response_text(cache_resp):
26+
return cache_resp['data']
27+
28+
29+
def save_query_info(result, model, query, delta_time_log):
30+
cache.data_manager.save_query_resp(result, model=model, query=json.dumps(query, ensure_ascii=False),
31+
delta_time=delta_time_log)
32+
33+
34+
def response_hitquery(cache_resp):
35+
return cache_resp['hitQuery']
36+
37+
38+
data2vec = Data2VecAudio()
39+
mysql_config = configparser.ConfigParser()
40+
mysql_config.read('modelcache/config/mysql_config.ini')
41+
milvus_config = configparser.ConfigParser()
42+
milvus_config.read('modelcache/config/milvus_config.ini')
43+
data_manager = get_data_manager(CacheBase("mysql", config=mysql_config),
44+
VectorBase("milvus", dimension=data2vec.dimension, milvus_config=milvus_config))
45+
46+
47+
cache.init(
48+
embedding_func=data2vec.to_embeddings,
49+
data_manager=data_manager,
50+
similarity_evaluation=SearchDistanceEvaluation(),
51+
query_pre_embedding_func=query_multi_splicing,
52+
insert_pre_embedding_func=insert_multi_splicing,
53+
)
54+
55+
# cache.set_openai_key()
56+
global executor
57+
executor = ThreadPoolExecutor(max_workers=6)
58+
59+
60+
@app.route('/welcome')
61+
def first_flask(): # 视图函数
62+
return 'hello, modelcache!'
63+
64+
65+
@app.route('/modelcache', methods=['GET', 'POST'])
66+
def user_backend():
67+
try:
68+
if request.method == 'POST':
69+
request_data = request.json
70+
elif request.method == 'GET':
71+
request_data = request.args
72+
param_dict = json.loads(request_data)
73+
except Exception as e:
74+
result = {"errorCode": 101, "errorDesc": str(e), "cacheHit": False, "delta_time": 0, "hit_query": '',
75+
"answer": ''}
76+
cache.data_manager.save_query_resp(result, model='', query='', delta_time=0)
77+
return json.dumps(result)
78+
79+
# param parsing
80+
try:
81+
request_type = param_dict.get("type")
82+
83+
scope = param_dict.get("scope")
84+
if scope is not None:
85+
model = scope.get('model')
86+
model = model.replace('-', '_')
87+
model = model.replace('.', '_')
88+
query = param_dict.get("query")
89+
chat_info = param_dict.get("chat_info")
90+
if request_type is None or request_type not in ['query', 'insert', 'detox', 'remove']:
91+
result = {"errorCode": 102,
92+
"errorDesc": "type exception, should one of ['query', 'insert', 'detox', 'remove']",
93+
"cacheHit": False, "delta_time": 0, "hit_query": '', "answer": ''}
94+
cache.data_manager.save_query_resp(result, model=model, query='', delta_time=0)
95+
return json.dumps(result)
96+
except Exception as e:
97+
result = {"errorCode": 103, "errorDesc": str(e), "cacheHit": False, "delta_time": 0, "hit_query": '',
98+
"answer": ''}
99+
return json.dumps(result)
100+
101+
# model filter
102+
filter_resp = model_blacklist_filter(model, request_type)
103+
if isinstance(filter_resp, dict):
104+
return json.dumps(filter_resp)
105+
106+
if request_type == 'query':
107+
try:
108+
response = adapter.ChatCompletion.create_query(
109+
scope={"model": model},
110+
query=query
111+
)
112+
if response is None:
113+
result = {"errorCode": 0, "errorDesc": '', "cacheHit": False, "delta_time": delta_time, "hit_query": '',
114+
"answer": ''}
115+
elif response in ['adapt_query_exception']:
116+
result = {"errorCode": 201, "errorDesc": response, "cacheHit": False, "delta_time": delta_time,
117+
"hit_query": '', "answer": ''}
118+
else:
119+
answer = response_text(response)
120+
hit_query = response_hitquery(response)
121+
result = {"errorCode": 0, "errorDesc": '', "cacheHit": True, "delta_time": delta_time,
122+
"hit_query": hit_query, "answer": answer}
123+
future = executor.submit(save_query_info, result, model, query, delta_time_log)
124+
except Exception as e:
125+
result = {"errorCode": 202, "errorDesc": e, "cacheHit": False, "delta_time": 0,
126+
"hit_query": '', "answer": ''}
127+
logging.info('result: {}'.format(result))
128+
129+
return json.dumps(result, ensure_ascii=False)
130+
131+
if request_type == 'insert':
132+
try:
133+
try:
134+
response = adapter.ChatCompletion.create_insert(
135+
model=model,
136+
chat_info=chat_info
137+
)
138+
except Exception as e:
139+
result = {"errorCode": 303, "errorDesc": e, "writeStatus": "exception"}
140+
return json.dumps(result, ensure_ascii=False)
141+
142+
if response in ['adapt_insert_exception']:
143+
result = {"errorCode": 301, "errorDesc": response, "writeStatus": "exception"}
144+
elif response == 'success':
145+
result = {"errorCode": 0, "errorDesc": "", "writeStatus": "success"}
146+
else:
147+
result = {"errorCode": 302, "errorDesc": response,
148+
"writeStatus": "exception"}
149+
return json.dumps(result, ensure_ascii=False)
150+
except Exception as e:
151+
result = {"errorCode": 304, "errorDesc": e, "writeStatus": "exception"}
152+
return json.dumps(result, ensure_ascii=False)
153+
154+
if request_type == 'remove':
155+
remove_type = param_dict.get("remove_type")
156+
id_list = param_dict.get("id_list", [])
157+
158+
response = adapter.ChatCompletion.create_remove(
159+
model=model,
160+
remove_type=remove_type,
161+
id_list=id_list
162+
)
163+
164+
if not isinstance(response, dict):
165+
result = {"errorCode": 401, "errorDesc": "", "response": response, "removeStatus": "exception"}
166+
return json.dumps(result)
167+
168+
state = response.get('status')
169+
170+
if state == 'success':
171+
result = {"errorCode": 0, "errorDesc": "", "response": response, "writeStatus": "success"}
172+
else:
173+
result = {"errorCode": 402, "errorDesc": "", "response": response, "writeStatus": "exception"}
174+
return json.dumps(result)
175+
176+
177+
if __name__ == '__main__':
178+
app.run(host='0.0.0.0', port=5000, debug=True)

modelcache/embedding/data2vec.py

-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import torch
66
from transformers import BertTokenizer, BertModel
77
from modelcache.embedding.base import BaseEmbedding
8-
# from modelcache.utils.env_config import get_data2vec_model
98

109

1110
def mean_pooling(model_output, attention_mask):

0 commit comments

Comments
 (0)