In [1]:
# AutoDL官方学术资源加速
import subprocess
import os

result = subprocess.run('bash -c "source /etc/network_turbo && env | grep proxy"', shell=True, capture_output=True, text=True)
output = result.stdout
for line in output.splitlines():
    if '=' in line:
        var, value = line.split('=', 1)
        os.environ[var] = value

In [2]:
import sys
import os

# 添加项目根目录到Python路径
project_root = "/home/cuipeng/Gemma"
sys.path.append(project_root)

# 导入必要模块
from src.core.model.model_initializer import initialize_model_and_tokenizer
from src.core.utils.model_utils import generate_response, apply_chat_template

In [3]:
# 导入必要模块
from src.core.model.model_initializer import initialize_model_and_tokenizer
from src.core.utils.model_utils import generate_response, apply_chat_template
import ipywidgets as widgets # type: ignore
from IPython.display import display, clear_output # type: ignore

In [4]:
def create_chat_interface():
    """创建聊天界面"""
    # 初始化模型和tokenizer
    model_path = "google/gemma-2-9b"
    cache_dir = "/root/autodl-tmp/gemma"
    lora_path = "/root/autodl-tmp/models/stage1/checkpoints/gemma-base-zh/checkpoint-43500"
    
    print("正在加载模型...")
    model, tokenizer = initialize_model_and_tokenizer(
        model_path=model_path,
        cache_dir=cache_dir,
        lora_path=lora_path,
        use_quantization=True
    )
    model.eval()
    print("模型加载完成!")

    # 添加系统提示词输入区域
    system_prompt = widgets.Textarea(
        value='你是一个专业、友好的AI助手。请用简洁、准确的方式回答问题。',
        placeholder='请输入系统提示词...',
        description='系统提示词:',
        disabled=False,
        layout=widgets.Layout(
            width='100%',
            height='150px'
        )
    )

    # 修改输出区域的样式，添加自动换行
    conversation_output = widgets.Output(
        layout=widgets.Layout(
            width='100%',
            max_width='800px',
            min_height='400px',
            border='1px solid #ddd',
            overflow='auto',
            padding='10px'
        )
    )

    # 修改输入框，添加回车键支持
    input_box = widgets.Text(
        value='',
        placeholder='请输入您的问题...(按Enter发送)',
        description='用户:',
        disabled=False,
        layout=widgets.Layout(
            width='100%',
            max_width='700px'
        )
    )

    # 按钮组件
    send_button = widgets.Button(
        description='发送',
        disabled=False,
        button_style='primary',
        tooltip='发送消息',
        icon='paper-plane',
        layout=widgets.Layout(width='100px')
    )

    clear_button = widgets.Button(
        description='清空对话',
        button_style='warning',
        tooltip='清空对话历史',
        layout=widgets.Layout(width='100px')
    )

    test_prompt_button = widgets.Button(
        description='测试系统提示词',
        button_style='info',
        tooltip='测试系统提示词是否生效',
        layout=widgets.Layout(width='150px')
    )

    # 添加提示词显示控制
    show_prompt_checkbox = widgets.Checkbox(
        value=False,
        description='显示完整Prompt',
        indent=False
    )

    # 按钮容器
    button_container = widgets.HBox(
        [send_button, clear_button, test_prompt_button],
        layout=widgets.Layout(
            width='auto',
            margin='0 0 0 10px'
        )
    )

    # 输入区域容器
    input_container = widgets.HBox(
        [input_box, button_container],
        layout=widgets.Layout(
            width='100%',
            max_width='1200px',
            justify_content='space-between'
        )
    )

    # 创建左右布局容器
    chat_container = widgets.VBox(
        [conversation_output, input_container],
        layout=widgets.Layout(
            width='70%',
            padding='10px'
        )
    )

    # 系统设置容器
    system_container = widgets.VBox(
        [
            widgets.HTML(value='<h4>系统设置</h4>'), 
            system_prompt,
            show_prompt_checkbox
        ],
        layout=widgets.Layout(
            width='28%',
            padding='10px',
            border='1px solid #ddd',
            margin='0 0 0 10px'
        )
    )

    # 主容器
    main_container = widgets.HBox(
        [chat_container, system_container],
        layout=widgets.Layout(
            width='100%',
            max_width='1200px',
            margin='0 auto'
        )
    )

    def send_message(text):
        if not text.strip():
            return
            
        input_box.value = ''
        
        with conversation_output:
            # 显示用户输入
            display(widgets.HTML(f"<div style='margin: 10px 0'><b>用户:</b> {text}</div>"))
            
            # 构建对话
            dialogue = [
                {
                    "role": "system",
                    "content": system_prompt.value
                },
                {
                    "role": "user",
                    "content": text
                }
            ]
            
            # 生成回答
            prompt = apply_chat_template(dialogue)
            
            # 如果勾选了显示prompt，则显示完整prompt
            if show_prompt_checkbox.value:
                display(widgets.HTML(
                    f"<div style='margin: 10px 0; color: gray; font-size: 0.8em;'>"
                    f"<b>完整Prompt:</b><pre>{prompt}</pre></div>"
                ))
            
            response = generate_response(
                model,
                tokenizer,
                prompt,
                max_new_tokens=1024,
                temperature=0.7
            )
            
            # 显示助手回答
            display(widgets.HTML(
                f"<div style='margin: 10px 0; white-space: pre-wrap;'>"
                f"<b>助手:</b> {response}</div>"
            ))

    def test_system_prompt(b):
        with conversation_output:
            display(widgets.HTML("<div style='color: blue'>正在测试系统提示词...</div>"))
            # 构建测试对话
            dialogue = [
                {
                    "role": "system",
                    "content": system_prompt.value
                },
                {
                    "role": "user",
                    "content": "你是谁？请介绍一下你自己。"
                }
            ]
            
            prompt = apply_chat_template(dialogue)
            
            # 显示当前系统提示词
            display(widgets.HTML(
                f"<div style='margin: 10px 0; color: gray; font-size: 0.8em;'>"
                f"<b>当前系统提示词:</b><pre>{system_prompt.value}</pre></div>"
            ))
            
            # 如果勾选了显示prompt，则显示完整prompt
            if show_prompt_checkbox.value:
                display(widgets.HTML(
                    f"<div style='margin: 10px 0; color: gray; font-size: 0.8em;'>"
                    f"<b>完整Prompt:</b><pre>{prompt}</pre></div>"
                ))
            
            response = generate_response(
                model,
                tokenizer,
                prompt,
                max_new_tokens=1024,
                temperature=0.7
            )
            
            display(widgets.HTML(
                f"<div style='margin: 10px 0; white-space: pre-wrap;'>"
                f"<b>助手的回答:</b> {response}</div>"
            ))

    def on_send_button_clicked(b):
        send_message(input_box.value)

    def on_enter_pressed(widget):
        send_message(widget.value)
            
    def on_clear_button_clicked(b):
        conversation_output.clear_output()

    # 绑定事件处理函数
    send_button.on_click(on_send_button_clicked)
    clear_button.on_click(on_clear_button_clicked)
    test_prompt_button.on_click(test_system_prompt)
    input_box.on_submit(on_enter_pressed)  # 添加回车键支持
    
    # 显示界面
    display(main_container)

In [5]:
# 创建聊天界面
create_chat_interface()

正在加载模型...


Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]



模型加载完成!


  input_box.on_submit(on_enter_pressed)  # 添加回车键支持


HBox(children=(VBox(children=(Output(layout=Layout(border_bottom='1px solid #ddd', border_left='1px solid #ddd…