In [None]:
# langchainLLM模板
from typing import Any, Dict, Iterator, List, Mapping, Optional
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk


class CustomLLM(LLM):
    """A custom chat model that echoes the first `n` characters of the input.

    When contributing an implementation to LangChain, carefully document
    the model including the initialization parameters, include
    an example of how to initialize the model and include any relevant
    links to the underlying models documentation or API.

    Example:

        .. code-block:: python

            model = CustomChatModel(n=2)
            result = model.invoke([HumanMessage(content="hello")])
            result = model.batch([[HumanMessage(content="hello")],
                                 [HumanMessage(content="world")]])
    """

    n: int
    """The number of characters from the last message of the prompt to be echoed."""

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        """Run the LLM on the given input.

        Override this method to implement the LLM logic.

        Args:
            prompt: The prompt to generate from.
            stop: Stop words to use when generating. Model output is cut off at the
                first occurrence of any of the stop substrings.
                If stop tokens are not supported consider raising NotImplementedError.
            run_manager: Callback manager for the run.
            **kwargs: Arbitrary additional keyword arguments. These are usually passed
                to the model provider API call.

        Returns:
            The model output as a string. Actual completions SHOULD NOT include the prompt.
        """
        if stop is not None:
            raise ValueError("stop kwargs are not permitted.")
        return prompt[: self.n]

    def _stream(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> Iterator[GenerationChunk]:
        """Stream the LLM on the given prompt.

        This method should be overridden by subclasses that support streaming.

        If not implemented, the default behavior of calls to stream will be to
        fallback to the non-streaming version of the model and return
        the output as a single chunk.

        Args:
            prompt: The prompt to generate from.
            stop: Stop words to use when generating. Model output is cut off at the
                first occurrence of any of these substrings.
            run_manager: Callback manager for the run.
            **kwargs: Arbitrary additional keyword arguments. These are usually passed
                to the model provider API call.

        Returns:
            An iterator of GenerationChunks.
        """
        for char in prompt[: self.n]:
            chunk = GenerationChunk(text=char)
            if run_manager:
                run_manager.on_llm_new_token(chunk.text, chunk=chunk)

            yield chunk

    @property
    def _identifying_params(self) -> Dict[str, Any]:
        """Return a dictionary of identifying parameters."""
        return {
            # The model name allows users to specify custom token counting
            # rules in LLM monitoring applications (e.g., in LangSmith users
            # can provide per token pricing for their model and monitor
            # costs for the given LLM.)
            "model_name": "CustomChatModel",
        }

    @property
    def _llm_type(self) -> str:
        """Get the type of language model used by this chat model. Used for logging purposes only."""
        return "custom"

# 我的LLM模板

In [3]:
from typing import Any, Dict, Mapping
from langchain_core.language_models.llms import LLM
from typing import Dict, Any
from pydantic import Field
class CustomLLM(LLM):
    """自定义语言模型类，继承自LLM基类。

    属性:
        model_name: 模型名称。
        url: 模型平台的URL，需要先去某个url获得access_token,然后再去另一个url获得模型的结果。（可选）
        api_key: API密钥。（可选）
        appid: 应用ID。（可选）
        api_secret: API密码。（可选）
        access_token: 模型平台的访问令牌。（可选）
        timeout: 模型平台的超时时间。
        temperature: 控制生成文本的随机性的参数，值越大，输出的随机性越大。
        model_kwargs: 模型的其他可选参数。
    """

    url: str = None

    api_key: str = None

    appid: str = None

    api_secret: str = None

    access_token: str = None

    request_timeout: float = 50.0

    temperature: float = 0.36

    # # 必备的可选参数
    # model_kwargs: Dict[str, Any] = Field(default_factory=dict)

    # 定义一个返回默认参数的方法
    @property
    def _default_params(self) -> Dict[str, Any]:
        """获取调用默认参数。"""
        normal_params = {
            "temperature": self.temperature,
            "request_timeout": self.request_timeout,
            }
        return {**normal_params}
    
    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """Get the identifying parameters."""
        return {**{"model_name": self.model_name}, **self._default_params}

he


# 千帆平台自定义LLM

In [2]:
# 向授权服务地址 https://aip.baidubce.com/oauth/2.0/token 发送请求（推荐使用POST）。
import requests
import json
from dotenv import load_dotenv, find_dotenv
from typing import Dict, List, Optional, Any
from test_MYLLM import CustomLLM
from pydantic import Field
from langchain.callbacks.manager import CallbackManagerForLLMRun
_ = load_dotenv(find_dotenv())

def get_access_token(API_KEY: str, SECRET_KEY: str):
    """
    使用 AK，SK 生成鉴权签名（Access Token）
    :return: access_token，或是None(如果错误)
    """
    url = "https://aip.baidubce.com/oauth/2.0/token"
    headers = {"Content-Type": "application/json"}
    params = {"grant_type": "client_credentials", "client_id": API_KEY, "client_secret": SECRET_KEY}
    return str(requests.post(url, params=params).json().get("access_token"))

class QianFanLLM(CustomLLM):
    model_name: str = None
    """
    model_name: 可选择
    - 
    """
    url: str = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat?access_token={}"
    api_key: str = None
    appid: str = None
    api_secret: str = None
    access_token: str = None
    request_timeout: float = 50.0
    # model_kwargs: Dict[str, Any] = Field(default_factory=dict)
    def init_access_token(self):
        # 当模型初始化时，获取access_token
        # 同时需要有api_key和api_secret
        if self.api_key is None or self.api_secret is None:
            raise ValueError("api_key and api_secret are required.")
        else:
            try:
                self.access_token = get_access_token(self.api_key, self.api_secret)
                # print("access_token: ", self.access_token)
            except Exception as e:
                print(e)
                raise  # 重新抛出异常
    
    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        # 调用模型的逻辑
        if self.access_token is None:
            self.init_access_token()
        # 发送请求到api调用的url
        url = self.url.format(self.access_token)

        # 配置POST参数
        payload = json.dumps({
                "messages": [
                    {"role": "user", 
                    "content": prompt}
                    ],
                'temperature': self.temperature
        })


        headers = {"Content-Type": "application/json"}

        response = requests.request("POST", url, headers=headers, data=payload, timeout=self.request_timeout)
        # 返回结果
        # print(response)
        if response.status_code != 200:
            raise ValueError("Request failed with status code: {}".format(response.status_code))
        else:
            # 返回的是一个 Json 字符串
            text = json.loads(response.text)
            # print(js)
            return text["result"]
        
    @property
    def _identifying_params(self) -> Dict[str, Any]:
        return {
            **{"model": self.model_name},
            **super()._identifying_params,
        }
    
    @property
    def _llm_type(self) -> str:
        return "QianFanLLM"
            

# 测试QianFanLLM
model = QianFanLLM(model_name="Yi-34B-Chat", 
                   api_key="HnngI6IOybeRJyLgvtowIRUt", 
                   api_secret="Vx7tFMIdGDrouCsyBRuChmbNBHhlMM2e")
result = model.invoke("你好")
print(result)

你好！我是零一万物开发的智能助手，我叫 Yi，我是由零一万物的优秀工程师们一起开发的。我可以回答你的问题，提供信息，帮助你解决问题。请问有什么我可以帮助你的？


*千帆测试成功*

In [None]:
"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro?access_token="

"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro_preemptible?access_token="

"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token={}"

"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat?access_token="

"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions?access_token="

"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token={access_token}"

# 智谱LLM

In [12]:
import requests
import json
from dotenv import load_dotenv, find_dotenv
from typing import Dict, List, Optional, Any
from test_MYLLM import CustomLLM
import zhipuai
from langchain.pydantic_v1 import Field, root_validator
from langchain.utils import get_from_dict_or_env
from zhipuai import ZhipuAI
_ = load_dotenv(find_dotenv())

class ZhiPuLLM(CustomLLM):
    url: str = "https://open.bigmodel.cn/api/paas/v4/chat/completions"
    model_name: str = None
    api_key: str = None
    appid: str = None
    api_secret: str = None
    access_token: str = None
    client: Any
    request_timeout: float = 50.0
    temperature: float = 0.36
    model_kwargs: Dict[str, Any] = Field(default_factory=dict)

    # 验证zhipu_api_key
    @root_validator
    def validate_api_key(cls, values: Dict) -> Dict:
        values["ZHIPUAI_API_KEY"] = get_from_dict_or_env(values,"zhipuai_api_key", "ZHIPUAI_API_KEY")

        try:
            zhipuai.api_key = values["ZHIPUAI_API_KEY"]
            # print(zhipuai.api_key)
            values["client"] = ZhipuAI(api_key=zhipuai.api_key)
            # print(values["client"])
        except Exception as e:
            raise ValueError("api_key is required.")

        return values
    


    @property
    def _default_params(self) -> Dict[str, Any]:
        """Get the default parameters."""
        normal_params = {
            "temperature": self.temperature,
            "request_timeout": self.request_timeout,
            "model": self.model_name,
        }
        return {**normal_params, **self.model_kwargs}


    def _convert_prompt_msg_params(
        self,
        prompt: str,
        **kwargs: Any,
    ) -> dict:
        messages = [{"role": "user", "content": prompt}]
        return {
            **{"model": self.model_name, "messages": messages},
            **kwargs,
        }

    def _call(
        self,
        prompt: str,
        **kwargs: Any,
    ) -> str:     

                # 配置POST参数
        payload = json.dumps({

            "model": self.model_name,

            "messages": [
                    {"role": "user", 
                    "content": prompt}
                    ],
                'temperature': self.temperature
        })


        headers = {
            'Authorization': 'Bearer ' + self.api_key,
            "Content-Type": "application/json"
               }

        response = requests.request("POST", self.url, headers=headers, data=payload, timeout=self.request_timeout)
        # print(response)
        if response.status_code != 200:
            raise ValueError("Request failed with status code: {}".format(response.status_code))
        else:
            # 返回的是一个 Json 字符串
            text = json.loads(response.text)
            return text["choices"][0]["message"]["content"]
        

        # return response.choices[0].message.content
            

    @property
    def _identifying_params(self) -> Dict[str, Any]:
        return {
            **super()._identifying_params,
        }
    @property
    def _llm_type(self) -> str:
        """Return type of llm."""
        return "ZhiPuLLM"

# 测试ZhiPuLLM
model = ZhiPuLLM(model_name = "glm-3-turbo",
                  api_key = "26128e545ddd8a44c6588c4d530a5fbe.hj6klSWNSFWnli9p",
                  )

print(model.invoke("你好"))

你好👋！我是人工智能助手智谱清言（ChatGLM），很高兴见到你，欢迎问我任何问题。


*智谱测试成功*

# 星火spark

In [35]:
# coding: utf-8
import _thread as thread
import os
import queue
import time
import base64

import base64
import datetime
import hashlib
import hmac
import json
from urllib.parse import urlparse
import ssl
from datetime import datetime
from time import mktime
from urllib.parse import urlencode
from wsgiref.handlers import format_date_time

import websocket
# import openpyxl
from concurrent.futures import ThreadPoolExecutor, as_completed
import os
from test_MYLLM import CustomLLM
from langchain.callbacks.manager import CallbackManagerForLLMRun

class Spark_LLM(CustomLLM):
    # 讯飞星火大模型的自定义 LLM
    # URL
    url : str = "wss://spark-api.xf-yun.com/v3.5/chat"
    # APPID
    appid : str = None
    # APISecret
    api_secret : str = None
    # Domain
    domain :str = "generalv3.5"
    # max_token
    max_tokens : int = 4096

    def getText(self, role, content, text = []):
        # role 是指定角色，content 是 prompt 内容
        jsoncon = {}
        jsoncon["role"] = role
        jsoncon["content"] = content
        text.append(jsoncon)
        print(text)
        return text

    def _call(self, 
              prompt : str, 
              stop: Optional[List[str]] = None,
              run_manager: Optional[CallbackManagerForLLMRun] = None,
              **kwargs: Any):
        if self.api_key == None or self.appid == None or self.api_secret == None:
            # 三个 Key 均存在才可以正常调用
            print("请填入 Key")
            raise ValueError("Key 不存在")
        # 将 Prompt 填充到星火格式
        question = prompt
        print(question)
        # 发起请求
        try:
            response = main(self.appid, self.api_key, self.api_secret, self.url, self.domain, question)
            # print(response)
            return response
        except Exception as e:
            print(e)
            print("请求失败")
            return "请求失败"
        
    @property
    def _llm_type(self) -> str:
        return "Spark"


class Ws_Param(object):
    # 初始化
    def __init__(self, APPID, APIKey, APISecret, gpt_url):
        self.APPID = APPID
        self.APIKey = APIKey
        self.APISecret = APISecret
        self.host = urlparse(gpt_url).netloc
        self.path = urlparse(gpt_url).path
        self.gpt_url = gpt_url

    # 生成url
    def create_url(self):
        # 生成RFC1123格式的时间戳
        now = datetime.now()
        date = format_date_time(mktime(now.timetuple()))

        # 拼接字符串
        signature_origin = "host: " + self.host + "\n"
        signature_origin += "date: " + date + "\n"
        signature_origin += "GET " + self.path + " HTTP/1.1"

        # 进行hmac-sha256进行加密
        signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
                                 digestmod=hashlib.sha256).digest()

        signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')

        authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'

        authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')

        # 将请求的鉴权参数组合为字典
        v = {
            "authorization": authorization,
            "date": date,
            "host": self.host
        }
        # 拼接鉴权参数，生成url
        url = self.gpt_url + '?' + urlencode(v)
        # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释，比对相同参数时生成的url与自己代码生成的url是否一致
        return url


