In [None]:
import gradio as gr
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    StoppingCriteria,
    StoppingCriteriaList,
)
from pathlib import Path
import os
import re
from peft import (
    LoraModel,
    LoraConfig,
    get_peft_model,
    PeftModel,
    AutoPeftModelForCausalLM,
    PeftMixedModel,
)

In [None]:
# _BASE_MODEL_PATH = Path("../../models/Mistral-7B-Instruct-v0.2")
# _LORA_MODEL_PATH = Path("../output/loras/checkpoint-540")
_MERGED_MODEL_PATH = Path("output/merged/zephyr-7b-beta-calculator_v1-(200_1)-2024-01-23-17-53-54/")
# _ADAPTER_NAME = "jarvis-calculator-v0_1"

In [None]:
tokenizer = AutoTokenizer.from_pretrained(_MERGED_MODEL_PATH)
# tokenizer = tokenizer.from_pretrained(_LORA_MODEL_PATH)
model = AutoModelForCausalLM.from_pretrained(
    _MERGED_MODEL_PATH,
    low_cpu_mem_usage=True,
    device_map="cuda:1",
    torch_dtype=torch.bfloat16
)

In [None]:
# _LORA_MODEL_PATH = Path("../output/loras/checkpoint-540")
# peft_model = PeftModel.from_pretrained(
#     model,
#     os.path.join(_LORA_MODEL_PATH, _ADAPTER_NAME),
#     is_trainable=False,
#     from_transformers=True,
#     device_map="auto",
# )
# peft_model.merge_adapter()

In [None]:
class CustomStoppingCriteria(StoppingCriteria):
    def __init__(self, stops: list = []):
        StoppingCriteria.__init__(self)
        self.stops = stops

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        for token in self.stops:
            if input_ids[0][-1].cpu().numpy() == token:
                return True
        return False

In [None]:
def generate(prompt: str):
    prompt = tokenizer(prompt, return_tensors="pt").to(model.device)
    output = model.generate(
        **prompt,
        max_new_tokens=256,
        do_sample=False,
        pad_token_id=tokenizer.eos_token_id,
        eos_token_id=tokenizer.eos_token_id,
        stopping_criteria=StoppingCriteriaList([
                CustomStoppingCriteria(stops=[tokenizer.encode("<stop>", add_special_tokens=False)[-1]])
            ]),
    )
    output = tokenizer.decode(output[0], skip_special_tokens=False)
    last_assistant_resp = output.split("<|assistant|>")[-1].strip()
    if last_assistant_resp.endswith('<stop>') and last_assistant_resp.rfind('<calculator>') > last_assistant_resp.rfind('</calculator>'):
        expr = last_assistant_resp[last_assistant_resp.rfind('<calculator>') + len('<calculator>'):last_assistant_resp.rfind('<stop>')]
        output = f'{output} {eval(expr)} </calculator>'
        return generate(output)
    return output

In [None]:
with gr.Blocks() as iface:
    chatbox = gr.Chatbot(render_markdown=False, sanitize_html=False)
    msg = gr.Textbox()
    clear = gr.ClearButton([chatbox, msg])

    def respond(message, chat_history):
        if message:
            if chat_history == []:
                message = '<|system|>\n<|user|>\n' + message
            else:
                message = '\n<|user|>\n' + message

            prev_prompt = ''.join(chat[0] + chat[1] for chat in chat_history)
            bot_message = generate(prev_prompt + message + '\n<|assistant|>\n').replace(prev_prompt + message, '')
            
            chat_history.append((message, bot_message))
        return '', chat_history

    msg.submit(respond, inputs=[msg, chatbox], outputs=[msg, chatbox])

In [None]:
iface.close()
iface.launch(inbrowser=True)

In [None]:
prompt = """<|system|>
<|user|> 
Hey, I am trying to calculate the total cost of 10 items that cost 5 rupees each. Can you help me?
<|assistant|> 
"""
prompt = tokenizer(prompt, return_tensors="pt").to(model.device)

In [None]:
tokenizer.add_eos_token, tokenizer.add_bos_token

In [None]:
output = model.generate(
    **prompt,
    max_new_tokens=128,
    do_sample=False,
    # temperature=0.6,
    pad_token_id=tokenizer.eos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    stopping_criteria=StoppingCriteriaList([
        # CustomStoppingCriteria(stops=[tokenizer.encode("<stop>", add_special_tokens=False)[-1]])
    ]),
)
print(tokenizer.decode(output[0], skip_special_tokens=False))