In [None]:
import os
import json
import time
import torch
import datetime
import random
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer, LlamaForCausalLM


__file__ = os.path.abspath('file')


if torch.cuda.is_available():
    device = torch.device(0)
else:
    device = torch.device('cpu')

LOGDIR = "./logs"

no_change_btn = gr.Button.update()
enable_btn = gr.Button.update(interactive=True, visible=True)
disable_btn = gr.Button.update(interactive=False)
invisible_btn = gr.Button(interactive=False, visible=False)

MODEL_MAPPING_PATH = {
    "Qwen-7B-Chat": "",
    "LLaMA-2-7B-Chat": ""
}

tokenizer = None
model = None
config = None


models = ["LLaMA-2-7B-Chat", "LLaMA-2-13B-Chat", "LLaMA-2-70B-Chat", "ChatGLM2-6B", 
          "Qwen-7B-Chat", "Qwen-14B-Chat", "Baichuan2-7B-Chat", "Baichuan2-13B-Chat"]


def load_model(model_name: str):
    global tokenizer
    global model
    global config
    state = None
    model_path = MODEL_MAPPING_PATH[model_name]
    if model_name.lower().startswith("llama"):
        tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
        model = LlamaForCausalLM.from_pretrained(model_path, trust_remote_code=True, device_map="auto")
    else:
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, device_map="auto")
        if model_name.lower().startswith("qwen"):
            from transformers.generation import GenerationConfig
            config = GenerationConfig.from_pretrained(model_path, trust_remote_code=True)
        elif model_name.lower().startswith("baichuan"):
            from transformers.generation.utils import GenerationConfig
            config = GenerationConfig.from_pretrained(model_path, trust_remote_code=True)
        else:
            pass
    return (state, [], "") + (enable_btn,) * 2
    

def llama_chat(prompt, history, temperature, top_p, max_new_tokens):
    if not history:
        history = []
    generation_config = dict(
        temperature=temperature,
        top_k=0,
        top_p=top_p,
        do_sample=True,
        max_new_tokens=max_new_tokens
    )
    with torch.no_grad():
        tokenized_data = tokenizer(prompt, return_tensors="pt")
        generation_output = model.generate(
            input_ids=tokenized_data["input_ids"].to(device),
            attention_mask=tokenized_data['attention_mask'].to(device),
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
            **generation_config)
        response = tokenizer.decode(generation_output[0], skip_special_tokens=True)
    history.append((prompt, response))
    return (history, history,) + (enable_btn,) * 2


def _parse_text(text):
    lines = text.split("\n")
    lines = [line for line in lines if line != ""]
    count = 0
    for i, line in enumerate(lines):
        if "```" in line:
            count += 1
            items = line.split("`")
            if count % 2 == 1:
                lines[i] = f'<pre><code class="language-{items[-1]}">'
            else:
                lines[i] = f"<br></code></pre>"
        else:
            if i > 0:
                if count % 2 == 1:
                    line = line.replace("`", r"\`")
                    line = line.replace("<", "&lt;")
                    line = line.replace(">", "&gt;")
                    line = line.replace(" ", "&nbsp;")
                    line = line.replace("*", "&ast;")
                    line = line.replace("_", "&lowbar;")
                    line = line.replace("-", "&#45;")
                    line = line.replace(".", "&#46;")
                    line = line.replace("!", "&#33;")
                    line = line.replace("(", "&#40;")
                    line = line.replace(")", "&#41;")
                    line = line.replace("$", "&#36;")
                lines[i] = "<br>" + line
    text = "".join(lines)
    return text
    
    
def qwen_chat(prompt, history, temperature, top_p, max_new_tokens):
    if not history:
        history = []
    full_response = ""
    config.temperature = temperature
    config.top_p = top_p
    config.max_new_tokens = max_new_tokens
    for response in model.chat_stream(tokenizer, prompt, history=history, generation_config=config):
        full_response = _parse_text(response)
    history.append((prompt, full_response)) 
    return (history, history,) + (enable_btn,) * 2


def baichuan_chat(prompt, history, temperature, top_p, max_new_tokens):
    if not history:
        history = []
    history.append({"role": "user", "content": prompt})
    full_response = ""
    config.temperature = temperature
    config.top_p = top_p
    config.max_new_tokens = max_new_tokens
    for response in model.chat(tokenizer, history, stream=True, generation_config=config):
        full_response = response
    history.append({"role": "assistant", "content": full_response})
    return (history, history,) + (enable_btn,) * 2
    

def get_conv_log_filename():
    t = datetime.datetime.now()
    name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
    return name


