In [21]:
!pip install -q openai

28772.02s - pydevd: Sending message related to process being replaced timed-out after 5 seconds


In [1]:
# get BERT model
from transformers import BertTokenizer, BertModel
import torch
model_name = "bert-base-uncased"
local_bert_path = "./bert_model"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = BertTokenizer.from_pretrained(local_bert_path)
model = BertModel.from_pretrained(local_bert_path, output_hidden_states=True).to(device)

model.eval()

OSError: Can't load tokenizer for './bert_model'. If you were trying to load it from 'https://huggingface.co/models', make sure you don't have a local directory with the same name. Otherwise, make sure './bert_model' is the correct path to a directory containing all relevant files for a BertTokenizer tokenizer.

In [23]:
# model.save_pretrained(local_bert_path)
# tokenizer.save_pretrained(local_bert_path)

In [72]:
import json
from typing import Union
import numpy as np

def bert_emb(text: str):
    sentence = text
    inputs = tokenizer(sentence, return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    sentence_emb = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
    return sentence_emb

def cos_sim(a: np.ndarray, b: np.ndarray):
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

class MemoryBank(object):
    def __init__(self):
        super(MemoryBank, self).__init__()
        self.query: list[tuple[str, dict[str, Union[np.ndarray|int]]]] = [0]*50
        # self.answer: dict[str, np.ndarray] = dict()
        self.q_threshold = 0.9
        self.max_memory_size = 50
        self.max_threshold = 0.98
        self.leth_ptr = 0
    
    def pre_memory_compute(self, question: str):
        """
        Used to pre-compute and find the most simlarity prompt, thus to reduce model generate time
        also will proceed on insertion
        """
        self.oversize_deletion()
        query_, query_emb = question, bert_emb(question)
        query_sim_list = []
        for idx in range(self.leth_ptr%self.max_memory_size):
            q, ptr = self.query[idx]
            query_sim = cos_sim(query_emb, ptr["q_emb"])
            if query_sim > self.q_threshold:
                self.query[idx][1]["query_sim_list"].append((query_, query_sim))
                self.query[idx][1]["query_sim_list"] = sorted(self.query[idx][1]["query_sim_list"], key=lambda x: x[1], reverse=True)
            query_sim_list.append((idx, query_sim))
        # self.leth_ptr+=1
        if len(query_sim_list) < 1:
            self.query[self.leth_ptr%self.max_memory_size] = (query_, {"q_emb": query_emb, "ans": None, "query_sim_list": []})
            return False
        query_sim_list = sorted(query_sim_list, key=lambda x: x[1], reverse=True)

        self.query[self.leth_ptr%self.max_memory_size] = (query_, {"q_emb": query_emb, "ans": None, "query_sim_list": query_sim_list})
        most_sim = max(query_sim_list, key=lambda x: x[1])
        if most_sim[1]>self.max_threshold:
            return self.query[most_sim[0]][1]["ans"]
        else:
            return False

    def answer_insert(self, answer:str):
        self.query[self.leth_ptr%self.max_memory_size][1]["ans"] = answer
        self.leth_ptr+=1

    def oversize_deletion(self):
        if self.leth_ptr > self.max_memory_size:
            overfeat_idx = self.leth_ptr%self.max_memory_size
            for idx in range(overfeat_idx, self.max_memory_size):
                self.query[idx]["query_sim_list"] = list(filter(lambda x: x[0]!=overfeat_idx, self.query[idx]["query_sim_list"]))
        return

In [73]:
class Answer_sheet(object):
    def __init__(self, memory_path: str):
        super(Answer_sheet, self).__init__()
        self.path = memory_path
        self.memory_size = 50
        self.cache_: MemoryBank = MemoryBank()
        self.load_memory()
    
    def load_memory(self):
        # read the json file
        with open(self.path) as f:
            data = json.load(f)
        
        for itm in data:
            query = itm["user_input"]
            answer = itm["response"]
            return_data = self.cache_.pre_memory_compute(query)
            self.cache_.answer_insert(answer)
            
        return data
    
    def get_cache(self):
        return self.cache_
    
    def retrieve_confirm(self, question: str):
        quick_answer = self.cache_.pre_memory_compute(question)
        # if quick_answer:
        #     answer = quick_answer
        # else:
        #     answer = model_answer(question)
        # self.cache_.answer_insert(answer)
        # return answer
        return quick_answer

    def retrieve(self, question: str):
        return self.retrieve_confirm(question)

In [67]:
def get_local_file(type: str="A"):
    file_name: str = f"./chat_history_system_{type.upper()}.json"
    return file_name

In [74]:
fnm = get_local_file("A")
answer_method = Answer_sheet(fnm)

In [51]:
from openai import OpenAI

client = OpenAI(
    api_key="sk-805KuY00oiV8JRl2ZKgwL7ANQdtdKZMBwPZvl1XqcYvdZqbQ",
    base_url="https://api.chatanywhere.tech/v1"
)        
# 发送请求
response = client.chat.completions.create(
    model="gpt-3.5-turbo",
    messages=[
        {"role":"user", "content":""}
    ]
)
# 打印响应内容
print(response.choices[0].message.content)

Hello! How can I assist you today?


In [75]:
import gradio as gr
import json
from datetime import datetime
from typing import Union, List

def save_chat_history(system_name, user_input, response):
    timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    chat_record = {
        "timestamp": timestamp,
        "system": system_name,
        "user_input": user_input,
        "response": response
    }

    try:
        with open(f"chat_history_{system_name}.json", "r", encoding="utf-8") as f:
            chat_history = json.load(f)
            if not isinstance(chat_history, list):
                chat_history = []
    except FileNotFoundError:
        chat_history = []

    chat_history.append(chat_record)

    with open(f"chat_history_{system_name}.json", "w", encoding="utf-8") as f:
        json.dump(chat_history, f, ensure_ascii=False, indent=4)


def chatbot_response_1(user_input: str, memory_bank: Answer_sheet):
    response_content = memory_bank.cache_.pre_memory_compute(user_input)
    
    if not response_content:
        response = client.chat.completions.create(
            model="gpt-3.5-turbo",
            messages=[
                {"role": "user", "content": user_input}
            ]
        )
        response_content = response.choices[0].message.content
    memory_bank.cache_.answer_insert(response_content)
    # 保存对话记录
    save_chat_history("system_A", user_input, response_content)
    return response_content

def chatbot_response_2(user_input, memory_bank: Answer_sheet):
    response_content = memory_bank.cache_.pre_memory_compute(user_input)

    if not response_content:
        response = client.chat.completions.create(
            model="gpt-3.5-turbo",
            messages=[
                {"role": "user", "content": user_input}
            ]
        )
        response_content = response.choices[0].message.content
    memory_bank.cache_.answer_insert(response_content)
    # 保存对话记录
    save_chat_history("system_B", user_input, response_content)
    return response_content
    
with gr.Blocks() as demo:
    with gr.Row() as initial_page:
        gr.Markdown("## 选择一个问答系统")
        btn1 = gr.Button("A")
        btn2 = gr.Button("B")

    with gr.Column(visible=False) as qa_page_1:
        gr.Markdown("## 问答系统:A")
        chatbot1 = gr.Chatbot()
        msg1 = gr.Textbox(label="输入你的问题")
        submit1 = gr.Button("发送")
        back1 = gr.Button("返回")

        file_path = get_local_file("A")
        retrieve_component = Answer_sheet(file_path)

        def respond_1(message, chat_history):
            response = chatbot_response_1(message, retrieve_component)
            chat_history.append((message, response))
            return "", chat_history
        # 添加回车触发提交
        msg1.submit(respond_1, [msg1, chatbot1], [msg1, chatbot1])
        submit1.click(respond_1, [msg1, chatbot1], [msg1, chatbot1])

    with gr.Column(visible=False) as qa_page_2:
        gr.Markdown("## 问答系统:B")
        chatbot2 = gr.Chatbot()
        msg2 = gr.Textbox(label="输入你的问题")
        submit2 = gr.Button("发送")
        back2 = gr.Button("返回")

        file_path = get_local_file("B")
        retrieve_component = Answer_sheet(file_path)

        def respond_2(message, chat_history):
            response = chatbot_response_2(message, retrieve_component)
            chat_history.append((message, response))
            return "", chat_history

        msg2.submit(respond_2, [msg2, chatbot2], [msg2, chatbot2])
        submit2.click(respond_2, [msg2, chatbot2], [msg2, chatbot2])

    # click button to implement page jumps 
    def show_qa_page_1():
        return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)

    def show_qa_page_2():
        return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)

    def go_back():
        return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)

    btn1.click(show_qa_page_1, [], [initial_page, qa_page_1, qa_page_2])
    btn2.click(show_qa_page_2, [], [initial_page, qa_page_1, qa_page_2])
    back1.click(go_back, [], [initial_page, qa_page_1, qa_page_2])
    back2.click(go_back, [], [initial_page, qa_page_1, qa_page_2])

