安装opencompass：Kaggle上已经为我们准备好了其他常用包，只需安装opencompass用于评测即可。如果不在Kaggle上运行，则还需要安装其他必要包。

In [5]:
!pip install langchain_community mediawikiapi wikibase-rest-api-client spacy
!pip install SPARQLWrapper
!pip install peft 
!pip uninstall packaging --no-input
!pip install packaging==21.3
# python -m spacy download en_core_web_sm

'!pip install langchain_community mediawikiapi wikibase-rest-api-client spacy\n!pip install SPARQLWrapper\n!pip install peft \n!pip uninstall packaging --no-input\n!pip install packaging==21.3\n# python -m spacy download en_core_web_sm'

# 包导入

In [6]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from langchain_community.tools.wikidata.tool import WikidataAPIWrapper, WikidataQueryRun
import spacy
import requests
from SPARQLWrapper import SPARQLWrapper, JSON
from peft import PeftModel

# 联网搜索外部知识

In [7]:
class KnowledgeSearcher:
    def __init__(self, wikidata_endpoint="https://query.wikidata.org/sparql"):
        self.sparql = SPARQLWrapper(wikidata_endpoint)
        self.wikipedia_api = "https://en.wikipedia.org/api/rest_v1/page/summary/"

    def search_knowledge(self, keywords):
        """
        Search for detailed knowledge from Wikidata and Wikipedia.
        """
        entity_data = self.query_wikidata(keywords)
        wiki_summary = self.query_wikipedia(keywords[0])
        
        return {
            "wikidata_details": entity_data,
            "wikipedia_summary": wiki_summary
        }

    def query_wikidata(self, keywords):
        """
        Query Wikidata for detailed properties of the entity.
        """
        query = f"""
        SELECT ?item ?itemLabel ?description ?instance_ofLabel ?countryLabel ?inception ?industryLabel
        WHERE {{
          ?item rdfs:label "{' '.join(keywords)}"@en.
          OPTIONAL {{ ?item schema:description ?description FILTER(LANG(?description) = "en") }}
          OPTIONAL {{ ?item wdt:P31 ?instance_of. }}
          OPTIONAL {{ ?item wdt:P17 ?country. }}
          OPTIONAL {{ ?item wdt:P571 ?inception. }}
          OPTIONAL {{ ?item wdt:P452 ?industry. }}
          SERVICE wikibase:label {{ bd:serviceParam wikibase:language "en". }}
        }}
        LIMIT 1
        """
        self.sparql.setQuery(query)
        self.sparql.setReturnFormat(JSON)
        
        try:
            results = self.sparql.query().convert()
            return self.format_wikidata_results(results)
        except Exception as e:
            return f"Error querying Wikidata: {e}"

    def query_wikipedia(self, keyword):
        """
        Query Wikipedia API for an article summary.
        """
        try:
            response = requests.get(f"{self.wikipedia_api}{keyword}")
            response.raise_for_status()
            data = response.json()
            return data.get("extract", "No summary available.")
        except requests.RequestException as e:
            return f"Error querying Wikipedia: {e}"

    def format_wikidata_results(self, results):
        """
        Format the results from the Wikidata query.
        """
        formatted_data = {}
        for result in results["results"]["bindings"]:
            for key in result:
                formatted_data[key] = result[key]["value"]
        return formatted_data



# 部署机器人

In [8]:
class Chatbot:
    def __init__(self, model, tokenizer, max_input_length, device):
        """Initialize the chatbot."""
        self.model = model.to(device)  # Move model to GPU
        self.tokenizer = tokenizer
        self.max_input_length = max_input_length
        self.history = []
        self.device = device
        
    def get_input(self):
        """Get user input."""
        user_input = input("You: ")
        return user_input
    
    def handle_input(self, user_input):
        """Handle user input."""
        if user_input == "\\quit":
            print("Ending the conversation.")
            return False  # End the conversation
        elif user_input == "\\newsession":
            print("Starting a new session.")
            self.history = []  # Clear conversation history
        else:
            # Only append user input to history for generating responses
            self.history.append(f"User: {user_input}")
            
            # Manage conversation history length
            conversation = " ".join(self.history)
            conversation_length = len(self.tokenizer(conversation)["input_ids"])
            
            if conversation_length > self.max_input_length:
                # Truncate conversation history to fit model's max length
                tokens = self.tokenizer(conversation, truncation=True, max_length=self.max_input_length, return_tensors="pt")
                conversation = self.tokenizer.decode(tokens['input_ids'][0])
            
            # Generate model response
            input_ids = self.tokenizer(conversation, return_tensors="pt").input_ids.to(self.device)
            with torch.no_grad():
                outputs = self.model.generate(input_ids, max_length=1024, num_return_sequences=1)
            response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            print(f"Bot: {response}")
            
            # Add only the bot's response to the history
            self.history.append(f"Bot: {response}")
        
        return True  # Continue the conversation
    
    def start(self):
        """Start the chatbot."""
        print("Chatbot started, type \\quit to end the conversation.")
        while True:
            user_input = self.get_input()
            if not self.handle_input(user_input):
                break
from SPARQLWrapper import SPARQLWrapper, JSON


