In [None]:
# Copyright (c) Alibaba Cloud.

import copy
import re
from threading import Thread

import gradio as gr
import torch
from qwen_vl_utils import process_vision_info
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer

DEFAULT_CKPT_PATH = 'Qwen/Qwen2.5-VL-7B-Instruct'

# Define arguments manually (instead of argparse)
class Args:
    checkpoint_path = DEFAULT_CKPT_PATH
    cpu_only = False
    flash_attn2 = False
    share = False
    inbrowser = False
    server_port = 7860
    server_name = "127.0.0.1"

args = Args()

# Load Model and Processor
def _load_model_processor(args):
    device_map = 'cpu' if args.cpu_only else 'auto'
    
    if args.flash_attn2:
        model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            args.checkpoint_path,
            torch_dtype="auto",
            attn_implementation="flash_attention_2",
            device_map=device_map
        )
    else:
        model = Qwen2_5_VLForConditionalGeneration.from_pretrained(args.checkpoint_path, device_map=device_map)

    processor = AutoProcessor.from_pretrained(args.checkpoint_path)
    return model, processor

model, processor = _load_model_processor(args)

# Define text and image parsing functions
def _parse_text(text):
    lines = text.split('\n')
    lines = [line for line in lines if line != '']
    count = 0
    for i, line in enumerate(lines):
        if '```' in line:
            count += 1
            items = line.split('`')
            if count % 2 == 1:
                lines[i] = f'<pre><code class="language-{items[-1]}">'
            else:
                lines[i] = '<br></code></pre>'
        else:
            if i > 0:
                if count % 2 == 1:
                    line = line.replace('`', r'\`').replace('<', '&lt;').replace('>', '&gt;')
                    line = line.replace(' ', '&nbsp;').replace('*', '&ast;').replace('_', '&lowbar;')
                    line = line.replace('-', '&#45;').replace('.', '&#46;').replace('!', '&#33;')
                    line = line.replace('(', '&#40;').replace(')', '&#41;').replace('$', '&#36;')
                lines[i] = '<br>' + line
    text = ''.join(lines)
    return text

def _remove_image_special(text):
    return re.sub(r'<box>.*?(</box>|$)', '', text.replace('<ref>', '').replace('</ref>', ''))

# Define chatbot processing functions
def call_local_model(model, processor, messages):
    messages = _transform_messages(messages)
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors='pt')
    inputs = inputs.to(model.device)

    tokenizer = processor.tokenizer
    streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)

    gen_kwargs = {"max_new_tokens": 512, "streamer": streamer, **inputs}
    thread = Thread(target=model.generate, kwargs=gen_kwargs)
    thread.start()

    generated_text = ""
    for new_text in streamer:
        generated_text += new_text
        yield generated_text

def predict(chatbot, task_history):
    chat_query = chatbot[-1][0]
    if len(chat_query) == 0:
        chatbot.pop()
        task_history.pop()
        return chatbot

    history_cp = copy.deepcopy(task_history)
    messages = []
    content = []
    
    for q, a in history_cp:
        content.append({"text": q})
        messages.append({"role": "user", "content": content})
        messages.append({"role": "assistant", "content": [{"text": a}]})
        content = []
    messages.pop()

    for response in call_local_model(model, processor, messages):
        chatbot[-1] = (_parse_text(chat_query), _remove_image_special(_parse_text(response)))
        yield chatbot

def add_text(history, task_history, text):
    history = history if history is not None else []
    task_history = task_history if task_history is not None else []
    history = history + [(_parse_text(text), None)]
    task_history = task_history + [(text, None)]
    return history, task_history, ""

def reset_state(chatbot, task_history):
    task_history.clear()
    chatbot.clear()
    return []

# Launch Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("""\
    <p align="center"><img src="https://modelscope.oss-cn-beijing.aliyuncs.com/resource/qwen.png" style="height: 80px"/><p>"""
               )
    gr.Markdown("""<center><font size=8>Qwen2.5-VL</center>""")
    chatbot = gr.Chatbot(label="Qwen2.5-VL", height=500)
    query = gr.Textbox(lines=2, label="Input")
    task_history = gr.State([])

    with gr.Row():
        submit_btn = gr.Button("🚀 Submit")
        empty_bin = gr.Button("🧹 Clear History")

    submit_btn.click(add_text, [chatbot, task_history, query], [chatbot, task_history]).then(
        predict, [chatbot, task_history], [chatbot], show_progress=True
    )
    empty_bin.click(reset_state, [chatbot, task_history], [chatbot], show_progress=True)

# Run the Gradio app inside the notebook
demo.queue().launch(share=args.share, inbrowser=args.inbrowser, server_port=args.server_port, server_name=args.server_name)