# 收到websocket错误的处理
def on_error(ws, error):
    print("### error:", error)


# 收到websocket关闭的处理
def on_close(ws):
    print("### closed ###")


# 收到websocket连接建立的处理
def on_open(ws):
    thread.start_new_thread(run, (ws,))


def run(ws, *args):
    data = json.dumps(gen_params(appid=ws.appid, query=ws.query, domain=ws.domain))
    ws.send(data)



def gen_params(appid, query, domain):
    """
    通过appid和用户的提问来生成请参数
    """

    data = {
        "header": {
            "app_id": appid,
            "uid": "1234",           
            # "patch_id": []    #接入微调模型，对应服务发布后的resourceid          
        },
        "parameter": {
            "chat": {
                "domain": domain,
                "temperature": 0.5,
                "max_tokens": 4096,
                "auditing": "default",
            }
        },
        "payload": {
            "message": {
                "text": [{"role": "user", "content": query}]
            }
        }
    }
    return data


def main(appid, api_secret, api_key, gpt_url, domain, query):
    output_queue = queue.Queue()
    wsParam = Ws_Param(appid, api_key, api_secret, gpt_url)
    websocket.enableTrace(False)
    wsUrl = wsParam.create_url()
    # 收到websocket消息的处理
    def on_message(ws, message):
        # print(message)
        data = json.loads(message)
        code = data['header']['code']
        if code != 0:
            print(f'请求错误: {code}, {data}')
            ws.close()
        else:
            choices = data["payload"]["choices"]
            status = choices["status"]
            content = choices["text"][0]["content"]
            #print(content,end='')
            output_queue.put(content)
            if status == 2:
                print("#### 关闭会话")
                ws.close()
    ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
    ws.appid = appid
    ws.query = query
    ws.domain = domain
    ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
    return ''.join([output_queue.get() for _ in range(output_queue.qsize())])



