|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +import json |
| 3 | +from flask import Flask, request |
| 4 | +import logging |
| 5 | +from datetime import datetime |
| 6 | +import configparser |
| 7 | +import time |
| 8 | +import json |
| 9 | +from modelcache import cache |
| 10 | +from modelcache.adapter import adapter |
| 11 | +from modelcache.manager import CacheBase, VectorBase, get_data_manager |
| 12 | +from modelcache.similarity_evaluation.distance import SearchDistanceEvaluation |
| 13 | +from modelcache.processor.pre import query_multi_splicing |
| 14 | +from modelcache.processor.pre import insert_multi_splicing |
| 15 | +from concurrent.futures import ThreadPoolExecutor |
| 16 | +from modelcache.utils.model_filter import model_blacklist_filter |
| 17 | +from modelcache.embedding import Data2VecAudio |
| 18 | +# from modelcache.maya_embedding_service.maya_embedding_service import get_cache_embedding_text2vec |
| 19 | + |
| 20 | + |
| 21 | +# 创建一个Flask实例 |
| 22 | +app = Flask(__name__) |
| 23 | + |
| 24 | + |
| 25 | +def response_text(cache_resp): |
| 26 | + return cache_resp['data'] |
| 27 | + |
| 28 | + |
| 29 | +def save_query_info(result, model, query, delta_time_log): |
| 30 | + cache.data_manager.save_query_resp(result, model=model, query=json.dumps(query, ensure_ascii=False), |
| 31 | + delta_time=delta_time_log) |
| 32 | + |
| 33 | + |
| 34 | +def response_hitquery(cache_resp): |
| 35 | + return cache_resp['hitQuery'] |
| 36 | + |
| 37 | + |
| 38 | +data2vec = Data2VecAudio() |
| 39 | +mysql_config = configparser.ConfigParser() |
| 40 | +mysql_config.read('modelcache/config/mysql_config.ini') |
| 41 | +milvus_config = configparser.ConfigParser() |
| 42 | +milvus_config.read('modelcache/config/milvus_config.ini') |
| 43 | +data_manager = get_data_manager(CacheBase("mysql", config=mysql_config), |
| 44 | + VectorBase("milvus", dimension=data2vec.dimension, milvus_config=milvus_config)) |
| 45 | + |
| 46 | + |
| 47 | +cache.init( |
| 48 | + embedding_func=data2vec.to_embeddings, |
| 49 | + data_manager=data_manager, |
| 50 | + similarity_evaluation=SearchDistanceEvaluation(), |
| 51 | + query_pre_embedding_func=query_multi_splicing, |
| 52 | + insert_pre_embedding_func=insert_multi_splicing, |
| 53 | +) |
| 54 | + |
| 55 | +# cache.set_openai_key() |
| 56 | +global executor |
| 57 | +executor = ThreadPoolExecutor(max_workers=6) |
| 58 | + |
| 59 | + |
| 60 | +@app.route('/welcome') |
| 61 | +def first_flask(): # 视图函数 |
| 62 | + return 'hello, modelcache!' |
| 63 | + |
| 64 | + |
| 65 | +@app.route('/modelcache', methods=['GET', 'POST']) |
| 66 | +def user_backend(): |
| 67 | + try: |
| 68 | + if request.method == 'POST': |
| 69 | + request_data = request.json |
| 70 | + elif request.method == 'GET': |
| 71 | + request_data = request.args |
| 72 | + param_dict = json.loads(request_data) |
| 73 | + except Exception as e: |
| 74 | + result = {"errorCode": 101, "errorDesc": str(e), "cacheHit": False, "delta_time": 0, "hit_query": '', |
| 75 | + "answer": ''} |
| 76 | + cache.data_manager.save_query_resp(result, model='', query='', delta_time=0) |
| 77 | + return json.dumps(result) |
| 78 | + |
| 79 | + # param parsing |
| 80 | + try: |
| 81 | + request_type = param_dict.get("type") |
| 82 | + |
| 83 | + scope = param_dict.get("scope") |
| 84 | + if scope is not None: |
| 85 | + model = scope.get('model') |
| 86 | + model = model.replace('-', '_') |
| 87 | + model = model.replace('.', '_') |
| 88 | + query = param_dict.get("query") |
| 89 | + chat_info = param_dict.get("chat_info") |
| 90 | + if request_type is None or request_type not in ['query', 'insert', 'detox', 'remove']: |
| 91 | + result = {"errorCode": 102, |
| 92 | + "errorDesc": "type exception, should one of ['query', 'insert', 'detox', 'remove']", |
| 93 | + "cacheHit": False, "delta_time": 0, "hit_query": '', "answer": ''} |
| 94 | + cache.data_manager.save_query_resp(result, model=model, query='', delta_time=0) |
| 95 | + return json.dumps(result) |
| 96 | + except Exception as e: |
| 97 | + result = {"errorCode": 103, "errorDesc": str(e), "cacheHit": False, "delta_time": 0, "hit_query": '', |
| 98 | + "answer": ''} |
| 99 | + return json.dumps(result) |
| 100 | + |
| 101 | + # model filter |
| 102 | + filter_resp = model_blacklist_filter(model, request_type) |
| 103 | + if isinstance(filter_resp, dict): |
| 104 | + return json.dumps(filter_resp) |
| 105 | + |
| 106 | + if request_type == 'query': |
| 107 | + try: |
| 108 | + response = adapter.ChatCompletion.create_query( |
| 109 | + scope={"model": model}, |
| 110 | + query=query |
| 111 | + ) |
| 112 | + if response is None: |
| 113 | + result = {"errorCode": 0, "errorDesc": '', "cacheHit": False, "delta_time": delta_time, "hit_query": '', |
| 114 | + "answer": ''} |
| 115 | + elif response in ['adapt_query_exception']: |
| 116 | + result = {"errorCode": 201, "errorDesc": response, "cacheHit": False, "delta_time": delta_time, |
| 117 | + "hit_query": '', "answer": ''} |
| 118 | + else: |
| 119 | + answer = response_text(response) |
| 120 | + hit_query = response_hitquery(response) |
| 121 | + result = {"errorCode": 0, "errorDesc": '', "cacheHit": True, "delta_time": delta_time, |
| 122 | + "hit_query": hit_query, "answer": answer} |
| 123 | + future = executor.submit(save_query_info, result, model, query, delta_time_log) |
| 124 | + except Exception as e: |
| 125 | + result = {"errorCode": 202, "errorDesc": e, "cacheHit": False, "delta_time": 0, |
| 126 | + "hit_query": '', "answer": ''} |
| 127 | + logging.info('result: {}'.format(result)) |
| 128 | + |
| 129 | + return json.dumps(result, ensure_ascii=False) |
| 130 | + |
| 131 | + if request_type == 'insert': |
| 132 | + try: |
| 133 | + try: |
| 134 | + response = adapter.ChatCompletion.create_insert( |
| 135 | + model=model, |
| 136 | + chat_info=chat_info |
| 137 | + ) |
| 138 | + except Exception as e: |
| 139 | + result = {"errorCode": 303, "errorDesc": e, "writeStatus": "exception"} |
| 140 | + return json.dumps(result, ensure_ascii=False) |
| 141 | + |
| 142 | + if response in ['adapt_insert_exception']: |
| 143 | + result = {"errorCode": 301, "errorDesc": response, "writeStatus": "exception"} |
| 144 | + elif response == 'success': |
| 145 | + result = {"errorCode": 0, "errorDesc": "", "writeStatus": "success"} |
| 146 | + else: |
| 147 | + result = {"errorCode": 302, "errorDesc": response, |
| 148 | + "writeStatus": "exception"} |
| 149 | + return json.dumps(result, ensure_ascii=False) |
| 150 | + except Exception as e: |
| 151 | + result = {"errorCode": 304, "errorDesc": e, "writeStatus": "exception"} |
| 152 | + return json.dumps(result, ensure_ascii=False) |
| 153 | + |
| 154 | + if request_type == 'remove': |
| 155 | + remove_type = param_dict.get("remove_type") |
| 156 | + id_list = param_dict.get("id_list", []) |
| 157 | + |
| 158 | + response = adapter.ChatCompletion.create_remove( |
| 159 | + model=model, |
| 160 | + remove_type=remove_type, |
| 161 | + id_list=id_list |
| 162 | + ) |
| 163 | + |
| 164 | + if not isinstance(response, dict): |
| 165 | + result = {"errorCode": 401, "errorDesc": "", "response": response, "removeStatus": "exception"} |
| 166 | + return json.dumps(result) |
| 167 | + |
| 168 | + state = response.get('status') |
| 169 | + |
| 170 | + if state == 'success': |
| 171 | + result = {"errorCode": 0, "errorDesc": "", "response": response, "writeStatus": "success"} |
| 172 | + else: |
| 173 | + result = {"errorCode": 402, "errorDesc": "", "response": response, "writeStatus": "exception"} |
| 174 | + return json.dumps(result) |
| 175 | + |
| 176 | + |
| 177 | +if __name__ == '__main__': |
| 178 | + app.run(host='0.0.0.0', port=5000, debug=True) |
0 commit comments