demo.launch()



* Running on local URL:  http://127.0.0.1:7862

To create a public link, set `share=True` in `launch()`.




In [78]:
retrieve_component.cache_.query[4]

('hi there, do you know how can I get a girlfriend?',
 {'q_emb': array([ 5.65255396e-02, -2.19656825e-01,  1.10276386e-01, -3.13620232e-02,
          3.65164876e-01, -1.46088347e-01,  9.27966833e-02,  7.21849978e-01,
         -1.20279223e-01,  1.29396811e-01,  1.16444364e-01, -7.41654634e-01,
         -2.93072551e-01,  6.14998281e-01, -1.63040385e-02,  6.06041551e-02,
         -1.74761117e-01,  1.92809924e-01, -1.18943015e-02,  5.56279063e-01,
          5.10691591e-02,  3.80209118e-01, -2.70124316e-01,  1.08444847e-01,
          4.07317698e-01,  5.03449216e-02,  1.73942536e-01,  1.59422740e-01,
          2.99515426e-01, -3.95339310e-01,  2.00941786e-01, -2.71441601e-02,
         -2.13072002e-01, -2.65172273e-01,  1.24241613e-01, -2.99234279e-02,
          2.23774230e-03, -2.34813422e-01, -2.13881820e-01,  2.02464119e-01,
         -7.72613287e-01, -2.30122492e-01, -1.32707763e-03,  2.50830740e-01,
         -2.84246266e-01, -7.73672044e-01,  4.06138629e-01, -4.54660118e-01,
         -2.3