import os
apiid = os.getenv("SPARK_APPID")
api_key = os.getenv("SPARK_API_KEY")
api_secret = os.getenv("SPARK_API_SECRET")
llm=Spark_LLM(appid=apiid, api_key=api_key, api_secret=api_secret)
llm.invoke("你好")

你好
### error: Handshake status 401 Unauthorized -+-+- {'date': 'Thu, 25 Apr 2024 13:56:06 GMT', 'content-type': 'application/json; charset=utf-8', 'connection': 'keep-alive', 'content-length': '76', 'server': 'kong/1.3.0'} -+-+- b'{"message":"HMAC signature cannot be verified: fail to retrieve credential"}'
### error: on_close() takes 1 positional argument but 3 were given



''

In [33]:
# coding: utf-8
import _thread as thread
import os
import time
import base64

import base64
import datetime
import hashlib
import hmac
import json
from urllib.parse import urlparse
import ssl
from datetime import datetime
from time import mktime
from urllib.parse import urlencode
from wsgiref.handlers import format_date_time

import websocket

from concurrent.futures import ThreadPoolExecutor, as_completed
import os


class Ws_Param(object):
    # 初始化
    def __init__(self, APPID, APIKey, APISecret, gpt_url):
        self.APPID = APPID
        self.APIKey = APIKey
        self.APISecret = APISecret
        self.host = urlparse(gpt_url).netloc
        self.path = urlparse(gpt_url).path
        self.gpt_url = gpt_url

    # 生成url
    def create_url(self):
        # 生成RFC1123格式的时间戳
        now = datetime.now()
        date = format_date_time(mktime(now.timetuple()))

        # 拼接字符串
        signature_origin = "host: " + self.host + "\n"
        signature_origin += "date: " + date + "\n"
        signature_origin += "GET " + self.path + " HTTP/1.1"

        # 进行hmac-sha256进行加密
        signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
                                 digestmod=hashlib.sha256).digest()

        signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')

        authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'

        authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')

        # 将请求的鉴权参数组合为字典
        v = {
            "authorization": authorization,
            "date": date,
            "host": self.host
        }
        # 拼接鉴权参数，生成url
        url = self.gpt_url + '?' + urlencode(v)
        # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释，比对相同参数时生成的url与自己代码生成的url是否一致
        return url


