# https://qwen.readthedocs.io/zh-cn/stable/inference/chat.html#

# 加载模型

In [3]:
# 修改实现 参考官方prompt格式
import argparse
from dataclasses import dataclass, field
from typing import Optional, List, Dict
import sys
import torch
from transformers import TrainingArguments, HfArgumentParser, Trainer, AutoTokenizer, AutoModelForCausalLM
import re
import os
# 设置可见的 GPU 设备为 cuda:0
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

def load_single_model(model_path,torch_dtype,trust_remote_code,device_map,use_cache):
    return AutoModelForCausalLM.from_pretrained(
            pretrained_model_name_or_path=model_path, 
            torch_dtype=torch_dtype,
            trust_remote_code=trust_remote_code,  # Qwen模型需要这个参数
            device_map=device_map,  # 可选，用于自动处理模型加载到设备
            use_cache=use_cache
        )

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_path = "/ssd/xiaxinyuan/code/CS3602_NLP_Final_Project/output/peft_3b/checkpoint-30000"
max_position_embeddings = 4096 # 模型支持的最大长度
tokenizer = AutoTokenizer.from_pretrained(model_path,device_map="auto" )
model = load_single_model(model_path,"bfloat16",True,"auto",False)
model = model.to(device)


  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.19s/it]


# 实现对话函数

In [10]:
# 定义生成回复的函数
def generate_response(model, tokenizer, conversation_history, user_input):
    """
    根据用户输入和对话历史生成模型回复。
    """
    # 更新对话历史②
    conversation_history.append({"role": "user", "content": user_input})
    print(user_input)
    
    text=tokenizer.apply_chat_template(conversation_history,tokenize=False,add_generation_prompt=True)
    inputs=tokenizer([text],return_tensors="pt").to(model.device)
    if len(inputs["input_ids"][0])>max_position_embeddings:
        # 不移除system 移除一轮对话
        conversation_history.pop(1) 
        conversation_history.pop(2)
        text=tokenizer.apply_chat_template(conversation_history,tokenize=False,add_generation_prompt=True)
        inputs=tokenizer([text],return_tensors="pt").to(model.device)

    # print(text)查看输入的所有prompt
    
    outputs = model.generate(**inputs,pad_token_id=tokenizer.eos_token_id, #在生成时用eos填充序列
                            max_new_tokens=100, #新生成文本长度
                            # num_beams=5,
                            # temperature=0.7,
                            # top_k=50,
                            # top_p=0.95,
                            # repetition_penalty=1.2
                            )
    # print(outputs)
    response = tokenizer.decode(outputs[:, inputs['input_ids'].shape[-1]:][0], skip_special_tokens=True) #在解码过程中跳过特殊符号如eos pad
    print("Bot:",response)
    # 添加对话历史②
    conversation_history.pop()
    last_round_content=conversation_history[-1]["content"]
    match = re.search(r'\[Round (\d+)\]', last_round_content)
    if match:
        last_round = int(match.group(1))
    else:
        last_round = 0
    conversation_history.append({"role": "user", "content": f"[Round {last_round+1}]:{user_input}"})
    conversation_history.append({"role": "assistant", "content": f"[Round {last_round+1}]:{response.strip()}"})
    return response.strip()

# 定义主聊天逻辑
def chat(model, tokenizer):
    """
    主聊天逻辑，支持 quit、newsession 等指令。
    """
    global conversation_history
    conversation_history=[{"role": "system", "content": "You are a helpful assistant."}]
    print("chat start")
    while True:
        # 获取用户输入
        user_input = input("User: ").strip()
        # 处理不同指令 ③
        if user_input.lower() == "\quit":
            print("Session ended. Bye!")
            break
        elif user_input.lower() == "\\newsession":
            conversation_history = [{"role": "system", "content": "You are a helpful assistant."}]
            print("Conversation history cleaned.")
        else:
            response = generate_response(model, tokenizer, conversation_history, user_input)

In [11]:
chat(model, tokenizer)

chat start
hello
Bot: Hi there! How can I assist you today? Is there something specific you would like to know or do?
Session ended. Bye!