class KnowledgeEnhancedChatbot(Chatbot):
    def __init__(self, model, tokenizer, max_input_length, device):
        """Initialize the knowledge-enhanced chatbot."""
        super().__init__(model, tokenizer, max_input_length, device)
        #self.documents = documents
        self.vectorizer = TfidfVectorizer()
        #self.tfidf_matrix = self.vectorizer.fit_transform(documents)
        self.searcher = KnowledgeSearcher()
    # 自定义函数：提取查询中的关键词
    def extract_keywords(self, query: str):
        nlp = spacy.load("en_core_web_sm")
        doc = nlp(query)
        keywords = [token.text for token in doc if not token.is_stop and not token.is_punct]
        return keywords
    
    '''def search_knowledge(self, keywords):
        """Search for the most relevant information in documents."""
        wikidata = WikidataQueryRun(api_wrapper=WikidataAPIWrapper(api_url='http://api.wlai.vip'))
        # 使用关键词查询 Wikidata
        result = wikidata.run(" ".join(keywords))  # 将提取的关键词传递给查询函数
        return result'''
    
    def enhance_knowledge(self, query):
        """Enhance the knowledge based on the input query."""
        # 加载 spaCy 模型

        
        # 初始化 API 包装器

        # 提取关键词
        keywords = self.extract_keywords(query)
        if len(keywords) == 0:
            return "no information"
        knowledge = self.searcher.search_knowledge(keywords)["wikipedia_summary"]
        #print(f"Found relevant information from documents: {knowledge}")
        return knowledge
    
    def handle_input(self, user_input):
        """Handle user input and integrate knowledge enhancement."""
        if user_input == "\\quit":
            print("Ending the conversation.")
            return False  # End the conversation
        elif user_input == "\\newsession":
            print("Starting a new session.")
            self.history = []  # Clear conversation history
        else:
            # Query external knowledge
            knowledge = self.enhance_knowledge(user_input)

            this_chat_input = f"{knowledge}\n{user_input}\n"
            this_chat_input_with_tag = f"this chat input:\n {this_chat_input}"
            # Manage conversation history length
            tag_history = "history chat:\n"

                
            history = " ".join(self.history)   # 反序合并对话历史
            if len(self.history) != 0:
                history = tag_history + history
            
            conversation_length = len(this_chat_input_with_tag) + len(history)
            
            if conversation_length > self.max_input_length:
                # 计算最多可以使用的对话长度，扣除 tag_history 的长度
                max_conversation_length = self.max_input_length - len(self.tokenizer(this_chat_input_with_tag+tag_history)["input_ids"])
                # 截断对话历史，保留最新的对话
                history = history[::-1]
                tokens = self.tokenizer(history, truncation=True, max_length=max_conversation_length, return_tensors="pt")
                history = self.tokenizer.decode(tokens['input_ids'][0])
                history = history[::-1]
                history = tag_history + history
            
            conversation = history + this_chat_input_with_tag 
            #print(conversation)
            inputs = self.tokenizer(conversation, truncation=True, max_length=self.max_input_length, return_tensors="pt").to(self.device)
            # 生成回复
            outputs = self.model.generate(
    			input_ids=inputs["input_ids"], 
				#attention_mask=inputs["attention_mask"],
    			# max_length=max_length,
    			do_sample=True, 
    			temperature=0.5,
    			top_p=0.9,
    			pad_token_id=tokenizer.eos_token_id
            )
            response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            response = response[len(conversation):].strip()
            #input_ids_len = len(conversation)
            #response = response[input_ids_len:]
            print(f"Bot: {response}\n")
            
            # Add only the bot's response to the history
            # Only append user input to history for generating responses
            self.history.append(this_chat_input)
            self.history.append(f"{response}\n")
        
        return True  # Continue the conversation

# Load the base model
base_model_path = "/kaggle/input/qwen2.5/transformers/1.5b/1"
tokenizer = AutoTokenizer.from_pretrained(base_model_path)
base_model = AutoModelForCausalLM.from_pretrained(base_model_path)

# Load the fine-tuned model
finetuned_model_path = "/kaggle/input/1.5b/pytorch/default/1"
model = PeftModel.from_pretrained(base_model, finetuned_model_path)

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Get the model's max input length
max_input_length = model.config.max_position_embeddings

# Create an instance of the knowledge-enhanced chatbot
enhanced_chatbot = KnowledgeEnhancedChatbot(model, tokenizer, max_input_length, device)

# Start the chatbot
enhanced_chatbot.start()


Chatbot started, type \quit to end the conversation.


You:  Chen Yongjun is a student in SJTU


Bot: Chen is a surname, which is a common surname in China, Japan, and Korea. It is derived from the character for "arm" or "army" and can be pronounced in several ways, including "Chen" and "Ch'en". There are many people with this surname in China, including Chen Yongjun, who is a student at Shanghai Jiao Tong University.



You:  Who is Chen Yongjun


Bot: Chen Yongjun is a student in SJTU



You:  What is kaggle


Bot: Kaggle is a data science competition platform and online community for data scientists and machine learning practitioners under Google LLC. It enables users to find and publish datasets, explore and build models in a web-based data science environment, work with other data scientists and machine learning engineers, and enter competitions to solve data science challenges.



You:  \quit


Ending the conversation.