# 收到websocket错误的处理
def on_error(ws, error):
    print("### error:", error)


# 收到websocket关闭的处理
def on_close(ws):
    print("### closed ###")


# 收到websocket连接建立的处理
def on_open(ws):
    thread.start_new_thread(run, (ws,))


def run(ws, *args):
    data = json.dumps(gen_params(appid=ws.appid, query=ws.query, domain=ws.domain))
    ws.send(data)


# 收到websocket消息的处理
def on_message(ws, message):
    # print(message)
    data = json.loads(message)
    code = data['header']['code']
    if code != 0:
        print(f'请求错误: {code}, {data}')
        ws.close()
    else:
        choices = data["payload"]["choices"]
        status = choices["status"]
        content = choices["text"][0]["content"]
        print(content,end='')
        if status == 2:
            print("#### 关闭会话")
            ws.close()


def gen_params(appid, query, domain):
    """
    通过appid和用户的提问来生成请参数
    """

    data = {
        "header": {
            "app_id": appid,
            "uid": "1234",           
            # "patch_id": []    #接入微调模型，对应服务发布后的resourceid          
        },
        "parameter": {
            "chat": {
                "domain": domain,
                "temperature": 0.5,
                "max_tokens": 4096,
                "auditing": "default",
            }
        },
        "payload": {
            "message": {
                "text": [{"role": "user", "content": query}]
            }
        }
    }
    return data