def vote_last_response(state, vote_type, model_selector, request: gr.Request):
    with open(get_conv_log_filename(), mode="a", encoding='utf-8') as f:
        data = {
            "tstamp": round(time.time(), 4),
            "type": vote_type,
            "model": model_selector,
            "state": state,
            "ip": request.client.host,
        }
        f.write(json.dumps(data) + "\n")

        
def upvote_last_response(state, model_selector, request:gr.Request):
    vote_last_response(state[-1], "upvote", model_selector, request)
    return (state, state,) + (disable_btn,) * 2


def downvote_last_response(state, model_selector, request:gr.Request):
    vote_last_response(state[-1], "downvote", model_selector, request)
    return (state, state,) + (disable_btn,) * 2

    
def clear_history(request: gr.Request):
    state = None
    return (state, [], "") + (enable_btn,) * 2


def chat(model_selector, prompt, history, temperature, top_p, max_new_tokens):
    
    if model_selector.lower().startswith("llama"):
        response = llama_chat(prompt, history, temperature, top_p, max_new_tokens)
    elif model_selector.lower().startswith("qwen"):
        response = qwen_chat(prompt, history, temperature, top_p, max_new_tokens)
    elif model_selector.lower().startswith("baichuan"):
        response = baichuan_chat(prompt, history, temperature, top_p, max_new_tokens)
    return response


with gr.Blocks() as demo:
    
    state = gr.State([])
    notice_markdown = """# <center>⚔️ 大语言模型竞技场 ⚔️</center>"""
    gr.Markdown(notice_markdown, elem_id="notice_markdown")
    
    with gr.Row(elem_id="model_selector_row"):
        
        model_selector = gr.Dropdown(
            choices=models,
            value=models[0] if len(models) > 0 else "",
            interactive=True,
            show_label=False,
            container=False,
        )
        
    chatbot = gr.Chatbot(
        [],
        label="向下滚动并开始聊天",
        avatar_images=((os.path.join(os.path.dirname(__file__), "human.png")),
            (os.path.join(os.path.dirname(__file__), "bot.png"))),).style(height=350)

    with gr.Row():
        
        with gr.Column(scale=0.85):
            textbox = gr.Textbox(
                show_label=False,
                placeholder="请在此输入您的提示词并按Enter键:",
                container=False,
                elem_id="input_box"
            )
        
        with gr.Column(scale=0.15, min_width=0):
            send = gr.Button(value="发送", variant="primary")

    with gr.Row():
        upvote_btn = gr.Button(value="👍 赞成")
        downvote_btn = gr.Button(value="👎 否决")
        regenerate_btn = gr.Button(value="🔄 重新生成")
        clear_btn = gr.Button(value="🗑️ 清除历史")

    with gr.Accordion("Parameters", open=False):
   
        temperature = gr.Slider(minimum=0.0, 
                                maximum=1.0, 
                                value=0.7,
                                step=0.1,
                                interactive=True,
                                label="Temperature")

        top_p = gr.Slider(minimum=0.0, 
                          maximum=1.0, 
                          value=1.0,
                          step=0.1,
                          interactivate=True,
                          label="Top p")

        max_new_tokens = gr.Slider(minimum=16, 
                                   maximum=2048, 
                                   value=512,
                                   step=1, 
                                   interactivate=True,
                                   label="Max new tokens")
        
    load_model(model_selector.value)
        
    model_selector.change(load_model, 
                          inputs=[model_selector], 
                          outputs=[state, chatbot, textbox] + [upvote_btn, downvote_btn], 
                          show_progress=False,
                         )
    
    textbox.submit(
        chat,
        inputs=[model_selector, textbox, state, temperature, top_p, max_new_tokens],
        outputs=[chatbot, state] + [upvote_btn, downvote_btn],
    )

    send.click(
        chat,
        inputs=[model_selector, textbox, state, temperature, top_p, max_new_tokens],
        outputs=[chatbot, state] + [upvote_btn, downvote_btn],
    )
    
    upvote_btn.click(
        upvote_last_response,
        inputs=[state, model_selector],
        outputs=[chatbot, state] + [upvote_btn, downvote_btn],
    )
    
    downvote_btn.click(
        downvote_last_response,
        inputs=[state, model_selector],
        outputs=[chatbot, state] + [upvote_btn, downvote_btn],
    )
    
    regenerate_btn.click(
        chat,
        inputs=[model_selector, textbox, state, temperature, top_p, max_new_tokens],
        outputs=[chatbot, state] + [upvote_btn, downvote_btn],
    )
        
    clear_btn.click(
        clear_history,
        inputs=None,
        outputs=[state, chatbot, textbox] + [upvote_btn, downvote_btn],
    )

demo.queue()
demo.launch()