In [1]:
import os
import gradio as gr
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers import BitsAndBytesConfig
import torch
from transformers import pipeline
from llama_index.core import StorageContext, load_index_from_storage
from llama_index.core.postprocessor import SimilarityPostprocessor
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.indices.vector_store import VectorIndexRetriever
from llama_index.core import Settings
from llama_index.embeddings.huggingface import HuggingFaceEmbedding

cache_dir="E:\Cache\Hugging_Face"

project_path = os.path.abspath(os.path.relpath('../../../', os.getcwd()))
data_path = os.path.join(project_path, 'FT4LLM/Data')
knowledge_path=os.path.join(data_path, 'articles')
prompt_path = os.path.join(data_path, 'prompt')

In [3]:
# set RAG
Settings.embed_model = HuggingFaceEmbedding(model_name="ekorman-strive/bge-large-en-v1.5", cache_folder=cache_dir)
Settings.llm = None
Settings.chunk_size = 356
Settings.chunk_overlap = 50

storage_context = StorageContext.from_defaults(persist_dir=cache_dir + '/vector_cache')
index = load_index_from_storage(storage_context)
retriever = VectorIndexRetriever(
    index=index,
    similarity_top_k=3,
)
query_engine = RetrieverQueryEngine(
    retriever=retriever,
    node_postprocessors=[SimilarityPostprocessor(similarity_cutoff=0.39)],
)

In [2]:
# set LLM and tokenizer
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True
)
config = PeftConfig.from_pretrained("CocoNutZENG/Epipaca")

base_model = AutoModelForCausalLM.from_pretrained("hfl/llama-3-chinese-8b-instruct",quantization_config=bnb_config,
    device_map='auto',
    cache_dir=cache_dir)
model = PeftModel.from_pretrained(base_model, "CocoNutZENG/Epipaca")

tokenizer = AutoTokenizer.from_pretrained("CocoNutZENG/Epipaca", padding_side="right",cache_dir=cache_dir)
tokenizer.pad_token = tokenizer.eos_token

Coversation_epipaca = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    device_map="auto",
)
Coversation_original = pipeline(
    "text-generation",
    model= base_model,
    tokenizer=tokenizer,
    device_map="auto",
)
terminators = [
    Coversation_epipaca.tokenizer.eos_token_id,
    Coversation_original.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]


In [19]:
def get_RAG_promt(query):
    query_response = query_engine.query(query)
    context = "Begin of Context:\n"
    for i in range(3):
        context = context + query_response.source_nodes[i].text + "\n"
    return context+'\n End of Context\n'


def get_beginning_prompt(language):
    if language=="en":
        # return "You are a practitioner in the epilepsy treatment industry. Try best to complete the user's instruction given to you. Be professional."
        return  "No matter what you receive, just repeat the user input."
    else:
        # return "你是癫痫康复行业的专业人士，请尽力完成用户的指令，并保持输出专业和简短"
        return  "No matter what you receive, just repeat the user input."
def get_beginning_prompt_RAG(language):
    if language=="en":
        # return "You are a practitioner in the epilepsy treatment industry. Try best to complete the user's instruction given to you. Be professional. \n In each round of User questions, you will be given a Context by System for reference, and you should refer to the Context as detail as possible for your response. "
        return  "No matter what you receive, just repeat the user input."
    else:
        # return "你是癫痫康复行业的专业人士，请尽力完成用户的指令，并保持输出专业和简短。\n 在每一轮User提问时，你将会获得System给出的Context进行参考, 请参照Context进行回答。"
        return  "No matter what you receive, just repeat the user input."
def compare_models_NO_RAG(message):
    if ('\u0041' <= message[0] <= '\u005a') or ('\u0061' <= message[0] <= '\u007a'):
        language='en'
    else:
        language='zh'
    messages_map = [
        {"role": "system", "content": get_beginning_prompt(language=language)},
    ]
    messages_map.append({"role": "user", "content": message})
    input_processed = Coversation_epipaca.tokenizer.apply_chat_template(
        messages_map, 
        tokenize=False, 
        add_generation_prompt=True
    )
    output1 = Coversation_original(input_processed, max_new_tokens=256,
        eos_token_id=terminators,
        do_sample=True,
        temperature=0.2,
        top_p=0.8,)[0]['generated_text'][len(input_processed):]
    output2 = Coversation_epipaca(input_processed, max_new_tokens=256,
        eos_token_id=terminators,
        do_sample=True,
        temperature=0.2,
        top_p=0.8,)[0]['generated_text'][len(input_processed):]
    return output1, output2

def compare_models_w_RAG(message):
    if ('\u0041' <= message[0] <= '\u005a') or ('\u0061' <= message[0] <= '\u007a'):
        language='en'
    else:
        language='zh'
    messages_map_original = [
        {"role": "system", "content": get_beginning_prompt(language=language)},
        {"role": "user", "content": message}
    ]
    messages_map_epipaca = [
        {"role": "system", "content": get_beginning_prompt_RAG(language=language)},
        {"role": "system", "content": get_RAG_promt(message)},
        {"role": "user", "content": message}
    ]
  
    input_processed_original = Coversation_epipaca.tokenizer.apply_chat_template(
        messages_map_original, 
        tokenize=False, 
        add_generation_prompt=True
    )
    input_processed_epipaca = Coversation_epipaca.tokenizer.apply_chat_template(
        messages_map_epipaca, 
        tokenize=False, 
        add_generation_prompt=True
    )
    output1 = Coversation_original(input_processed_original, max_new_tokens=256,
        eos_token_id=terminators,
        do_sample=True,
        temperature=0.2,
        top_p=0.8,)[0]['generated_text'][len(input_processed_original):]
    output2 = Coversation_epipaca(input_processed_epipaca, max_new_tokens=256,
        eos_token_id=terminators,
        do_sample=True,
        temperature=0.2,
        top_p=0.8,)[0]['generated_text'][len(input_processed_epipaca):]
    return output1, output2


with gr.Blocks() as demo:
    gr.Markdown("## Epilepsy LLM Comparison")
    with gr.Row():
        with gr.Column():
            output1 = gr.Textbox(label="LLAMA-3 Chinese",lines=10)
        with gr.Column():
            output2 = gr.Textbox(label="Epipaca",lines=10)
    with gr.Row():
        user_input = gr.Textbox(label="Input your word here")
    compare_button = gr.Button("Generate")
    compare_button.click(compare_models_w_RAG, inputs=user_input, outputs=[output1, output2])
demo.launch()