def main(appid, api_secret, api_key, gpt_url, domain, query):
    wsParam = Ws_Param(appid, api_key, api_secret, gpt_url)
    websocket.enableTrace(False)
    wsUrl = wsParam.create_url()

    ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
    ws.appid = appid
    ws.query = query
    ws.domain = domain
    ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})


if __name__ == "__main__":
    
    import os
    apiid = os.getenv("SPARK_APPID")
    api_key = os.getenv("SPARK_API_KEY")
    api_secret = os.getenv("SPARK_API_SECRET")
    main(
        appid=apiid,
        api_secret=api_secret,
        api_key=api_key,
        #appid、api_secret、api_key三个服务认证信息请前往开放平台控制台查看（https://console.xfyun.cn/services/bm35）
        gpt_url="wss://spark-api.xf-yun.com/v3.5/chat",
        # Spark_url = "ws://spark-api.xf-yun.com/v3.1/chat"  # v3.0环境的地址
        # Spark_url = "ws://spark-api.xf-yun.com/v2.1/chat"  # v2.0环境的地址
        # Spark_url = "ws://spark-api.xf-yun.com/v1.1/chat"  # v1.5环境的地址
        domain="generalv3.5",
        # domain = "generalv3"    # v3.0版本
        # domain = "generalv2"    # v2.0版本
        # domain = "general"    # v2.0版本
        query="给我写一篇100字的作文"
    )


标题：秋日私语

