In [2]:
# get BERT model
from transformers import BertTokenizer, BertModel
import torch
model_name = "bert-base-uncased"
local_bert_path = "./bert"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = BertTokenizer.from_pretrained(local_bert_path)
model = BertModel.from_pretrained(local_bert_path, output_hidden_states=True).to(device)

model.eval()

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
  

In [4]:
import json
from typing import Union
import numpy as np

def bert_emb(text: str):
    sentence = text
    inputs = tokenizer(sentence, return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    sentence_emb = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
    return sentence_emb

def cos_sim(a: np.ndarray, b: np.ndarray):
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

class MemoryBank(object):
    def __init__(self):
        super(MemoryBank, self).__init__()
        self.query: list[tuple[str, dict[str, Union[np.ndarray|int]]]] = [0]*50
        # self.answer: dict[str, np.ndarray] = dict()
        self.q_threshold = 0.9
        self.max_memory_size = 50
        self.max_threshold = 0.98
        self.leth_ptr = 0
    
    def pre_memory_compute(self, question: str):
        """
        Used to pre-compute and find the most simlarity prompt, thus to reduce model generate time
        also will proceed on insertion
        """
        self.oversize_deletion()
        query_, query_emb = question, bert_emb(question)
        query_sim_list = []
        for idx in range(self.leth_ptr%self.max_memory_size):
            q, ptr = self.query[idx]
            query_sim = cos_sim(query_emb, ptr["q_emb"])
            if query_sim > self.q_threshold:
                self.query[idx][1]["query_sim_list"].append((query_, query_sim))
                self.query[idx][1]["query_sim_list"] = sorted(self.query[idx][1]["query_sim_list"], key=lambda x: x[1], reverse=True)
            query_sim_list.append((idx, query_sim))
        # self.leth_ptr+=1
        if len(query_sim_list) < 1:
            self.query[self.leth_ptr%self.max_memory_size] = (query_, {"q_emb": query_emb, "ans": None, "query_sim_list": []})
            return False
        query_sim_list = sorted(query_sim_list, key=lambda x: x[1], reverse=True)

        self.query[self.leth_ptr%self.max_memory_size] = (query_, {"q_emb": query_emb, "ans": None, "query_sim_list": query_sim_list})
        most_sim = max(query_sim_list, key=lambda x: x[1])
        if most_sim[1]>self.max_threshold:
            return self.query[most_sim[0]][1]["ans"]
        else:
            return False

    def answer_insert(self, answer:str):
        self.query[self.leth_ptr%self.max_memory_size][1]["ans"] = answer
        self.leth_ptr+=1

    def oversize_deletion(self):
        if self.leth_ptr > self.max_memory_size:
            overfeat_idx = self.leth_ptr%self.max_memory_size
            for idx in range(overfeat_idx, self.max_memory_size):
                self.query[idx]["query_sim_list"] = list(filter(lambda x: x[0]!=overfeat_idx, self.query[idx]["query_sim_list"]))
        return

In [5]:
class Answer_sheet(object):
    def __init__(self, memory_path: str):
        super(Answer_sheet, self).__init__()
        self.path = memory_path
        self.memory_size = 50
        self.cache_: MemoryBank = MemoryBank()
        self.load_memory()
    
    def load_memory(self):
        # read the json file
        with open(self.path) as f:
            data = json.load(f)
        
        for itm in data:
            query = itm["user_input"]
            answer = itm["response"]
            return_data = self.cache_.pre_memory_compute(query)
            self.cache_.answer_insert(answer)
            
        return data
    
    def get_cache(self):
        return self.cache_
    
    def retrieve_confirm(self, question: str):
        quick_answer = self.cache_.pre_memory_compute(question)
        return quick_answer

    def retrieve(self, question: str):
        return self.retrieve_confirm(question)

In [6]:
def get_local_file(type: str="system_GPT"):
    file_name: str = f"./chat_history_{type}.json"
    return file_name

In [11]:
from openai import OpenAI

client = OpenAI(
    api_key="sk-r3o5oQfe0V8ZlRsnO2PFoYlDayEUwgAoIEPi3QkzFgiXSmUm",
    base_url="https://api.wlai.vip/v1"
)        
# 发送请求
response = client.chat.completions.create(
    model="gpt-3.5-turbo",
    messages=[
        {"role":"user", "content":""}
    ]
)
# 打印响应内容
print(response.choices[0].message.content)

Hello! How can I assist you today?


In [12]:
from openai import OpenAI
client2 = OpenAI(api_key="sk-r3o5oQfe0V8ZlRsnO2PFoYlDayEUwgAoIEPi3QkzFgiXSmUm",base_url="http://0.0.0.0:8001/v1")
messages = [{"role": "user", "content": ""}]
result = client2.chat.completions.create(messages=messages, model="/mnt/workspace/LLaMA-Factory/model_med")
print(result.choices[0].message.content)

APIConnectionError: Connection error.

In [16]:
import gradio as gr
import json
from datetime import datetime

def save_chat_history(system_name, user_input, response):
    timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    chat_record = {
        "timestamp": timestamp,
        "system": system_name,
        "user_input": user_input,
        "response": response
    }
    
    # 将对话记录追加到文件
    try:
        with open(f"chat_history_{system_name}.json", "r", encoding="utf-8") as f:
            chat_history = json.load(f)
            if not isinstance(chat_history, list):
                chat_history = []
    except FileNotFoundError:
        chat_history = []

    chat_history.append(chat_record)

    with open(f"chat_history_{system_name}.json", "w", encoding="utf-8") as f:
        json.dump(chat_history, f, ensure_ascii=False, indent=4)

def chatbot_response_1(user_input: str, memory_bank: Answer_sheet):
    response_content = memory_bank.cache_.pre_memory_compute(user_input)

    if not response_content:
        response = client.chat.completions.create(
            model="gpt-3.5-turbo",
            messages=[
                {"role": "user", "content": user_input}
            ]
        )
        response_content = response.choices[0].message.content
        
    memory_bank.cache_.answer_insert(response_content)
    
    # 保存对话记录
    save_chat_history("system_GPT", user_input, response_content)
    return response_content

def chatbot_response_2(user_input: str, memory_bank: Answer_sheet):
    response_content = memory_bank.cache_.pre_memory_compute(user_input)

    if not response_content:
        response = client2.chat.completions.create(
            model="/mnt/workspace/LLaMA-Factory/model_med",
            messages=[
                {"role": "user", "content": user_input}
            ]
        )
        response_content = response.choices[0].message.content

    memory_bank.cache_.answer_insert(response_content)
        
    # 保存对话记录
    save_chat_history("system_LLAMA_Med", user_input, response_content)
    return response_content

with gr.Blocks(css=".gradio-container {background-color: #f9f9f9; font-family: 'Arial', sans-serif; max-width: 800px; margin: auto;} .btn {background-color: #4CAF50; color: white; border: none; padding: 5px 10px; text-align: center; text-decoration: none; display: inline-block; font-size: 14px; margin: 2px 1px; cursor: pointer;} .btn:hover {background-color: #45a049;}") as demo:
    with gr.Row() as initial_page:
        gr.Markdown("<h1 style='text-align: center; color: #333;'>选择一个问答系统</h1>")
        btn1 = gr.Button("GPT",elem_classes=["btn"])
        btn2 = gr.Button("LLAMA_Med",elem_classes=["btn"])

    with gr.Column(visible=False) as qa_page_1:
        gr.Markdown("<h2 style='text-align: center; color: #333;'>GPT</h2>")
        chatbot1 = gr.Chatbot()
        msg1 = gr.Textbox(label="输入你的问题")
        submit1 = gr.Button("发送",elem_classes=["btn"])
        back1 = gr.Button("返回",elem_classes=["btn"])
        
        file_path = get_local_file("system_GPT")
        retrieve_component = Answer_sheet(file_path)

        def respond_1(message, chat_history):
            response = chatbot_response_1(message, retrieve_component)
            chat_history.append((message, response))
            return "", chat_history
        # 添加回车触发提交
        msg1.submit(respond_1, [msg1, chatbot1], [msg1, chatbot1])
        submit1.click(respond_1, [msg1, chatbot1], [msg1, chatbot1])

    with gr.Column(visible=False) as qa_page_2:
        gr.Markdown("<h2 style='text-align: center; color: #333;'>LLAMA_Med</h2>")
        chatbot2 = gr.Chatbot()
        msg2 = gr.Textbox(label="输入你的问题")
        submit2 = gr.Button("发送",elem_classes=["btn"])
        back2 = gr.Button("返回",elem_classes=["btn"])
        
        file_path = get_local_file("system_LLAMA_Med")
        retrieve_component = Answer_sheet(file_path)

        def respond_2(message, chat_history):
            response = chatbot_response_2(message, retrieve_component)
            chat_history.append((message, response))
            return "", chat_history

        msg2.submit(respond_2, [msg2, chatbot2], [msg2, chatbot2])
        submit2.click(respond_2, [msg2, chatbot2], [msg2, chatbot2])

    # click button to implement page jumps 
    def show_qa_page_1():
        return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)

    def show_qa_page_2():
        return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)

    def go_back():
        return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)

    btn1.click(show_qa_page_1, [], [initial_page, qa_page_1, qa_page_2])
    btn2.click(show_qa_page_2, [], [initial_page, qa_page_1, qa_page_2])
    back1.click(go_back, [], [initial_page, qa_page_1, qa_page_2])
    back2.click(go_back, [], [initial_page, qa_page_1, qa_page_2])

demo.launch()



* Running on local URL:  http://127.0.0.1:7868

To create a public link, set `share=True` in `launch()`.