秋风轻抚，枫叶如火。我在林间小径漫步，思绪飘渺。落叶铺就金黄地毯，每一步都踏出沙沙的旋律。树梢低语，仿佛在诉说着季节的秘密。我闭上眼睛，让心灵与自然交融，感受那份宁静与和谐。这一刻，我仿佛听见了秋天的私语。#### 关闭会话
### error: on_close() takes 1 positional argument but 3 were given


In [4]:
from langchain.llms.base import LLM
from typing import Any, List, Mapping, Optional, Dict, Union, Tuple
from pydantic import Field
from test_MYLLM import CustomLLM
import json
import requests
from langchain.callbacks.manager import CallbackManagerForLLMRun
import _thread as thread
import base64
import datetime
import hashlib
import hmac
import json
from urllib.parse import urlparse
import ssl
from datetime import datetime
from time import mktime
from urllib.parse import urlencode
from wsgiref.handlers import format_date_time
import websocket  # 使用websocket_client
import queue

class Spark_LLM(CustomLLM):
    # 讯飞星火大模型的自定义 LLM
    # URL
    url : str = "wss://spark-api.xf-yun.com/v3.5/chat"
    # APPID
    appid : str = None
    # APISecret
    api_secret : str = None
    # Domain
    domain :str = "generalv3.5"
    # max_token
    max_tokens : int = 4096

    def getText(self, role, content, text = []):
        # role 是指定角色，content 是 prompt 内容
        jsoncon = {}
        jsoncon["role"] = role
        jsoncon["content"] = content
        text.append(jsoncon)
        return text

    def _call(self, prompt : str, stop: Optional[List[str]] = None,
                run_manager: Optional[CallbackManagerForLLMRun] = None,
                **kwargs: Any):
        if self.api_key == None or self.appid == None or self.api_secret == None:
            # 三个 Key 均存在才可以正常调用
            print("请填入 Key")
            raise ValueError("Key 不存在")
        # 将 Prompt 填充到星火格式
        question = self.getText("user", prompt)
        # 发起请求
        try:
            response = spark_main(self.appid,self.api_key,self.api_secret,self.url,self.domain,question, self.temperature, self.max_tokens)
            return response
        except Exception as e:
            print(e)
            print("请求失败")
            return "请求失败"
        
    @property
    def _llm_type(self) -> str:
        return "Spark"

class Ws_Param(object):
    # 初始化
    def __init__(self, APPID, APIKey, APISecret, Spark_url):
        self.APPID = APPID
        self.APIKey = APIKey
        self.APISecret = APISecret
        self.host = urlparse(Spark_url).netloc
        self.path = urlparse(Spark_url).path
        self.Spark_url = Spark_url
        # 自定义
        self.temperature = 0
        self.max_tokens = 2048

    # 生成url
    def create_url(self):
        # 生成RFC1123格式的时间戳
        now = datetime.now()
        date = format_date_time(mktime(now.timetuple()))

        # 拼接字符串
        signature_origin = "host: " + self.host + "\n"
        signature_origin += "date: " + date + "\n"
        signature_origin += "GET " + self.path + " HTTP/1.1"

        # 进行hmac-sha256进行加密
        signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
                                 digestmod=hashlib.sha256).digest()

        signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')

        authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'

        authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')

        # 将请求的鉴权参数组合为字典
        v = {
            "authorization": authorization,
            "date": date,
            "host": self.host
        }
        # 拼接鉴权参数，生成url
        url = self.Spark_url + '?' + urlencode(v)
        # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释，比对相同参数时生成的url与自己代码生成的url是否一致
        return url


# 收到websocket错误的处理
def on_error(ws, error):
    print("### error:", error)


# 收到websocket关闭的处理
def on_close(ws,one,two):
    print(" ")


# 收到websocket连接建立的处理
def on_open(ws):
    thread.start_new_thread(run, (ws,))


def run(ws, *args):
    data = json.dumps(gen_params(appid=ws.appid, domain= ws.domain,question=ws.question, temperature = ws.temperature, max_tokens = ws.max_tokens))
    ws.send(data)


# 收到websocket消息的处理
def on_message(ws, message):
    # print(message)
    data = json.loads(message)
    code = data['header']['code']
    if code != 0:
        print(f'请求错误: {code}, {data}')
        ws.close()
    else:
        choices = data["payload"]["choices"]
        status = choices["status"]
        content = choices["text"][0]["content"]
        print(content,end ="")
        global answer
        answer += content
        # print(1)
        if status == 2:
            ws.close()


def gen_params(appid, domain,question, temperature, max_tokens):
    """
    通过appid和用户的提问来生成请参数
    """
    data = {
        "header": {
            "app_id": appid,
            "uid": "1234"
        },
        "parameter": {
            "chat": {
                "domain": domain,
                "random_threshold": 0.5,
                "max_tokens": max_tokens,
                "temperature" : temperature,
                "auditing": "default"
            }
        },
        "payload": {
            "message": {
                "text": question
            }
        }
    }
    return data


def spark_main(appid, api_key, api_secret, Spark_url,domain, question, temperature, max_tokens):
    # print("星火:")
    output_queue = queue.Queue()
    def on_message(ws, message):
        data = json.loads(message)
        code = data['header']['code']
        if code != 0:
            print(f'请求错误: {code}, {data}')
            ws.close()
        else:
            choices = data["payload"]["choices"]
            status = choices["status"]
            content = choices["text"][0]["content"]
            # print(content, end='')
            # 将输出值放入队列
            output_queue.put(content)
            if status == 2: 
                ws.close()

    wsParam = Ws_Param(appid, api_key, api_secret, Spark_url)
    websocket.enableTrace(False)
    wsUrl = wsParam.create_url()
    ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
    ws.appid = appid
    ws.question = question
    ws.domain = domain
    ws.temperature = temperature
    ws.max_tokens = max_tokens
    ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
    return ''.join([output_queue.get() for _ in range(output_queue.qsize())])

import os
apiid = os.getenv("SPARK_APPID")
api_key = os.getenv("SPARK_API_KEY")
api_secret = os.getenv("SPARK_API_SECRET")
llm=Spark_LLM(appid=apiid, api_key=api_key, api_secret=api_secret)
llm.invoke("你好, 你是谁？") 

 


'您好，我是科大讯飞研发的认知智能大模型，我的名字叫讯飞星火认知大模型。我可以和人类进行自然交流，解答问题，高效完成各领域认知智能需求。'

In [35]:
from langchain.llms.base import LLM
from typing import Any, List, Mapping, Optional, Dict, Union, Tuple
from pydantic import Field
from test_MYLLM import CustomLLM
import json
import requests
from langchain.callbacks.manager import CallbackManagerForLLMRun
import _thread as thread
import base64
import datetime
import hashlib
import hmac
import json
from urllib.parse import urlparse
import ssl
from datetime import datetime
from time import mktime
from urllib.parse import urlencode
from wsgiref.handlers import format_date_time
import websocket  # 使用websocket_client
import queue

class Spark_LLM(CustomLLM):
    # 讯飞星火大模型的自定义 LLM
    # URL
    url : str = "wss://spark-api.xf-yun.com/v3.1/chat"
    # APPID
    appid : str = None
    # APISecret
    api_secret : str = None
    # Domain
    domain :str = "generalv3"
    # max_token
    max_tokens : int = 4096

    def getText(self, role, content, text = []):
        # role 是指定角色，content 是 prompt 内容
        jsoncon = {}
        jsoncon["role"] = role
        jsoncon["content"] = content
        text.append(jsoncon)
        return text

    def _call(self, prompt : str, stop: Optional[List[str]] = None,
                run_manager: Optional[CallbackManagerForLLMRun] = None,
                **kwargs: Any):
        if self.api_key == None or self.appid == None or self.api_secret == None:
            # 三个 Key 均存在才可以正常调用
            print("请填入 Key")
            raise ValueError("Key 不存在")
        # 将 Prompt 填充到星火格式
        question = self.getText("user", prompt)
        # 发起请求
        try:
            response = spark_main(self.appid,self.api_key,self.api_secret,self.url,self.domain,question, self.temperature, self.max_tokens)
            return response
        except Exception as e:
            print(e)
            print("请求失败")
            return "请求失败"
        
    @property
    def _llm_type(self) -> str:
        return "Spark"

class Ws_Param(object):
    # 初始化
    def __init__(self, APPID, APIKey, APISecret, Spark_url):
        self.APPID = APPID
        self.APIKey = APIKey
        self.APISecret = APISecret
        self.host = urlparse(Spark_url).netloc
        self.path = urlparse(Spark_url).path
        self.Spark_url = Spark_url
        # 自定义
        self.temperature = 0
        self.max_tokens = 2048

    # 生成url
    def create_url(self):
        # 生成RFC1123格式的时间戳
        now = datetime.now()
        date = format_date_time(mktime(now.timetuple()))

        # 拼接字符串
        signature_origin = "host: " + self.host + "\n"
        signature_origin += "date: " + date + "\n"
        signature_origin += "GET " + self.path + " HTTP/1.1"

        # 进行hmac-sha256进行加密
        signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
                                 digestmod=hashlib.sha256).digest()

        signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')

        authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'

        authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')

        # 将请求的鉴权参数组合为字典
        v = {
            "authorization": authorization,
            "date": date,
            "host": self.host
        }
        # 拼接鉴权参数，生成url
        url = self.Spark_url + '?' + urlencode(v)
        # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释，比对相同参数时生成的url与自己代码生成的url是否一致
        return url


# 收到websocket错误的处理
def on_error(ws, error):
    print("### error:", error)


# 收到websocket关闭的处理
def on_close(ws,one,two):
    pass


# 收到websocket连接建立的处理
def on_open(ws):
    thread.start_new_thread(run, (ws,))


def run(ws, *args):
    data = json.dumps(gen_params(appid=ws.appid, domain= ws.domain,question=ws.question, temperature = ws.temperature, max_tokens = ws.max_tokens))
    ws.send(data)


# # 收到websocket消息的处理
# def on_message(ws, message):
#     # print(message)
#     data = json.loads(message)
#     code = data['header']['code']
#     if code != 0:
#         print(f'请求错误: {code}, {data}')
#         ws.close()
#     else:
#         choices = data["payload"]["choices"]
#         status = choices["status"]
#         content = choices["text"][0]["content"]
#         print(content,end ="")
#         global answer
#         answer += content
#         # print(1)
#         if status == 2:
#             ws.close()


def gen_params(appid, domain,question, temperature, max_tokens):
    """
    通过appid和用户的提问来生成请参数
    """
    data = {
        "header": {
            "app_id": appid,
            "uid": "1234"
        },
        "parameter": {
            "chat": {
                "domain": domain,
                "random_threshold": 0.5,
                "max_tokens": max_tokens,
                "temperature" : temperature,
                "auditing": "default"
            }
        },
        "payload": {
            "message": {
                "text": question
            }
        }
    }
    return data


def spark_main(appid, api_key, api_secret, Spark_url,domain, question, temperature, max_tokens):
    # print("星火:")
    output_queue = queue.Queue()
    def on_message(ws, message):
        data = json.loads(message)
        code = data['header']['code']
        if code != 0:
            print(f'请求错误: {code}, {data}')
            ws.close()
        else:
            choices = data["payload"]["choices"]
            status = choices["status"]
            content = choices["text"][0]["content"].strip()

            # content = '\n'.join(line for line in content.splitlines() if line.strip())
            if content:
                # print(content, end='')
                output_queue.put(content)
            if status == 2: 
                # print(f"status == 2, output_queue = {output_queue.qsize()}")
                ws.close()

    wsParam = Ws_Param(appid, api_key, api_secret, Spark_url)
    websocket.enableTrace(False)
    wsUrl = wsParam.create_url()
    ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
    ws.appid = appid
    ws.question = question
    ws.domain = domain
    ws.temperature = temperature
    ws.max_tokens = max_tokens
    print(f"output_queue.qsize() = {output_queue.qsize()}")
    # print(f"wsUrl = {wsUrl}")
    ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
    
    # print([output_queue.get() for _ in range(output_queue.qsize())])
    return (''.join([output_queue.get() for _ in range(output_queue.qsize())]).strip())

import os
apiid = os.getenv("SPARK_APPID")
api_key = os.getenv("SPARK_API_KEY")
api_secret = os.getenv("SPARK_API_SECRET")
llm=Spark_LLM(appid=apiid, api_key=api_key, api_secret=api_secret)
print(llm.invoke("你好"))

output_queue.qsize() = 0


你好！有什么我可以帮助你的吗？
